Skip to content

Commit 76ea823

Browse files
committed
fix a memory access violation on windows
1 parent cc9d701 commit 76ea823

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

wikipedia2vec/wikipedia2vec.pyx

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import re
1818
import time
1919
import six
2020
import six.moves.cPickle as pickle
21+
import sys
22+
import uuid
2123
cimport cython
2224
cimport numpy as np
2325
np.import_array()
@@ -301,12 +303,25 @@ cdef class Wikipedia2Vec:
301303
dim_size = params.dim_size
302304
vocab_size = len(self.dictionary)
303305

304-
syn0_mmap = mmap.mmap(-1, dim_size * vocab_size * ctypes.sizeof(c_float))
305-
syn1_mmap = mmap.mmap(-1, dim_size * vocab_size * ctypes.sizeof(c_float))
306-
self.syn0 = np.frombuffer(syn0_mmap, dtype=np.float32).reshape(vocab_size, dim_size)
307-
self.syn1 = np.frombuffer(syn1_mmap, dtype=np.float32).reshape(vocab_size, dim_size)
308-
syn0_carr = self.syn0
309-
syn1_carr = self.syn1
306+
if sys.platform == 'win32':
307+
syn0_addr = uuid.uuid1().hex
308+
syn0_mmap = mmap.mmap(-1, dim_size * vocab_size * ctypes.sizeof(c_float),
309+
tagname=syn0_addr)
310+
syn1_addr = uuid.uuid1().hex
311+
syn1_mmap = mmap.mmap(-1, dim_size * vocab_size * ctypes.sizeof(c_float),
312+
tagname=syn1_addr)
313+
self.syn0 = np.frombuffer(syn0_mmap, dtype=np.float32).reshape(vocab_size, dim_size)
314+
self.syn1 = np.frombuffer(syn1_mmap, dtype=np.float32).reshape(vocab_size, dim_size)
315+
316+
else:
317+
syn0_mmap = mmap.mmap(-1, dim_size * vocab_size * ctypes.sizeof(c_float))
318+
syn1_mmap = mmap.mmap(-1, dim_size * vocab_size * ctypes.sizeof(c_float))
319+
self.syn0 = np.frombuffer(syn0_mmap, dtype=np.float32).reshape(vocab_size, dim_size)
320+
self.syn1 = np.frombuffer(syn1_mmap, dtype=np.float32).reshape(vocab_size, dim_size)
321+
syn0_carr = self.syn0
322+
syn1_carr = self.syn1
323+
syn0_addr = <uintptr_t>&syn0_carr[0, 0]
324+
syn1_addr = <uintptr_t>&syn1_carr[0, 0]
310325

311326
self.syn0[:] = (np.random.rand(vocab_size, dim_size) - 0.5) / dim_size
312327
self.syn1[:] = np.zeros((vocab_size, dim_size))
@@ -318,8 +333,8 @@ cdef class Wikipedia2Vec:
318333
mention_db_obj,
319334
tokenizer,
320335
sentence_detector,
321-
<uintptr_t>&syn0_carr[0, 0],
322-
<uintptr_t>&syn1_carr[0, 0],
336+
syn0_addr,
337+
syn1_addr,
323338
word_neg_table,
324339
entity_neg_table,
325340
exp_table,
@@ -430,11 +445,13 @@ cdef float32_t [:] work
430445

431446

432447
def init_worker(dump_db_, dictionary_obj, link_graph_obj, mention_db_obj, tokenizer_,
433-
sentence_detector_, uintptr_t syn0_ptr, uintptr_t syn1_ptr, word_neg_table_,
434-
entity_neg_table_, exp_table_, sample_ints_, link_indices_, params_):
448+
sentence_detector_, syn0_addr, syn1_addr, word_neg_table_, entity_neg_table_,
449+
exp_table_, sample_ints_, link_indices_, params_):
435450
global dump_db, dictionary, link_graph, mention_db, tokenizer, sentence_detector, syn0, syn1,\
436451
word_neg_table, entity_neg_table, exp_table, sample_ints, link_indices, params, work
437452

453+
cdef uintptr_t syn0_ptr, syn1_ptr
454+
438455
dump_db = dump_db_
439456
tokenizer = tokenizer_
440457
sentence_detector = sentence_detector_
@@ -461,10 +478,19 @@ def init_worker(dump_db_, dictionary_obj, link_graph_obj, mention_db_obj, tokeni
461478
mention_db = None
462479

463480
vocab_size = len(dictionary)
464-
syn0 = np.PyArray_SimpleNewFromData(2, [vocab_size, params.dim_size], np.NPY_FLOAT32,
465-
<float32_t *>syn0_ptr)
466-
syn1 = np.PyArray_SimpleNewFromData(2, [vocab_size, params.dim_size], np.NPY_FLOAT32,
467-
<float32_t *>syn1_ptr)
481+
if sys.platform == 'win32':
482+
syn0_mmap = mmap.mmap(-1, params.dim_size * vocab_size * ctypes.sizeof(c_float), tagname=syn0_addr)
483+
syn1_mmap = mmap.mmap(-1, params.dim_size * vocab_size * ctypes.sizeof(c_float), tagname=syn1_addr)
484+
syn0 = np.frombuffer(syn0_mmap, dtype=np.float32).reshape(-1, params.dim_size)
485+
syn1 = np.frombuffer(syn1_mmap, dtype=np.float32).reshape(-1, params.dim_size)
486+
487+
else:
488+
syn0_ptr = syn0_addr
489+
syn1_ptr = syn1_addr
490+
syn0 = np.PyArray_SimpleNewFromData(2, [vocab_size, params.dim_size], np.NPY_FLOAT32,
491+
<float32_t *>syn0_ptr)
492+
syn1 = np.PyArray_SimpleNewFromData(2, [vocab_size, params.dim_size], np.NPY_FLOAT32,
493+
<float32_t *>syn1_ptr)
468494
work = np.zeros(params.dim_size, dtype=np.float32)
469495

470496

0 commit comments

Comments
 (0)