1010
1111namespace pytorch_jni {
1212
13- constexpr static int kTensorTypeCodeByte = 1 ;
14- constexpr static int kTensorTypeCodeInt32 = 2 ;
15- constexpr static int kTensorTypeCodeFloat32 = 3 ;
13+ constexpr static int kTensorDTypeByte = 1 ;
14+ constexpr static int kTensorDTypeInt32 = 2 ;
15+ constexpr static int kTensorDTypeFloat32 = 3 ;
16+ constexpr static int kTensorDTypeLong64 = 4 ;
17+ constexpr static int kTensorDTypeDouble64 = 5 ;
1618
1719template <typename K = jobject, typename V = jobject>
1820struct JHashMap
@@ -42,34 +44,40 @@ struct JHashMap
4244
4345static at::Tensor newAtTensor (
4446 facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
45- facebook::jni::alias_ref<jlongArray> jdims ,
46- jint typeCode ) {
47- const auto rank = jdims ->size ();
48- const auto dimsArr = jdims ->getRegion (0 , rank);
49- std::vector<int64_t > dimsVec {};
50- dimsVec .reserve (rank);
47+ facebook::jni::alias_ref<jlongArray> jshape ,
48+ jint jdtype ) {
49+ const auto rank = jshape ->size ();
50+ const auto shapeArr = jshape ->getRegion (0 , rank);
51+ std::vector<int64_t > shapeVec {};
52+ shapeVec .reserve (rank);
5153 auto numel = 1 ;
5254 for (auto i = 0 ; i < rank; ++i) {
53- dimsVec .push_back (dimsArr [i]);
54- numel *= dimsArr [i];
55+ shapeVec .push_back (shapeArr [i]);
56+ numel *= shapeArr [i];
5557 }
5658 JNIEnv* jni = facebook::jni::Environment::current ();
5759 caffe2::TypeMeta typeMeta{};
5860 int dataElementSizeBytes = 0 ;
59- if (kTensorTypeCodeFloat32 == typeCode ) {
61+ if (kTensorDTypeFloat32 == jdtype ) {
6062 dataElementSizeBytes = 4 ;
6163 typeMeta = caffe2::TypeMeta::Make<float >();
62- } else if (kTensorTypeCodeInt32 == typeCode ) {
64+ } else if (kTensorDTypeInt32 == jdtype ) {
6365 dataElementSizeBytes = 4 ;
64- typeMeta = caffe2::TypeMeta::Make<int >();
65- } else if (kTensorTypeCodeByte == typeCode ) {
66+ typeMeta = caffe2::TypeMeta::Make<int32_t >();
67+ } else if (kTensorDTypeByte == jdtype ) {
6668 dataElementSizeBytes = 1 ;
67- typeMeta = caffe2::TypeMeta::Make<uint8_t >();
69+ typeMeta = caffe2::TypeMeta::Make<int8_t >();
70+ } else if (kTensorDTypeLong64 == jdtype) {
71+ dataElementSizeBytes = 8 ;
72+ typeMeta = caffe2::TypeMeta::Make<int64_t >();
73+ } else if (kTensorDTypeDouble64 == jdtype) {
74+ dataElementSizeBytes = 8 ;
75+ typeMeta = caffe2::TypeMeta::Make<double >();
6876 } else {
6977 facebook::jni::throwNewJavaException (
7078 facebook::jni::gJavaLangIllegalArgumentException ,
71- " Unknown Tensor typeCode %d" ,
72- typeCode );
79+ " Unknown Tensor jdtype %d" ,
80+ jdtype );
7381 }
7482 const auto dataCapacity = jni->GetDirectBufferCapacity (jbuffer.get ());
7583 if (dataCapacity != numel) {
@@ -84,7 +92,7 @@ static at::Tensor newAtTensor(
8492 }
8593 return torch::from_blob (
8694 jni->GetDirectBufferAddress (jbuffer.get ()),
87- torch::IntArrayRef (dimsVec ),
95+ torch::IntArrayRef (shapeVec ),
8896 at::TensorOptions (typeMeta));
8997}
9098
@@ -94,44 +102,48 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
94102
95103 static facebook::jni::local_ref<JTensor> newJTensor (
96104 facebook::jni::alias_ref<facebook::jni::JByteBuffer> jBuffer,
97- facebook::jni::alias_ref<jlongArray> jDims ,
98- jint typeCode ) {
105+ facebook::jni::alias_ref<jlongArray> jShape ,
106+ jint jdtype ) {
99107 static auto jMethodNewTensor =
100108 JTensor::javaClassStatic ()
101109 ->getStaticMethod <facebook::jni::local_ref<JTensor>(
102110 facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
103111 facebook::jni::alias_ref<jlongArray>,
104112 jint)>(" nativeNewTensor" );
105113 return jMethodNewTensor (
106- JTensor::javaClassStatic (), jBuffer, jDims, typeCode );
114+ JTensor::javaClassStatic (), jBuffer, jShape, jdtype );
107115 }
108116
109117 static facebook::jni::local_ref<JTensor> newJTensorFromAtTensor (
110118 const at::Tensor& tensor) {
111119 const auto scalarType = tensor.scalar_type ();
112- int typeCode = 0 ;
120+ int jdtype = 0 ;
113121 if (at::kFloat == scalarType) {
114- typeCode = kTensorTypeCodeFloat32 ;
122+ jdtype = kTensorDTypeFloat32 ;
115123 } else if (at::kInt == scalarType) {
116- typeCode = kTensorTypeCodeInt32 ;
124+ jdtype = kTensorDTypeInt32 ;
117125 } else if (at::kByte == scalarType) {
118- typeCode = kTensorTypeCodeByte ;
126+ jdtype = kTensorDTypeByte ;
127+ } else if (at::kLong == scalarType) {
128+ jdtype = kTensorDTypeLong64 ;
129+ } else if (at::kDouble == scalarType) {
130+ jdtype = kTensorDTypeDouble64 ;
119131 } else {
120132 facebook::jni::throwNewJavaException (
121133 facebook::jni::gJavaLangIllegalArgumentException ,
122134 " at::Tensor scalar type is not supported on java side" );
123135 }
124136
125- const auto & tensorDims = tensor.sizes ();
126- std::vector<int64_t > tensorDimsVec ;
127- for (const auto & dim : tensorDims ) {
128- tensorDimsVec .push_back (dim );
137+ const auto & tensorShape = tensor.sizes ();
138+ std::vector<int64_t > tensorShapeVec ;
139+ for (const auto & s : tensorShape ) {
140+ tensorShapeVec .push_back (s );
129141 }
130142
131- facebook::jni::local_ref<jlongArray> jTensorDims =
132- facebook::jni::make_long_array (tensorDimsVec .size ());
143+ facebook::jni::local_ref<jlongArray> jTensorShape =
144+ facebook::jni::make_long_array (tensorShapeVec .size ());
133145
134- jTensorDims ->setRegion (0 , tensorDimsVec .size (), tensorDimsVec .data ());
146+ jTensorShape ->setRegion (0 , tensorShapeVec .size (), tensorShapeVec .data ());
135147
136148 facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
137149 facebook::jni::JByteBuffer::allocateDirect (tensor.nbytes ());
@@ -140,18 +152,18 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
140152 jTensorBuffer->getDirectBytes (),
141153 tensor.storage ().data (),
142154 tensor.nbytes ());
143- return JTensor::newJTensor (jTensorBuffer, jTensorDims, typeCode );
155+ return JTensor::newJTensor (jTensorBuffer, jTensorShape, jdtype );
144156 }
145157
146158 static at::Tensor newAtTensorFromJTensor (
147159 facebook::jni::alias_ref<JTensor> jtensor) {
148- static const auto typeCodeMethod =
149- JTensor::javaClassStatic ()->getMethod <jint ()>(" getTypeCode " );
150- jint typeCode = typeCodeMethod (jtensor);
160+ static const auto dtypeMethod =
161+ JTensor::javaClassStatic ()->getMethod <jint ()>(" dtype " );
162+ jint jdtype = dtypeMethod (jtensor);
151163
152- static const auto dimsField =
153- JTensor::javaClassStatic ()->getField <jlongArray>(" dims " );
154- auto jdims = jtensor->getFieldValue (dimsField );
164+ static const auto shapeField =
165+ JTensor::javaClassStatic ()->getField <jlongArray>(" shape " );
166+ auto jshape = jtensor->getFieldValue (shapeField );
155167
156168 static auto dataBufferMethod =
157169 JTensor::javaClassStatic ()
@@ -160,7 +172,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
160172 " getRawDataBuffer" );
161173 facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
162174 dataBufferMethod (jtensor);
163- return newAtTensor (jbuffer, jdims, typeCode );
175+ return newAtTensor (jbuffer, jshape, jdtype );
164176 }
165177};
166178
0 commit comments