Skip to content

Commit ffde23d

Browse files
seraveeezyang
authored andcommitted
use the correct datatype format (#8144)
1 parent e53fec0 commit ffde23d

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

caffe2/operators/shape_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ OPERATOR_SCHEMA(Shape)
2222
} else {
2323
out[0].add_dims(axes.size());
2424
}
25-
out[0].set_data_type(TensorProto::INT32);
25+
out[0].set_data_type(TensorProto::INT64);
2626
return out;
2727
})
2828
.SetDoc(R"DOC(

caffe2/python/operator_test/shape_inference_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,12 @@ def testInt8Conversion(self):
494494
workspace.FeedBlob('x', np.random.rand(100, 150).astype(np.float32))
495495
self.InferTensorRunAndCompare(model)
496496

497+
def testShapeOp(self):
498+
model = model_helper.ModelHelper(name="shape_op_test")
499+
model.Shape('x', 'y')
500+
workspace.FeedBlob('x', np.random.rand(100, 150).astype(np.float32))
501+
self.InferTensorRunAndCompare(model)
502+
497503
def InferTensorRunAndCompare(self, model, expected_uninferred_blobs=None):
498504
'''
499505
Runs shape inference, and then the model to check

0 commit comments

Comments
 (0)