Skip to content

Commit 954713f

Browse files
committed
add Conv2dParams
1 parent 0df7013 commit 954713f

File tree

10 files changed

+102
-43
lines changed

10 files changed

+102
-43
lines changed

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ public Conv(int rank,
5656
protected override void build(TensorShape input_shape)
5757
{
5858
int channel_axis = data_format == "channels_first" ? 1 : -1;
59-
int input_dim = input_shape.Dimensions[input_shape.NDim - 1];
59+
int input_dim = channel_axis < 0 ?
60+
input_shape.Dimensions[input_shape.NDim + channel_axis] :
61+
input_shape.Dimensions[channel_axis];
6062
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
6163
kernel = add_weight(name: "kernel",
6264
shape: kernel_shape,
@@ -102,7 +104,7 @@ protected override Tensor call(Tensor inputs, Tensor training = null)
102104
}
103105
else
104106
{
105-
outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC");
107+
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
106108
}
107109
}
108110

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,14 @@ protected virtual void add_update(Tensor[] updates, bool inputs = false)
206206
_updates.AddRange(updates_op);
207207
}
208208

209+
// Determine layer name (non-unique).
209210
protected virtual void _init_set_name(string name, bool zero_based = true)
210211
{
212+
var base_name = name;
213+
_name = name;
211214
if (name == null)
212-
_name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(this.GetType().Name), zero_based: zero_based);
213-
else
214-
_name = name;
215+
(_name, base_name) = _make_unique_name();
216+
_base_name = base_name;
215217
}
216218

217219
protected virtual (string, string) _make_unique_name()

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ public Tensor __call__(Tensor inputs,
6767
return outputs;
6868
}
6969

70-
protected override void _init_set_name(string name, bool zero_based = true)
71-
{
72-
// Determine layer name (non-unique).
73-
base._init_set_name(name, zero_based: zero_based);
74-
_base_name = this.name;
75-
}
76-
7770
protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
7871
{
7972
foreach(var name in collection_list)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
public class Conv2dParams
8+
{
9+
public string Name { get; set; }
10+
11+
/// <summary>
12+
/// An optional `string` from: `"NHWC", "NCHW"`. Defaults to `"NHWC"`.
13+
/// Specify the data format of the input and output data. With the
14+
/// default format "NHWC", the data is stored in the order of:
15+
/// [batch, height, width, channels].
16+
/// </summary>
17+
public string DataFormat { get; set; } = "NHWC";
18+
19+
/// <summary>
20+
/// Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`.
21+
/// A 4-D tensor. The dimension order is interpreted according to the value
22+
/// </summary>
23+
public Tensor Input { get; set; }
24+
25+
/// <summary>
26+
/// A 4-D tensor of shape
27+
/// </summary>
28+
public Tensor Filter { get; set; }
29+
30+
/// <summary>
31+
/// The stride of the sliding window for each
32+
/// dimension of `input`. The dimension order is determined by the value of
33+
/// `data_format`, see below for details.
34+
/// </summary>
35+
public int[] Strides { get; set; }
36+
37+
/// <summary>
38+
/// A `string` from: `"SAME", "VALID", "EXPLICIT"`.
39+
/// </summary>
40+
public string Padding { get; set; }
41+
42+
public int[] ExplicitPaddings { get; set; } = new int[0];
43+
44+
public bool UseCudnnOnGpu { get; set; } = true;
45+
46+
public int[] Dilations { get; set; } = new [] { 1, 1, 1, 1 };
47+
48+
public Conv2dParams()
49+
{
50+
51+
}
52+
}
53+
}

src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public class _NonAtrousConvolution
1111
public string name;
1212
public int[] strides;
1313
public string data_format;
14-
private Func<object, Tensor> conv_op;
14+
private Func<Conv2dParams, Tensor> conv_op;
1515

