Skip to content

Commit 2453bc2

Browse files
authored
Implement clamp using ATen (#3739)
1 parent 23ca19a commit 2453bc2

File tree

8 files changed

+213
-142
lines changed

8 files changed

+213
-142
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,56 +2669,80 @@
26692669
]]
26702670
[[
26712671
name: clamp
2672+
cname: clamp
26722673
variants:
26732674
- method
26742675
- function
26752676
return: argument 0
2676-
options:
2677-
- cname: clamp
2678-
arguments:
2679-
- arg: THTensor* destination
2680-
output: True
2681-
- THTensor* self
2682-
- real min
2683-
- real max
2684-
- cname: cmaxValue
2685-
arguments:
2686-
- arg: THTensor* result
2687-
output: True
2688-
- THTensor* self
2689-
- arg: real min
2690-
kwarg_only: True
2691-
- cname: cminValue
2692-
arguments:
2693-
- arg: THTensor* result
2694-
output: True
2695-
- THTensor* self
2696-
- arg: real max
2697-
kwarg_only: True
2677+
arguments:
2678+
- arg: THTensor* destination
2679+
output: True
2680+
- THTensor* self
2681+
- real min
2682+
- real max
26982683
]]
26992684
[[
27002685
name: clamp_
27012686
cname: clamp
2702-
return: self
2703-
options:
2704-
- cname: clamp
2705-
arguments:
2706-
- THTensor* self
2707-
- THTensor* self
2708-
- real min
2709-
- real max
2710-
- cname: cmaxValue
2711-
arguments:
2712-
- THTensor* self
2713-
- THTensor* self
2714-
- arg: real min
2715-
kwarg_only: True
2716-
- cname: cminValue
2717-
arguments:
2718-
- THTensor* self
2719-
- THTensor* self
2720-
- arg: real max
2721-
kwarg_only: True
2687+
variants:
2688+
- method
2689+
- function
2690+
return: argument 0
2691+
arguments:
2692+
- THTensor* self
2693+
- THTensor* self
2694+
- real min
2695+
- real max
2696+
]]
2697+
[[
2698+
name: clamp_min
2699+
cname: cmaxValue
2700+
variants:
2701+
- method
2702+
- function
2703+
return: argument 0
2704+
arguments:
2705+
- arg: THTensor* result
2706+
output: True
2707+
- THTensor* self
2708+
- real min
2709+
]]
2710+
[[
2711+
name: clamp_min_
2712+
cname: cmaxValue
2713+
variants:
2714+
- method
2715+
- function
2716+
return: argument 0
2717+
arguments:
2718+
- THTensor* self
2719+
- THTensor* self
2720+
- real min
2721+
]]
2722+
[[
2723+
name: clamp_max
2724+
cname: cminValue
2725+
variants:
2726+
- method
2727+
- function
2728+
return: argument 0
2729+
arguments:
2730+
- arg: THTensor* result
2731+
output: True
2732+
- THTensor* self
2733+
- real max
2734+
]]
2735+
[[
2736+
name: clamp_max_
2737+
cname: cminValue
2738+
variants:
2739+
- method
2740+
- function
2741+
return: argument 0
2742+
arguments:
2743+
- THTensor* self
2744+
- THTensor* self
2745+
- real max
27222746
]]
27232747
[[
27242748
name: dot

tools/autograd/derivatives.yaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@
136136
- name: ceil(Tensor self)
137137
self: zeros_like(grad)
138138

139+
- name: clamp(Tensor self, Scalar min, Scalar max)
140+
self: grad * (self > min).type_as(grad) * (self < max).type_as(grad)
141+
142+
- name: clamp_min(Tensor self, Scalar min)
143+
self: grad * (self > min).type_as(grad)
144+
145+
- name: clamp_max(Tensor self, Scalar max)
146+
self: grad * (self < max).type_as(grad)
147+
139148
- name: clone(Tensor self)
140149
self: grad
141150

@@ -465,7 +474,7 @@
465474
self: grad
466475

467476
- name: renorm # TODO!
468-
477+
469478
- name: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale)
470479
input: RoiPooling2d_backward(input, rois, pooledHeight, pooledWidth, spatialScale, grad, result1)
471480

tools/autograd/gen_variable_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def should_generate_python_binding(declaration):
10191019

10201020
# don't bind size or stride since the python signatures are different
10211021
# exclude alias from Python bindings as well at least for now
1022-
if name in ['alias', 'size', 'stride']:
1022+
if name in ['alias', 'size', 'stride'] or name.startswith('clamp'):
10231023
return False
10241024

10251025
if name.endswith('_backward'):

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,77 @@ using namespace torch::autograd::utils;
1616

1717
namespace torch { namespace autograd {
1818

19-
static PyObject * THPVariable_detach(PyObject* self, PyObject* args)
19+
static Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) {
20+
AutoNoGIL no_gil;
21+
AutoGPU auto_gpu(self);
22+
return self.clamp(min, max);
23+
}
24+
static Tensor dispatch_clamp_min(const Tensor & self, Scalar min) {
25+
AutoNoGIL no_gil;
26+
AutoGPU auto_gpu(self);
27+
return self.clamp_min(min);
28+
}
29+
static Tensor dispatch_clamp_max(const Tensor & self, Scalar max) {
30+
AutoNoGIL no_gil;
31+
AutoGPU auto_gpu(self);
32+
return self.clamp_max(max);
33+
}
34+
35+
PyObject * THPVariable_clamp(PyObject* self, PyObject* args, PyObject* kwargs)
2036
{
2137
HANDLE_TH_ERRORS
38+
static PythonArgParser parser({
39+
"clamp(Scalar min=None, Scalar max=None)",
40+
});
2241
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
23-
return THPVariable_Wrap(self_.detach());
42+
PyObject* parsed_args[2];
43+
auto r = parser.parse(args, kwargs, parsed_args);
44+
if (!r.isNone(0) && !r.isNone(1)) {
45+
return THPVariable_Wrap(dispatch_clamp(self_, r.scalar(0), r.scalar(1)));
46+
} else if (!r.isNone(0)) {
47+
return THPVariable_Wrap(dispatch_clamp_min(self_, r.scalar(0)));
48+
} else if (!r.isNone(1)) {
49+
return THPVariable_Wrap(dispatch_clamp_max(self_, r.scalar(1)));
50+
} else {
51+
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
52+
}
2453
END_HANDLE_TH_ERRORS
2554
}
2655

27-
static PyObject * THPVariable_detach_(PyObject* self, PyObject* args)
56+
static Tensor & dispatch_clamp_(Tensor & self, Scalar min, Scalar max) {
57+
AutoNoGIL no_gil;
58+
AutoGPU auto_gpu(self);
59+
return self.clamp_(min, max);
60+
}
61+
static Tensor & dispatch_clamp_min_(Tensor & self, Scalar min) {
62+
AutoNoGIL no_gil;
63+
AutoGPU auto_gpu(self);
64+
return self.clamp_min_(min);
65+
}
66+
static Tensor & dispatch_clamp_max_(Tensor & self, Scalar max) {
67+
AutoNoGIL no_gil;
68+
AutoGPU auto_gpu(self);
69+
return self.clamp_max_(max);
70+
}
71+
72+
PyObject * THPVariable_clamp_(PyObject* self, PyObject* args, PyObject* kwargs)
2873
{
2974
HANDLE_TH_ERRORS
30-
reinterpret_cast<THPVariable*>(self)->cdata.detach_();
31-
Py_INCREF(self);
32-
return self;
75+
static PythonArgParser parser({
76+
"clamp_(Scalar min=None, Scalar max=None)",
77+
});
78+
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
79+
PyObject* parsed_args[2];
80+
auto r = parser.parse(args, kwargs, parsed_args);
81+
if (!r.isNone(0) && !r.isNone(1)) {
82+
return THPVariable_Wrap(dispatch_clamp_(self_, r.scalar(0), r.scalar(1)));
83+
} else if (!r.isNone(0)) {
84+
return THPVariable_Wrap(dispatch_clamp_min_(self_, r.scalar(0)));
85+
} else if (!r.isNone(1)) {
86+
return THPVariable_Wrap(dispatch_clamp_max_(self_, r.scalar(1)));
87+
} else {
88+
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
89+
}
3390
END_HANDLE_TH_ERRORS
3491
}
3592

@@ -52,6 +109,23 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args)
52109
END_HANDLE_TH_ERRORS
53110
}
54111

