Skip to content

Commit ade4ef7

Browse files
committed
nest.map_structure implemented and tested
1 parent 6fc9add commit ade4ef7

File tree

4 files changed

+297
-138
lines changed

4 files changed

+297
-138
lines changed

src/TensorFlowNET.Core/Python.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,55 @@ public static float time()
120120
yield return (t1.Data<T1>(i), t2.Data<T2>(i));
121121
}
122122

123+
public static IEnumerable<(T1, T2)> zip<T1, T2>(IEnumerable<T1> e1, IEnumerable<T2> e2)
124+
{
125+
var iter2 = e2.GetEnumerator();
126+
foreach (var v1 in e1)
127+
{
128+
iter2.MoveNext();
129+
var v2 = iter2.Current;
130+
yield return (v1, v2);
131+
}
132+
}
133+
134+
/// <summary>
135+
/// Untyped implementation of zip for arbitrary data
136+
///
137+
/// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays
138+
/// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]]
139+
/// </summary>
140+
/// <param name="lists">one or multiple sequences to be zipped</param>
141+
/// <returns></returns>
142+
public static IEnumerable<object[]> zip(params object[] lists)
143+
{
144+
if (lists.Length == 0)
145+
yield break;
146+
var first = lists[0];
147+
if (first == null)
148+
yield break;
149+
var arity = (first as IEnumerable).OfType<object>().Count();
150+
for (int i = 0; i < arity; i++)
151+
{
152+
var array= new object[lists.Length];
153+
for (int j = 0; j < lists.Length; j++)
154+
array[j] = GetSequenceElementAt(lists[j], i);
155+
yield return array;
156+
}
157+
}
158+
159+
private static object GetSequenceElementAt(object sequence, int i)
160+
{
161+
switch (sequence)
162+
{
163+
case Array array:
164+
return array.GetValue(i);
165+
case IList list:
166+
return list[i];
167+
default:
168+
return (sequence as IEnumerable).OfType<object>().Skip(Math.Max(0, i)).FirstOrDefault();
169+
}
170+
}
171+
123172
public static IEnumerable<(int, T)> enumerate<T>(IList<T> values)
124173
{
125174
for (int i = 0; i < values.Count; i++)
@@ -137,6 +186,7 @@ public static Dictionary<string, object> ConvertToDict(object dyn)
137186
}
138187
return dictionary;
139188
}
189+
140190
}
141191

142192
public interface IPython : IDisposable

src/TensorFlowNET.Core/Util/nest.py.cs

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.Linq;
5-
using System.Text;
65
using NumSharp;
76

87
namespace Tensorflow.Util
@@ -24,6 +23,14 @@ namespace Tensorflow.Util
2423
public static class nest
2524
{
2625

26+
public static IEnumerable<object[]> zip(params object[] structures)
27+
=> Python.zip(structures);
28+
29+
public static IEnumerable<(T1, T2)> zip<T1, T2>(IEnumerable<T1> e1, IEnumerable<T2> e2)
30+
=> Python.zip(e1, e2);
31+
32+
public static Dictionary<string, object> ConvertToDict(object dyn)
33+
=> Python.ConvertToDict(dyn);
2734

2835
//def _get_attrs_values(obj):
2936
// """Returns the list of values from an attrs instance."""
@@ -75,8 +82,14 @@ private static object _sequence_like(object instance, IEnumerable<object> args)
7582
//# instances. This is intentional, to avoid potential bugs caused by mixing
7683
//# ordered and plain dicts (e.g., flattening a dict but using a
7784
//# corresponding `OrderedDict` to pack it back).
78-
// result = dict(zip(_sorted(instance), args))
79-
// return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
85+
switch (instance)
86+
{
87+
case Hashtable hash:
88+
var result = new Hashtable();
89+
foreach ((object key, object value) in zip(_sorted(hash).OfType<object>(), args))
90+
result[key] = value;
91+
return result;
92+
}
8093
}
8194
//else if( _is_namedtuple(instance) || _is_attrs(instance))
8295
// return type(instance)(*args)
@@ -140,7 +153,9 @@ private static IEnumerable<object> _yield_value(object iterable)
140153
}
141154

