Skip to content

Commit 1f4cd41

Browse files
ShadyBoukharyumar456
authored andcommitted
Added complex types to mean wrappers and weighted mean.
1 parent 20b27f9 commit 1f4cd41

File tree

3 files changed

+96
-59
lines changed

3 files changed

+96
-59
lines changed

com/arrayfire/Statistics.java

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,68 @@
11
package com.arrayfire;
22

33
public class Statistics extends ArrayFire {
4-
static private native long afMean(long ref, int dim);
4+
static private native long afMean(long ref, int dim);
55

6-
static private native long afMeanWeighted(long ref, long weightsRef, int dim);
6+
static private native long afMeanWeighted(long ref, long weightsRef, int dim);
77

8-
static private native double afMeanAll(long ref, int[] dims);
8+
static private native double afMeanAll(long ref);
99

10-
static private native FloatComplex afMeanAllFloatComplex(long ref, int[] dims);
11-
static private native DoubleComplex afMeanAllDoubleComplex(long ref, int[] dims);
10+
static private native double afMeanAllWeighted(long ref, long weightsRef);
1211

13-
//static private native jlong afMeanAllWeighted(long ref, long weightsRef);
12+
static private native FloatComplex afMeanAllFloatComplex(long ref);
1413

15-
static public Array mean(final Array in, int dim) {
16-
return new Array(afMean(in.ref, dim));
14+
static private native DoubleComplex afMeanAllDoubleComplex(long ref);
15+
16+
static private native FloatComplex afMeanAllFloatComplexWeighted(long ref, long weightsRef);
17+
18+
static private native DoubleComplex afMeanAllDoubleComplexWeighted(long ref, long weightsRef);
19+
20+
static public Array mean(final Array in, int dim) {
21+
return new Array(afMean(in.ref, dim));
22+
}
23+
24+
static public Array mean(final Array in, final Array weights, int dim) {
25+
return new Array(afMeanWeighted(in.ref, weights.ref, dim));
26+
}
27+
28+
static public <T> T mean(final Array in, Class<T> type) throws Exception {
29+
if (type == FloatComplex.class) {
30+
FloatComplex res = (FloatComplex) afMeanAllFloatComplex(in.ref);
31+
return type.cast(res);
32+
} else if (type == DoubleComplex.class) {
33+
DoubleComplex res = (DoubleComplex) afMeanAllDoubleComplex(in.ref);
34+
return type.cast(res);
1735
}
1836

19-
static public Array mean(final Array in, final Array weights, int dim) {
20-
return new Array(afMeanWeighted(in.ref, weights.ref, dim));
37+
double res = afMeanAll(in.ref);
38+
if (type == Float.class) {
39+
return type.cast(Float.valueOf((float) res));
40+
} else if (type == Double.class) {
41+
return type.cast(Double.valueOf((double) res));
42+
} else if (type == Integer.class) {
43+
return type.cast(Integer.valueOf((int) res));
2144
}
45+
throw new Exception("Unknown type");
46+
}
2247

23-
static public <T> T mean(final Array in, Class<T> type) throws Exception {
24-
if (type == FloatComplex.class) {
25-
FloatComplex res = (FloatComplex)afMeanAllFloatComplex(in.ref, in.dims());
26-
return type.cast(res);
27-
} else if (type == DoubleComplex.class) {
28-
DoubleComplex res = (DoubleComplex)afMeanAllDoubleComplex(in.ref, in.dims());
29-
return type.cast(res);
30-
}
31-
32-
double res = afMeanAll(in.ref, in.dims());
33-
if (type == Float.class) {
34-
return type.cast(Float.valueOf((float)res));
35-
} else if (type == Double.class) {
36-
return type.cast(Double.valueOf((double) res));
37-
} else if (type == Integer.class) {
38-
return type.cast(Integer.valueOf((int) res));
39-
}
40-
throw new Exception("Unknown type");
48+
static public <T> T mean(final Array in, final Array weights, Class<T> type) throws Exception {
49+
if (type == FloatComplex.class) {
50+
FloatComplex res = (FloatComplex) afMeanAllFloatComplexWeighted(in.ref, weights.ref);
51+
return type.cast(res);
52+
} else if (type == DoubleComplex.class) {
53+
System.out.println(Long.toString(weights.ref));
54+
DoubleComplex res = (DoubleComplex) afMeanAllDoubleComplexWeighted(in.ref, weights.ref);
55+
return type.cast(res);
4156
}
42-
}
4357

58+
double res = afMeanAllWeighted(in.ref, weights.ref);
59+
if (type == Float.class) {
60+
return type.cast(Float.valueOf((float) res));
61+
} else if (type == Double.class) {
62+
return type.cast(Double.valueOf((double) res));
63+
} else if (type == Integer.class) {
64+
return type.cast(Integer.valueOf((int) res));
65+
}
66+
throw new Exception("Unknown type");
67+
}
68+
}

examples/HelloWorld.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ public static void main(String[] args) {
2828

2929
Array forMean = new Array();
3030
Array weights = new Array();
31-
Data.randn(forMean, new int[] {1}, Array.FloatComplexType);
32-
Data.randn(weights, new int[] {1}, Array.FloatComplexType);
31+
Data.randn(forMean, new int[] {3, 3}, Array.FloatComplexType);
32+
Data.randn(weights, new int[] {3, 3}, Array.FloatType);
3333
forMean.print("forMean");
34-
//Array mean = Statistics.mean(forMean, weights, 0);
35-
//mean.print("mean");
3634

37-
DoubleComplex abc = Statistics.mean(forMean, DoubleComplex.class);
35+
FloatComplex abc = Statistics.mean(forMean, weights, FloatComplex.class);
3836
System.out.println(String.format("Mean is: %f and %f", abc.real(), abc.imag()));
3937

4038
System.out.println("Create a 2-by-3 matrix from host data");

src/statistics.cpp

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,63 @@ BEGIN_EXTERN_C
66
#define STATISTICS_FUNC(FUNC) AF_MANGLE(Statistics, FUNC)
77

88
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMean)(JNIEnv *env, jclass clazz,
9-
jlong ref, jint dim) {
9+
jlong ref, jint dim) {
1010
af_array ret = 0;
1111
AF_CHECK(af_mean(&ret, ARRAY(ref), dim));
1212
return JLONG(ret);
1313
}
1414

15-
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMeanWeighted)(JNIEnv *env, jclass clazz,
16-
jlong ref, jlong weightsRef,
15+
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afMeanWeighted)(JNIEnv *env,
16+
jclass clazz, jlong ref,
17+
jlong weightsRef,
1718
jint dim) {
1819
af_array ret = 0;
1920
AF_CHECK(af_mean_weighted(&ret, ARRAY(ref), ARRAY(weightsRef), dim));
2021
return JLONG(ret);
2122
}
2223

23-
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAll)(JNIEnv *env,
24-
jclass clazz, jlong ref,
25-
jintArray dims) {
24+
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAll)(JNIEnv *env, jclass clazz,
25+
jlong ref) {
2626
double ret = 0;
27-
jint *dimptr = env->GetIntArrayElements(dims, 0);
28-
dim_t tdims[4] = {dimptr[0], dimptr[1], dimptr[2], dimptr[3]};
2927
AF_CHECK(af_mean_all(&ret, NULL, ARRAY(ref)));
30-
env->ReleaseIntArrayElements(dims, dimptr, 0);
31-
return (jdouble)ret;
28+
return (jdouble)ret;
3229
}
3330

