Skip to content

Commit 6c00ae5

Browse files
ShadyBoukharyumar456
authored andcommitted
Created a function that creates new java objects in the native code.
This improves maintainablity as the code is less error prone and more compact.
1 parent b985338 commit 6c00ae5

File tree

5 files changed

+63
-36
lines changed

5 files changed

+63
-36
lines changed

examples/HelloWorld.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public static void main(String[] args) {
3232
Data.randn(weights, new int[] { 5, 3 }, Array.DoubleType);
3333
forVar.print("forVar");
3434

35-
double abc = Statistics.mean(forVar, weights, Double.class);
35+
double abc = Statistics.var(forVar, weights, Double.class);
3636
System.out.println(String.format("Variance is: %f", abc));
3737

3838
System.out.println("Create a 2-by-3 matrix from host data");

src/data.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,6 @@ JNIEXPORT jobjectArray JNICALL DATA_FUNC(getFloatComplexFromArray)(JNIEnv *env,
8181
AF_CHECK(af_get_elements(&elements, ARRAY(ref)));
8282

8383
jclass cls = env->FindClass("com/arrayfire/FloatComplex");
84-
jmethodID id = env->GetMethodID(cls, "<init>", "(FF)V");
85-
if (id == NULL) return NULL;
86-
8784
result = env->NewObjectArray(elements, cls, NULL);
8885

8986
af_cfloat *tmp = (af_cfloat *)malloc(sizeof(af_cfloat) * elements);
@@ -92,7 +89,7 @@ JNIEXPORT jobjectArray JNICALL DATA_FUNC(getFloatComplexFromArray)(JNIEnv *env,
9289
for (int i = 0; i < elements; i++) {
9390
float re = tmp[i].real;
9491
float im = tmp[i].imag;
95-
jobject obj = env->NewObject(cls, id, re, im);
92+
jobject obj = java::createJavaObject(env, java::JavaObjects::FloatComplex, re, im);
9693
env->SetObjectArrayElement(result, i, obj);
9794
}
9895

@@ -107,9 +104,6 @@ DATA_FUNC(getDoubleComplexFromArray)(JNIEnv *env, jclass clazz, jlong ref) {
107104
AF_CHECK(af_get_elements(&elements, ARRAY(ref)));
108105

109106
jclass cls = env->FindClass("com/arrayfire/DoubleComplex");
110-
jmethodID id = env->GetMethodID(cls, "<init>", "(DD)V");
111-
if (id == NULL) return NULL;
112-
113107
result = env->NewObjectArray(elements, cls, NULL);
114108

115109
af_cdouble *tmp = (af_cdouble *)malloc(sizeof(af_cdouble) * elements);
@@ -118,7 +112,8 @@ DATA_FUNC(getDoubleComplexFromArray)(JNIEnv *env, jclass clazz, jlong ref) {
118112
for (int i = 0; i < elements; i++) {
119113
double re = tmp[i].real;
120114
double im = tmp[i].imag;
121-
jobject obj = env->NewObject(cls, id, re, im);
115+
jobject obj =
116+
java::createJavaObject(env, java::JavaObjects::DoubleComplex, re, im);
122117
env->SetObjectArrayElement(result, i, obj);
123118
}
124119

src/java/java.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,37 @@ void throwArrayFireException(JNIEnv *env, const char *functionName,
7171
env->Throw(exception);
7272
env->DeleteLocalRef(exceptionClass);
7373
}
74+
75+
template <typename... Args>
76+
jobject createJavaObject(JNIEnv *env, JavaObjects objectType, Args... args) {
77+
switch (objectType) {
78+
case JavaObjects::FloatComplex: {
79+
80+
static jclass cls = env->FindClass("com/arrayfire/FloatComplex");
81+
static std::string sig = generateFunctionSignature(JavaType::Void,
82+
{JavaType::Float, JavaType::Float});
83+
static jmethodID id = env->GetMethodID(cls, "<init>", sig.c_str());
84+
jobject obj = env->NewObject(cls, id, args...);
85+
return obj;
86+
87+
} break;
88+
case JavaObjects::DoubleComplex: {
89+
90+
static jclass cls = env->FindClass("com/arrayfire/DoubleComplex");
91+
static std::string sig = generateFunctionSignature(
92+
JavaType::Void, {JavaType::Double, JavaType::Double});
93+
static jmethodID id = env->GetMethodID(cls, "<init>", sig.c_str());
94+
jobject obj = env->NewObject(cls, id, args...);
95+
return obj;
96+
} break;
97+
}
98+
}
99+
#define INSTANTIATE(type) \
100+
template jobject createJavaObject<type>(JNIEnv *, JavaObjects, type, type); \
101+
102+
INSTANTIATE(float)
103+
INSTANTIATE(double)
104+
105+
#undef INSTANTIATE
106+
74107
} // namespace java

src/java/java.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
namespace java {
77

8+
enum class JavaObjects {
9+
FloatComplex,
10+
DoubleComplex
11+
};
12+
13+
template<typename... Args>
14+
jobject createJavaObject(JNIEnv *env, JavaObjects objectType, Args... args);
15+
816
void throwArrayFireException(JNIEnv *env, const char *functionName,
917
const char *file, const int line, const int code);
1018
} // namespace java

src/statistics.cpp

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,21 @@ BEGIN_EXTERN_C
55

66
#define STATISTICS_FUNC(FUNC) AF_MANGLE(Statistics, FUNC)
77

8-
#define INSTANTIATE_MEAN(jtype, param) \
8+
#define INSTANTIATE_MEAN(jtype) \
99
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)( \
1010
JNIEnv * env, jclass clazz, jlong ref) { \
1111
double real = 0, img = 0; \
1212
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
13-
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
14-
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
15-
jobject obj = env->NewObject(cls, id, real, img); \
16-
return obj; \
13+
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
1714
}
1815

19-
#define INSTANTIATE_WEIGHTED(jtype, param, Name, name) \
16+
#define INSTANTIATE_WEIGHTED(jtype, Name, name) \
2017
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All##jtype##Weighted)( \
2118
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
2219
double real = 0, img = 0; \
2320
AF_CHECK( \
2421
af_##name##_all_weighted(&real, &img, ARRAY(ref), ARRAY(weightsRef))); \
25-
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
26-
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
27-
jobject obj = env->NewObject(cls, id, real, img); \
28-
return obj; \
22+
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
2923
}
3024

3125
#define INSTANTIATE_ALL_REAL_WEIGHTED(Name, name) \
@@ -45,15 +39,12 @@ BEGIN_EXTERN_C
4539
return JLONG(ret); \
4640
}
4741

48-
#define INSTANTIATE_VAR(jtype, param) \
49-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afVarAll##jtype)( \
50-
JNIEnv * env, jclass clazz, jlong ref, jboolean isBiased) { \
51-
double real = 0, img = 0; \
52-
AF_CHECK(af_var_all(&real, &img, ARRAY(ref), isBiased)); \
53-
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
54-
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
55-
jobject obj = env->NewObject(cls, id, real, img); \
56-
return obj; \
42+
#define INSTANTIATE_VAR(jtype) \
43+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afVarAll##jtype)( \
44+
JNIEnv * env, jclass clazz, jlong ref, jboolean isBiased) { \
45+
double real = 0, img = 0; \
46+
AF_CHECK(af_var_all(&real, &img, ARRAY(ref), isBiased)); \
47+
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
5748
}
5849

5950
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMean)(JNIEnv *env, jclass clazz,
@@ -70,12 +61,12 @@ JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAll)(JNIEnv *env, jclass clazz,
7061
return (jdouble)ret;
7162
}
7263

