Skip to content

Commit f9c0a08

Browse files
David Riazatifacebook-github-bot
authored andcommitted
Fix len() for tensors (#13398)
Summary: Fixes #13376, `len(tensor)` was converting tensor to a 1 element list and returning 1 every time. Pull Request resolved: #13398 Differential Revision: D12867630 Pulled By: driazati fbshipit-source-id: 28f3580a072d763df0980b3149c49d1894842ec9
1 parent 9577811 commit f9c0a08

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

test/test_jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3511,6 +3511,12 @@ def bad_negative_index():
35113511
self.checkScriptRaisesRegex(bad_negative_index, (), IndexError,
35123512
"list index out of range")
35133513

3514+
def test_tensor_len(self):
3515+
def func(x):
3516+
return len(x)
3517+
3518+
self.checkScript(func, [torch.ones(4, 5, 6)])
3519+
35143520
def test_list_len(self):
35153521
def func():
35163522
a = [1, 2, 3]

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ RegisterOperators reg2({
768768

769769
#define DEFINE_STRING_OP(op_name, string_op, result) \
770770
Operator( \
771-
#op_name "(str a, str b) ->" #result, \
771+
#op_name "(str a, str b) ->" #result, \
772772
[](Node* node) { \
773773
return [=](Stack& stack) { \
774774
auto b = pop(stack).toStringRef(); \
@@ -781,8 +781,20 @@ Operator( \
781781
DEFINE_STRING_OP(aten::eq, a == b, bool)
782782
DEFINE_STRING_OP(aten::ne, a != b, bool)
783783
DEFINE_STRING_OP(aten::add, a + b, str)
784+
#undef DEFINE_STRING_OP
784785

785-
786+
// tensor length op (size of 1st dimension)
787+
Operator(
788+
"aten::len(Tensor t) -> int",
789+
[](Stack& stack) {
790+
at::Tensor t = pop(stack).toTensor();
791+
if (t.dim() == 0) {
792+
AT_ERROR("len() of a 0-d tensor");
793+
}
794+
push(stack, t.sizes()[0]);
795+
return 0;
796+
}
797+
),
786798
#define CREATE_LIST_OPS(decl_type, c_type) \
787799
Operator("aten::select(" decl_type "[] a, int b) -> " decl_type, listSelect<Shared<c_type>>), \
788800
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
@@ -799,6 +811,7 @@ Operator( \
799811
CREATE_LIST_OPS("float", DoubleList)
800812
CREATE_LIST_OPS("Tensor", TensorList)
801813
CREATE_LIST_OPS("t", GenericList)
814+
#undef CREATE_LIST_OPS
802815

803816

804817
Operator("aten::eq(int[] a, int[] b) -> int", listEq<Shared<IntList>>),

0 commit comments

Comments
 (0)