66import java .lang .reflect .Modifier ;
77import java .lang .reflect .Type ;
88import java .util .ArrayList ;
9+ import java .util .Arrays ;
910import java .util .List ;
1011import java .util .Map ;
1112import java .util .Objects ;
13+ import java .util .Optional ;
1214
1315import org .scijava .common3 .validity .ValidityException ;
1416import org .scijava .ops .api .Hints ;
2931import javassist .CtNewConstructor ;
3032import javassist .CtNewMethod ;
3133import javassist .NotFoundException ;
34+ import org .scijava .types .inference .InterfaceInference ;
3235
3336public class PythonOpInfo implements OpInfo {
3437
@@ -96,7 +99,7 @@ public String implementationName() {
9699 @ Override
97100 public StructInstance <?> createOpInstance (List <?> dependencies ) {
98101 try {
99- return struct ().createInstance (javassistOp (source ));
102+ return struct ().createInstance (javassistOp (source , struct . members () ));
100103 }
101104 catch (Throwable ex ) {
102105 throw new IllegalStateException ("Failed to invoke Op method: " + source +
@@ -131,6 +134,22 @@ public String id() {
131134 return null ;
132135 }
133136
137+ private static Type reifyType (String type ) throws ClassNotFoundException {
138+ if (type .indexOf ('<' ) == -1 ) {
139+ return Thread .currentThread ().getContextClassLoader ().loadClass (type );
140+ }
141+ else {
142+ // TODO: Consider nested types
143+ Type baseType = reifyType (type .substring (0 , type .indexOf ('<' )));
144+ String [] strParams = type .substring (type .indexOf ('<' ) + 1 , type .length () - 1 ).split ("\\ s*,\\ s*" );
145+ Type [] typeParams = new Type [strParams .length ];
146+ for (int i = 0 ; i < strParams .length ; i ++) {
147+ typeParams [i ] = reifyType (strParams [i ]);
148+ }
149+ return Types .parameterize (Types .raw (baseType ), typeParams );
150+ }
151+ }
152+
134153 /**
135154 * TODO: This is SUPER hacky. Yeehaw!
136155 *
@@ -141,7 +160,7 @@ private static List<Member<?>> parseParams(List<Map<String, Object>> params)
141160 List <Member <?>> members = new ArrayList <>();
142161 final ClassLoader cl = Thread .currentThread ().getContextClassLoader ();
143162 for (Map <String , Object > map : params ) {
144- Class <?> type = cl . loadClass ((String ) map .get ("type" ));
163+ Type type = reifyType ((String ) map .get ("type" ));
145164 String description = (String ) map .getOrDefault ("description" , "" );
146165 List <String > keys = new ArrayList <>(map .keySet ());
147166 keys .remove ("type" );
@@ -186,7 +205,7 @@ public String getDescription() {
186205 return members ;
187206 }
188207
189- private Object javassistOp (String source ) throws Throwable {
208+ private Object javassistOp (String source , List < Member <?>> params ) throws Throwable {
190209 ClassPool pool = ClassPool .getDefault ();
191210
192211 // Create wrapper class
@@ -199,13 +218,16 @@ private Object javassistOp(String source) throws Throwable {
199218 CtClass jasOpType = pool .get (Types .raw (opType ).getName ());
200219 cc .addInterface (jasOpType );
201220
221+ // Add Interpreter field
222+ cc .addField (createInterpreterField (pool , cc ));
223+
202224 // Add constructor
203225 CtConstructor constructor = CtNewConstructor .make (createConstructor (cc ), cc );
204226 cc .addConstructor (constructor );
205227
206228 // add functional interface method
207229 CtMethod functionalMethod = CtNewMethod .make (createFunctionalMethod (
208- source ), cc );
230+ source , params ), cc );
209231 cc .addMethod (functionalMethod );
210232 c = cc .toClass (MethodHandles .lookup ());
211233 }
@@ -223,18 +245,77 @@ private String formClassName(String source) {
223245
224246 // class name -> OwnerName_PythonFunction
225247 List <String > nameElements = List .of (source .split ("\\ ." ));
226- String className = packageName + "." + String .join ("_" , nameElements );
227- return className ;
248+ return packageName + "." + String .join ("_" , nameElements );
249+ }
250+
251+ private CtField createInterpreterField (ClassPool pool , CtClass cc ) throws NotFoundException ,
252+ CannotCompileException
253+ {
254+ String fStr = "jep.Interpreter interp = " +
255+ "org.scijava.ops.python.OpsPythonInterpreter.interpreter();" ;
256+ CtField f = CtField .make (fStr , cc );
257+ f .setModifiers (Modifier .PRIVATE + Modifier .FINAL );
258+ return f ;
228259 }
229260
261+
230262 private String createConstructor (CtClass cc )
231263 {
232264 // constructor signature
233265 return "public " + cc .getSimpleName () + "() {}" ;
234266 }
235267
236- private String createFunctionalMethod (String source ) {
268+ private String createFunctionalMethod (String source , List < Member <?>> params ) {
237269 StringBuilder sb = new StringBuilder ();
270+
271+ // determine the name of the functional method
272+ String methodName = InterfaceInference .singularAbstractMethod (Types .raw (
273+ opType )).getName ();
274+
275+ // method modifiers
276+ Optional <Member <?>> result = params .stream () //
277+ .filter (m -> m .getIOType () == ItemIO .OUTPUT ).findFirst ();
278+ sb .append ("public " ) //
279+ .append (result .isEmpty () ? "void" : "Object" ) //
280+ .append (" " ) //
281+ .append (methodName ) //
282+ .append ("(" );
283+
284+ // method inputs
285+ int applyInputs = inputs ().size ();
286+ for (int i = 0 ; i < applyInputs ; i ++) {
287+ sb .append (" Object in" ).append (i );
288+ if (i < applyInputs - 1 ) sb .append ("," );
289+ }
290+ sb .append (") { " );
291+
292+ // Set each parameter in the interpreter
293+ for (int i = 0 ; i < applyInputs ; i ++) {
294+ sb .append ("interp.set(\" in" ).append (i ).append ("\" , in" ).append (i ).append ("); " );
295+ }
296+ // Import command
297+ int funcIdx = source .lastIndexOf ('.' );
298+ String packageName = source .substring (0 , funcIdx );
299+ String funcName = source .substring (funcIdx + 1 );
300+ sb .append ("interp.exec(\" from " ).append (packageName ).append (" import " ).append (funcName ).append ("\" ); " );
301+
302+ // Execute command
303+ sb .append ("interp.exec(\" " );
304+ if (result .isPresent ()) {
305+ sb .append ("out = " );
306+ }
307+ sb .append (funcName ).append ("(" );
308+ for (int i = 0 ; i < applyInputs ; i ++) {
309+ sb .append (" in" ).append (i );
310+ if (i < applyInputs - 1 ) sb .append ("," );
311+ }
312+ sb .append (")\" ); " );
313+ // return if needed
314+ if (result .isPresent ()) {
315+ sb .append (" return interp.getValue(\" out\" );" );
316+ }
317+ sb .append ("}" );
318+
238319 return sb .toString ();
239320 }
240321
0 commit comments