Skip to content

Commit 349ee3d

Browse files
committed
fix wikipedia2vec.py
1 parent 81761b0 commit 349ee3d

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

wikipedia2vec/wikipedia2vec.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,20 @@
1717
import numpy as np
1818
from marisa_trie import RecordTrie, Trie
1919
from tqdm import tqdm
20-
21-
from cython.cimports.libcpp.random cimport mt19937
22-
from cython.cimports.scipy.linalg.cython_blas cimport saxpy, sdot
20+
from cython.cimports.libcpp.random import mt19937
21+
from cython.cimports.scipy.linalg.cython_blas import saxpy, sdot
2322

2423
from .dictionary import Dictionary
2524
from .dump_db import DumpDB
2625
from .link_graph import LinkGraph
2726
from .mention_db import MentionDB
2827
from .utils.sentence_detector.base_sentence_detector import BaseSentenceDetector
2928
from .utils.tokenizer.base_tokenizer import BaseTokenizer
30-
31-
from cython.cimports.wikipedia2vec.dictionary cimport Item, Word
32-
from cython.cimports.wikipedia2vec.dump_db cimport Paragraph, WikiLink
33-
from cython.cimports.wikipedia2vec.mention_db cimport Mention
34-
from cython.cimports.wikipedia2vec.utils.sentence_detector.sentence cimport Sentence
35-
from cython.cimports.wikipedia2vec.utils.tokenizer.token cimport Token
29+
from cython.cimports.wikipedia2vec.dictionary import Item, Word
30+
from cython.cimports.wikipedia2vec.dump_db import Paragraph, WikiLink
31+
from cython.cimports.wikipedia2vec.mention_db import Mention
32+
from cython.cimports.wikipedia2vec.utils.sentence_detector.sentence import Sentence
33+
from cython.cimports.wikipedia2vec.utils.tokenizer.token import Token
3634

3735
MAX_EXP = cython.declare(cython.float, 6.0)
3836
EXP_TABLE_SIZE = cython.declare(cython.int, 1000)
@@ -270,8 +268,10 @@ def train(
270268

271269
vocab_size = len(self.dictionary)
272270

273-
syn0_arr = (np.random.rand(vocab_size, dim_size).astype(np.float32) - 0.5) / dim_size
274-
syn1_arr = np.zeros((vocab_size, dim_size), dtype=np.float32)
271+
syn0_obj = _convert_np_array_to_shared_array_object(
272+
(np.random.rand(vocab_size, dim_size).astype(np.float32) - 0.5) / dim_size
273+
)
274+
syn1_obj = _convert_np_array_to_shared_array_object(np.zeros((vocab_size, dim_size), dtype=np.float32))
275275

276276
init_args = (
277277
dump_db,
@@ -280,8 +280,8 @@ def train(
280280
mention_db.serialize() if mention_db is not None else None,
281281
tokenizer,
282282
sentence_detector,
283-
_convert_np_array_to_shared_array_object(syn0_arr),
284-
_convert_np_array_to_shared_array_object(syn1_arr),
283+
syn0_obj,
284+
syn1_obj,
285285
_convert_np_array_to_shared_array_object(word_neg_table),
286286
_convert_np_array_to_shared_array_object(entity_neg_table),
287287
_convert_np_array_to_shared_array_object(exp_table),
@@ -314,11 +314,8 @@ def args_generator(titles: List[str], iteration: int):
314314

315315
logger.info("Terminating pool workers...")
316316

317-
syn0 = np.frombuffer(syn0_arr, dtype=np.float32).reshape((vocab_size, dim_size))
318-
syn1 = np.frombuffer(syn1_arr, dtype=np.float32).reshape((vocab_size, dim_size))
319-
320-
self.syn0 = syn0
321-
self.syn1 = syn1
317+
self.syn0 = _convert_shared_array_object_to_np_array(syn0_obj)
318+
self.syn1 = _convert_shared_array_object_to_np_array(syn1_obj)
322319

323320
train_params = dict(
324321
dump_db=dump_db.uuid,

0 commit comments

Comments
 (0)