Skip to content

Commit c4db450

Browse files
committed
add tests
1 parent e37f681 commit c4db450

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

tests/test_wikipedia2vec.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import pkg_resources
22
import unittest
33
from tempfile import NamedTemporaryFile
4-
from unittest import mock
54

65
import numpy as np
6+
from scipy.special import expit
77

88
from wikipedia2vec.dictionary import Dictionary, Word, Entity
99
from wikipedia2vec.dump_db import DumpDB
@@ -179,6 +179,32 @@ def test_save_load_text(self):
179179
)
180180
self.assertEqual(len(dictionary), len(wiki2vec2.dictionary))
181181

182+
def test_build_sampling_table(self):
183+
table = wiki2vec._build_word_sampling_table(0.01)
184+
self.assertIsInstance(table, np.ndarray)
185+
self.assertEqual(np.uint32, table.dtype)
186+
187+
total_count = sum(word.count for word in dictionary.words())
188+
threshold = 0.01 * total_count
189+
uint_max = np.iinfo(np.uint32).max
190+
for word in dictionary.words():
191+
if word.count > total_count * 0.01:
192+
self.assertAlmostEqual(
193+
min(1.0, (np.sqrt(word.count / threshold) + 1) * (threshold / word.count)) * uint_max,
194+
table[word.index],
195+
delta=1,
196+
)
197+
else:
198+
self.assertEqual(uint_max, table[word.index])
199+
200+
def test_build_exp_table(self):
201+
max_exp = 6
202+
table_size = 1000
203+
exp_table = wiki2vec._build_exp_table(max_exp, table_size)
204+
for value in range(-max_exp, max_exp):
205+
index = int((value + max_exp) * (table_size / max_exp / 2))
206+
self.assertAlmostEqual(expit(value), exp_table[index], delta=1e-2)
207+
182208

183209
if __name__ == "__main__":
184210
unittest.main()

0 commit comments

Comments
 (0)