1+ import os
12import pkg_resources
23import unittest
34from collections import Counter
4- from tempfile import NamedTemporaryFile
5+ from tempfile import TemporaryDirectory
56
67import numpy as np
78from scipy .special import expit
1415
1516
1617db = None
17- db_file = None
18+ db_dir = None
1819dictionary = None
1920wiki2vec = None
2021
2122
2223class TestWikipedia2Vec (unittest .TestCase ):
2324 @classmethod
2425 def setUpClass (cls ):
25- global db , db_file , tokenizer , dictionary , wiki2vec
26+ global db , db_dir , tokenizer , dictionary , wiki2vec
2627 dump_file = pkg_resources .resource_filename ("tests" , "test_data/enwiki-pages-articles-sample.xml.bz2" )
2728 dump_reader = WikiDumpReader (dump_file )
28- db_file = NamedTemporaryFile ()
29+ db_dir = TemporaryDirectory ()
30+ db_file = os .path .join (db_dir .name , "test.db" )
2931
30- DumpDB .build (dump_reader , db_file . name , 1 , 1 )
31- db = DumpDB (db_file . name )
32+ DumpDB .build (dump_reader , db_file , 1 , 1 )
33+ db = DumpDB (db_file )
3234
3335 tokenizer = get_tokenizer ("regexp" )
3436 dictionary = Dictionary .build (
@@ -50,7 +52,8 @@ def setUpClass(cls):
5052
5153 @classmethod
5254 def tearDownClass (cls ):
53- db_file .close ()
55+ db .close ()
56+ db_dir .cleanup ()
5457
5558 def test_dictionary_property (self ):
5659 self .assertEqual (wiki2vec .dictionary , dictionary )
@@ -118,9 +121,10 @@ def test_most_similar_by_vector(self):
118121 self .assertEqual (scores , [o .score for o in ret ])
119122
120123 def test_save_load (self ):
121- with NamedTemporaryFile () as f :
122- wiki2vec .save (f .name )
123- wiki2vec2 = Wikipedia2Vec .load (f .name )
124+ with TemporaryDirectory () as dir_name :
125+ file_name = os .path .join (dir_name , "model.pkl" )
126+ wiki2vec .save (file_name )
127+ wiki2vec2 = Wikipedia2Vec .load (file_name , numpy_mmap_mode = None )
124128 self .assertTrue (np .array_equal (wiki2vec .syn0 , wiki2vec2 .syn0 ))
125129 self .assertTrue (np .array_equal (wiki2vec .syn1 , wiki2vec2 .syn1 ))
126130
@@ -134,9 +138,10 @@ def test_save_load(self):
134138
135139 def test_save_load_text (self ):
136140 for out_format in ("word2vec" , "glove" , "default" ):
137- with NamedTemporaryFile () as f :
138- wiki2vec .save_text (f .name , out_format = out_format )
139- with open (f .name ) as f :
141+ with TemporaryDirectory () as dir_name :
142+ file_name = os .path .join (dir_name , "model.txt" )
143+ wiki2vec .save_text (file_name , out_format = out_format )
144+ with open (file_name ) as f :
140145 if out_format == "word2vec" :
141146 first_line = next (f )
142147 self .assertEqual (str (len (dictionary )) + " " + "100" , first_line .rstrip ())
@@ -163,7 +168,7 @@ def test_save_load_text(self):
163168
164169 self .assertEqual (len (dictionary ), num_items )
165170
166- wiki2vec2 = Wikipedia2Vec .load_text (f . name )
171+ wiki2vec2 = Wikipedia2Vec .load_text (file_name )
167172 for word in dictionary .words ():
168173 self .assertTrue (
169174 np .allclose (
0 commit comments