34-
#define INSTANTIATE_MEAN(jtype, param) \
35-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)(JNIEnv *env, jclass clazz, \
36-
jlong ref, \
37-
jintArray dims) { \
38-
double real = 0, img = 0; \
39-
jint *dimptr = env->GetIntArrayElements(dims, 0); \
40-
dim_t tdims[4] = {dimptr[0], dimptr[1], dimptr[2], dimptr[3]}; \
41-
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
42-
env->ReleaseIntArrayElements(dims, dimptr, 0); \
43-
jclass cls = env->FindClass("com/arrayfire/"#jtype); \
44-
jmethodID id = env->GetMethodID(cls, "<init>", "("#param")V"); \
45-
jobject obj = env->NewObject(cls, id, real, img); \
46-
return obj; \
31+
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afMeanAllWeighted)(JNIEnv *env,
32+
jclass clazz,
33+
jlong ref,
34+
jlong weightsRef) {
35+
double ret = 0;
36+
AF_CHECK(af_mean_all_weighted(&ret, NULL, ARRAY(ref), ARRAY(weightsRef)));
37+
return (jdouble)ret;
4738
}
4839

40+
#define INSTANTIATE_MEAN(jtype, param) \
41+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype)( \
42+
JNIEnv * env, jclass clazz, jlong ref) { \
43+
double real = 0, img = 0; \
44+
AF_CHECK(af_mean_all(&real, &img, ARRAY(ref))); \
45+
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
46+
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
47+
jobject obj = env->NewObject(cls, id, real, img); \
48+
return obj; \
49+
}
50+
51+
#define INSTANTIATE_MEAN_WEIGHTED(jtype, param) \
52+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afMeanAll##jtype##Weighted)( \
53+
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
54+
double real = 0, img = 0; \
55+
AF_CHECK( \
56+
af_mean_all_weighted(&real, &img, ARRAY(ref), ARRAY(weightsRef))); \
57+
jclass cls = env->FindClass("com/arrayfire/" #jtype); \
58+
jmethodID id = env->GetMethodID(cls, "<init>", "(" #param ")V"); \
59+
jobject obj = env->NewObject(cls, id, real, img); \
60+
return obj; \
61+
}
62+
4963
INSTANTIATE_MEAN(FloatComplex, FF)
5064
INSTANTIATE_MEAN(DoubleComplex, DD)
65+
INSTANTIATE_MEAN_WEIGHTED(FloatComplex, FF)
66+
INSTANTIATE_MEAN_WEIGHTED(DoubleComplex, DD)
5167

5268
END_EXTERN_C
53-
54-

0 commit comments

Comments
 (0)