Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions android/pytorch_android/src/main/cpp/pytorch_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,53 +41,51 @@ struct JHashMap
};

static at::Tensor newAtTensor(
facebook::jni::alias_ref<facebook::jni::JBuffer> inputData,
facebook::jni::alias_ref<jlongArray> inputDims,
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
facebook::jni::alias_ref<jlongArray> jdims,
jint typeCode) {
const auto inputDimsRank = inputDims->size();
const auto inputDimsArr = inputDims->getRegion(0, inputDimsRank);
std::vector<int64_t> inputDimsVec;
auto inputNumel = 1;
for (auto i = 0; i < inputDimsRank; ++i) {
inputDimsVec.push_back(inputDimsArr[i]);
inputNumel *= inputDimsArr[i];
const auto rank = jdims->size();
const auto dimsArr = jdims->getRegion(0, rank);
std::vector<int64_t> dimsVec{};
dimsVec.reserve(rank);
auto numel = 1;
for (auto i = 0; i < rank; ++i) {
dimsVec.push_back(dimsArr[i]);
numel *= dimsArr[i];
}
JNIEnv* jni = facebook::jni::Environment::current();
caffe2::TypeMeta inputTypeMeta{};
int inputDataElementSizeBytes = 0;
caffe2::TypeMeta typeMeta{};
int dataElementSizeBytes = 0;
if (kTensorTypeCodeFloat32 == typeCode) {
inputDataElementSizeBytes = 4;
inputTypeMeta = caffe2::TypeMeta::Make<float>();
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<float>();
} else if (kTensorTypeCodeInt32 == typeCode) {
inputDataElementSizeBytes = 4;
inputTypeMeta = caffe2::TypeMeta::Make<int>();
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<int>();
} else if (kTensorTypeCodeByte == typeCode) {
inputDataElementSizeBytes = 1;
inputTypeMeta = caffe2::TypeMeta::Make<uint8_t>();
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown Tensor typeCode %d",
typeCode);
}
const auto inputDataCapacity = jni->GetDirectBufferCapacity(inputData.get());
if (inputDataCapacity != inputNumel) {
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
if (dataCapacity != numel) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Tensor dimensions(elements number:%d, element byte size:%d, total "
"bytes:%d) inconsistent with buffer capacity(%d)",
inputNumel,
inputDataElementSizeBytes,
inputNumel * inputDataElementSizeBytes,
inputDataCapacity);
numel,
dataElementSizeBytes,
numel * dataElementSizeBytes,
dataCapacity);
}

at::Tensor inputTensor = torch::empty(torch::IntArrayRef(inputDimsVec));
inputTensor.unsafeGetTensorImpl()->ShareExternalPointer(
{jni->GetDirectBufferAddress(inputData.get()), at::DeviceType::CPU},
inputTypeMeta,
inputDataCapacity);
return inputTensor;
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
torch::IntArrayRef(dimsVec),
at::TensorOptions(typeMeta));
}

class JTensor : public facebook::jni::JavaClass<JTensor> {
Expand Down