Skip to content

Commit ec319b1

Browse files
committed
added more test cases (gradient tests)
1 parent c997c72 commit ec319b1

File tree

6 files changed

+2858
-43
lines changed

6 files changed

+2858
-43
lines changed

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
101101
throw new NotImplementedException("call channels_first");
102102
}
103103
else
104-
{
105-
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
104+
{
105+
outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC");
106106
}
107107
}
108108

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static Convolution Convolution(TensorShape input_shape,
3030
/// <param name="name"></param>
3131
/// <returns></returns>
3232
public static Tensor bias_add(Tensor value,
33-
RefVariable bias,
33+
Tensor bias,
3434
string data_format = null,
3535
string name = null)
3636
{

test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,46 +38,7 @@ public void testCondFalse()
3838
});
3939
}
4040

41-
[Ignore("Todo")]
42-
[TestMethod]
43-
public void testCondMissingArg1()
44-
{
45-
// def testCondMissingArg1(self):
46-
// x = constant_op.constant(1)
47-
// with self.assertRaises(TypeError):
48-
// control_flow_ops.cond(True, false_fn=lambda: x)
49-
50-
}
51-
52-
[Ignore("Todo")]
53-
[TestMethod]
54-
public void testCondMissingArg2()
55-
{
56-
// def testCondMissingArg2(self):
57-
// x = constant_op.constant(1)
58-
// with self.assertRaises(TypeError):
59-
// control_flow_ops.cond(True, lambda: x)
60-
}
61-
62-
[Ignore("Todo")]
63-
[TestMethod]
64-
public void testCondDuplicateArg1()
65-
{
66-
// def testCondDuplicateArg1(self):
67-
// x = constant_op.constant(1)
68-
// with self.assertRaises(TypeError):
69-
// control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
70-
}
71-
72-
[Ignore("Todo")]
73-
[TestMethod]
74-
public void testCondDuplicateArg2()
75-
{
76-
// def testCondDuplicateArg2(self):
77-
// x = constant_op.constant(1)
78-
// with self.assertRaises(TypeError):
79-
// control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
80-
}
41+
// NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
8142

8243
}
8344
}

0 commit comments

Comments
 (0)