73-
INSTANTIATE_MEAN(FloatComplex, FF)
74-
INSTANTIATE_MEAN(DoubleComplex, DD)
64+
INSTANTIATE_MEAN(FloatComplex)
65+
INSTANTIATE_MEAN(DoubleComplex)
7566
INSTANTIATE_ALL_REAL_WEIGHTED(Mean, mean)
7667
INSTANTIATE_REAL_WEIGHTED(Mean, mean)
77-
INSTANTIATE_WEIGHTED(FloatComplex, FF, Mean, mean)
78-
INSTANTIATE_WEIGHTED(DoubleComplex, DD, Mean, mean)
68+
INSTANTIATE_WEIGHTED(FloatComplex, Mean, mean)
69+
INSTANTIATE_WEIGHTED(DoubleComplex, Mean, mean)
7970

8071
#undef INSTANTIATE_MEAN
8172

@@ -95,12 +86,12 @@ JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afVarAll)(JNIEnv *env, jclass clazz,
9586
return (jdouble)ret;
9687
}
9788

98-
INSTANTIATE_VAR(FloatComplex, FF)
99-
INSTANTIATE_VAR(DoubleComplex, DD)
89+
INSTANTIATE_VAR(FloatComplex)
90+
INSTANTIATE_VAR(DoubleComplex)
10091
INSTANTIATE_REAL_WEIGHTED(Var, var)
10192
INSTANTIATE_ALL_REAL_WEIGHTED(Var, var)
102-
INSTANTIATE_WEIGHTED(FloatComplex, FF, Var, var)
103-
INSTANTIATE_WEIGHTED(DoubleComplex, DD, Var, var)
93+
INSTANTIATE_WEIGHTED(FloatComplex, Var, var)
94+
INSTANTIATE_WEIGHTED(DoubleComplex, Var, var)
10495

10596
#undef INSTANTIATE_VAR
10697
#undef INSTANTIATE_MEAN

0 commit comments

Comments
 (0)