Skip to content

Commit 695fd98

Browse files
authored
Compatibility: write nDimension/_nDimension corresponding to dim()/_dim(). (#8676)
Currently, THTensor_(nDimension) goes to _dim(), which makes it difficult to move individual usages over to the new API. Instead, let's create a THTensor_(_nDimension) going to _dim() and THTensor_(nDimension) going to _dim(). To do this, we will redirect all current calls and move them over as we did for _dim() and dim().
1 parent 6402a42 commit 695fd98

File tree

75 files changed

+362
-338
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+362
-338
lines changed

aten/src/TH/THTensor.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ typedef struct THTensor
3232
return storage->unsafe_data<T>() + storageOffset;
3333
}
3434

35-
// NOTE: this returns the "old" TH dimension view where no dimensions represents an empty tensor.
36-
// There will be a dim() function that gives the new view that supports 0-sized dimensions.
35+
// [NOTE: _dim() vs dim()]
36+
// _dim() returns the "old" TH dimension view where no dimensions represents an empty tensor.
37+
// dim() returns the ATen view of the dimensionality, i.e. 0-sized dimensions are supported.
3738
inline int64_t _dim() const {
3839
return is_empty() ? 0 : dim_;
3940
}
4041

41-
// NOTE: this is the ATen view of the dimensionality, i.e. 0-sized dimensions are supported.
4242
inline int64_t dim() const {
4343
return dim_;
4444
}

aten/src/TH/generic/THTensor.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,26 @@ ptrdiff_t THTensor_(storageOffset)(const THTensor *self)
1616
}
1717

1818
int THTensor_(nDimension)(const THTensor *self)
19+
{
20+
return self->dim();
21+
}
22+
23+
int THTensor_(_nDimension)(const THTensor *self)
1924
{
2025
return self->_dim();
2126
}
2227

2328
int64_t THTensor_(size)(const THTensor *self, int dim)
2429
{
2530
THArgCheck((dim >= 0) && (dim < self->_dim()), 2, "dimension %d out of range of %dD tensor",
26-
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
31+
dim+TH_INDEX_BASE, THTensor_(_nDimension)(self));
2732
return self->size[dim];
2833
}
2934

3035
int64_t THTensor_(stride)(const THTensor *self, int dim)
3136
{
3237
THArgCheck((dim >= 0) && (dim < self->_dim()), 2, "dimension %d out of range of %dD tensor",
33-
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
38+
dim+TH_INDEX_BASE, THTensor_(_nDimension)(self));
3439
return self->stride[dim];
3540
}
3641

aten/src/TH/generic/THTensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ typedef struct THTensor THTensor;
2323
/**** access methods ****/
2424
TH_API THStorage* THTensor_(storage)(const THTensor *self);
2525
TH_API ptrdiff_t THTensor_(storageOffset)(const THTensor *self);
26+
27+
// See [NOTE: _dim() vs dim()]; _nDimension corresponds to _dim(), nDimension corresponds to dim().
2628
TH_API int THTensor_(nDimension)(const THTensor *self);
29+
TH_API int THTensor_(_nDimension)(const THTensor *self);
2730
TH_API int64_t THTensor_(size)(const THTensor *self, int dim);
2831
TH_API int64_t THTensor_(stride)(const THTensor *self, int dim);
2932
TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self);

aten/src/TH/generic/THTensorCopy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
1717
const int MIN_SZ = 60 * 60;
1818
return THTensor_(isContiguous)(tensor) &&
19-
THTensor_(nDimension)(src) == 2 &&
19+
THTensor_(_nDimension)(src) == 2 &&
2020
THTensor_(stride)(src, 0) == 1 &&
2121
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
2222
THTensor_(nElement)(tensor) >= MIN_SZ;

aten/src/TH/generic/THTensorLapack.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co
948948

949949
void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a)
950950
{
951-
THArgCheck(THTensor_(nDimension)(a) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimension)(a));
951+
THArgCheck(THTensor_(_nDimension)(a) == 3, 1, "expected 3D tensor, got %dD", THTensor_(_nDimension)(a));
952952
if (!pivot) {
953953
THError("btrifact without pivoting is not implemented on the CPU");
954954
}
@@ -1023,10 +1023,10 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
10231023

10241024
void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots)
10251025
{
1026-
THArgCheck(THTensor_(nDimension)(atf) == 3, 1, "expected 3D tensor, got %dD",
1027-
THTensor_(nDimension)(atf));
1028-
THArgCheck(THTensor_(nDimension)(b) == 3 ||
1029-
THTensor_(nDimension)(b) == 2, 4, "expected 2D or 3D tensor");
1026+
THArgCheck(THTensor_(_nDimension)(atf) == 3, 1, "expected 3D tensor, got %dD",
1027+
THTensor_(_nDimension)(atf));
1028+
THArgCheck(THTensor_(_nDimension)(b) == 3 ||
1029+
THTensor_(_nDimension)(b) == 2, 4, "expected 2D or 3D tensor");
10301030
THArgCheck(THTensor_(size)(atf, 0) ==
10311031
THTensor_(size)(b, 0), 3, "number of batches must be equal");
10321032
THArgCheck(THTensor_(size)(atf, 1) ==

0 commit comments

Comments
 (0)