Skip to content

Commit 20b27f9

Browse files
ShadyBoukharyumar456
authored andcommitted
Added wrapper around mean function.
1 parent f51a396 commit 20b27f9

5 files changed

Lines changed: 111 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ADD_JAR(${AF_JAR}
2929
com/arrayfire/Image.java
3030
com/arrayfire/Signal.java
3131
com/arrayfire/Util.java
32+
com/arrayfire/Statistics.java
3233
com/util/JNIException.java
3334
com/arrayfire/ArrayFireException.java
3435
)

com/arrayfire/Statistics.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.arrayfire;
2+
3+
public class Statistics extends ArrayFire {
4+
static private native long afMean(long ref, int dim);
5+
6+
static private native long afMeanWeighted(long ref, long weightsRef, int dim);
7+
8+
static private native double afMeanAll(long ref, int[] dims);
9+
10+
static private native FloatComplex afMeanAllFloatComplex(long ref, int[] dims);
11+
static private native DoubleComplex afMeanAllDoubleComplex(long ref, int[] dims);
12+
13+
//static private native jlong afMeanAllWeighted(long ref, long weightsRef);
14+
15+
static public Array mean(final Array in, int dim) {
16+
return new Array(afMean(in.ref, dim));
17+
}
18+
19+
static public Array mean(final Array in, final Array weights, int dim) {
20+
return new Array(afMeanWeighted(in.ref, weights.ref, dim));
21+
}
22+
23+
static public <T> T mean(final Array in, Class<T> type) throws Exception {
24+
if (type == FloatComplex.class) {
25+
FloatComplex res = (FloatComplex)afMeanAllFloatComplex(in.ref, in.dims());
26+
return type.cast(res);
27+
} else if (type == DoubleComplex.class) {
28+
DoubleComplex res = (DoubleComplex)afMeanAllDoubleComplex(in.ref, in.dims());
29+
return type.cast(res);
30+
}
31+
32+
double res = afMeanAll(in.ref, in.dims());
33+
if (type == Float.class) {
34+
return type.cast(Float.valueOf((float)res));
35+
} else if (type == Double.class) {
36+
return type.cast(Double.valueOf((double) res));
37+
} else if (type == Integer.class) {
38+
return type.cast(Integer.valueOf((int) res));
39+
}
40+
throw new Exception("Unknown type");
41+
}
42+
}
43+

examples/HelloWorld.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,18 @@ public static void main(String[] args) {
2525
Arith.mul(d, b, c);
2626
d.print("d");
2727

28+
29+
Array forMean = new Array();
30+
Array weights = new Array();
31+
Data.randn(forMean, new int[] {1}, Array.FloatComplexType);
32+
Data.randn(weights, new int[] {1}, Array.FloatComplexType);
33+
forMean.print("forMean");
34+
//Array mean = Statistics.mean(forMean, weights, 0);
35+
//mean.print("mean");
36+
37+
DoubleComplex abc = Statistics.mean(forMean, DoubleComplex.class);
38+
System.out.println(String.format("Mean is: %f and %f", abc.real(), abc.imag()));
39+
2840
System.out.println("Create a 2-by-3 matrix from host data");
2941
int[] dims = new int[] { 2, 3 };
3042
int total = 1;

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ ADD_LIBRARY(${AF_LIB} SHARED
2020
data.cpp
2121
image.cpp
2222
signal.cpp
23+
statistics.cpp
2324
util.cpp
2425
)
2526

src/statistics.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
#include "jni_helper.h"
3+
4+
BEGIN_EXTERN_C
5+
6+
#define STATISTICS_FUNC(FUNC) AF_MANGLE(Statistics, FUNC)
7+
8+
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMean)(JNIEnv *env, jclass clazz,
9+
jlong ref, jint dim) {
10+
af_array ret = 0;
11+
AF_CHECK(af_mean(&ret, ARRAY(ref), dim));
12+
return JLONG(ret);
13+
}
14+
15+
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMeanWeighted)(JNIEnv *env, jclass clazz,
16+
jlong ref, jlong weightsRef,
17+
jint dim) {
18+
af_array ret = 0;
19+
AF_CHECK(af_mean_weighted(&ret, ARRAY(ref), ARRAY(weightsRef), dim));
20+
return JLONG(ret);
21+
}
22+
23+
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAll)(JNIEnv *env,
24+
jclass clazz, jlong ref,
25+
jintArray dims) {
26+
double ret = 0;
27+
jint *dimptr = env->GetIntArrayElements(dims, 0);
28+
dim_t tdims[4] = {dimptr[0], dimptr[1], dimptr[2], dimptr[3]};
29+
AF_CHECK(af_mean_all(&ret, NULL, ARRAY(ref)));
30+
env->ReleaseIntArrayElements(dims, dimptr, 0);
31+
return (jdouble)ret;
32+
}
33+
34+
#define INSTANTIATE_MEAN(jtype, param) \
35+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)(JNIEnv *env, jclass clazz, \
36+
jlong ref, \
37+
jintArray dims) { \
38+
double real = 0, img = 0; \
39+
jint *dimptr = env->GetIntArrayElements(dims, 0); \
40+
dim_t tdims[4] = {dimptr[0], dimptr[1], dimptr[2], dimptr[3]}; \
41+
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
42+
env->ReleaseIntArrayElements(dims, dimptr, 0); \
43+
jclass cls = env->FindClass("com/arrayfire/"#jtype); \
44+
jmethodID id = env->GetMethodID(cls, "<init>", "("#param")V"); \
45+
jobject obj = env->NewObject(cls, id, real, img); \
46+
return obj; \
47+
}
48+
49+
INSTANTIATE_MEAN(FloatComplex, FF)
50+
INSTANTIATE_MEAN(DoubleComplex, DD)
51+
52+
END_EXTERN_C
53+
54+

0 commit comments

Comments
 (0)