@@ -23,6 +23,11 @@ public class Array implements AutoCloseable {
2323 private native static long createArrayFromInt (int [] dims , int [] elems );
2424 private native static long createArrayFromBoolean (int [] dims , boolean [] elems );
2525
26+ private native static long createRanduArray (int [] dims , int type );
27+ private native static long createRandnArray (int [] dims , int type );
28+ private native static long createConstantsArray (double val , int [] dims , int type );
29+
30+
2631 private native static void destroyArray (long ref );
2732 private native static int [] getDims (long ref );
2833 private native static int getType (long ref );
@@ -66,9 +71,13 @@ public class Array implements AutoCloseable {
6671 private native static long sqrt (long a );
6772
6873 // Scalar return operations
69- private native static float sum (long a );
70- private native static float max (long a );
71- private native static float min (long a );
74+ private native static double sumAll (long a );
75+ private native static double maxAll (long a );
76+ private native static double minAll (long a );
77+
78+ private native static long sum (long a , int dim );
79+ private native static long max (long a , int dim );
80+ private native static long min (long a , int dim );
7281
7382 // Scalar operations
7483 private native static long addf (long a , float b );
@@ -115,7 +124,7 @@ public String typeName(int ty) throws Exception {
115124 throw new Exception ("Unknown type" );
116125 }
117126
118- private int [] dim4 (int [] dims ) throws Exception {
127+ private static int [] dim4 (int [] dims ) throws Exception {
119128
120129 if ( dims == null ) {
121130 throw new Exception ("Null dimensions object provided" );
@@ -280,6 +289,37 @@ public boolean[] getBooleanArray() throws Exception {
280289 }
281290
282291 // Binary operations
292+
293+ public static Array randu (int [] dims , int type ) throws Exception {
294+ int [] adims = dim4 (dims );
295+ long ref = createRanduArray (adims , type );
296+ if (ref == 0 ) throw new Exception ("Failed to create Array" );
297+
298+ Array ret_val = new Array ();
299+ ret_val .ref = ref ;
300+ return ret_val ;
301+ }
302+
303+ public static Array randn (int [] dims , int type ) throws Exception {
304+ int [] adims = dim4 (dims );
305+ long ref = createRandnArray (adims , type );
306+ if (ref == 0 ) throw new Exception ("Failed to create Array" );
307+
308+ Array ret_val = new Array ();
309+ ret_val .ref = ref ;
310+ return ret_val ;
311+ }
312+
313+ public static Array constant (double val , int [] dims , int type ) throws Exception {
314+ int [] adims = dim4 (dims );
315+ long ref = createConstantsArray (val , adims , type );
316+ if (ref == 0 ) throw new Exception ("Failed to create Array" );
317+
318+ Array ret_val = new Array ();
319+ ret_val .ref = ref ;
320+ return ret_val ;
321+ }
322+
283323 public static Array add (Array a , Array b ) throws Exception {
284324 Array ret_val = new Array ();
285325 ret_val .ref = add (a .ref ,b .ref );
@@ -432,11 +472,39 @@ public static Array sqrt(Array a) throws Exception {
432472 }
433473
434474 // Scalar return operations
435- public static float sum (Array a ) throws Exception { return sum (a .ref ); }
475+ public static double sumAll (Array a ) throws Exception { return sumAll (a .ref ); }
476+ public static double maxAll (Array a ) throws Exception { return maxAll (a .ref ); }
477+ public static double minAll (Array a ) throws Exception { return minAll (a .ref ); }
478+
479+ public static Array sum (Array a , int dim ) throws Exception {
480+ Array ret_val = new Array ();
481+ ret_val .ref = sum (a .ref , dim );
482+ return ret_val ;
483+ }
484+
485+ public static Array max (Array a , int dim ) throws Exception {
486+ Array ret_val = new Array ();
487+ ret_val .ref = max (a .ref , dim );
488+ return ret_val ;
489+ }
490+
491+ public static Array min (Array a , int dim ) throws Exception {
492+ Array ret_val = new Array ();
493+ ret_val .ref = min (a .ref , dim );
494+ return ret_val ;
495+ }
436496
437- public static float max (Array a ) throws Exception { return max (a .ref ); }
497+ public static Array sum (Array a ) throws Exception {
498+ return sum (a , -1 );
499+ }
438500
439- public static float min (Array a ) throws Exception { return min (a .ref ); }
501+ public static Array max (Array a ) throws Exception {
502+ return max (a , -1 );
503+ }
504+
505+ public static Array min (Array a ) throws Exception {
506+ return min (a , -1 );
507+ }
440508
441509 // Scalar operations
442510 public static Array add (Array a , float b ) throws Exception {
0 commit comments