Skip to content

Commit d250f01

Browse files
IvanKobzarevfacebook-github-bot
authored andcommitted
Tensor renaming to dtype, shape; support long, double (#26183)
Summary: Applying dzhulgakov review comments org.pytorch.Tensor: - dims renamed to shape - typeCode to dtype - numElements to numel newFloatTensor, newIntTensor... to newTensor(...) Add support of dtype=long, double Resorted in code byte,int,float,long,double For if conditions order float,int,byte,long,double as I expect that float and int branches will be used more often Tensor.toString() does not have data, only numel (data buffer capacity) Pull Request resolved: #26183 Differential Revision: D17374332 Pulled By: IvanKobzarev fbshipit-source-id: ee93977d9c43c400b6c054b6286080321ccb81bc
1 parent 1114b05 commit d250f01

File tree

6 files changed

+293
-201
lines changed

6 files changed

+293
-201
lines changed

android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
package org.pytorch;
22

3-
import static org.junit.Assert.assertArrayEquals;
4-
import static org.junit.Assert.assertFalse;
5-
import static org.junit.Assert.assertNotNull;
6-
import static org.junit.Assert.assertTrue;
7-
83
import android.content.Context;
9-
import androidx.test.ext.junit.runners.AndroidJUnit4;
10-
import androidx.test.platform.app.InstrumentationRegistry;
4+
5+
import org.junit.Before;
6+
import org.junit.Test;
7+
import org.junit.runner.RunWith;
8+
119
import java.io.File;
1210
import java.io.FileOutputStream;
1311
import java.io.IOException;
1412
import java.io.InputStream;
1513
import java.io.OutputStream;
1614
import java.util.HashMap;
1715
import java.util.Map;
18-
import org.junit.Before;
19-
import org.junit.Test;
20-
import org.junit.runner.RunWith;
16+
17+
import androidx.test.ext.junit.runners.AndroidJUnit4;
18+
import androidx.test.platform.app.InstrumentationRegistry;
19+
20+
import static org.junit.Assert.assertArrayEquals;
21+
import static org.junit.Assert.assertFalse;
22+
import static org.junit.Assert.assertNotNull;
23+
import static org.junit.Assert.assertTrue;
2124

2225
@RunWith(AndroidJUnit4.class)
2326
public class PytorchInstrumentedTests {
@@ -33,7 +36,7 @@ public void setUp() {
3336
public void testForwardNull() throws IOException {
3437
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
3538
final IValue input =
36-
IValue.tensor(Tensor.newByteTensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
39+
IValue.tensor(Tensor.newTensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
3740
assertTrue(input.isTensor());
3841
final IValue output = module.forward(input);
3942
assertTrue(output.isNull());
@@ -94,13 +97,13 @@ public void testEqFloat() throws IOException {
9497

9598
@Test
9699
public void testEqTensor() throws IOException {
97-
final long[] inputTensorDims = new long[] {1, 3, 224, 224};
98-
final long numElements = Tensor.numElements(inputTensorDims);
100+
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
101+
final long numElements = Tensor.numel(inputTensorShape);
99102
final float[] inputTensorData = new float[(int) numElements];
100103
for (int i = 0; i < numElements; ++i) {
101104
inputTensorData[i] = i;
102105
}
103-
final Tensor inputTensor = Tensor.newFloatTensor(inputTensorDims, inputTensorData);
106+
final Tensor inputTensor = Tensor.newTensor(inputTensorShape, inputTensorData);
104107

105108
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
106109
final IValue input = IValue.tensor(inputTensor);
@@ -110,7 +113,7 @@ public void testEqTensor() throws IOException {
110113
assertTrue(output.isTensor());
111114
final Tensor outputTensor = output.getTensor();
112115
assertNotNull(outputTensor);
113-
assertArrayEquals(inputTensorDims, outputTensor.dims);
116+
assertArrayEquals(inputTensorShape, outputTensor.shape);
114117
float[] outputData = outputTensor.getDataAsFloatArray();
115118
for (int i = 0; i < numElements; i++) {
116119
assertTrue(inputTensorData[i] == outputData[i]);
@@ -216,8 +219,8 @@ public void testRunUndefinedMethod() throws IOException {
216219

217220
@Test
218221
public void testTensorMethods() {
219-
long[] dims = new long[] {1, 3, 224, 224};
220-
final int numel = (int) Tensor.numElements(dims);
222+
long[] shape = new long[] {1, 3, 224, 224};
223+
final int numel = (int) Tensor.numel(shape);
221224
int[] ints = new int[numel];
222225
float[] floats = new float[numel];
223226

@@ -228,16 +231,16 @@ public void testTensorMethods() {
228231
floats[i] = i / 1000.f;
229232
}
230233

231-
Tensor tensorBytes = Tensor.newByteTensor(dims, bytes);
232-
assertTrue(tensorBytes.isByteTensor());
234+
Tensor tensorBytes = Tensor.newTensor(shape, bytes);
235+
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_BYTE);
233236
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
234237

235-
Tensor tensorInts = Tensor.newIntTensor(dims, ints);
236-
assertTrue(tensorInts.isIntTensor());
238+
Tensor tensorInts = Tensor.newTensor(shape, ints);
239+
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
237240
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
238241

239-
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
240-
assertTrue(tensorFloats.isFloatTensor());
242+
Tensor tensorFloats = Tensor.newTensor(shape, floats);
243+
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
241244
float[] floatsOut = tensorFloats.getDataAsFloatArray();
242245
assertTrue(floatsOut.length == numel);
243246
for (int i = 0; i < numel; i++) {
@@ -247,11 +250,11 @@ public void testTensorMethods() {
247250

248251
@Test(expected = IllegalStateException.class)
249252
public void testTensorIllegalStateOnWrongType() {
250-
long[] dims = new long[] {1, 3, 224, 224};
251-
final int numel = (int) Tensor.numElements(dims);
253+
long[] shape = new long[] {1, 3, 224, 224};
254+
final int numel = (int) Tensor.numel(shape);
252255
float[] floats = new float[numel];
253-
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
254-
assertTrue(tensorFloats.isFloatTensor());
256+
Tensor tensorFloats = Tensor.newTensor(shape, floats);
257+
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
255258
tensorFloats.getDataAsByteArray();
256259
}
257260

android/pytorch_android/src/main/cpp/pytorch_jni.cpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010

1111
namespace pytorch_jni {
1212

13-
constexpr static int kTensorTypeCodeByte = 1;
14-
constexpr static int kTensorTypeCodeInt32 = 2;
15-
constexpr static int kTensorTypeCodeFloat32 = 3;
13+
constexpr static int kTensorDTypeByte = 1;
14+
constexpr static int kTensorDTypeInt32 = 2;
15+
constexpr static int kTensorDTypeFloat32 = 3;
16+
constexpr static int kTensorDTypeLong64 = 4;
17+
constexpr static int kTensorDTypeDouble64 = 5;
1618

1719
template <typename K = jobject, typename V = jobject>
1820
struct JHashMap
@@ -42,34 +44,40 @@ struct JHashMap
4244

4345
static at::Tensor newAtTensor(
4446
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
45-
facebook::jni::alias_ref<jlongArray> jdims,
46-
jint typeCode) {
47-
const auto rank = jdims->size();
48-
const auto dimsArr = jdims->getRegion(0, rank);
49-
std::vector<int64_t> dimsVec{};
50-
dimsVec.reserve(rank);
47+
facebook::jni::alias_ref<jlongArray> jshape,
48+
jint jdtype) {
49+
const auto rank = jshape->size();
50+
const auto shapeArr = jshape->getRegion(0, rank);
51+
std::vector<int64_t> shapeVec{};
52+
shapeVec.reserve(rank);
5153
auto numel = 1;
5254
for (auto i = 0; i < rank; ++i) {
53-
dimsVec.push_back(dimsArr[i]);
54-
numel *= dimsArr[i];
55+
shapeVec.push_back(shapeArr[i]);
56+
numel *= shapeArr[i];
5557
}
5658
JNIEnv* jni = facebook::jni::Environment::current();
5759
caffe2::TypeMeta typeMeta{};
5860
int dataElementSizeBytes = 0;
59-
if (kTensorTypeCodeFloat32 == typeCode) {
61+
if (kTensorDTypeFloat32 == jdtype) {
6062
dataElementSizeBytes = 4;
6163
typeMeta = caffe2::TypeMeta::Make<float>();
62-
} else if (kTensorTypeCodeInt32 == typeCode) {
64+
} else if (kTensorDTypeInt32 == jdtype) {
6365
dataElementSizeBytes = 4;
64-
typeMeta = caffe2::TypeMeta::Make<int>();
65-
} else if (kTensorTypeCodeByte == typeCode) {
66+
typeMeta = caffe2::TypeMeta::Make<int32_t>();
67+
} else if (kTensorDTypeByte == jdtype) {
6668
dataElementSizeBytes = 1;
67-
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
69+
typeMeta = caffe2::TypeMeta::Make<int8_t>();
70+
} else if (kTensorDTypeLong64 == jdtype) {
71+
dataElementSizeBytes = 8;
72+
typeMeta = caffe2::TypeMeta::Make<int64_t>();
73+
} else if (kTensorDTypeDouble64 == jdtype) {
74+
dataElementSizeBytes = 8;
75+
typeMeta = caffe2::TypeMeta::Make<double>();
6876
} else {
6977
facebook::jni::throwNewJavaException(
7078
facebook::jni::gJavaLangIllegalArgumentException,
71-
"Unknown Tensor typeCode %d",
72-
typeCode);
79+
"Unknown Tensor jdtype %d",
80+
jdtype);
7381
}
7482
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
7583
if (dataCapacity != numel) {
@@ -84,7 +92,7 @@ static at::Tensor newAtTensor(
8492
}
8593
return torch::from_blob(
8694
jni->GetDirectBufferAddress(jbuffer.get()),
87-
torch::IntArrayRef(dimsVec),
95+
torch::IntArrayRef(shapeVec),
8896
at::TensorOptions(typeMeta));
8997
}
9098

@@ -94,44 +102,48 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
94102

95103
static facebook::jni::local_ref<JTensor> newJTensor(
96104
facebook::jni::alias_ref<facebook::jni::JByteBuffer> jBuffer,
97-
facebook::jni::alias_ref<jlongArray> jDims,
98-
jint typeCode) {
105+
facebook::jni::alias_ref<jlongArray> jShape,
106+
jint jdtype) {
99107
static auto jMethodNewTensor =
100108
JTensor::javaClassStatic()
101109
->getStaticMethod<facebook::jni::local_ref<JTensor>(
102110
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
103111
facebook::jni::alias_ref<jlongArray>,
104112
jint)>("nativeNewTensor");
105113
return jMethodNewTensor(
106-
JTensor::javaClassStatic(), jBuffer, jDims, typeCode);
114+
JTensor::javaClassStatic(), jBuffer, jShape, jdtype);
107115
}
108116

109117
static facebook::jni::local_ref<JTensor> newJTensorFromAtTensor(
110118
const at::Tensor& tensor) {
111119
const auto scalarType = tensor.scalar_type();
112-
int typeCode = 0;
120+
int jdtype = 0;
113121
if (at::kFloat == scalarType) {
114-
typeCode = kTensorTypeCodeFloat32;
122+
jdtype = kTensorDTypeFloat32;
115123
} else if (at::kInt == scalarType) {
116-
typeCode = kTensorTypeCodeInt32;
124+
jdtype = kTensorDTypeInt32;
117125
} else if (at::kByte == scalarType) {
118-
typeCode = kTensorTypeCodeByte;
126+
jdtype = kTensorDTypeByte;
127+
} else if (at::kLong == scalarType) {
128+
jdtype = kTensorDTypeLong64;
129+
} else if (at::kDouble == scalarType) {
130+
jdtype = kTensorDTypeDouble64;
119131
} else {
120132
facebook::jni::throwNewJavaException(
121133
facebook::jni::gJavaLangIllegalArgumentException,
122134
"at::Tensor scalar type is not supported on java side");
123135
}
124136

125-
const auto& tensorDims = tensor.sizes();
126-
std::vector<int64_t> tensorDimsVec;
127-
for (const auto& dim : tensorDims) {
128-
tensorDimsVec.push_back(dim);
137+
const auto& tensorShape = tensor.sizes();
138+
std::vector<int64_t> tensorShapeVec;
139+
for (const auto& s : tensorShape) {
140+
tensorShapeVec.push_back(s);
129141
}
130142

131-
facebook::jni::local_ref<jlongArray> jTensorDims =
132-
facebook::jni::make_long_array(tensorDimsVec.size());
143+
facebook::jni::local_ref<jlongArray> jTensorShape =
144+
facebook::jni::make_long_array(tensorShapeVec.size());
133145

134-
jTensorDims->setRegion(0, tensorDimsVec.size(), tensorDimsVec.data());
146+
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
135147

136148
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
137149
facebook::jni::JByteBuffer::allocateDirect(tensor.nbytes());
@@ -140,18 +152,18 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
140152
jTensorBuffer->getDirectBytes(),
141153
tensor.storage().data(),
142154
tensor.nbytes());
143-
return JTensor::newJTensor(jTensorBuffer, jTensorDims, typeCode);
155+
return JTensor::newJTensor(jTensorBuffer, jTensorShape, jdtype);
144156
}
145157

146158
static at::Tensor newAtTensorFromJTensor(
147159
facebook::jni::alias_ref<JTensor> jtensor) {
148-
static const auto typeCodeMethod =
149-
JTensor::javaClassStatic()->getMethod<jint()>("getTypeCode");
150-
jint typeCode = typeCodeMethod(jtensor);
160+
static const auto dtypeMethod =
161+
JTensor::javaClassStatic()->getMethod<jint()>("dtype");
162+
jint jdtype = dtypeMethod(jtensor);
151163

152-
static const auto dimsField =
153-
JTensor::javaClassStatic()->getField<jlongArray>("dims");
154-
auto jdims = jtensor->getFieldValue(dimsField);
164+
static const auto shapeField =
165+
JTensor::javaClassStatic()->getField<jlongArray>("shape");
166+
auto jshape = jtensor->getFieldValue(shapeField);
155167

156168
static auto dataBufferMethod =
157169
JTensor::javaClassStatic()
@@ -160,7 +172,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
160172
"getRawDataBuffer");
161173
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
162174
dataBufferMethod(jtensor);
163-
return newAtTensor(jbuffer, jdims, typeCode);
175+
return newAtTensor(jbuffer, jshape, jdtype);
164176
}
165177
};
166178

android/pytorch_android/src/main/java/org/pytorch/Module.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ public static Module load(final String modelAbsolutePath) {
1212
return new Module(modelAbsolutePath);
1313
}
1414

15-
private Module(final String modelAbsolutePath) {
16-
this.mNativePeer = new NativePeer(modelAbsolutePath);
15+
private Module(final String moduleAbsolutePath) {
16+
this.mNativePeer = new NativePeer(moduleAbsolutePath);
1717
}
1818

1919
public IValue forward(IValue... inputs) {

0 commit comments

Comments
 (0)