142155
//# See the swig file (util.i) for documentation.
143-
public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string);
156+
public static bool is_sequence(object arg)
157+
=> arg is IEnumerable && !(arg is string) && !(arg is NDArray) &&
158+
!(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>));
144159

145160
public static bool is_mapping(object arg) => arg is IDictionary;
146161

@@ -355,38 +370,54 @@ private static (int new_index, List<object> child) _packed_nest_with_indices(obj
355370
/// <returns> `flat_sequence` converted to have the same recursive structure as
356371
/// `structure`.
357372
/// </returns>
358-
public static object pack_sequence_as(object structure, List<object> flat_sequence)
373+
public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_sequence)
359374
{
360-
if (flat_sequence == null)
375+
List<object> flat = null;
376+
if (flat_sequence is List<object>)
377+
flat = flat_sequence as List<object>;
378+
else
379+
flat=new List<object>(flat_sequence.OfType<object>());
380+
if (flat_sequence==null)
361381
throw new ArgumentException("flat_sequence must not be null");
362382
// if not is_sequence(flat_sequence):
363383
// raise TypeError("flat_sequence must be a sequence")
364384

365385
if (!is_sequence(structure))
366386
{
367-
if (len(flat_sequence) != 1)
368-
throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1");
369-
return flat_sequence.FirstOrDefault();
387+
if (len(flat) != 1)
388+
throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1");
389+
return flat.FirstOrDefault();
370390
}
371391
int final_index = 0;
372392
List<object> packed = null;
373393
try
374394
{
375-
(final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0);
376-
if (final_index < len(flat_sequence))
377-
throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }");
395+
(final_index, packed) = _packed_nest_with_indices(structure, flat, 0);
396+
if (final_index < len(flat))
397+
throw new IndexOutOfRangeException(
398+
$"Final index: {final_index} was smaller than len(flat_sequence): {len(flat)}");
399+
return _sequence_like(structure, packed);
378400
}
379401
catch (IndexOutOfRangeException)
380402
{
381403
var flat_structure = flatten(structure);
382-
if (len(flat_structure) != len(flat_sequence))
404+
if (len(flat_structure) != len(flat))
383405
{
384406
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
385-
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}");
407+
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
408+
}
409+
return _sequence_like(structure, packed);
410+
}
411+
catch (ArgumentOutOfRangeException)
412+
{
413+
var flat_structure = flatten(structure);
414+
if (len(flat_structure) != len(flat))
415+
{
416+
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
417+
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
386418
}
387419
return _sequence_like(structure, packed);
388420
}
389-
return packed;
390421
}
391422

