|
1 | 1 | import pkg_resources |
2 | 2 | import unittest |
3 | 3 | from tempfile import NamedTemporaryFile |
4 | | -from unittest import mock |
5 | 4 |
|
6 | 5 | import numpy as np |
| 6 | +from scipy.special import expit |
7 | 7 |
|
8 | 8 | from wikipedia2vec.dictionary import Dictionary, Word, Entity |
9 | 9 | from wikipedia2vec.dump_db import DumpDB |
@@ -179,6 +179,32 @@ def test_save_load_text(self): |
179 | 179 | ) |
180 | 180 | self.assertEqual(len(dictionary), len(wiki2vec2.dictionary)) |
181 | 181 |
|
| 182 | + def test_build_sampling_table(self): |
| 183 | + table = wiki2vec._build_word_sampling_table(0.01) |
| 184 | + self.assertIsInstance(table, np.ndarray) |
| 185 | + self.assertEqual(np.uint32, table.dtype) |
| 186 | + |
| 187 | + total_count = sum(word.count for word in dictionary.words()) |
| 188 | + threshold = 0.01 * total_count |
| 189 | + uint_max = np.iinfo(np.uint32).max |
| 190 | + for word in dictionary.words(): |
| 191 | + if word.count > total_count * 0.01: |
| 192 | + self.assertAlmostEqual( |
| 193 | + min(1.0, (np.sqrt(word.count / threshold) + 1) * (threshold / word.count)) * uint_max, |
| 194 | + table[word.index], |
| 195 | + delta=1, |
| 196 | + ) |
| 197 | + else: |
| 198 | + self.assertEqual(uint_max, table[word.index]) |
| 199 | + |
| 200 | + def test_build_exp_table(self): |
| 201 | + max_exp = 6 |
| 202 | + table_size = 1000 |
| 203 | + exp_table = wiki2vec._build_exp_table(max_exp, table_size) |
| 204 | + for value in range(-max_exp, max_exp): |
| 205 | + index = int((value + max_exp) * (table_size / max_exp / 2)) |
| 206 | + self.assertAlmostEqual(expit(value), exp_table[index], delta=1e-2) |
| 207 | + |
182 | 208 |
|
183 | 209 | if __name__ == "__main__": |
184 | 210 | unittest.main() |
0 commit comments