Skip to content

Commit 1908f75

Browse files
committed
WIP WIP WIP: It's alive!!
1 parent 1a60263 commit 1908f75

File tree

4 files changed

+181
-14
lines changed

4 files changed

+181
-14
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package org.scijava.ops.python;
2+
3+
import java.io.BufferedReader;
4+
import java.io.IOException;
5+
import java.io.InputStreamReader;
6+
7+
import jep.Interpreter;
8+
import jep.MainInterpreter;
9+
import jep.SharedInterpreter;
10+
11+
public class OpsPythonInterpreter {
12+
13+
private static OpsPythonInterpreter instance;
14+
private final Interpreter interp;
15+
16+
private OpsPythonInterpreter() {
17+
// TODO: To use Numpy, I have to set the environment variable
18+
// LD_PRELOAD=~/miniconda3/envs/scijava-ops-python/lib/libpython3.11.so
19+
// Can we automate this?
20+
// See https://github.com/ninia/jep/issues/338
21+
try {
22+
MainInterpreter.setJepLibraryPath(getJepPath());
23+
} catch(IOException | InterruptedException e) {
24+
throw new RuntimeException(e);
25+
}
26+
interp = new SharedInterpreter();
27+
}
28+
29+
public static Interpreter interpreter() {
30+
if (instance == null) {
31+
instance = new OpsPythonInterpreter();
32+
}
33+
return instance.interp;
34+
}
35+
36+
private static String getJepPath() throws IOException, InterruptedException {
37+
ProcessBuilder processBuilder = new ProcessBuilder("conda",
38+
"run", "-n", "scijava-ops-python", "python", "jep_path.py");
39+
processBuilder.redirectErrorStream(true);
40+
41+
Process process = processBuilder.start();
42+
StringBuilder processOutput = new StringBuilder();
43+
44+
try (BufferedReader processOutputReader = new BufferedReader(
45+
new InputStreamReader(process.getInputStream()));)
46+
{
47+
String readLine;
48+
49+
while ((readLine = processOutputReader.readLine()) != null)
50+
{
51+
processOutput.append(readLine + System.lineSeparator());
52+
}
53+
54+
process.waitFor();
55+
}
56+
57+
return processOutput.toString().trim();
58+
59+
}
60+
61+
}

scijava/scijava-ops-python/src/main/java/org/scijava/ops/python/PythonOpInfo.java

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import java.lang.reflect.Modifier;
77
import java.lang.reflect.Type;
88
import java.util.ArrayList;
9+
import java.util.Arrays;
910
import java.util.List;
1011
import java.util.Map;
1112
import java.util.Objects;
13+
import java.util.Optional;
1214

1315
import org.scijava.common3.validity.ValidityException;
1416
import org.scijava.ops.api.Hints;
@@ -29,6 +31,7 @@
2931
import javassist.CtNewConstructor;
3032
import javassist.CtNewMethod;
3133
import javassist.NotFoundException;
34+
import org.scijava.types.inference.InterfaceInference;
3235