112+
static PyObject * THPVariable_detach(PyObject* self, PyObject* args)
113+
{
114+
HANDLE_TH_ERRORS
115+
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
116+
return THPVariable_Wrap(self_.detach());
117+
END_HANDLE_TH_ERRORS
118+
}
119+
120+
static PyObject * THPVariable_detach_(PyObject* self, PyObject* args)
121+
{
122+
HANDLE_TH_ERRORS
123+
reinterpret_cast<THPVariable*>(self)->cdata.detach_();
124+
Py_INCREF(self);
125+
return self;
126+
END_HANDLE_TH_ERRORS
127+
}
128+
55129
static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
56130
{
57131
HANDLE_TH_ERRORS
@@ -61,7 +135,6 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
61135
END_HANDLE_TH_ERRORS
62136
}
63137

64-
65138
// generated methods start here
66139

67140
${py_methods}
@@ -79,6 +152,8 @@ PyMethodDef variable_methods[] = {
79152
{"__truediv__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL},
80153
{"__idiv__", (PyCFunction)THPVariable_div_, METH_VARARGS | METH_KEYWORDS, NULL},
81154
{"__mod__", (PyCFunction)THPVariable_remainder, METH_VARARGS | METH_KEYWORDS, NULL},
155+
{"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS, NULL},
156+
{"clamp_", (PyCFunction)THPVariable_clamp_, METH_VARARGS | METH_KEYWORDS, NULL},
82157
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_NOARGS, NULL},
83158
{"detach", (PyCFunction)THPVariable_detach, METH_NOARGS, NULL},
84159
{"detach_", (PyCFunction)THPVariable_detach_, METH_NOARGS, NULL},

torch/autograd/_functions/pointwise.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,6 @@ def backward(ctx, grad_output):
9292
return grad_output * i.sign()
9393

9494

95-
class Clamp(Function):
96-
97-
@staticmethod
98-
def forward(ctx, i, min_val, max_val):
99-
ctx._mask = (i.ge(min_val) * i.le(max_val))
100-
return i.clamp(min_val, max_val)
101-
102-
@staticmethod
103-
def backward(ctx, grad_output):
104-
mask = Variable(ctx._mask.type_as(grad_output.data))
105-
return grad_output * mask, None, None
106-
107-
10895
class Sqrt(Function):
10996

11097
@staticmethod
@@ -269,19 +256,6 @@ def backward(ctx, grad_output):
269256
)
270257

271258

272-
class CmaxConstant(Function):
273-
274-
@staticmethod
275-
def forward(ctx, i, constant):
276-
ctx._mask = i.gt(constant)
277-
return i.clamp(min=constant)
278-
279-
@staticmethod
280-
def backward(ctx, grad_output):
281-
mask = Variable(ctx._mask.type_as(grad_output.data))
282-
return grad_output * mask, None
283-
284-
285259
class Cmin(Function):
286260

287261
@staticmethod
@@ -300,19 +274,6 @@ def backward(ctx, grad_output):
300274
)
301275

302276

303-
class CminConstant(Function):
304-
305-
@staticmethod
306-
def forward(ctx, i, constant):
307-
ctx._mask = i.lt(constant)
308-
return i.clamp(max=constant)
309-
310-
@staticmethod
311-
def backward(ctx, grad_output):
312-
mask = Variable(ctx._mask.type_as(grad_output.data))
313-
return grad_output * mask, None
314-
315-
316277
class _ConstantGrad(Function):
317278
grad_value = 0
318279

torch/autograd/variable.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,6 @@ def char(self):
310310
def byte(self):
311311
return self.type(self._get_type('ByteTensor'))
312312

313-
def clamp(self, min=None, max=None):
314-
if min is None and max is None:
315-
raise ValueError("clamp requires specifying at least one of "
316-
"min and max arguments")
317-
elif min is None and max is not None:
318-
return CminConstant.apply(self, max)
319-
elif min is not None and max is None:
320-
return CmaxConstant.apply(self, min)
321-
else:
322-
return Clamp.apply(self, min, max)
323-
324313
def prod(self, dim=None, keepdim=None):
325314
return Prod.apply(self, dim, keepdim)
326315

0 commit comments

Comments
 (0)