@@ -26,21 +26,21 @@ public class CnnTextClassification : IExample
2626 public string Name => "CNN Text Classification" ;
2727 public int ? DataLimit = null ;
2828 public bool ImportGraph { get ; set ; } = true ;
29- public bool UseSubset = false ; // <----- set this true to use a limited subset of dbpedia
3029
31- private string dataDir = "text_classification " ;
30+ private string dataDir = "word_cnn " ;
3231 private string dataFileName = "dbpedia_csv.tar.gz" ;
3332
3433 private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv" ;
3534 private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv" ;
36-
35+
3736 private const int NUM_CLASS = 14 ;
3837 private const int BATCH_SIZE = 64 ;
3938 private const int NUM_EPOCHS = 10 ;
4039 private const int WORD_MAX_LEN = 100 ;
4140 private const int CHAR_MAX_LEN = 1014 ;
4241
4342 protected float loss_value = 0 ;
43+ int vocabulary_size = 50000 ;
4444
4545 public bool Run ( )
4646 {
@@ -63,10 +63,9 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
6363 int [ ] [ ] x = null ;
6464 int [ ] y = null ;
6565 int alphabet_size = 0 ;
66- int vocabulary_size = 0 ;
6766
6867 var word_dict = DataHelpers . build_word_dict ( TRAIN_PATH ) ;
69- vocabulary_size = len ( word_dict ) ;
68+ // vocabulary_size = len(word_dict);
7069 ( x , y ) = DataHelpers . build_word_dataset ( TRAIN_PATH , word_dict , WORD_MAX_LEN ) ;
7170
7271 Console . WriteLine ( "\t DONE " ) ;
@@ -142,7 +141,7 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
142141 if ( valid_accuracy > max_accuracy )
143142 {
144143 max_accuracy = valid_accuracy ;
145- saver . save ( sess , $ "{ dataDir } /word_cnn.ckpt", global_step : step . ToString ( ) ) ;
144+ saver . save ( sess , $ "{ dataDir } /word_cnn.ckpt", global_step : step ) ;
146145 print ( "Model is saved.\n " ) ;
147146 }
148147 }
@@ -218,18 +217,10 @@ private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_
218217
219218 public void PrepareData ( )
220219 {
221- if ( UseSubset )
222- {
223- var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip" ;
224- Web . Download ( url , dataDir , "dbpedia_subset.zip" ) ;
225- Compress . UnZip ( Path . Combine ( dataDir , "dbpedia_subset.zip" ) , Path . Combine ( dataDir , "dbpedia_csv" ) ) ;
226- }
227- else
228- {
229- string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz" ;
230- Web . Download ( url , dataDir , dataFileName ) ;
231- Compress . ExtractTGZ ( Path . Join ( dataDir , dataFileName ) , dataDir ) ;
232- }
220+ // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
221+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip" ;
222+ Web . Download ( url , dataDir , "dbpedia_subset.zip" ) ;
223+ Compress . UnZip ( Path . Combine ( dataDir , "dbpedia_subset.zip" ) , Path . Combine ( dataDir , "dbpedia_csv" ) ) ;
233224
234225 if ( ImportGraph )
235226 {
@@ -242,7 +233,7 @@ public void PrepareData()
242233 Console . WriteLine ( "Discarding cached file: " + meta_path ) ;
243234 File . Delete ( meta_path ) ;
244235 }
245- var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file ;
236+ url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file ;
246237 Web . Download ( url , "graph" , meta_file ) ;
247238 }
248239 }
0 commit comments