Skip to content

Commit 00e91fd

Browse files
committed
[android] Use torch::from_blob instead of shareExternalPointer, nits
1 parent fc93d1a commit 00e91fd

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

android/pytorch_android/src/main/cpp/pytorch_jni.cpp

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,53 +41,51 @@ struct JHashMap
4141
};
4242

4343
static at::Tensor newAtTensor(
44-
facebook::jni::alias_ref<facebook::jni::JBuffer> inputData,
45-
facebook::jni::alias_ref<jlongArray> inputDims,
44+
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
45+
facebook::jni::alias_ref<jlongArray> jdims,
4646
jint typeCode) {
47-
const auto inputDimsRank = inputDims->size();
48-
const auto inputDimsArr = inputDims->getRegion(0, inputDimsRank);
49-
std::vector<int64_t> inputDimsVec;
50-
auto inputNumel = 1;
51-
for (auto i = 0; i < inputDimsRank; ++i) {
52-
inputDimsVec.push_back(inputDimsArr[i]);
53-
inputNumel *= inputDimsArr[i];
47+
const auto rank = jdims->size();
48+
const auto dimsArr = jdims->getRegion(0, rank);
49+
std::vector<int64_t> dimsVec{};
50+
dimsVec.reserve(rank);
51+
auto numel = 1;
52+
for (auto i = 0; i < rank; ++i) {
53+
dimsVec.push_back(dimsArr[i]);
54+
numel *= dimsArr[i];
5455
}
5556
JNIEnv* jni = facebook::jni::Environment::current();
56-
caffe2::TypeMeta inputTypeMeta{};
57-
int inputDataElementSizeBytes = 0;
57+
caffe2::TypeMeta typeMeta{};
58+
int dataElementSizeBytes = 0;
5859
if (kTensorTypeCodeFloat32 == typeCode) {
59-
inputDataElementSizeBytes = 4;
60-
inputTypeMeta = caffe2::TypeMeta::Make<float>();
60+
dataElementSizeBytes = 4;
61+
typeMeta = caffe2::TypeMeta::Make<float>();
6162
} else if (kTensorTypeCodeInt32 == typeCode) {
62-
inputDataElementSizeBytes = 4;
63-
inputTypeMeta = caffe2::TypeMeta::Make<int>();
63+
dataElementSizeBytes = 4;
64+
typeMeta = caffe2::TypeMeta::Make<int>();
6465
} else if (kTensorTypeCodeByte == typeCode) {
65-
inputDataElementSizeBytes = 1;
66-
inputTypeMeta = caffe2::TypeMeta::Make<uint8_t>();
66+
dataElementSizeBytes = 1;
67+
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
6768
} else {
6869
facebook::jni::throwNewJavaException(
6970
facebook::jni::gJavaLangIllegalArgumentException,
7071
"Unknown Tensor typeCode %d",
7172
typeCode);
7273
}
73-
const auto inputDataCapacity = jni->GetDirectBufferCapacity(inputData.get());
74-
if (inputDataCapacity != inputNumel) {
74+
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
75+
if (dataCapacity != numel) {
7576
facebook::jni::throwNewJavaException(
7677
facebook::jni::gJavaLangIllegalArgumentException,
7778
"Tensor dimensions(elements number:%d, element byte size:%d, total "
7879
"bytes:%d) inconsistent with buffer capacity(%d)",
79-
inputNumel,
80-
inputDataElementSizeBytes,
81-
inputNumel * inputDataElementSizeBytes,
82-
inputDataCapacity);
80+
numel,
81+
dataElementSizeBytes,
82+
numel * dataElementSizeBytes,
83+
dataCapacity);
8384
}
84-
85-
at::Tensor inputTensor = torch::empty(torch::IntArrayRef(inputDimsVec));
86-
inputTensor.unsafeGetTensorImpl()->ShareExternalPointer(
87-
{jni->GetDirectBufferAddress(inputData.get()), at::DeviceType::CPU},
88-
inputTypeMeta,
89-
inputDataCapacity);
90-
return inputTensor;
85+
return torch::from_blob(
86+
jni->GetDirectBufferAddress(jbuffer.get()),
87+
torch::IntArrayRef(dimsVec),
88+
at::TensorOptions(typeMeta));
9189
}
9290

9391
class JTensor : public facebook::jni::JavaClass<JTensor> {

0 commit comments

Comments
 (0)