-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[android] Tensor renaming to dtype, shape; support long, double #26183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,9 +10,11 @@ | |
|
|
||
| namespace pytorch_jni { | ||
|
|
||
| constexpr static int kTensorTypeCodeByte = 1; | ||
| constexpr static int kTensorTypeCodeInt32 = 2; | ||
| constexpr static int kTensorTypeCodeFloat32 = 3; | ||
| constexpr static int kTensorDTypeByte = 1; | ||
| constexpr static int kTensorDTypeInt32 = 2; | ||
| constexpr static int kTensorDTypeFloat32 = 3; | ||
| constexpr static int kTensorDTypeLong64 = 4; | ||
| constexpr static int kTensorDTypeDouble64 = 5; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this standard naming in java? It seems repetitive (is a double ever not 64 bits?). In python we call this "float64" (and the one above int64). |
||
|
|
||
| template <typename K = jobject, typename V = jobject> | ||
| struct JHashMap | ||
|
|
@@ -42,34 +44,40 @@ struct JHashMap | |
|
|
||
| static at::Tensor newAtTensor( | ||
| facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer, | ||
| facebook::jni::alias_ref<jlongArray> jdims, | ||
| jint typeCode) { | ||
| const auto rank = jdims->size(); | ||
| const auto dimsArr = jdims->getRegion(0, rank); | ||
| std::vector<int64_t> dimsVec{}; | ||
| dimsVec.reserve(rank); | ||
| facebook::jni::alias_ref<jlongArray> jshape, | ||
| jint jdtype) { | ||
| const auto rank = jshape->size(); | ||
| const auto shapeArr = jshape->getRegion(0, rank); | ||
| std::vector<int64_t> shapeVec{}; | ||
| shapeVec.reserve(rank); | ||
| auto numel = 1; | ||
| for (auto i = 0; i < rank; ++i) { | ||
| dimsVec.push_back(dimsArr[i]); | ||
| numel *= dimsArr[i]; | ||
| shapeVec.push_back(shapeArr[i]); | ||
| numel *= shapeArr[i]; | ||
| } | ||
| JNIEnv* jni = facebook::jni::Environment::current(); | ||
| caffe2::TypeMeta typeMeta{}; | ||
| int dataElementSizeBytes = 0; | ||
| if (kTensorTypeCodeFloat32 == typeCode) { | ||
| if (kTensorDTypeFloat32 == jdtype) { | ||
| dataElementSizeBytes = 4; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you can do |
||
| typeMeta = caffe2::TypeMeta::Make<float>(); | ||
| } else if (kTensorTypeCodeInt32 == typeCode) { | ||
| } else if (kTensorDTypeInt32 == jdtype) { | ||
| dataElementSizeBytes = 4; | ||
| typeMeta = caffe2::TypeMeta::Make<int>(); | ||
| } else if (kTensorTypeCodeByte == typeCode) { | ||
| typeMeta = caffe2::TypeMeta::Make<int32_t>(); | ||
| } else if (kTensorDTypeByte == jdtype) { | ||
| dataElementSizeBytes = 1; | ||
| typeMeta = caffe2::TypeMeta::Make<uint8_t>(); | ||
| typeMeta = caffe2::TypeMeta::Make<int8_t>(); | ||
| } else if (kTensorDTypeLong64 == jdtype) { | ||
| dataElementSizeBytes = 8; | ||
| typeMeta = caffe2::TypeMeta::Make<int64_t>(); | ||
| } else if (kTensorDTypeDouble64 == jdtype) { | ||
| dataElementSizeBytes = 8; | ||
| typeMeta = caffe2::TypeMeta::Make<double>(); | ||
| } else { | ||
| facebook::jni::throwNewJavaException( | ||
| facebook::jni::gJavaLangIllegalArgumentException, | ||
| "Unknown Tensor typeCode %d", | ||
| typeCode); | ||
| "Unknown Tensor jdtype %d", | ||
| jdtype); | ||
| } | ||
| const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); | ||
| if (dataCapacity != numel) { | ||
|
|
@@ -84,7 +92,7 @@ static at::Tensor newAtTensor( | |
| } | ||
| return torch::from_blob( | ||
| jni->GetDirectBufferAddress(jbuffer.get()), | ||
| torch::IntArrayRef(dimsVec), | ||
| torch::IntArrayRef(shapeVec), | ||
| at::TensorOptions(typeMeta)); | ||
| } | ||
|
|
||
|
|
@@ -94,44 +102,48 @@ class JTensor : public facebook::jni::JavaClass<JTensor> { | |
|
|
||
| static facebook::jni::local_ref<JTensor> newJTensor( | ||
| facebook::jni::alias_ref<facebook::jni::JByteBuffer> jBuffer, | ||
| facebook::jni::alias_ref<jlongArray> jDims, | ||
| jint typeCode) { | ||
| facebook::jni::alias_ref<jlongArray> jShape, | ||
| jint jdtype) { | ||
| static auto jMethodNewTensor = | ||
| JTensor::javaClassStatic() | ||
| ->getStaticMethod<facebook::jni::local_ref<JTensor>( | ||
| facebook::jni::alias_ref<facebook::jni::JByteBuffer>, | ||
| facebook::jni::alias_ref<jlongArray>, | ||
| jint)>("nativeNewTensor"); | ||
| return jMethodNewTensor( | ||
| JTensor::javaClassStatic(), jBuffer, jDims, typeCode); | ||
| JTensor::javaClassStatic(), jBuffer, jShape, jdtype); | ||
| } | ||
|
|
||
| static facebook::jni::local_ref<JTensor> newJTensorFromAtTensor( | ||
| const at::Tensor& tensor) { | ||
| const auto scalarType = tensor.scalar_type(); | ||
| int typeCode = 0; | ||
| int jdtype = 0; | ||
| if (at::kFloat == scalarType) { | ||
| typeCode = kTensorTypeCodeFloat32; | ||
| jdtype = kTensorDTypeFloat32; | ||
| } else if (at::kInt == scalarType) { | ||
| typeCode = kTensorTypeCodeInt32; | ||
| jdtype = kTensorDTypeInt32; | ||
| } else if (at::kByte == scalarType) { | ||
| typeCode = kTensorTypeCodeByte; | ||
| jdtype = kTensorDTypeByte; | ||
| } else if (at::kLong == scalarType) { | ||
| jdtype = kTensorDTypeLong64; | ||
| } else if (at::kDouble == scalarType) { | ||
| jdtype = kTensorDTypeDouble64; | ||
| } else { | ||
| facebook::jni::throwNewJavaException( | ||
| facebook::jni::gJavaLangIllegalArgumentException, | ||
| "at::Tensor scalar type is not supported on java side"); | ||
| } | ||
|
|
||
| const auto& tensorDims = tensor.sizes(); | ||
| std::vector<int64_t> tensorDimsVec; | ||
| for (const auto& dim : tensorDims) { | ||
| tensorDimsVec.push_back(dim); | ||
| const auto& tensorShape = tensor.sizes(); | ||
| std::vector<int64_t> tensorShapeVec; | ||
| for (const auto& s : tensorShape) { | ||
| tensorShapeVec.push_back(s); | ||
| } | ||
|
|
||
| facebook::jni::local_ref<jlongArray> jTensorDims = | ||
| facebook::jni::make_long_array(tensorDimsVec.size()); | ||
| facebook::jni::local_ref<jlongArray> jTensorShape = | ||
| facebook::jni::make_long_array(tensorShapeVec.size()); | ||
|
|
||
| jTensorDims->setRegion(0, tensorDimsVec.size(), tensorDimsVec.data()); | ||
| jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data()); | ||
|
|
||
| facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer = | ||
| facebook::jni::JByteBuffer::allocateDirect(tensor.nbytes()); | ||
|
|
@@ -140,18 +152,18 @@ class JTensor : public facebook::jni::JavaClass<JTensor> { | |
| jTensorBuffer->getDirectBytes(), | ||
| tensor.storage().data(), | ||
| tensor.nbytes()); | ||
| return JTensor::newJTensor(jTensorBuffer, jTensorDims, typeCode); | ||
| return JTensor::newJTensor(jTensorBuffer, jTensorShape, jdtype); | ||
| } | ||
|
|
||
| static at::Tensor newAtTensorFromJTensor( | ||
| facebook::jni::alias_ref<JTensor> jtensor) { | ||
| static const auto typeCodeMethod = | ||
| JTensor::javaClassStatic()->getMethod<jint()>("getTypeCode"); | ||
| jint typeCode = typeCodeMethod(jtensor); | ||
| static const auto dtypeMethod = | ||
| JTensor::javaClassStatic()->getMethod<jint()>("dtype"); | ||
| jint jdtype = dtypeMethod(jtensor); | ||
|
|
||
| static const auto dimsField = | ||
| JTensor::javaClassStatic()->getField<jlongArray>("dims"); | ||
| auto jdims = jtensor->getFieldValue(dimsField); | ||
| static const auto shapeField = | ||
| JTensor::javaClassStatic()->getField<jlongArray>("shape"); | ||
| auto jshape = jtensor->getFieldValue(shapeField); | ||
|
|
||
| static auto dataBufferMethod = | ||
| JTensor::javaClassStatic() | ||
|
|
@@ -160,7 +172,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> { | |
| "getRawDataBuffer"); | ||
| facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer = | ||
| dataBufferMethod(jtensor); | ||
| return newAtTensor(jbuffer, jdims, typeCode); | ||
| return newAtTensor(jbuffer, jshape, jdtype); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't actually have a Byte dtype in the python frontend (because numpy defines it differently). It's probably better to use
DTYPE_UINT8.