Skip to content

Commit e90699b

Browse files
yongjiksoumith
authored andcommitted
Fixed double memory accesses of several pointwise operations. (#5068)
Because nvcc does not know that in/out pointers do not alias each other, if we assign a value to *out and then use *in again, the kernel has to emit a write to *out and then another read from *in. (Affected kernels become marginally faster after the fix.)
1 parent d515806 commit e90699b

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

torch/lib/THC/THCTensorMathPointwise.cuh

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,13 @@ struct TensorPowOp {
273273
} else if (StaticExp == 2) {
274274
*out = THCNumerics<T>::mul(*in, *in);
275275
} else if (StaticExp == 3) {
276-
*out = THCNumerics<T>::mul(*in, *in);
277-
*out = THCNumerics<T>::mul(*out, *in);
276+
T square = THCNumerics<T>::mul(*in, *in);
277+
*out = THCNumerics<T>::mul(square, *in);
278278
} else if (StaticExp == -1) {
279279
*out = THCNumerics<T>::cinv(*in);
280280
} else if (StaticExp == -2) {
281-
*out = THCNumerics<T>::mul(*in, *in);
282-
*out = THCNumerics<T>::cinv(*out);
281+
T square = THCNumerics<T>::mul(*in, *in);
282+
*out = THCNumerics<T>::cinv(square);
283283
} else {
284284
*out = THCNumerics<T>::pow(*in, val);
285285
}
@@ -295,8 +295,8 @@ struct TensorPowOp {
295295
} else if (StaticExp == -1) {
296296
*v = THCNumerics<T>::cinv(*v);
297297
} else if (StaticExp == -2) {
298-
*v = THCNumerics<T>::mul(*v, *v);
299-
*v = THCNumerics<T>::cinv(*v);
298+
T square = THCNumerics<T>::mul(*v, *v);
299+
*v = THCNumerics<T>::cinv(square);
300300
} else {
301301
*v = THCNumerics<T>::pow(*v, val);
302302
}
@@ -402,17 +402,19 @@ struct TensorDivOp<half> {
402402
template <typename T>
403403
struct TensorCRemainderOp {
404404
__device__ __forceinline__ void operator()(T* out, T* in) {
405-
*out = *out % *in;
406-
if ((*out * *in)<0){
407-
*out += *in;
405+
T val = *out % *in;
406+
if ((val * *in)<0){
407+
val += *in;
408408
}
409+
*out = val;
409410
}
410411

411412
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
412-
*out = *in1 % *in2;
413-
if ((*out * *in2)<0){
414-
*out += *in2;
413+
T val = *in1 % *in2;
414+
if ((val * *in2)<0){
415+
val += *in2;
415416
}
417+
*out = val;
416418
}
417419
};
418420

@@ -548,20 +550,24 @@ struct TensorCrossOp {
548550
TensorCrossOp(int64_t sx, int64_t sy, int64_t so) : sx(sx), sy(sy), so(so) {}
549551

550552
__device__ __forceinline__ void operator()(T* out, T* x, T*y) {
551-
out[0 * so] = THCNumerics<T>::sub(
553+
T val0 = THCNumerics<T>::sub(
552554
THCNumerics<T>::mul(x[1 * sx], y[2 * sy]),
553555
THCNumerics<T>::mul(x[2 * sx], y[1 * sy])
554556
);
555557

556-
out[1 * so] = THCNumerics<T>::sub(
558+
T val1 = THCNumerics<T>::sub(
557559
THCNumerics<T>::mul(x[2 * sx], y[0 * sy]),
558560
THCNumerics<T>::mul(x[0 * sx], y[2 * sy])
559561
);
560562

561-
out[2 * so] = THCNumerics<T>::sub(
563+
T val2 = THCNumerics<T>::sub(
562564
THCNumerics<T>::mul(x[0 * sx], y[1 * sy]),
563565
THCNumerics<T>::mul(x[1 * sx], y[0 * sy])
564566
);
567+
568+
out[0 * so] = val0;
569+
out[1 * so] = val1;
570+
out[2 * so] = val2;
565571
}
566572

567573
const int64_t sx, sy, so;

0 commit comments

Comments
 (0)