22using System . Collections ;
33using System . Collections . Generic ;
44using System . Linq ;
5- using System . Text ;
65using NumSharp ;
76
87namespace Tensorflow . Util
@@ -24,6 +23,14 @@ namespace Tensorflow.Util
2423 public static class nest
2524 {
2625
26+ public static IEnumerable < object [ ] > zip ( params object [ ] structures )
27+ => Python . zip ( structures ) ;
28+
29+ public static IEnumerable < ( T1 , T2 ) > zip < T1 , T2 > ( IEnumerable < T1 > e1 , IEnumerable < T2 > e2 )
30+ => Python . zip ( e1 , e2 ) ;
31+
32+ public static Dictionary < string , object > ConvertToDict ( object dyn )
33+ => Python . ConvertToDict ( dyn ) ;
2734
2835 //def _get_attrs_values(obj):
2936 // """Returns the list of values from an attrs instance."""
@@ -75,8 +82,14 @@ private static object _sequence_like(object instance, IEnumerable<object> args)
7582 //# instances. This is intentional, to avoid potential bugs caused by mixing
7683 //# ordered and plain dicts (e.g., flattening a dict but using a
7784 //# corresponding `OrderedDict` to pack it back).
78- // result = dict(zip(_sorted(instance), args))
79- // return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
85+ switch ( instance )
86+ {
87+ case Hashtable hash :
88+ var result = new Hashtable ( ) ;
89+ foreach ( ( object key , object value ) in zip ( _sorted ( hash ) . OfType < object > ( ) , args ) )
90+ result [ key ] = value ;
91+ return result ;
92+ }
8093 }
8194 //else if( _is_namedtuple(instance) || _is_attrs(instance))
8295 // return type(instance)(*args)
@@ -140,7 +153,9 @@ private static IEnumerable<object> _yield_value(object iterable)
140153 }
141154
142155 //# See the swig file (util.i) for documentation.
143- public static bool is_sequence ( object arg ) => arg is IEnumerable && ! ( arg is string ) ;
156+ public static bool is_sequence ( object arg )
157+ => arg is IEnumerable && ! ( arg is string ) && ! ( arg is NDArray ) &&
158+ ! ( arg . GetType ( ) . IsGenericType && arg . GetType ( ) . GetGenericTypeDefinition ( ) == typeof ( HashSet < > ) ) ;
144159
145160 public static bool is_mapping ( object arg ) => arg is IDictionary ;
146161
@@ -355,38 +370,54 @@ private static (int new_index, List<object> child) _packed_nest_with_indices(obj
355370 /// <returns> `flat_sequence` converted to have the same recursive structure as
356371 /// `structure`.
357372 /// </returns>
358- public static object pack_sequence_as ( object structure , List < object > flat_sequence )
373+ public static object pack_sequence_as < T > ( object structure , IEnumerable < T > flat_sequence )
359374 {
360- if ( flat_sequence == null )
375+ List < object > flat = null ;
376+ if ( flat_sequence is List < object > )
377+ flat = flat_sequence as List < object > ;
378+ else
379+ flat = new List < object > ( flat_sequence . OfType < object > ( ) ) ;
380+ if ( flat_sequence == null )
361381 throw new ArgumentException ( "flat_sequence must not be null" ) ;
362382 // if not is_sequence(flat_sequence):
363383 // raise TypeError("flat_sequence must be a sequence")
364384
365385 if ( ! is_sequence ( structure ) )
366386 {
367- if ( len ( flat_sequence ) != 1 )
368- throw new ValueError ( $ "Structure is a scalar but len(flat_sequence) == { len ( flat_sequence ) } > 1") ;
369- return flat_sequence . FirstOrDefault ( ) ;
387+ if ( len ( flat ) != 1 )
388+ throw new ValueError ( $ "Structure is a scalar but len(flat_sequence) == { len ( flat ) } > 1") ;
389+ return flat . FirstOrDefault ( ) ;
370390 }
371391 int final_index = 0 ;
372392 List < object > packed = null ;
373393 try
374394 {
375- ( final_index , packed ) = _packed_nest_with_indices ( structure , flat_sequence , 0 ) ;
376- if ( final_index < len ( flat_sequence ) )
377- throw new IndexOutOfRangeException ( $ "Final index: { final_index } was smaller than len(flat_sequence): { len ( flat_sequence ) } ") ;
395+ ( final_index , packed ) = _packed_nest_with_indices ( structure , flat , 0 ) ;
396+ if ( final_index < len ( flat ) )
397+ throw new IndexOutOfRangeException (
398+ $ "Final index: { final_index } was smaller than len(flat_sequence): { len ( flat ) } ") ;
399+ return _sequence_like ( structure , packed ) ;
378400 }
379401 catch ( IndexOutOfRangeException )
380402 {
381403 var flat_structure = flatten ( structure ) ;
382- if ( len ( flat_structure ) != len ( flat_sequence ) )
404+ if ( len ( flat_structure ) != len ( flat ) )
383405 {
384406 throw new ValueError ( "Could not pack sequence. Structure had %d elements, but " +
385- $ "flat_sequence had { len ( flat_structure ) } elements. flat_sequence had: { len ( flat_sequence ) } ") ;
407+ $ "flat_sequence had { len ( flat_structure ) } elements. flat_sequence had: { len ( flat ) } ") ;
408+ }
409+ return _sequence_like ( structure , packed ) ;
410+ }
411+ catch ( ArgumentOutOfRangeException )
412+ {
413+ var flat_structure = flatten ( structure ) ;
414+ if ( len ( flat_structure ) != len ( flat ) )
415+ {
416+ throw new ValueError ( "Could not pack sequence. Structure had %d elements, but " +
417+ $ "flat_sequence had { len ( flat_structure ) } elements. flat_sequence had: { len ( flat ) } ") ;
386418 }
387419 return _sequence_like ( structure , packed ) ;
388420 }
389- return packed ;
390421 }
391422
392423 /// <summary>
@@ -396,10 +427,9 @@ public static object pack_sequence_as(object structure, List<object> flat_sequen
396427 /// `structure[i]`. All structures in `structure` must have the same arity,
397428 /// and the return value will contain the results in the same structure.
398429 /// </summary>
399- /// <typeparam name="T"></typeparam>
400- /// <typeparam name="U"></typeparam>
430+ /// <typeparam name="T">the type of the elements of the output structure (object if diverse)</typeparam>
401431 /// <param name="func"> A callable that accepts as many arguments as there are structures.</param>
402- /// <param name="structure ">scalar, or tuple or list of constructed scalars and/or other
432+ /// <param name="structures ">scalar, or tuple or list of constructed scalars and/or other
403433 /// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param>
404434 /// <param name="check_types">If set to
405435 /// `True` (default) the types of iterables within the structures have to be
@@ -414,18 +444,41 @@ public static object pack_sequence_as(object structure, List<object> flat_sequen
414444 /// `check_types` is `False` the sequence types of the first structure will be
415445 /// used.
416446 /// </returns>
417- public static IEnumerable < U > map_structure < T , U > ( Func < T , U > func , IEnumerable < T > structure , bool check_types = false )
447+ public static IEnumerable < object > map_structure ( Func < object [ ] , object > func , object structure , params object [ ] more_structures )
418448 {
449+ // TODO: check structure and types
450+ // for other in structure[1:]:
451+ // assert_same_structure(structure[0], other, check_types=check_types)
452+
453+ if ( more_structures . Length == 0 )
454+ {
455+ // we don't need to zip if we have only one structure
456+ return map_structure ( a => func ( new object [ ] { a } ) , structure ) ;
457+ }
458+ var flat_structures = new List < object > ( ) { flatten ( structure ) } ;
459+ flat_structures . AddRange ( more_structures . Select ( flatten ) ) ;
460+ var entries = zip ( flat_structures ) ;
461+ var mapped_flat_structure = entries . Select ( func ) ;
419462
463+ return ( pack_sequence_as ( structure , mapped_flat_structure ) as IEnumerable ) . OfType < object > ( ) ;
464+ }
465+
466+ /// <summary>
467+ /// Same as map_structure, but with only one structure (no combining of multiple structures)
468+ /// </summary>
469+ /// <param name="func"></param>
470+ /// <param name="structure"></param>
471+ /// <returns></returns>
472+ public static IEnumerable < object > map_structure ( Func < object , object > func , object structure )
473+ {
474+ // TODO: check structure and types
420475 // for other in structure[1:]:
421476 // assert_same_structure(structure[0], other, check_types=check_types)
422477
423- // flat_structure = [ flatten(s) for s in structure]
424- // entries = zip(* flat_structure)
478+ var flat_structure = flatten ( structure ) ;
479+ var mapped_flat_structure = flat_structure . Select ( func ) . ToList ( ) ;
425480
426- // return pack_sequence_as(
427- // structure[0], [func(*x) for x in entries])
428- return null ;
481+ return ( pack_sequence_as ( structure , mapped_flat_structure ) as IEnumerable ) . OfType < object > ( ) ;
429482 }
430483
431484 //def map_structure_with_paths(func, *structure, **kwargs):
0 commit comments