3336
public class PythonOpInfo implements OpInfo {
3437

@@ -96,7 +99,7 @@ public String implementationName() {
9699
@Override
97100
public StructInstance<?> createOpInstance(List<?> dependencies) {
98101
try {
99-
return struct().createInstance(javassistOp(source));
102+
return struct().createInstance(javassistOp(source, struct.members()));
100103
}
101104
catch (Throwable ex) {
102105
throw new IllegalStateException("Failed to invoke Op method: " + source +
@@ -131,6 +134,22 @@ public String id() {
131134
return null;
132135
}
133136

137+
private static Type reifyType(String type) throws ClassNotFoundException{
138+
if(type.indexOf('<') == -1) {
139+
return Thread.currentThread().getContextClassLoader().loadClass(type);
140+
}
141+
else {
142+
// TODO: Consider nested types
143+
Type baseType = reifyType(type.substring(0, type.indexOf('<')));
144+
String[] strParams = type.substring(type.indexOf('<') + 1, type.length() - 1).split("\\s*,\\s*");
145+
Type[] typeParams = new Type[strParams.length];
146+
for(int i = 0; i < strParams.length; i++) {
147+
typeParams[i] = reifyType(strParams[i]);
148+
}
149+
return Types.parameterize(Types.raw(baseType), typeParams);
150+
}
151+
}
152+
134153
/**
135154
* TODO: This is SUPER hacky. Yeehaw!
136155
*
@@ -141,7 +160,7 @@ private static List<Member<?>> parseParams(List<Map<String, Object>> params)
141160
List<Member<?>> members = new ArrayList<>();
142161
final ClassLoader cl = Thread.currentThread().getContextClassLoader();
143162
for (Map<String, Object> map : params) {
144-
Class<?> type = cl.loadClass((String) map.get("type"));
163+
Type type = reifyType((String) map.get("type"));
145164
String description = (String) map.getOrDefault("description", "");
146165
List<String> keys = new ArrayList<>(map.keySet());
147166
keys.remove("type");
@@ -186,7 +205,7 @@ public String getDescription() {
186205
return members;
187206
}
188207

189-
private Object javassistOp(String source) throws Throwable {
208+
private Object javassistOp(String source, List<Member<?>> params) throws Throwable {
190209
ClassPool pool = ClassPool.getDefault();
191210

192211
// Create wrapper class
@@ -199,13 +218,16 @@ private Object javassistOp(String source) throws Throwable {
199218
CtClass jasOpType = pool.get(Types.raw(opType).getName());
200219
cc.addInterface(jasOpType);
201220

221+
// Add Interpreter field
222+
cc.addField(createInterpreterField(pool, cc));
223+
202224
// Add constructor
203225
CtConstructor constructor = CtNewConstructor.make(createConstructor(cc), cc);
204226
cc.addConstructor(constructor);
205227

206228
// add functional interface method
207229
CtMethod functionalMethod = CtNewMethod.make(createFunctionalMethod(
208-
source), cc);
230+
source, params), cc);
209231
cc.addMethod(functionalMethod);
210232
c = cc.toClass(MethodHandles.lookup());
211233
}
@@ -223,18 +245,77 @@ private String formClassName(String source) {
223245

224246
// class name -> OwnerName_PythonFunction
225247
List<String> nameElements = List.of(source.split("\\."));
226-
String className = packageName + "." + String.join("_", nameElements);
227-
return className;
248+
return packageName + "." + String.join("_", nameElements);
249+
}
250+
251+
private CtField createInterpreterField(ClassPool pool, CtClass cc) throws NotFoundException,
252+
CannotCompileException
253+
{
254+
String fStr = "jep.Interpreter interp = " +
255+
"org.scijava.ops.python.OpsPythonInterpreter.interpreter();";
256+
CtField f = CtField.make(fStr, cc);
257+
f.setModifiers(Modifier.PRIVATE + Modifier.FINAL);
258+
return f;
228259
}
229260

261+
230262
private String createConstructor(CtClass cc)
231263
{
232264
// constructor signature
233265
return "public " + cc.getSimpleName() + "() {}";
234266
}
235267

236-
private String createFunctionalMethod(String source) {
268+
private String createFunctionalMethod(String source, List<Member<?>> params) {
237269
StringBuilder sb = new StringBuilder();
270+
271+
// determine the name of the functional method
272+
String methodName = InterfaceInference.singularAbstractMethod(Types.raw(
273+
opType)).getName();
274+
275+
// method modifiers
276+
Optional<Member<?>> result = params.stream() //
277+
.filter(m -> m.getIOType() == ItemIO.OUTPUT).findFirst();
278+
sb.append("public ") //
279+
.append(result.isEmpty() ? "void" : "Object") //
280+
.append(" ") //
281+
.append(methodName) //
282+
.append("(");
283+
284+
// method inputs
285+
int applyInputs = inputs().size();
286+
for (int i = 0; i < applyInputs; i++) {
287+
sb.append(" Object in").append(i);
288+
if (i < applyInputs - 1) sb.append(",");
289+
}
290+
sb.append(") { ");
291+
292+
// Set each parameter in the interpreter
293+
for (int i = 0; i < applyInputs; i++) {
294+
sb.append("interp.set(\"in").append(i).append("\", in").append(i).append("); ");
295+
}
296+
// Import command
297+
int funcIdx = source.lastIndexOf('.');
298+
String packageName = source.substring(0, funcIdx);
299+
String funcName = source.substring(funcIdx + 1);
300+
sb.append("interp.exec(\"from ").append(packageName).append(" import ").append(funcName).append("\"); ");
301+
302+
// Execute command
303+
sb.append("interp.exec(\"");
304+
if (result.isPresent()) {
305+
sb.append("out = ");
306+
}
307+
sb.append(funcName).append("(");
308+
for (int i = 0; i < applyInputs; i++) {
309+
sb.append(" in").append(i);
310+
if (i < applyInputs - 1) sb.append(",");
311+
}
312+
sb.append(")\"); ");
313+
// return if needed
314+
if (result.isPresent()) {
315+
sb.append(" return interp.getValue(\"out\");");
316+
}
317+
sb.append("}");
318+
238319
return sb.toString();
239320
}
240321

scijava/scijava-ops-python/src/main/resources/ops.yaml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,23 @@ ops:
88
name: create.img
99
priority: 0.0
1010
source: numpy.zeros
11-
type: java.util.function.BiFunction
11+
type: java.util.function.Function
1212
parameters:
1313
- input: in1
14-
type: java.lang.Integer
15-
description: The first input
16-
- input: in2
17-
type: java.lang.Integer
18-
description: The second input
14+
type: java.util.List<java.lang.Integer>
15+
description: The shape of the output
1916
- output: out
2017
type: jep.NDArray
2118
description: The image
19+
- op:
20+
name: numpy.linalg.det
21+
priority: 0.0
22+
source: numpy.linalg.det
23+
type: java.util.function.Function
24+
parameters:
25+
- input: in1
26+
type: jep.AbstractNDArray
27+
description: The shape of the output
28+
- output: out
29+
type: java.lang.Float
30+
description: The output

scijava/scijava-ops-python/src/test/java/org/scijava/ops/python/PythonOpsTest.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
package org.scijava.ops.python;
22

3+
import java.nio.ByteBuffer;
4+
import java.nio.IntBuffer;
5+
import java.util.Arrays;
6+
import java.util.List;
7+
8+
import jep.DirectNDArray;
9+
import org.junit.jupiter.api.Assertions;
310
import org.junit.jupiter.api.BeforeAll;
411
import org.junit.jupiter.api.Test;
512
import org.scijava.ops.api.OpEnvironment;
@@ -21,8 +28,17 @@ public static void setup() {
2128

2229
@Test
2330
public void opsTest() {
24-
NDArray sum = env.op("create.img").input(2, 3).outType(NDArray.class)
31+
List<Integer> size = Arrays.asList(2, 2);
32+
NDArray sum = env.op("create.img").input(size).outType(NDArray.class)
33+
.apply();
34+
Assertions.assertArrayEquals(sum.getDimensions(), new int[] { 2, 2 });
35+
36+
float[] f = new float[] {2.0f, 1.0f, 1.0f, 2.0f};
37+
NDArray<float[]> input = new NDArray<>(f, 2, 2);
38+
Float output = env.op("numpy.linalg.det").input(input).outType(Float.class)
2539
.apply();
40+
Assertions.assertEquals(3.0, output,1e-6);
41+
2642
}
2743

2844
}

0 commit comments

Comments
 (0)