392423
/// <summary>
@@ -396,10 +427,9 @@ public static object pack_sequence_as(object structure, List<object> flat_sequen
396427
/// `structure[i]`. All structures in `structure` must have the same arity,
397428
/// and the return value will contain the results in the same structure.
398429
/// </summary>
399-
/// <typeparam name="T"></typeparam>
400-
/// <typeparam name="U"></typeparam>
430+
/// <typeparam name="T">the type of the elements of the output structure (object if diverse)</typeparam>
401431
/// <param name="func"> A callable that accepts as many arguments as there are structures.</param>
402-
/// <param name="structure">scalar, or tuple or list of constructed scalars and/or other
432+
/// <param name="structures">scalar, or tuple or list of constructed scalars and/or other
403433
/// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param>
404434
/// <param name="check_types">If set to
405435
/// `True` (default) the types of iterables within the structures have to be
@@ -414,18 +444,41 @@ public static object pack_sequence_as(object structure, List<object> flat_sequen
414444
/// `check_types` is `False` the sequence types of the first structure will be
415445
/// used.
416446
/// </returns>
417-
public static IEnumerable<U> map_structure<T, U>(Func<T, U> func, IEnumerable<T> structure, bool check_types = false)
447+
public static IEnumerable<object> map_structure(Func<object[], object> func, object structure, params object[] more_structures)
418448
{
449+
// TODO: check structure and types
450+
// for other in structure[1:]:
451+
// assert_same_structure(structure[0], other, check_types=check_types)
452+
453+
if (more_structures.Length==0)
454+
{
455+
// we don't need to zip if we have only one structure
456+
return map_structure(a => func(new object[]{a}), structure);
457+
}
458+
var flat_structures = new List<object>() { flatten(structure) };
459+
flat_structures.AddRange(more_structures.Select(flatten));
460+
var entries = zip(flat_structures);
461+
var mapped_flat_structure = entries.Select(func);
419462

463+
return (pack_sequence_as(structure, mapped_flat_structure) as IEnumerable).OfType<object>();
464+
}
465+
466+
/// <summary>
467+
/// Same as map_structure, but with only one structure (no combining of multiple structures)
468+
/// </summary>
469+
/// <param name="func"></param>
470+
/// <param name="structure"></param>
471+
/// <returns></returns>
472+
public static IEnumerable<object> map_structure(Func<object, object> func, object structure)
473+
{
474+
// TODO: check structure and types
420475
// for other in structure[1:]:
421476
// assert_same_structure(structure[0], other, check_types=check_types)
422477

423-
// flat_structure = [flatten(s) for s in structure]
424-
// entries = zip(*flat_structure)
478+
var flat_structure = flatten(structure);
479+
var mapped_flat_structure = flat_structure.Select(func).ToList();
425480

426-
// return pack_sequence_as(
427-
// structure[0], [func(*x) for x in entries])
428-
return null;
481+
return (pack_sequence_as(structure, mapped_flat_structure) as IEnumerable).OfType<object>();
429482
}
430483

431484
//def map_structure_with_paths(func, *structure, **kwargs):

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using System.Linq;
55
using System.Text;
66
using Microsoft.VisualStudio.TestTools.UnitTesting;
7+
using Newtonsoft.Json.Linq;
8+
using NumSharp;
79
using Tensorflow;
810
using Tensorflow.Util;
911

@@ -25,17 +27,45 @@ protected object None {
2527

2628
public void assertItemsEqual(ICollection given, ICollection expected)
2729
{
30+
if (given is Hashtable && expected is Hashtable)
31+
{
32+
Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString());
33+
return;
34+
}
2835
Assert.IsNotNull(expected);
2936
Assert.IsNotNull(given);
3037
var e = expected.OfType<object>().ToArray();
3138
var g = given.OfType<object>().ToArray();
3239
Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}");
3340
for (int i = 0; i < e.Length; i++)
34-
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
41+
{
42+
if (g[i] is NDArray && e[i] is NDArray)
43+
assertItemsEqual((g[i] as NDArray).Array, (e[i] as NDArray).Array);
44+
else if (e[i] is ICollection && g[i] is ICollection)
45+
assertEqual(g[i], e[i]);
46+
else
47+
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
48+
}
3549
}
3650

51+
public void assertAllEqual(ICollection given, ICollection expected)
52+
{
53+
assertItemsEqual(given, expected);
54+
}
55+
56+
3757
public void assertEqual(object given, object expected)
3858
{
59+
if (given is NDArray && expected is NDArray)
60+
{
61+
assertItemsEqual((given as NDArray).Array, (expected as NDArray).Array);
62+
return;
63+
}
64+
if (given is Hashtable && expected is Hashtable)
65+
{
66+
Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString());
67+
return;
68+
}
3969
if (given is ICollection && expected is ICollection)
4070
{
4171
assertItemsEqual(given as ICollection, expected as ICollection);
@@ -54,6 +84,16 @@ public void assertIsNotNone(object given)
5484
Assert.IsNotNull(given);
5585
}
5686

87+
public void assertFalse(bool cond)
88+
{
89+
Assert.IsFalse(cond);
90+
}
91+
92+
public void assertTrue(bool cond)
93+
{
94+
Assert.IsTrue(cond);
95+
}
96+
5797
#endregion
5898

5999
#region tensor evaluation

0 commit comments

Comments
 (0)