1616
public _NonAtrousConvolution(TensorShape input_shape,
1717
TensorShape filter_shape,
@@ -55,14 +55,14 @@ public _NonAtrousConvolution(TensorShape input_shape,
5555

5656
public Tensor __call__(Tensor inp, RefVariable filter)
5757
{
58-
return conv_op(new
58+
return conv_op(new Conv2dParams
5959
{
60-
input = inp,
61-
filter,
62-
strides,
63-
padding,
64-
data_format,
65-
name
60+
Input = inp,
61+
Filter = filter,
62+
Strides = strides,
63+
Padding = padding,
64+
DataFormat = data_format,
65+
Name = name
6666
});
6767
}
6868
}

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,36 @@ public class gen_nn_ops
1010
{
1111
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
1212

13-
public static Tensor conv2d(object parameters)
13+
/// <summary>
14+
/// Computes a 2-D convolution given 4-D `input` and `filter` tensors.
15+
///
16+
/// Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
17+
/// and a filter / kernel tensor of shape
18+
/// `[filter_height, filter_width, in_channels, out_channels]`, this op
19+
/// performs the following:
20+
///
21+
/// 1. Flattens the filter to a 2-D matrix with shape
22+
/// `[filter_height * filter_width * in_channels, output_channels]`.
23+
/// 2. Extracts image patches from the input tensor to form a *virtual*
24+
/// tensor of shape `[batch, out_height, out_width,
25+
/// filter_height * filter_width * in_channels]`.
26+
/// 3. For each patch, right-multiplies the filter matrix and the image patch
27+
/// vector.
28+
/// </summary>
29+
/// <param name="parameters"></param>
30+
/// <returns></returns>
31+
public static Tensor conv2d(Conv2dParams parameters)
1432
{
15-
var args = Python.ConvertToDict(parameters);
16-
17-
var input = args["input"];
18-
var filter = args["filter"];
19-
var strides = args["strides"];
20-
var padding = args["padding"];
21-
var name = args["name"];
22-
var data_format = args.ContainsKey("data_format") ? args["data_format"] : "NHWC";
23-
var use_cudnn_on_gpu = args.ContainsKey("use_cudnn_on_gpu") ? args["use_cudnn_on_gpu"] : true;
24-
var dilations = args.ContainsKey("dilations") ? args["dilations"] : new int[] { 1, 1, 1, 1 };
25-
26-
var _op = _op_def_lib._apply_op_helper("Conv2D", name: name?.ToString(), args: new
33+
var _op = _op_def_lib._apply_op_helper("Conv2D", name: parameters.Name, args: new
2734
{
28-
input,
29-
filter,
30-
strides,
31-
padding,
32-
use_cudnn_on_gpu,
33-
data_format,
34-
dilations
35+
input = parameters.Input,
36+
filter = parameters.Filter,
37+
strides = parameters.Strides,
38+
padding = parameters.Padding,
39+
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
40+
explicit_paddings = parameters.ExplicitPaddings,
41+
data_format = parameters.DataFormat,
42+
dilations = parameters.Dilations
3543
});
3644

3745
return _op.outputs[0];

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public static Tensor bias_add(Tensor value,
3737
{
3838
return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope =>
3939
{
40+
name = scope;
4041
value = ops.convert_to_tensor(value, name: "input");
4142
var bias_tensor = ops.convert_to_tensor(bias, dtype: value.dtype, name: "bias");
4243
return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name);

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ private void _init_from_args(object initial_value,
188188

189189
public Tensor _as_graph_element() => _variable;
190190

191-
public Tensor _TensorConversionFunction(bool as_ref = false)
191+
public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
192192
{
193193
if (as_ref)
194194
return _ref();

src/TensorFlowNET.Core/Variables/VariableScope.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public RefVariable get_variable(_VariableStore var_store,
4040
VariableSynchronization synchronization = VariableSynchronization.Auto,
4141
VariableAggregation aggregation= VariableAggregation.None)
4242
{
43-
string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name;
43+
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
4444
return with(ops.name_scope(null), scope =>
4545
{
4646
if (dtype == TF_DataType.DtInvalid)

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype
473473
case Tensor[] tensors:
474474
return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name);
475475
case RefVariable varVal:
476-
return varVal._TensorConversionFunction(as_ref: as_ref);
476+
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref);
477477
case ResourceVariable varVal:
478478
return null;
479479
case object[] objects:

0 commit comments

Comments
 (0)