Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions aten/src/TH/THTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ typedef struct THTensor
return storage->unsafe_data<T>() + storageOffset;
}

// NOTE: this returns the "old" TH dimension view where no dimensions represents an empty tensor.
// There will be a dim() function that gives the new view that supports 0-sized dimensions.
// [NOTE: _dim() vs dim()]
// _dim() returns the "old" TH dimension view where no dimensions represents an empty tensor.
// dim() returns the ATen view of the dimensionality, i.e. 0-sized dimensions are supported.
inline int64_t _dim() const {
return is_empty() ? 0 : dim_;
}

// NOTE: this is the ATen view of the dimensionality, i.e. 0-sized dimensions are supported.
inline int64_t dim() const {
return dim_;
}
Expand Down
9 changes: 7 additions & 2 deletions aten/src/TH/generic/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,26 @@ ptrdiff_t THTensor_(storageOffset)(const THTensor *self)
}

int THTensor_(nDimension)(const THTensor *self)
{
return self->dim();
}

int THTensor_(_nDimension)(const THTensor *self)
{
return self->_dim();
}

int64_t THTensor_(size)(const THTensor *self, int dim)
{
THArgCheck((dim >= 0) && (dim < self->_dim()), 2, "dimension %d out of range of %dD tensor",
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
dim+TH_INDEX_BASE, THTensor_(_nDimension)(self));
return self->size[dim];
}

int64_t THTensor_(stride)(const THTensor *self, int dim)
{
THArgCheck((dim >= 0) && (dim < self->_dim()), 2, "dimension %d out of range of %dD tensor",
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
dim+TH_INDEX_BASE, THTensor_(_nDimension)(self));
return self->stride[dim];
}

Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/generic/THTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ typedef struct THTensor THTensor;
/**** access methods ****/
TH_API THStorage* THTensor_(storage)(const THTensor *self);
TH_API ptrdiff_t THTensor_(storageOffset)(const THTensor *self);

// See [NOTE: _dim() vs dim()]; _nDimension corresponds to _dim(), nDimension corresponds to dim().
TH_API int THTensor_(nDimension)(const THTensor *self);
TH_API int THTensor_(_nDimension)(const THTensor *self);
TH_API int64_t THTensor_(size)(const THTensor *self, int dim);
TH_API int64_t THTensor_(stride)(const THTensor *self, int dim);
TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/TH/generic/THTensorCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
const int MIN_SZ = 60 * 60;
return THTensor_(isContiguous)(tensor) &&
THTensor_(nDimension)(src) == 2 &&
THTensor_(_nDimension)(src) == 2 &&
THTensor_(stride)(src, 0) == 1 &&
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
THTensor_(nElement)(tensor) >= MIN_SZ;
Expand Down
10 changes: 5 additions & 5 deletions aten/src/TH/generic/THTensorLapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co

void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a)
{
THArgCheck(THTensor_(nDimension)(a) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimension)(a));
THArgCheck(THTensor_(_nDimension)(a) == 3, 1, "expected 3D tensor, got %dD", THTensor_(_nDimension)(a));
if (!pivot) {
THError("btrifact without pivoting is not implemented on the CPU");
}
Expand Down Expand Up @@ -1023,10 +1023,10 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf

void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots)
{
THArgCheck(THTensor_(nDimension)(atf) == 3, 1, "expected 3D tensor, got %dD",
THTensor_(nDimension)(atf));
THArgCheck(THTensor_(nDimension)(b) == 3 ||
THTensor_(nDimension)(b) == 2, 4, "expected 2D or 3D tensor");
THArgCheck(THTensor_(_nDimension)(atf) == 3, 1, "expected 3D tensor, got %dD",
THTensor_(_nDimension)(atf));
THArgCheck(THTensor_(_nDimension)(b) == 3 ||
THTensor_(_nDimension)(b) == 2, 4, "expected 2D or 3D tensor");
THArgCheck(THTensor_(size)(atf, 0) ==
THTensor_(size)(b, 0), 3, "number of batches must be equal");
THArgCheck(THTensor_(size)(atf, 1) ==
Expand Down
Loading