Skip to content

Commit b459f6c

Browse files
committed
fix tests
1 parent 349ee3d commit b459f6c

File tree

11 files changed

+102
-91
lines changed

11 files changed

+102
-91
lines changed

.github/workflows/test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ jobs:
3232
pip install numpy scipy Cython
3333
sh ./cythonize.sh
3434
35+
- name: Install ICU
36+
run: |
37+
sudo apt-get install -y libicu-dev
38+
pip install PyICU
39+
if: matrix.os == 'ubuntu-latest'
40+
3541
- name: Install package
3642
run: |
3743
pip install -e .[mecab]

tests/__init__.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +0,0 @@
1-
import pkg_resources
2-
from tempfile import NamedTemporaryFile
3-
4-
from wikipedia2vec.utils.wiki_dump_reader import WikiDumpReader
5-
from wikipedia2vec.dump_db import DumpDB
6-
7-
dump_db = None
8-
dump_db_file = None
9-
10-
11-
def get_dump_db():
12-
return dump_db
13-
14-
15-
def setUp():
16-
global dump_db, dump_db_file
17-
18-
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
19-
dump_reader = WikiDumpReader(dump_file)
20-
dump_db_file = NamedTemporaryFile()
21-
22-
DumpDB.build(dump_reader, dump_db_file.name, 1, 1)
23-
dump_db = DumpDB(dump_db_file.name)
24-
25-
26-
def tearDown():
27-
dump_db_file.close()

tests/test_dictionary.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import os
12
import pkg_resources
23
import unittest
3-
from tempfile import NamedTemporaryFile
4+
from tempfile import TemporaryDirectory
45
from unittest import mock
56

67
import numpy as np
@@ -48,20 +49,21 @@ def test_doc_count_property(self):
4849

4950

5051
db = None
51-
db_file = None
52+
db_dir = None
5253
dictionary = None
5354

5455

5556
class TestDictionary(unittest.TestCase):
5657
@classmethod
5758
def setUpClass(cls):
58-
global db, db_file, tokenizer, dictionary
59+
global db, db_dir, tokenizer, dictionary
5960
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
6061
dump_reader = WikiDumpReader(dump_file)
61-
db_file = NamedTemporaryFile()
62+
db_dir = TemporaryDirectory()
63+
db_file = os.path.join(db_dir.name, "test.db")
6264

63-
DumpDB.build(dump_reader, db_file.name, 1, 1)
64-
db = DumpDB(db_file.name)
65+
DumpDB.build(dump_reader, db_file, 1, 1)
66+
db = DumpDB(db_file)
6567

