@@ -41,53 +41,51 @@ struct JHashMap
4141};
4242
4343static at::Tensor newAtTensor (
44- facebook::jni::alias_ref<facebook::jni::JBuffer> inputData ,
45- facebook::jni::alias_ref<jlongArray> inputDims ,
44+ facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer ,
45+ facebook::jni::alias_ref<jlongArray> jdims ,
4646 jint typeCode) {
47- const auto inputDimsRank = inputDims->size ();
48- const auto inputDimsArr = inputDims->getRegion (0 , inputDimsRank);
49- std::vector<int64_t > inputDimsVec;
50- auto inputNumel = 1 ;
51- for (auto i = 0 ; i < inputDimsRank; ++i) {
52- inputDimsVec.push_back (inputDimsArr[i]);
53- inputNumel *= inputDimsArr[i];
47+ const auto rank = jdims->size ();
48+ const auto dimsArr = jdims->getRegion (0 , rank);
49+ std::vector<int64_t > dimsVec{};
50+ dimsVec.reserve (rank);
51+ auto numel = 1 ;
52+ for (auto i = 0 ; i < rank; ++i) {
53+ dimsVec.push_back (dimsArr[i]);
54+ numel *= dimsArr[i];
5455 }
5556 JNIEnv* jni = facebook::jni::Environment::current ();
56- caffe2::TypeMeta inputTypeMeta {};
57- int inputDataElementSizeBytes = 0 ;
57+ caffe2::TypeMeta typeMeta {};
58+ int dataElementSizeBytes = 0 ;
5859 if (kTensorTypeCodeFloat32 == typeCode) {
59- inputDataElementSizeBytes = 4 ;
60- inputTypeMeta = caffe2::TypeMeta::Make<float >();
60+ dataElementSizeBytes = 4 ;
61+ typeMeta = caffe2::TypeMeta::Make<float >();
6162 } else if (kTensorTypeCodeInt32 == typeCode) {
62- inputDataElementSizeBytes = 4 ;
63- inputTypeMeta = caffe2::TypeMeta::Make<int >();
63+ dataElementSizeBytes = 4 ;
64+ typeMeta = caffe2::TypeMeta::Make<int >();
6465 } else if (kTensorTypeCodeByte == typeCode) {
65- inputDataElementSizeBytes = 1 ;
66- inputTypeMeta = caffe2::TypeMeta::Make<uint8_t >();
66+ dataElementSizeBytes = 1 ;
67+ typeMeta = caffe2::TypeMeta::Make<uint8_t >();
6768 } else {
6869 facebook::jni::throwNewJavaException (
6970 facebook::jni::gJavaLangIllegalArgumentException ,
7071 " Unknown Tensor typeCode %d" ,
7172 typeCode);
7273 }
73- const auto inputDataCapacity = jni->GetDirectBufferCapacity (inputData .get ());
74- if (inputDataCapacity != inputNumel ) {
74+ const auto dataCapacity = jni->GetDirectBufferCapacity (jbuffer .get ());
75+ if (dataCapacity != numel ) {
7576 facebook::jni::throwNewJavaException (
7677 facebook::jni::gJavaLangIllegalArgumentException ,
7778 " Tensor dimensions(elements number:%d, element byte size:%d, total "
7879 " bytes:%d) inconsistent with buffer capacity(%d)" ,
79- inputNumel ,
80- inputDataElementSizeBytes ,
81- inputNumel * inputDataElementSizeBytes ,
82- inputDataCapacity );
80+ numel ,
81+ dataElementSizeBytes ,
82+ numel * dataElementSizeBytes ,
83+ dataCapacity );
8384 }
84-
85- at::Tensor inputTensor = torch::empty (torch::IntArrayRef (inputDimsVec));
86- inputTensor.unsafeGetTensorImpl ()->ShareExternalPointer (
87- {jni->GetDirectBufferAddress (inputData.get ()), at::DeviceType::CPU},
88- inputTypeMeta,
89- inputDataCapacity);
90- return inputTensor;
85+ return torch::from_blob (
86+ jni->GetDirectBufferAddress (jbuffer.get ()),
87+ torch::IntArrayRef (dimsVec),
88+ at::TensorOptions (typeMeta));
9189}
9290
9391class JTensor : public facebook ::jni::JavaClass<JTensor> {
0 commit comments