Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
package org.pytorch;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import android.content.Context;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.platform.app.InstrumentationRegistry;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.platform.app.InstrumentationRegistry;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

@RunWith(AndroidJUnit4.class)
public class PytorchInstrumentedTests {
Expand All @@ -33,7 +36,7 @@ public void setUp() {
public void testForwardNull() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final IValue input =
IValue.tensor(Tensor.newByteTensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
IValue.tensor(Tensor.newTensor(new long[] {1}, Tensor.allocateByteBuffer(1)));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
Expand Down Expand Up @@ -94,13 +97,13 @@ public void testEqFloat() throws IOException {

@Test
public void testEqTensor() throws IOException {
final long[] inputTensorDims = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numElements(inputTensorDims);
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.newFloatTensor(inputTensorDims, inputTensorData);
final Tensor inputTensor = Tensor.newTensor(inputTensorShape, inputTensorData);

final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final IValue input = IValue.tensor(inputTensor);
Expand All @@ -110,7 +113,7 @@ public void testEqTensor() throws IOException {
assertTrue(output.isTensor());
final Tensor outputTensor = output.getTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorDims, outputTensor.dims);
assertArrayEquals(inputTensorShape, outputTensor.shape);
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
Expand Down Expand Up @@ -216,8 +219,8 @@ public void testRunUndefinedMethod() throws IOException {

@Test
public void testTensorMethods() {
long[] dims = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numElements(dims);
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];

Expand All @@ -228,16 +231,16 @@ public void testTensorMethods() {
floats[i] = i / 1000.f;
}

Tensor tensorBytes = Tensor.newByteTensor(dims, bytes);
assertTrue(tensorBytes.isByteTensor());
Tensor tensorBytes = Tensor.newTensor(shape, bytes);
assertTrue(tensorBytes.dtype() == Tensor.DTYPE_BYTE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't actually have a Byte dtype in the python frontend (because numpy defines it differently). It's probably better to use DTYPE_UINT8.

assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());

Tensor tensorInts = Tensor.newIntTensor(dims, ints);
assertTrue(tensorInts.isIntTensor());
Tensor tensorInts = Tensor.newTensor(shape, ints);
assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());

Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
assertTrue(tensorFloats.isFloatTensor());
Tensor tensorFloats = Tensor.newTensor(shape, floats);
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
Expand All @@ -247,11 +250,11 @@ public void testTensorMethods() {

@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] dims = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numElements(dims);
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.newFloatTensor(dims, floats);
assertTrue(tensorFloats.isFloatTensor());
Tensor tensorFloats = Tensor.newTensor(shape, floats);
assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32);
tensorFloats.getDataAsByteArray();
}

Expand Down
94 changes: 53 additions & 41 deletions android/pytorch_android/src/main/cpp/pytorch_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

namespace pytorch_jni {

constexpr static int kTensorTypeCodeByte = 1;
constexpr static int kTensorTypeCodeInt32 = 2;
constexpr static int kTensorTypeCodeFloat32 = 3;
constexpr static int kTensorDTypeByte = 1;
constexpr static int kTensorDTypeInt32 = 2;
constexpr static int kTensorDTypeFloat32 = 3;
constexpr static int kTensorDTypeLong64 = 4;
constexpr static int kTensorDTypeDouble64 = 5;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this standard naming in java? It seems repetitive (is a double ever not 64 bits?). In python we call this "float64" (and the one above int64).


template <typename K = jobject, typename V = jobject>
struct JHashMap
Expand Down Expand Up @@ -42,34 +44,40 @@ struct JHashMap

static at::Tensor newAtTensor(
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
facebook::jni::alias_ref<jlongArray> jdims,
jint typeCode) {
const auto rank = jdims->size();
const auto dimsArr = jdims->getRegion(0, rank);
std::vector<int64_t> dimsVec{};
dimsVec.reserve(rank);
facebook::jni::alias_ref<jlongArray> jshape,
jint jdtype) {
const auto rank = jshape->size();
const auto shapeArr = jshape->getRegion(0, rank);
std::vector<int64_t> shapeVec{};
shapeVec.reserve(rank);
auto numel = 1;
for (auto i = 0; i < rank; ++i) {
dimsVec.push_back(dimsArr[i]);
numel *= dimsArr[i];
shapeVec.push_back(shapeArr[i]);
numel *= shapeArr[i];
}
JNIEnv* jni = facebook::jni::Environment::current();
caffe2::TypeMeta typeMeta{};
int dataElementSizeBytes = 0;
if (kTensorTypeCodeFloat32 == typeCode) {
if (kTensorDTypeFloat32 == jdtype) {
dataElementSizeBytes = 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can do typeMeta.itemsize() to write it generically

typeMeta = caffe2::TypeMeta::Make<float>();
} else if (kTensorTypeCodeInt32 == typeCode) {
} else if (kTensorDTypeInt32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<int>();
} else if (kTensorTypeCodeByte == typeCode) {
typeMeta = caffe2::TypeMeta::Make<int32_t>();
} else if (kTensorDTypeByte == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
typeMeta = caffe2::TypeMeta::Make<int8_t>();
} else if (kTensorDTypeLong64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<int64_t>();
} else if (kTensorDTypeDouble64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<double>();
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown Tensor typeCode %d",
typeCode);
"Unknown Tensor jdtype %d",
jdtype);
}
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
if (dataCapacity != numel) {
Expand All @@ -84,7 +92,7 @@ static at::Tensor newAtTensor(
}
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
torch::IntArrayRef(dimsVec),
torch::IntArrayRef(shapeVec),
at::TensorOptions(typeMeta));
}

Expand All @@ -94,44 +102,48 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {

static facebook::jni::local_ref<JTensor> newJTensor(
facebook::jni::alias_ref<facebook::jni::JByteBuffer> jBuffer,
facebook::jni::alias_ref<jlongArray> jDims,
jint typeCode) {
facebook::jni::alias_ref<jlongArray> jShape,
jint jdtype) {
static auto jMethodNewTensor =
JTensor::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JTensor>(
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
facebook::jni::alias_ref<jlongArray>,
jint)>("nativeNewTensor");
return jMethodNewTensor(
JTensor::javaClassStatic(), jBuffer, jDims, typeCode);
JTensor::javaClassStatic(), jBuffer, jShape, jdtype);
}

static facebook::jni::local_ref<JTensor> newJTensorFromAtTensor(
const at::Tensor& tensor) {
const auto scalarType = tensor.scalar_type();
int typeCode = 0;
int jdtype = 0;
if (at::kFloat == scalarType) {
typeCode = kTensorTypeCodeFloat32;
jdtype = kTensorDTypeFloat32;
} else if (at::kInt == scalarType) {
typeCode = kTensorTypeCodeInt32;
jdtype = kTensorDTypeInt32;
} else if (at::kByte == scalarType) {
typeCode = kTensorTypeCodeByte;
jdtype = kTensorDTypeByte;
} else if (at::kLong == scalarType) {
jdtype = kTensorDTypeLong64;
} else if (at::kDouble == scalarType) {
jdtype = kTensorDTypeDouble64;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"at::Tensor scalar type is not supported on java side");
}

const auto& tensorDims = tensor.sizes();
std::vector<int64_t> tensorDimsVec;
for (const auto& dim : tensorDims) {
tensorDimsVec.push_back(dim);
const auto& tensorShape = tensor.sizes();
std::vector<int64_t> tensorShapeVec;
for (const auto& s : tensorShape) {
tensorShapeVec.push_back(s);
}

facebook::jni::local_ref<jlongArray> jTensorDims =
facebook::jni::make_long_array(tensorDimsVec.size());
facebook::jni::local_ref<jlongArray> jTensorShape =
facebook::jni::make_long_array(tensorShapeVec.size());

jTensorDims->setRegion(0, tensorDimsVec.size(), tensorDimsVec.data());
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());

facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
facebook::jni::JByteBuffer::allocateDirect(tensor.nbytes());
Expand All @@ -140,18 +152,18 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
jTensorBuffer->getDirectBytes(),
tensor.storage().data(),
tensor.nbytes());
return JTensor::newJTensor(jTensorBuffer, jTensorDims, typeCode);
return JTensor::newJTensor(jTensorBuffer, jTensorShape, jdtype);
}

static at::Tensor newAtTensorFromJTensor(
facebook::jni::alias_ref<JTensor> jtensor) {
static const auto typeCodeMethod =
JTensor::javaClassStatic()->getMethod<jint()>("getTypeCode");
jint typeCode = typeCodeMethod(jtensor);
static const auto dtypeMethod =
JTensor::javaClassStatic()->getMethod<jint()>("dtype");
jint jdtype = dtypeMethod(jtensor);

static const auto dimsField =
JTensor::javaClassStatic()->getField<jlongArray>("dims");
auto jdims = jtensor->getFieldValue(dimsField);
static const auto shapeField =
JTensor::javaClassStatic()->getField<jlongArray>("shape");
auto jshape = jtensor->getFieldValue(shapeField);

static auto dataBufferMethod =
JTensor::javaClassStatic()
Expand All @@ -160,7 +172,7 @@ class JTensor : public facebook::jni::JavaClass<JTensor> {
"getRawDataBuffer");
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
dataBufferMethod(jtensor);
return newAtTensor(jbuffer, jdims, typeCode);
return newAtTensor(jbuffer, jshape, jdtype);
}
};

Expand Down
4 changes: 2 additions & 2 deletions android/pytorch_android/src/main/java/org/pytorch/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ public static Module load(final String modelAbsolutePath) {
return new Module(modelAbsolutePath);
}

private Module(final String modelAbsolutePath) {
this.mNativePeer = new NativePeer(modelAbsolutePath);
private Module(final String moduleAbsolutePath) {
this.mNativePeer = new NativePeer(moduleAbsolutePath);
}

public IValue forward(IValue... inputs) {
Expand Down
Loading