11from __future__ import annotations
22from dataclasses import dataclass
3- import ctypes
43import numpy as np
54import torch
6-
7- result_types = {
8- "batch_new" : ctypes .c_void_p ,
9- "batch_drop" : None ,
10- "batch_get_len" : ctypes .c_uint32 ,
11- "get_row_features" : ctypes .POINTER (ctypes .c_int16 ),
12- "get_stm_features" : ctypes .POINTER (ctypes .c_int16 ),
13- "get_nstm_features" : ctypes .POINTER (ctypes .c_int16 ),
14- "batch_get_total_features" : ctypes .c_uint32 ,
15- "get_targets" : ctypes .POINTER (ctypes .c_float ),
16- "file_reader_new" : ctypes .c_void_p ,
17- "close_file" : None ,
18- "try_to_load_batch" : ctypes .c_bool ,
19- }
20-
5+ from cffi import FFI
216
227@dataclass
238class SparseBatch :
@@ -26,64 +11,125 @@ class SparseBatch:
2611 target : torch .Tensor
2712 size : int
2813
14+ class LibWrapper :
15+ def __init__ (self , lib_path : str ) -> None :
16+ self .ffi = FFI ()
17+ self .lib = self .ffi .dlopen (lib_path )
18+ self ._define_functions ()
19+
20+ def _define_functions (self ) -> None :
21+ self .ffi .cdef ("""
22+ void* batch_new(uint32_t batch_size, float scale, float wdl);
23+ void batch_drop(void* batch);
24+ uint32_t batch_get_len(void* batch);
25+ int16_t* get_row_features(void* batch);
26+ int16_t* get_stm_features(void* batch);
27+ int16_t* get_nstm_features(void* batch);
28+ uint32_t batch_get_total_features(void* batch);
29+ float* get_targets(void* batch);
30+ void* file_reader_new(const char* file_path);
31+ void close_file(void* reader);
32+ bool try_to_load_batch(void* reader, void* batch);
33+ """ )
34+
35+ def batch_new (self , batch_size : int , scale : float , wdl : float ) -> ffi .CData :
36+ return self .lib .batch_new (batch_size , scale , wdl )
37+
38+ def batch_drop (self , batch : ffi .CData ) -> None :
39+ self .lib .batch_drop (batch )
40+
41+ def batch_get_len (self , batch : ffi .CData ) -> int :
42+ return self .lib .batch_get_len (batch )
43+
44+ def get_row_features (self , batch : ffi .CData ) -> ffi .CData :
45+ return self .lib .get_row_features (batch )
46+
47+ def get_stm_features (self , batch : ffi .CData ) -> ffi .CData :
48+ return self .lib .get_stm_features (batch )
49+
50+ def get_nstm_features (self , batch : ffi .CData ) -> ffi .CData :
51+ return self .lib .get_nstm_features (batch )
52+
53+ def batch_get_total_features (self , batch : ffi .CData ) -> int :
54+ return self .lib .batch_get_total_features (batch )
55+
56+ def get_targets (self , batch : ffi .CData ) -> ffi .CData :
57+ return self .lib .get_targets (batch )
58+
59+ def file_reader_new (self , file_path : bytes ) -> ffi .CData :
60+ file_path_buffer = self .ffi .new ("char[]" , file_path )
61+ return self .lib .file_reader_new (file_path_buffer )
62+
63+ def close_file (self , reader : ffi .CData ) -> None :
64+ self .lib .close_file (reader )
65+
66+ def try_to_load_batch (self , reader : ffi .CData , batch : ffi .CData ) -> bool :
67+ return self .lib .try_to_load_batch (reader , batch )
68+
2969
3070class BatchLoader :
3171 def __init__ (self , lib_path : str , files : list [bytes ], batch_size : int , scale : float , wdl : float ) -> None :
32- self .parse_lib = None
33- if not files : raise ValueError ("The files list cannot be empty." )
34- try : self .parse_lib = ctypes .CDLL (lib_path )
35- except OSError as e : raise Exception (f"Failed to load the library: { e } " )
36- self .load_parse_lib ()
72+ if not files :
73+ raise ValueError ("The files list cannot be empty." )
3774
38- self .files , self .file_index = files , 0
39- self .batch = ctypes .c_void_p (self .parse_lib .batch_new (ctypes .c_uint32 (batch_size ), ctypes .c_float (scale ), ctypes .c_float (wdl )))
40- if self .batch .value is None : raise Exception ("Failed to create batch" )
41-
42- self .current_reader = ctypes .c_void_p (self .parse_lib .file_reader_new (ctypes .create_string_buffer (files [0 ])))
43- if self .current_reader .value is None : raise Exception ("Failed to create file reader" )
75+ self .lib_wrapper = LibWrapper (lib_path )
76+ self .files = files
77+ self .file_index = 0
78+ self .batch = self .lib_wrapper .batch_new (batch_size , scale , wdl )
79+ self .current_reader = self .lib_wrapper .file_reader_new (files [0 ])
4480
4581 def next_batch (self , device : torch .device ) -> tuple [bool , SparseBatch ]:
4682 new_epoch = False
47- while not self .parse_lib .try_to_load_batch (self .current_reader , self .batch ):
48- self .parse_lib .close_file (self .current_reader )
49- self .file_index = (self .file_index + 1 ) % len (self .files )
50- file_path_buffer = ctypes .create_string_buffer (self .files [self .file_index ])
51- self .current_reader = ctypes .c_void_p (self .parse_lib .file_reader_new (file_path_buffer ))
83+ while not self .lib_wrapper .try_to_load_batch (self .current_reader , self .batch ):
84+ self ._load_next_file ()
5285 new_epoch = self .file_index == 0
5386 return new_epoch , self .to_pytorch_batch (device )
5487
5588 def to_pytorch_batch (self , device : torch .device ) -> SparseBatch :
56- def to_pytorch (array : np .ndarray ) -> torch .Tensor :
57- return torch .from_numpy (array ).to (device , non_blocking = True )
58-
59- total_features = self .parse_lib .batch_get_total_features (self .batch )
89+ # Retrieve batch information
90+ total_features = self .lib_wrapper .batch_get_total_features (self .batch )
91+ batch_len = self .lib_wrapper .batch_get_len (self .batch )
6092
61- rows_buffer = self .parse_lib .get_row_features (self .batch )
62- rows = to_pytorch (np .ctypeslib .as_array (rows_buffer , shape = (total_features ,)))
93+ # Get feature buffers
94+ rows_buffer = self .lib_wrapper .get_row_features (self .batch )
95+ stm_cols_buffer = self .lib_wrapper .get_stm_features (self .batch )
96+ nstm_cols_buffer = self .lib_wrapper .get_nstm_features (self .batch )
6397
64- stm_cols_buffer = self .parse_lib .get_stm_features (self .batch )
65- stm_cols = to_pytorch (np .ctypeslib .as_array (stm_cols_buffer , shape = (total_features ,)))
66-
67- nstm_cols_buffer = self .parse_lib .get_nstm_features (self .batch )
68- nstm_cols = to_pytorch (np .ctypeslib .as_array (nstm_cols_buffer , shape = (total_features ,)))
98+ # Convert buffers to PyTorch tensors
99+ rows = self ._get_buffer_data (rows_buffer , total_features , np .int16 , device )
100+ stm_cols = self ._get_buffer_data (stm_cols_buffer , total_features , np .int16 , device )
101+ nstm_cols = self ._get_buffer_data (nstm_cols_buffer , total_features , np .int16 , device )
69102
103+ # Create sparse tensors for STM and NSTM features
70104 values = torch .ones (total_features , device = device , dtype = torch .float32 )
105+ stm_indices = torch .stack ([rows , stm_cols ], dim = 0 )
106+ nstm_indices = torch .stack ([rows , nstm_cols ], dim = 0 )
71107
72- batch_len = self .parse_lib .batch_get_len (self .batch )
73- stm_sparse = torch .sparse_coo_tensor (torch .stack ([rows , stm_cols ], dim = 0 ), values , (batch_len , 768 ))
74- nstm_sparse = torch .sparse_coo_tensor (torch .stack ([rows , nstm_cols ], dim = 0 ), values , (batch_len , 768 ))
108+ stm_sparse = torch .sparse_coo_tensor (stm_indices , values , (batch_len , 768 ))
109+ nstm_sparse = torch .sparse_coo_tensor (nstm_indices , values , (batch_len , 768 ))
75110
76- target = to_pytorch (np .ctypeslib .as_array (self .parse_lib .get_targets (self .batch ), shape = (batch_len , 1 )))
111+ # Get and process target buffer
112+ targets_buffer = self .lib_wrapper .get_targets (self .batch )
113+ target = self ._get_buffer_data (targets_buffer , batch_len , np .float32 , device , reshape = True )
77114
78115 return SparseBatch (stm_sparse , nstm_sparse , target , batch_len )
79116
80- def load_parse_lib (self ):
81- for func_name , restype in result_types .items ():
82- func = getattr (self .parse_lib , func_name , None )
83- if func :
84- setattr (self .parse_lib , func_name , func )
85- func .restype = restype
117+ def _get_buffer_data (self , buffer : ffi .CData , length : int , dtype : np .dtype , device : torch .device , reshape : bool = False ) -> torch .Tensor :
118+ element_size = np .dtype (dtype ).itemsize
119+ expected_size = length * element_size
120+ data_buffer = self .lib_wrapper .ffi .buffer (buffer , expected_size )
121+ data = np .frombuffer (data_buffer , dtype = dtype )
122+ if reshape :
123+ data = data .reshape ((length , 1 ))
124+ return torch .from_numpy (data ).to (device , non_blocking = True )
125+
126+ def _load_next_file (self ) -> None :
127+ self .lib_wrapper .close_file (self .current_reader )
128+ self .file_index = (self .file_index + 1 ) % len (self .files )
129+ self .current_reader = self .lib_wrapper .file_reader_new (self .files [self .file_index ])
86130
87131 def __del__ (self ) -> None :
88- self .parse_lib .close_file (self .current_reader )
89- self .parse_lib .batch_drop (self .batch )
132+ if hasattr (self , 'current_reader' ) and self .current_reader :
133+ self .lib_wrapper .close_file (self .current_reader )
134+ if hasattr (self , 'batch' ) and self .batch :
135+ self .lib_wrapper .batch_drop (self .batch )
0 commit comments