1717import numpy as np
1818from marisa_trie import RecordTrie , Trie
1919from tqdm import tqdm
20-
21- from cython .cimports .libcpp .random cimport mt19937
22- from cython .cimports .scipy .linalg .cython_blas cimport saxpy , sdot
20+ from cython .cimports .libcpp .random import mt19937
21+ from cython .cimports .scipy .linalg .cython_blas import saxpy , sdot
2322
2423from .dictionary import Dictionary
2524from .dump_db import DumpDB
2625from .link_graph import LinkGraph
2726from .mention_db import MentionDB
2827from .utils .sentence_detector .base_sentence_detector import BaseSentenceDetector
2928from .utils .tokenizer .base_tokenizer import BaseTokenizer
30-
31- from cython .cimports .wikipedia2vec .dictionary cimport Item , Word
32- from cython .cimports .wikipedia2vec .dump_db cimport Paragraph , WikiLink
33- from cython .cimports .wikipedia2vec .mention_db cimport Mention
34- from cython .cimports .wikipedia2vec .utils .sentence_detector .sentence cimport Sentence
35- from cython .cimports .wikipedia2vec .utils .tokenizer .token cimport Token
29+ from cython .cimports .wikipedia2vec .dictionary import Item , Word
30+ from cython .cimports .wikipedia2vec .dump_db import Paragraph , WikiLink
31+ from cython .cimports .wikipedia2vec .mention_db import Mention
32+ from cython .cimports .wikipedia2vec .utils .sentence_detector .sentence import Sentence
33+ from cython .cimports .wikipedia2vec .utils .tokenizer .token import Token
3634
3735MAX_EXP = cython .declare (cython .float , 6.0 )
3836EXP_TABLE_SIZE = cython .declare (cython .int , 1000 )
@@ -270,8 +268,10 @@ def train(
270268
271269 vocab_size = len (self .dictionary )
272270
273- syn0_arr = (np .random .rand (vocab_size , dim_size ).astype (np .float32 ) - 0.5 ) / dim_size
274- syn1_arr = np .zeros ((vocab_size , dim_size ), dtype = np .float32 )
271+ syn0_obj = _convert_np_array_to_shared_array_object (
272+ (np .random .rand (vocab_size , dim_size ).astype (np .float32 ) - 0.5 ) / dim_size
273+ )
274+ syn1_obj = _convert_np_array_to_shared_array_object (np .zeros ((vocab_size , dim_size ), dtype = np .float32 ))
275275
276276 init_args = (
277277 dump_db ,
@@ -280,8 +280,8 @@ def train(
280280 mention_db .serialize () if mention_db is not None else None ,
281281 tokenizer ,
282282 sentence_detector ,
283- _convert_np_array_to_shared_array_object ( syn0_arr ) ,
284- _convert_np_array_to_shared_array_object ( syn1_arr ) ,
283+ syn0_obj ,
284+ syn1_obj ,
285285 _convert_np_array_to_shared_array_object (word_neg_table ),
286286 _convert_np_array_to_shared_array_object (entity_neg_table ),
287287 _convert_np_array_to_shared_array_object (exp_table ),
@@ -314,11 +314,8 @@ def args_generator(titles: List[str], iteration: int):
314314
315315 logger .info ("Terminating pool workers..." )
316316
317- syn0 = np .frombuffer (syn0_arr , dtype = np .float32 ).reshape ((vocab_size , dim_size ))
318- syn1 = np .frombuffer (syn1_arr , dtype = np .float32 ).reshape ((vocab_size , dim_size ))
319-
320- self .syn0 = syn0
321- self .syn1 = syn1
317+ self .syn0 = _convert_shared_array_object_to_np_array (syn0_obj )
318+ self .syn1 = _convert_shared_array_object_to_np_array (syn1_obj )
322319
323320 train_params = dict (
324321 dump_db = dump_db .uuid ,
0 commit comments