Skip to content

Commit a4f03c2

Browse files
committed
added sparse_softmax_cross_entropy_with_logits
1 parent 5981d2d commit a4f03c2

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ public static Tensor bias_add(Tensor value, RefVariable bias, string data_format
9090
public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
9191
=> gen_nn_ops.softmax(logits, name);
9292

93+
/// <summary>
94+
/// Computes sparse softmax cross entropy between `logits` and `labels`.
95+
/// </summary>
96+
/// <param name="labels"></param>
97+
/// <param name="logits"></param>
98+
/// <param name="name"></param>
99+
/// <returns></returns>
100+
public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null,
101+
Tensor logits = null, string name = null)
102+
=> nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name);
103+
93104
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
94105
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
95106
}

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,27 @@ public Graph BuildGraph()
203203
var h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
204204
});
205205

206+
Tensor logits = null;
207+
Tensor predictions = null;
208+
with(tf.name_scope("output"), delegate
209+
{
210+
logits = tf.layers.dense(h_pool_flat, keep_prob);
211+
predictions = tf.argmax(logits, -1, output_type: tf.int32);
212+
});
213+
214+
with(tf.name_scope("loss"), delegate
215+
{
216+
var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
217+
var loss = tf.reduce_mean(sscel);
218+
var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step);
219+
});
220+
221+
with(tf.name_scope("accuracy"), delegate
222+
{
223+
var correct_predictions = tf.equal(predictions, y);
224+
var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
225+
});
226+
206227
return graph;
207228
}
208229

0 commit comments

Comments
 (0)