6668
tokenizer = get_tokenizer("regexp")
6769
dictionary = Dictionary.build(
@@ -80,7 +82,8 @@ def setUpClass(cls):
8082

8183
@classmethod
8284
def tearDownClass(cls):
83-
db_file.close()
85+
db.close()
86+
db_dir.cleanup()
8487

8588
def test_uuid_property(self):
8689
self.assertIsInstance(dictionary.uuid, str)
@@ -264,9 +267,10 @@ def validate(obj):
264267
validate(Dictionary.load(dictionary.serialize()))
265268
validate(Dictionary.load(dictionary.serialize(shared_array=True)))
266269

267-
with NamedTemporaryFile() as f:
268-
dictionary.save(f.name)
269-
validate(Dictionary.load(f.name))
270+
with TemporaryDirectory() as dir_name:
271+
file_name = os.path.join(dir_name, "dictionary.pkl")
272+
dictionary.save(file_name)
273+
validate(Dictionary.load(file_name))
270274

271275

272276
if __name__ == "__main__":

tests/test_dump_db.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
import pickle
23
import pkg_resources
34
import unittest
45
import zlib
5-
from tempfile import NamedTemporaryFile
6+
from tempfile import TemporaryDirectory
67

78
from wikipedia2vec import dump_db
89
from wikipedia2vec.dump_db import DumpDB, Paragraph, WikiLink
@@ -50,23 +51,25 @@ def test_span_property(self):
5051

5152

5253
db = None
53-
db_file = None
54+
db_dir = None
5455

5556

5657
class TestDumpDB(unittest.TestCase):
5758
@classmethod
5859
def setUpClass(cls):
59-
global db, db_file
60+
global db, db_dir
6061
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
6162
dump_reader = WikiDumpReader(dump_file)
62-
db_file = NamedTemporaryFile()
63+
db_dir = TemporaryDirectory()
64+
db_file = os.path.join(db_dir.name, "test.db")
6365

64-
DumpDB.build(dump_reader, db_file.name, 1, 1)
65-
db = DumpDB(db_file.name)
66+
DumpDB.build(dump_reader, db_file, 1, 1)
67+
db = DumpDB(db_file)
6668

6769
@classmethod
6870
def tearDownClass(cls):
69-
db_file.close()
71+
db.close()
72+
db_dir.cleanup()
7073

7174
def test_uuid_property(self):
7275
self.assertIsInstance(db.uuid, str)

tests/test_link_graph.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import os
12
import pkg_resources
23
import unittest
3-
from tempfile import NamedTemporaryFile
4+
from tempfile import TemporaryDirectory
45

56
import numpy as np
67

@@ -11,22 +12,23 @@
1112
from wikipedia2vec.utils.wiki_dump_reader import WikiDumpReader
1213

1314
db = None
14-
db_file = None
15+
db_dir = None
1516
dictionary = None
1617
link_graph = None
1718

1819

1920
class TestLinkGraph(unittest.TestCase):
2021
@classmethod
2122
def setUpClass(cls):
22-
global db, db_file, dictionary, link_graph
23+
global db, db_dir, dictionary, link_graph
2324

2425
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
2526
dump_reader = WikiDumpReader(dump_file)
26-
db_file = NamedTemporaryFile()
27+
db_dir = TemporaryDirectory()
28+
db_file = os.path.join(db_dir.name, "test.db")
2729

28-
DumpDB.build(dump_reader, db_file.name, 1, 1)
29-
db = DumpDB(db_file.name)
30+
DumpDB.build(dump_reader, db_file, 1, 1)
31+
db = DumpDB(db_file)
3032

3133
tokenizer = get_default_tokenizer("en")
3234
dictionary = Dictionary.build(
@@ -46,7 +48,8 @@ def setUpClass(cls):
4648

4749
@classmethod
4850
def tearDownClass(cls):
49-
db_file.close()
51+
db.close()
52+
db_dir.cleanup()
5053

5154
def test_uuid_property(self):
5255
self.assertIsInstance(link_graph.uuid, str)
@@ -89,9 +92,10 @@ def validate(obj):
8992
validate(LinkGraph.load(link_graph.serialize(), dictionary))
9093
validate(LinkGraph.load(link_graph.serialize(shared_array=True), dictionary))
9194

92-
with NamedTemporaryFile() as f:
93-
link_graph.save(f.name)
94-
validate(LinkGraph.load(f.name, dictionary))
95+
with TemporaryDirectory() as dir_name:
96+
file_name = os.path.join(dir_name, "link_graph.pkl")
97+
link_graph.save(file_name)
98+
validate(LinkGraph.load(file_name, dictionary))
9599

96100

97101
if __name__ == "__main__":

tests/test_mention_db.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import os
12
import pkg_resources
23
import unittest
3-
from tempfile import NamedTemporaryFile
4+
from tempfile import TemporaryDirectory
45

56
from wikipedia2vec.dictionary import Dictionary
67
from wikipedia2vec.dump_db import DumpDB
@@ -9,20 +10,21 @@
910
from wikipedia2vec.utils.wiki_dump_reader import WikiDumpReader
1011

1112
db = None
12-
db_file = None
13+
db_dir = None
1314
tokenizer = None
1415
dictionary = None
1516

1617

1718
def setUpModule():
18-
global db, db_file, tokenizer, dictionary
19+
global db, db_dir, tokenizer, dictionary
1920

2021
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
2122
dump_reader = WikiDumpReader(dump_file)
22-
db_file = NamedTemporaryFile()
23+
db_dir = TemporaryDirectory()
24+
db_file = os.path.join(db_dir.name, "test.db")
2325

24-
DumpDB.build(dump_reader, db_file.name, 1, 1)
25-
db = DumpDB(db_file.name)
26+
DumpDB.build(dump_reader, db_file, 1, 1)
27+
db = DumpDB(db_file)
2628

2729
tokenizer = get_default_tokenizer("en")
2830
dictionary = Dictionary.build(
@@ -41,7 +43,8 @@ def setUpModule():
4143

4244

4345
def tearDownModule():
44-
db_file.close()
46+
db.close()
47+
db_dir.cleanup()
4548

4649

4750
class TestMention(unittest.TestCase):
@@ -147,9 +150,10 @@ def validate(obj):
147150

148151
validate(MentionDB.load(mention_db.serialize(), dictionary))
149152

150-
with NamedTemporaryFile() as f:
151-
mention_db.save(f.name)
152-
validate(MentionDB.load(f.name, dictionary))
153+
with TemporaryDirectory() as dir_name:
154+
file_name = os.path.join(dir_name, "mention.pkl")
155+
mention_db.save(file_name)
156+
validate(MentionDB.load(file_name, dictionary))
153157

154158

155159
if __name__ == "__main__":

tests/test_wikipedia2vec.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import os
12
import pkg_resources
23
import unittest
34
from collections import Counter
4-
from tempfile import NamedTemporaryFile
5+
from tempfile import TemporaryDirectory
56

67
import numpy as np
78
from scipy.special import expit
@@ -14,21 +15,22 @@
1415

1516

1617
db = None
17-
db_file = None
18+
db_dir = None
1819
dictionary = None
1920
wiki2vec = None
2021

2122

2223
class TestWikipedia2Vec(unittest.TestCase):
2324
@classmethod
2425
def setUpClass(cls):
25-
global db, db_file, tokenizer, dictionary, wiki2vec
26+
global db, db_dir, tokenizer, dictionary, wiki2vec
2627
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
2728
dump_reader = WikiDumpReader(dump_file)
28-
db_file = NamedTemporaryFile()
29+
db_dir = TemporaryDirectory()
30+
db_file = os.path.join(db_dir.name, "test.db")
2931

30-
DumpDB.build(dump_reader, db_file.name, 1, 1)
31-
db = DumpDB(db_file.name)
32+
DumpDB.build(dump_reader, db_file, 1, 1)
33+
db = DumpDB(db_file)
3234

3335
tokenizer = get_tokenizer("regexp")
3436
dictionary = Dictionary.build(
@@ -50,7 +52,8 @@ def setUpClass(cls):
5052

5153
@classmethod
5254
def tearDownClass(cls):
53-
db_file.close()
55+
db.close()
56+
db_dir.cleanup()
5457

5558
def test_dictionary_property(self):
5659
self.assertEqual(wiki2vec.dictionary, dictionary)
@@ -118,9 +121,10 @@ def test_most_similar_by_vector(self):
118121
self.assertEqual(scores, [o.score for o in ret])
119122

120123
def test_save_load(self):
121-
with NamedTemporaryFile() as f:
122-
wiki2vec.save(f.name)
123-
wiki2vec2 = Wikipedia2Vec.load(f.name)
124+
with TemporaryDirectory() as dir_name:
125+
file_name = os.path.join(dir_name, "model.pkl")
126+
wiki2vec.save(file_name)
127+
wiki2vec2 = Wikipedia2Vec.load(file_name, numpy_mmap_mode=None)
124128
self.assertTrue(np.array_equal(wiki2vec.syn0, wiki2vec2.syn0))
125129
self.assertTrue(np.array_equal(wiki2vec.syn1, wiki2vec2.syn1))
126130

@@ -134,9 +138,10 @@ def test_save_load(self):
134138

135139
def test_save_load_text(self):
136140
for out_format in ("word2vec", "glove", "default"):
137-
with NamedTemporaryFile() as f:
138-
wiki2vec.save_text(f.name, out_format=out_format)
139-
with open(f.name) as f:
141+
with TemporaryDirectory() as dir_name:
142+
file_name = os.path.join(dir_name, "model.txt")
143+
wiki2vec.save_text(file_name, out_format=out_format)
144+
with open(file_name) as f:
140145
if out_format == "word2vec":
141146
first_line = next(f)
142147
self.assertEqual(str(len(dictionary)) + " " + "100", first_line.rstrip())
@@ -163,7 +168,7 @@ def test_save_load_text(self):
163168

164169
self.assertEqual(len(dictionary), num_items)
165170

166-
wiki2vec2 = Wikipedia2Vec.load_text(f.name)
171+
wiki2vec2 = Wikipedia2Vec.load_text(file_name)
167172
for word in dictionary.words():
168173
self.assertTrue(
169174
np.allclose(

tests/utils/tokenizer/test_icu_tokenizer.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
import unittest
22

33
from wikipedia2vec.utils.tokenizer.token import Token
4-
from wikipedia2vec.utils.tokenizer.icu_tokenizer import ICUTokenizer
54

5+
try:
6+
import icu
67

7-
class TestICUTokenizer(unittest.TestCase):
8-
def setUp(self):
9-
self._tokenizer = ICUTokenizer("en")
8+
ICU_INSTALLED = True
9+
except ImportError:
10+
ICU_INSTALLED = False
1011

11-
def test_tokenize(self):
12-
text = "Tokyo is the capital of Japan"
13-
tokens = self._tokenizer.tokenize(text)
12+
if ICU_INSTALLED:
1413

15-
for token in tokens:
16-
self.assertIsInstance(token, Token)
17-
self.assertEqual(["Tokyo", "is", "the", "capital", "of", "Japan"], [t.text for t in tokens])
18-
self.assertEqual([(0, 5), (6, 8), (9, 12), (13, 20), (21, 23), (24, 29)], [t.span for t in tokens])
14+
class TestICUTokenizer(unittest.TestCase):
15+
def setUp(self):
16+
from wikipedia2vec.utils.tokenizer.icu_tokenizer import ICUTokenizer
17+
18+
self._tokenizer = ICUTokenizer("en")
19+
20+
def test_tokenize(self):
21+
text = "Tokyo is the capital of Japan"
22+
tokens = self._tokenizer.tokenize(text)
23+
24+
for token in tokens:
25+
self.assertIsInstance(token, Token)
26+
self.assertEqual(["Tokyo", "is", "the", "capital", "of", "Japan"], [t.text for t in tokens])
27+
self.assertEqual([(0, 5), (6, 8), (9, 12), (13, 20), (21, 23), (24, 29)], [t.span for t in tokens])
1928

2029

2130
if __name__ == "__main__":

0 commit comments

Comments
 (0)