Skip to content

Commit 0ea5978

Browse files
IvanKobzarevfacebook-github-bot
authored andcommitted
Use torch::from_blob instead of shareExternalPointer, nits (#25973)
Summary: The main part is to switch at::Tensor creation from usage of `torch::empty(torch::IntArrayRef(...))->ShareExternalPointer(...) to torch::from_blob(...)` Removed explicit set of `device CPU` as `at::TensorOptions` by default `device CPU` And renaming of local variables removing `input` prefix to make them shorter Pull Request resolved: #25973 Differential Revision: D17356837 Pulled By: IvanKobzarev fbshipit-source-id: 679e099b8aebd787dbf8ed422dae07a81243e18f
1 parent a3f0d98 commit 0ea5978

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

android/pytorch_android/src/main/cpp/pytorch_jni.cpp

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,53 +41,51 @@ struct JHashMap
4141
};
4242

4343
static at::Tensor newAtTensor(
44-
facebook::jni::alias_ref<facebook::jni::JBuffer> inputData,
45-
facebook::jni::alias_ref<jlongArray> inputDims,
44+
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
45+
facebook::jni::alias_ref<jlongArray> jdims,
4646
jint typeCode) {
47-
const auto inputDimsRank = inputDims->size();
48-
const auto inputDimsArr = inputDims->getRegion(0, inputDimsRank);
49-
std::vector<int64_t> inputDimsVec;
50-
auto inputNumel = 1;
51-
for (auto i = 0; i < inputDimsRank; ++i) {
52-
inputDimsVec.push_back(inputDimsArr[i]);
53-
inputNumel *= inputDimsArr[i];
47+
const auto rank = jdims->size();
48+
const auto dimsArr = jdims->getRegion(0, rank);
49+
std::vector<int64_t> dimsVec{};
50+
dimsVec.reserve(rank);
51+
auto numel = 1;
52+
for (auto i = 0; i < rank; ++i) {
53+
dimsVec.push_back(dimsArr[i]);
54+
numel *= dimsArr[i];
5455
}
5556
JNIEnv* jni = facebook::jni::Environment::current();
56-
caffe2::TypeMeta inputTypeMeta{};
57-
int inputDataElementSizeBytes = 0;
57+
caffe2::TypeMeta typeMeta{};
58+
int dataElementSizeBytes = 0;
5859
if (kTensorTypeCodeFloat32 == typeCode) {
59-
inputDataElementSizeBytes = 4;
60-
inputTypeMeta = caffe2::TypeMeta::Make<float>();
60+
dataElementSizeBytes = 4;
61+
typeMeta = caffe2::TypeMeta::Make<float>();
6162
} else if (kTensorTypeCodeInt32 == typeCode) {
62-
inputDataElementSizeBytes = 4;
63-
inputTypeMeta = caffe2::TypeMeta::Make<int>();
63+
dataElementSizeBytes = 4;
64+
typeMeta = caffe2::TypeMeta::Make<int>();
6465
} else if (kTensorTypeCodeByte == typeCode) {
65-
inputDataElementSizeBytes = 1;
66-
inputTypeMeta = caffe2::TypeMeta::Make<uint8_t>();
66+
dataElementSizeBytes = 1;
67+
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
6768
} else {
6869
facebook::jni::throwNewJavaException(
6970
facebook::jni::gJavaLangIllegalArgumentException,
7071
"Unknown Tensor typeCode %d",
7172
typeCode);
7273
}
73-
const auto inputDataCapacity = jni->GetDirectBufferCapacity(inputData.get());
74-
if (inputDataCapacity != inputNumel) {
74+
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
75+
if (dataCapacity != numel) {
7576
facebook::jni::throwNewJavaException(
7677
facebook::jni::gJavaLangIllegalArgumentException,
7778
"Tensor dimensions(elements number:%d, element byte size:%d, total "
7879
"bytes:%d) inconsistent with buffer capacity(%d)",
79-
inputNumel,
80-
inputDataElementSizeBytes,
81-
inputNumel * inputDataElementSizeBytes,
82-
inputDataCapacity);
80+
numel,
81+
dataElementSizeBytes,
82+
numel * dataElementSizeBytes,
83+
dataCapacity);
8384
}
84-
85-
at::Tensor inputTensor = torch::empty(torch::IntArrayRef(inputDimsVec));
86-
inputTensor.unsafeGetTensorImpl()->ShareExternalPointer(
87-
{jni->GetDirectBufferAddress(inputData.get()), at::DeviceType::CPU},
88-
inputTypeMeta,
89-
inputDataCapacity);
90-
return inputTensor;
85+
return torch::from_blob(
86+
jni->GetDirectBufferAddress(jbuffer.get()),
87+
torch::IntArrayRef(dimsVec),
88+
at::TensorOptions(typeMeta));
9189
}
9290

9391
class JTensor : public facebook::jni::JavaClass<JTensor> {

0 commit comments

Comments
 (0)