Skip to content

Commit 4431632

Browse files
committed
merge with
2 parents 04394f2 + a7fc2e3 commit 4431632

File tree

11 files changed

+2963
-169
lines changed

11 files changed

+2963
-169
lines changed

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
101101
throw new NotImplementedException("call channels_first");
102102
}
103103
else
104-
{
105-
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
104+
{
105+
outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC");
106106
}
107107
}
108108

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@
1010
namespace Tensorflow
1111
{
1212

13-
/// <summary>
14-
/// Represents a graph node that performs computation on tensors.
15-
///
16-
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
17-
/// more `Tensor` objects as input, and produces zero or more `Tensor`
18-
/// objects as output. Objects of type `Operation` are created by
19-
/// calling an op constructor(such as `tf.matmul`)
20-
/// or `tf.Graph.create_op`.
21-
///
22-
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
23-
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
24-
/// as output.
25-
///
26-
/// After the graph has been launched in a session, an `Operation` can
27-
/// be executed by passing it to
28-
/// `tf.Session.run`.
13+
/// <summary>
14+
/// Represents a graph node that performs computation on tensors.
15+
///
16+
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
17+
/// more `Tensor` objects as input, and produces zero or more `Tensor`
18+
/// objects as output. Objects of type `Operation` are created by
19+
/// calling an op constructor(such as `tf.matmul`)
20+
/// or `tf.Graph.create_op`.
21+
///
22+
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
23+
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
24+
/// as output.
25+
///
26+
/// After the graph has been launched in a session, an `Operation` can
27+
/// be executed by passing it to
28+
/// `tf.Session.run`.
2929
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
3030
/// </summary>
3131
public partial class Operation : ITensorOrOperation
@@ -271,47 +271,49 @@ public override bool Equals(object obj)
271271
return base.Equals(obj);
272272
}
273273

274-
/// <summary>
275-
/// Update the input to this operation at the given index.
276-
///
277-
/// NOTE: This is for TF internal use only.Please don't use it.
278-
/// </summary>
279-
/// <param name="index">the index of the input to update.</param>
280-
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
281-
public void _update_input(int index, Tensor tensor)
282-
{
283-
_assert_same_graph(tensor);
284-
285-
var input = _tf_input(index);
274+
/// <summary>
275+
/// Update the input to this operation at the given index.
276+
///
277+
/// NOTE: This is for TF internal use only.Please don't use it.
278+
/// </summary>
279+
/// <param name="index">the index of the input to update.</param>
280+
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
281+
public void _update_input(int index, Tensor tensor)
282+
{
283+
_assert_same_graph(tensor);
284+
285+
var input = _tf_input(index);
286286
var output = tensor._as_tf_output();
287287

288288
// Reset cached inputs.
289-
_inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None
290-
// TODO: implement below code dependencies
291-
c_api.TF_UpdateEdge(graph, output, input, status);
292-
}
293-
294-
private void _assert_same_graph(Tensor tensor)
295-
{
296-
//TODO: implement
297-
}
298-
299-
/// <summary>
300-
/// Create and return a new TF_Output for output_idx'th output of this op.
301-
/// </summary>
302-
public TF_Output _tf_output(int output_idx)
303-
{
304-
var tf_output = new TF_Output(op, output_idx);
305-
return tf_output;
306-
}
307-
308-
/// <summary>
309-
/// Create and return a new TF_Input for input_idx'th input of this op.
310-
/// </summary>
311-
public TF_Input _tf_input(int input_idx)
312-
{
313-
var tf_input = new TF_Input(op, input_idx);
314-
return tf_input;
315-
}
316-
}
317-
}
289+
_inputs = null;
290+
// after the c_api call next time _inputs is accessed
291+
// the updated inputs are reloaded from the c_api
292+
c_api.TF_UpdateEdge(_graph, output, input, status);
293+
//var updated_inputs = inputs;
294+
}
295+
296+
private void _assert_same_graph(Tensor tensor)
297+
{
298+
//TODO: implement
299+
}
300+
301+
/// <summary>
302+
/// Create and return a new TF_Output for output_idx'th output of this op.
303+
/// </summary>
304+
public TF_Output _tf_output(int output_idx)
305+
{
306+
var tf_output = new TF_Output(op, output_idx);
307+
return tf_output;
308+
}
309+
310+
/// <summary>
311+
/// Create and return a new TF_Input for input_idx'th input of this op.
312+
/// </summary>
313+
public TF_Input _tf_input(int input_idx)
314+
{
315+
var tf_input = new TF_Input(op, input_idx);
316+
return tf_input;
317+
}
318+
}
319+
}

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -290,33 +290,11 @@ public static Tensor cond(Tensor pred,
290290
{
291291
// TODO: here a chunk of original code is missing
292292
/*
293-
if fn1 is not None:
294-
if true_fn is not None:
295-
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
296-
true_fn = fn1
297-
elif true_fn is None:
298-
raise TypeError("cond(): true_fn argument required")
299-
if fn2 is not None:
300-
if false_fn is not None:
301-
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
302-
false_fn = fn2
303-
elif false_fn is None:
304-
raise TypeError("cond(): false_fn argument required")
305-
306-
if not callable(true_fn):
307-
raise TypeError("true_fn must be callable.")
308-
if not callable(false_fn):
309-
raise TypeError("false_fn must be callable.")
310-
311293
with ops.name_scope(name, "cond", [pred]):
312294
if context.executing_eagerly():
313295
if pred:
314296
return _UnpackIfSingleton(true_fn())
315297
return _UnpackIfSingleton(false_fn())
316-
317-
# Add the Switch to the graph.
318-
if isinstance(pred, bool):
319-
raise TypeError("pred must not be a Python bool")
320298
*/
321299

322300
// Add the Switch to the graph.

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static Convolution Convolution(TensorShape input_shape,
3030
/// <param name="name"></param>
3131
/// <returns></returns>
3232
public static Tensor bias_add(Tensor value,
33-
RefVariable bias,
33+
Tensor bias,
3434
string data_format = null,
3535
string name = null)
3636
{

src/TensorFlowNET.Core/Sessions/Session.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ public Session(Graph g, SessionOptions opts = null, Status s = null)
3838
Status.Check(true);
3939
}
4040

41+
public Session as_default()
42+
{
43+
tf.defaultSession = this;
44+
return this;
45+
}
46+
4147
public static Session LoadFromSavedModel(string path)
4248
{
4349
var graph = c_api.TF_NewGraph();

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,44 @@ protected object _eval_tensor(object tensor)
132132
}
133133

134134
/// <summary>
135-
/// Evaluates tensors and returns numpy values.
136-
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
135+
/// This function is used in many original tensorflow unit tests to evaluate tensors
136+
/// in a test session with special settings (for instance constant folding off)
137+
///
137138
/// </summary>
138-
/// <returns> tensors numpy values.</returns>
139-
[Obsolete("Why do we need this function? we already have Tensor.eval().")]
140-
public object evaluate(params Tensor[] tensors)
139+
public T evaluate<T>(Tensor tensor)
141140
{
141+
var results = new Dictionary<string, NDArray>();
142142
// if context.executing_eagerly():
143143
// return self._eval_helper(tensors)
144144
// else:
145145
{
146146
var sess = ops.get_default_session();
147-
if (sess == None)
148-
with(self.session(), s => sess = s);
149-
return sess.run(tensors);
147+
if (sess == null)
148+
sess = self.session();
149+
T t_result = (T)(object)null;
150+
with<Session>(sess, s =>
151+
{
152+
var ndarray=tensor.eval();
153+
if (typeof(T) == typeof(double))
154+
{
155+
double d = ndarray;
156+
t_result = (T)(object)d;
157+
}
158+
else if (typeof(T) == typeof(int))
159+
{
160+
int d = ndarray;
161+
t_result = (T) (object) d;
162+
}
163+
else
164+
{
165+
t_result = (T)(object)ndarray;
166+
}
167+
});
168+
return t_result;
150169
}
151170
}
152171

172+
153173
//Returns a TensorFlow Session for use in executing tests.
154174
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
155175
{
@@ -189,16 +209,11 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
189209
//if (context.executing_eagerly())
190210
// yield None
191211
//else
192-
{
193-
with<Session>(self._create_session(graph, config, force_gpu), sess =>
194-
{
195-
with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) =>
196-
{
197-
s = sess;
198-
});
199-
});
200-
}
201-
return s;
212+
//{
213+
s = self._create_session(graph, config, force_gpu);
214+
self._constrain_devices_and_set_default(s, use_gpu, force_gpu);
215+
//}
216+
return s.as_default();
202217
}
203218

204219
private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)

