Skip to content

Commit c0fd135

Browse files
committed
word_cnn save training step works.
1 parent 25fb8cb commit c0fd135

File tree

3 files changed

+21
-33
lines changed

3 files changed

+21
-33
lines changed

graph/word_cnn.meta

0 Bytes
Binary file not shown.

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.IO;
44
using System.Linq;
55
using System.Text;
6+
using static Tensorflow.Python;
67

78
namespace Tensorflow
89
{
@@ -144,26 +145,20 @@ private void _check_saver_def()
144145

145146
public string save(Session sess,
146147
string save_path,
147-
string global_step = "",
148+
int global_step = -1,
148149
string latest_filename = "",
149150
string meta_graph_suffix = "meta",
150151
bool write_meta_graph = true,
151152
bool write_state = true,
152-
bool strip_default_attrs = false)
153+
bool strip_default_attrs = false,
154+
bool save_debug_info = false)
153155
{
154156
if (string.IsNullOrEmpty(latest_filename))
155157
latest_filename = "checkpoint";
156158
string model_checkpoint_path = "";
157159
string checkpoint_file = "";
158160

159-
if (!string.IsNullOrEmpty(global_step))
160-
{
161-
162-
}
163-
else
164-
{
165-
checkpoint_file = save_path;
166-
}
161+
checkpoint_file = $"{save_path}-{global_step}";
167162

168163
var save_path_parent = Path.GetDirectoryName(save_path);
169164

@@ -189,6 +184,7 @@ public string save(Session sess,
189184
if (write_meta_graph)
190185
{
191186
string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix);
187+
export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info);
192188
}
193189

194190
return _is_empty ? string.Empty : model_checkpoint_path;
@@ -244,10 +240,11 @@ public void restore(Session sess, string save_path)
244240
public MetaGraphDef export_meta_graph(string filename= "",
245241
string[] collection_list = null,
246242
string export_scope = "",
247-
bool as_text= false,
248-
bool clear_devices= false,
249-
bool clear_extraneous_savers= false,
250-
bool strip_default_attrs= false)
243+
bool as_text = false,
244+
bool clear_devices = false,
245+
bool clear_extraneous_savers = false,
246+
bool strip_default_attrs = false,
247+
bool save_debug_info = false)
251248
{
252249
return export_meta_graph(
253250
filename: filename,

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ public class CnnTextClassification : IExample
2626
public string Name => "CNN Text Classification";
2727
public int? DataLimit = null;
2828
public bool ImportGraph { get; set; } = true;
29-
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia
3029

31-
private string dataDir = "text_classification";
30+
private string dataDir = "word_cnn";
3231
private string dataFileName = "dbpedia_csv.tar.gz";
3332

3433
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
3534
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
36-
35+
3736
private const int NUM_CLASS = 14;
3837
private const int BATCH_SIZE = 64;
3938
private const int NUM_EPOCHS = 10;
4039
private const int WORD_MAX_LEN = 100;
4140
private const int CHAR_MAX_LEN = 1014;
4241

4342
protected float loss_value = 0;
43+
int vocabulary_size = 50000;
4444

4545
public bool Run()
4646
{
@@ -63,10 +63,9 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
6363
int[][] x = null;
6464
int[] y = null;
6565
int alphabet_size = 0;
66-
int vocabulary_size = 0;
6766

6867
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
69-
vocabulary_size = len(word_dict);
68+
// vocabulary_size = len(word_dict);
7069
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
7170

7271
Console.WriteLine("\tDONE ");
@@ -142,7 +141,7 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
142141
if (valid_accuracy > max_accuracy)
143142
{
144143
max_accuracy = valid_accuracy;
145-
saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step.ToString());
144+
saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step);
146145
print("Model is saved.\n");
147146
}
148147
}
@@ -218,18 +217,10 @@ private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_
218217

219218
public void PrepareData()
220219
{
221-
if (UseSubset)
222-
{
223-
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
224-
Web.Download(url, dataDir, "dbpedia_subset.zip");
225-
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
226-
}
227-
else
228-
{
229-
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
230-
Web.Download(url, dataDir, dataFileName);
231-
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
232-
}
220+
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
221+
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
222+
Web.Download(url, dataDir, "dbpedia_subset.zip");
223+
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
233224

234225
if (ImportGraph)
235226
{
@@ -242,7 +233,7 @@ public void PrepareData()
242233
Console.WriteLine("Discarding cached file: " + meta_path);
243234
File.Delete(meta_path);
244235
}
245-
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
236+
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
246237
Web.Download(url, "graph", meta_file);
247238
}
248239
}

0 commit comments

Comments
 (0)