Skip to content

Commit 86c4622

Browse files
committed
refactor wikipedia2vec.py and add tests
1 parent 4552866 commit 86c4622

File tree

5 files changed

+939
-726
lines changed

5 files changed

+939
-726
lines changed

tests/test_wikipedia2vec.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import pkg_resources
2+
import unittest
3+
from tempfile import NamedTemporaryFile
4+
from unittest import mock
5+
6+
import numpy as np
7+
8+
from wikipedia2vec.dictionary import Dictionary, Word, Entity
9+
from wikipedia2vec.dump_db import DumpDB
10+
from wikipedia2vec.utils.tokenizer import get_tokenizer
11+
from wikipedia2vec.utils.wiki_dump_reader import WikiDumpReader
12+
from wikipedia2vec.wikipedia2vec import Wikipedia2Vec, ItemWithScore
13+
14+
15+
db = None
16+
db_file = None
17+
dictionary = None
18+
wiki2vec = None
19+
20+
21+
class TestWikipedia2Vec(unittest.TestCase):
22+
@classmethod
23+
def setUpClass(cls):
24+
global db, db_file, tokenizer, dictionary, wiki2vec
25+
dump_file = pkg_resources.resource_filename("tests", "test_data/enwiki-pages-articles-sample.xml.bz2")
26+
dump_reader = WikiDumpReader(dump_file)
27+
db_file = NamedTemporaryFile()
28+
29+
DumpDB.build(dump_reader, db_file.name, 1, 1)
30+
db = DumpDB(db_file.name)
31+
32+
tokenizer = get_tokenizer("regexp")
33+
dictionary = Dictionary.build(
34+
db,
35+
tokenizer=tokenizer,
36+
lowercase=True,
37+
min_word_count=2,
38+
min_entity_count=1,
39+
min_paragraph_len=5,
40+
category=True,
41+
disambi=True,
42+
pool_size=1,
43+
chunk_size=1,
44+
progressbar=False,
45+
)
46+
wiki2vec = Wikipedia2Vec(dictionary)
47+
wiki2vec.syn0 = np.random.rand(len(dictionary), 100).astype(np.float32)
48+
wiki2vec.syn1 = np.random.rand(len(dictionary), 100).astype(np.float32)
49+
50+
@classmethod
51+
def tearDownClass(cls):
52+
db_file.close()
53+
54+
def test_dictionary_property(self):
55+
self.assertEqual(wiki2vec.dictionary, dictionary)
56+
57+
def test_get_word(self):
58+
word = wiki2vec.get_word("the")
59+
self.assertIsInstance(word, Word)
60+
61+
def test_get_word_not_exist(self):
62+
self.assertEqual(None, wiki2vec.get_word("foobar"))
63+
64+
def test_get_entity(self):
65+
entity = wiki2vec.get_entity("Computer system")
66+
self.assertIsInstance(entity, Entity)
67+
68+
def test_get_entity_not_exist(self):
69+
self.assertIsNone(wiki2vec.get_entity("Foo"))
70+
71+
def test_get_word_vector(self):
72+
vector = wiki2vec.get_word_vector("the")
73+
self.assertEqual((100,), vector.shape)
74+
self.assertTrue((vector == wiki2vec.syn0[dictionary.get_word("the").index]).all())
75+
76+
def test_get_word_vector_not_exist(self):
77+
self.assertRaises(KeyError, wiki2vec.get_word_vector, "foobar")
78+
79+
def test_get_entity_vector(self):
80+
vector = wiki2vec.get_entity_vector("Computer system")
81+
self.assertEqual((100,), vector.shape)
82+
self.assertTrue((wiki2vec.syn0[dictionary.get_entity("Computer system").index] == vector).all())
83+
84+
def test_get_entity_vector_not_exist(self):
85+
self.assertRaises(KeyError, wiki2vec.get_entity_vector, "Foo")
86+
87+
def test_get_vector(self):
88+
word = dictionary.get_word("the")
89+
vector = wiki2vec.get_vector(word)
90+
self.assertEqual((100,), vector.shape)
91+
self.assertTrue((vector == wiki2vec.syn0[dictionary.get_word("the").index]).all())
92+
93+
def test_most_similar(self):
94+
word = dictionary.get_word("the")
95+
vector = wiki2vec.syn0[word.index]
96+
all_scores = np.dot(wiki2vec.syn0, vector) / np.linalg.norm(wiki2vec.syn0, axis=1) / np.linalg.norm(vector)
97+
indexes = np.argsort(-all_scores)[:10].tolist()
98+
scores = [float(all_scores[index]) for index in indexes]
99+
100+
ret = wiki2vec.most_similar(word, 10)
101+
for entry in ret:
102+
self.assertIsInstance(entry, ItemWithScore)
103+
self.assertEqual(indexes, [o.item.index for o in ret])
104+
self.assertEqual(scores, [o.score for o in ret])
105+
106+
def test_most_similar_by_vector(self):
107+
word = dictionary.get_word("the")
108+
vector = wiki2vec.syn0[word.index]
109+
all_scores = np.dot(wiki2vec.syn0, vector) / np.linalg.norm(wiki2vec.syn0, axis=1) / np.linalg.norm(vector)
110+
indexes = np.argsort(-all_scores)[:10].tolist()
111+
scores = [float(all_scores[index]) for index in indexes]
112+
113+
ret = wiki2vec.most_similar_by_vector(vector, 10)
114+
for entry in ret:
115+
self.assertIsInstance(entry, ItemWithScore)
116+
self.assertEqual(indexes, [o.item.index for o in ret])
117+
self.assertEqual(scores, [o.score for o in ret])
118+
119+
def test_save_load(self):
120+
with NamedTemporaryFile() as f:
121+
wiki2vec.save(f.name)
122+
wiki2vec2 = Wikipedia2Vec.load(f.name)
123+
self.assertTrue(np.array_equal(wiki2vec.syn0, wiki2vec2.syn0))
124+
self.assertTrue(np.array_equal(wiki2vec.syn1, wiki2vec2.syn1))
125+
126+
serialized_dictionary = dictionary.serialize()
127+
serialized_dictionary2 = wiki2vec2.dictionary.serialize()
128+
for key in serialized_dictionary.keys():
129+
if isinstance(serialized_dictionary[key], np.ndarray):
130+
self.assertTrue(np.array_equal(serialized_dictionary[key], serialized_dictionary2[key]))
131+
else:
132+
self.assertEqual(serialized_dictionary[key], serialized_dictionary2[key])
133+
134+
def test_save_load_text(self):
135+
for out_format in ("word2vec", "glove", "default"):
136+
with NamedTemporaryFile() as f:
137+
wiki2vec.save_text(f.name, out_format=out_format)
138+
with open(f.name) as f:
139+
if out_format == "word2vec":
140+
first_line = next(f)
141+
self.assertEqual(str(len(dictionary)) + " " + "100", first_line.rstrip())
142+
143+
num_items = 0
144+
for line in f:
145+
if out_format in ("word2vec", "glove"):
146+
name, *vec_str = line.rstrip().split(" ")
147+
name = name.replace("_", " ")
148+
else:
149+
name, vec_str = line.rstrip().split("\t")
150+
vec_str = vec_str.split(" ")
151+
152+
vector = np.array([float(s) for s in vec_str], dtype=np.float32)
153+
154+
if name.startswith("ENTITY/"):
155+
name = name[7:]
156+
orig_vector = wiki2vec.get_entity_vector(name)
157+
else:
158+
orig_vector = wiki2vec.get_word_vector(name)
159+
self.assertTrue(np.allclose(orig_vector, vector, atol=1e-3))
160+
161+
num_items += 1
162+
163+
self.assertEqual(len(dictionary), num_items)
164+
165+
wiki2vec2 = Wikipedia2Vec.load_text(f.name)
166+
for word in dictionary.words():
167+
self.assertTrue(
168+
np.allclose(
169+
wiki2vec.get_word_vector(word.text), wiki2vec2.get_word_vector(word.text), atol=1e-3
170+
)
171+
)
172+
for entity in dictionary.entities():
173+
self.assertTrue(
174+
np.allclose(
175+
wiki2vec.get_entity_vector(entity.title),
176+
wiki2vec2.get_entity_vector(entity.title),
177+
atol=1e-3,
178+
)
179+
)
180+
self.assertEqual(len(dictionary), len(wiki2vec2.dictionary))
181+
182+
183+
if __name__ == "__main__":
184+
unittest.main()

wikipedia2vec/utils/random.pxd

Lines changed: 0 additions & 19 deletions
This file was deleted.

wikipedia2vec/utils/random.pyx

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)