test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -91,74 +91,7 @@ with tf.Session() as sess:
9191
});
9292
}
9393

94-
[Ignore("Todo")]
95-
[TestMethod]
96-
public void testCondTrueLegacy()
97-
{
98-
// def testCondTrueLegacy(self):
99-
// x = constant_op.constant(2)
100-
// y = constant_op.constant(5)
101-
// z = control_flow_ops.cond(
102-
// math_ops.less(x, y),
103-
// fn1=lambda: math_ops.multiply(x, 17),
104-
// fn2=lambda: math_ops.add(y, 23))
105-
// self.assertEquals(self.evaluate(z), 34)
106-
}
107-
108-
[Ignore("Todo")]
109-
[TestMethod]
110-
public void testCondFalseLegacy()
111-
{
112-
// def testCondFalseLegacy(self):
113-
// x = constant_op.constant(2)
114-
// y = constant_op.constant(1)
115-
// z = control_flow_ops.cond(
116-
// math_ops.less(x, y),
117-
// fn1=lambda: math_ops.multiply(x, 17),
118-
// fn2=lambda: math_ops.add(y, 23))
119-
// self.assertEquals(self.evaluate(z), 24)
120-
}
121-
122-
[Ignore("Todo")]
123-
[TestMethod]
124-
public void testCondMissingArg1()
125-
{
126-
// def testCondMissingArg1(self):
127-
// x = constant_op.constant(1)
128-
// with self.assertRaises(TypeError):
129-
// control_flow_ops.cond(True, false_fn=lambda: x)
130-
131-
}
132-
133-
[Ignore("Todo")]
134-
[TestMethod]
135-
public void testCondMissingArg2()
136-
{
137-
// def testCondMissingArg2(self):
138-
// x = constant_op.constant(1)
139-
// with self.assertRaises(TypeError):
140-
// control_flow_ops.cond(True, lambda: x)
141-
}
142-
143-
[Ignore("Todo")]
144-
[TestMethod]
145-
public void testCondDuplicateArg1()
146-
{
147-
// def testCondDuplicateArg1(self):
148-
// x = constant_op.constant(1)
149-
// with self.assertRaises(TypeError):
150-
// control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
151-
}
152-
153-
[Ignore("Todo")]
154-
[TestMethod]
155-
public void testCondDuplicateArg2()
156-
{
157-
// def testCondDuplicateArg2(self):
158-
// x = constant_op.constant(1)
159-
// with self.assertRaises(TypeError):
160-
// control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
161-
}
94+
// NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
16295

16396
}
16497
}

0 commit comments

Comments
 (0)