Skip to content

Commit 7080875

Browse files
committed
Switch from ctypes to cffi
1 parent e001f44 commit 7080875

File tree

1 file changed

+101
-55
lines changed

1 file changed

+101
-55
lines changed

trainer/batchloader.py

Lines changed: 101 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,8 @@
11
from __future__ import annotations
22
from dataclasses import dataclass
3-
import ctypes
43
import numpy as np
54
import torch
6-
7-
result_types = {
8-
"batch_new": ctypes.c_void_p,
9-
"batch_drop": None,
10-
"batch_get_len": ctypes.c_uint32,
11-
"get_row_features": ctypes.POINTER(ctypes.c_int16),
12-
"get_stm_features": ctypes.POINTER(ctypes.c_int16),
13-
"get_nstm_features": ctypes.POINTER(ctypes.c_int16),
14-
"batch_get_total_features": ctypes.c_uint32,
15-
"get_targets": ctypes.POINTER(ctypes.c_float),
16-
"file_reader_new": ctypes.c_void_p,
17-
"close_file": None,
18-
"try_to_load_batch": ctypes.c_bool,
19-
}
20-
5+
from cffi import FFI
216

227
@dataclass
238
class SparseBatch:
@@ -26,64 +11,125 @@ class SparseBatch:
2611
target: torch.Tensor
2712
size: int
2813

14+
class LibWrapper:
15+
def __init__(self, lib_path: str) -> None:
16+
self.ffi = FFI()
17+
self.lib = self.ffi.dlopen(lib_path)
18+
self._define_functions()
19+
20+
def _define_functions(self) -> None:
21+
self.ffi.cdef("""
22+
void* batch_new(uint32_t batch_size, float scale, float wdl);
23+
void batch_drop(void* batch);
24+
uint32_t batch_get_len(void* batch);
25+
int16_t* get_row_features(void* batch);
26+
int16_t* get_stm_features(void* batch);
27+
int16_t* get_nstm_features(void* batch);
28+
uint32_t batch_get_total_features(void* batch);
29+
float* get_targets(void* batch);
30+
void* file_reader_new(const char* file_path);
31+
void close_file(void* reader);
32+
bool try_to_load_batch(void* reader, void* batch);
33+
""")
34+
35+
def batch_new(self, batch_size: int, scale: float, wdl: float) -> ffi.CData:
36+
return self.lib.batch_new(batch_size, scale, wdl)
37+
38+
def batch_drop(self, batch: ffi.CData) -> None:
39+
self.lib.batch_drop(batch)
40+
41+
def batch_get_len(self, batch: ffi.CData) -> int:
42+
return self.lib.batch_get_len(batch)
43+
44+
def get_row_features(self, batch: ffi.CData) -> ffi.CData:
45+
return self.lib.get_row_features(batch)
46+
47+
def get_stm_features(self, batch: ffi.CData) -> ffi.CData:
48+
return self.lib.get_stm_features(batch)
49+
50+
def get_nstm_features(self, batch: ffi.CData) -> ffi.CData:
51+
return self.lib.get_nstm_features(batch)
52+
53+
def batch_get_total_features(self, batch: ffi.CData) -> int:
54+
return self.lib.batch_get_total_features(batch)
55+
56+
def get_targets(self, batch: ffi.CData) -> ffi.CData:
57+
return self.lib.get_targets(batch)
58+
59+
def file_reader_new(self, file_path: bytes) -> ffi.CData:
60+
file_path_buffer = self.ffi.new("char[]", file_path)
61+
return self.lib.file_reader_new(file_path_buffer)
62+
63+
def close_file(self, reader: ffi.CData) -> None:
64+
self.lib.close_file(reader)
65+
66+
def try_to_load_batch(self, reader: ffi.CData, batch: ffi.CData) -> bool:
67+
return self.lib.try_to_load_batch(reader, batch)
68+
2969

3070
class BatchLoader:
3171
def __init__(self, lib_path: str, files: list[bytes], batch_size: int, scale: float, wdl: float) -> None:
32-
self.parse_lib = None
33-
if not files: raise ValueError("The files list cannot be empty.")
34-
try: self.parse_lib = ctypes.CDLL(lib_path)
35-
except OSError as e: raise Exception(f"Failed to load the library: {e}")
36-
self.load_parse_lib()
72+
if not files:
73+
raise ValueError("The files list cannot be empty.")
3774

38-
self.files, self.file_index = files, 0
39-
self.batch = ctypes.c_void_p(self.parse_lib.batch_new(ctypes.c_uint32(batch_size), ctypes.c_float(scale), ctypes.c_float(wdl)))
40-
if self.batch.value is None: raise Exception("Failed to create batch")
41-
42-
self.current_reader = ctypes.c_void_p(self.parse_lib.file_reader_new(ctypes.create_string_buffer(files[0])))
43-
if self.current_reader.value is None: raise Exception("Failed to create file reader")
75+
self.lib_wrapper = LibWrapper(lib_path)
76+
self.files = files
77+
self.file_index = 0
78+
self.batch = self.lib_wrapper.batch_new(batch_size, scale, wdl)
79+
self.current_reader = self.lib_wrapper.file_reader_new(files[0])
4480

4581
def next_batch(self, device: torch.device) -> tuple[bool, SparseBatch]:
4682
new_epoch = False
47-
while not self.parse_lib.try_to_load_batch(self.current_reader, self.batch):
48-
self.parse_lib.close_file(self.current_reader)
49-
self.file_index = (self.file_index + 1) % len(self.files)
50-
file_path_buffer = ctypes.create_string_buffer(self.files[self.file_index])
51-
self.current_reader = ctypes.c_void_p(self.parse_lib.file_reader_new(file_path_buffer))
83+
while not self.lib_wrapper.try_to_load_batch(self.current_reader, self.batch):
84+
self._load_next_file()
5285
new_epoch = self.file_index == 0
5386
return new_epoch, self.to_pytorch_batch(device)
5487

5588
def to_pytorch_batch(self, device: torch.device) -> SparseBatch:
56-
def to_pytorch(array: np.ndarray) -> torch.Tensor:
57-
return torch.from_numpy(array).to(device, non_blocking=True)
58-
59-
total_features = self.parse_lib.batch_get_total_features(self.batch)
89+
# Retrieve batch information
90+
total_features = self.lib_wrapper.batch_get_total_features(self.batch)
91+
batch_len = self.lib_wrapper.batch_get_len(self.batch)
6092

61-
rows_buffer = self.parse_lib.get_row_features(self.batch)
62-
rows = to_pytorch(np.ctypeslib.as_array(rows_buffer, shape=(total_features,)))
93+
# Get feature buffers
94+
rows_buffer = self.lib_wrapper.get_row_features(self.batch)
95+
stm_cols_buffer = self.lib_wrapper.get_stm_features(self.batch)
96+
nstm_cols_buffer = self.lib_wrapper.get_nstm_features(self.batch)
6397

64-
stm_cols_buffer = self.parse_lib.get_stm_features(self.batch)
65-
stm_cols = to_pytorch(np.ctypeslib.as_array(stm_cols_buffer, shape=(total_features,)))
66-
67-
nstm_cols_buffer = self.parse_lib.get_nstm_features(self.batch)
68-
nstm_cols = to_pytorch(np.ctypeslib.as_array(nstm_cols_buffer, shape=(total_features,)))
98+
# Convert buffers to PyTorch tensors
99+
rows = self._get_buffer_data(rows_buffer, total_features, np.int16, device)
100+
stm_cols = self._get_buffer_data(stm_cols_buffer, total_features, np.int16, device)
101+
nstm_cols = self._get_buffer_data(nstm_cols_buffer, total_features, np.int16, device)
69102

103+
# Create sparse tensors for STM and NSTM features
70104
values = torch.ones(total_features, device=device, dtype=torch.float32)
105+
stm_indices = torch.stack([rows, stm_cols], dim=0)
106+
nstm_indices = torch.stack([rows, nstm_cols], dim=0)
71107

72-
batch_len = self.parse_lib.batch_get_len(self.batch)
73-
stm_sparse = torch.sparse_coo_tensor(torch.stack([rows, stm_cols], dim=0), values, (batch_len, 768))
74-
nstm_sparse = torch.sparse_coo_tensor(torch.stack([rows, nstm_cols], dim=0), values, (batch_len, 768))
108+
stm_sparse = torch.sparse_coo_tensor(stm_indices, values, (batch_len, 768))
109+
nstm_sparse = torch.sparse_coo_tensor(nstm_indices, values, (batch_len, 768))
75110

76-
target = to_pytorch(np.ctypeslib.as_array(self.parse_lib.get_targets(self.batch), shape=(batch_len, 1)))
111+
# Get and process target buffer
112+
targets_buffer = self.lib_wrapper.get_targets(self.batch)
113+
target = self._get_buffer_data(targets_buffer, batch_len, np.float32, device, reshape=True)
77114

78115
return SparseBatch(stm_sparse, nstm_sparse, target, batch_len)
79116

80-
def load_parse_lib(self):
81-
for func_name, restype in result_types.items():
82-
func = getattr(self.parse_lib, func_name, None)
83-
if func:
84-
setattr(self.parse_lib, func_name, func)
85-
func.restype = restype
117+
def _get_buffer_data(self, buffer: ffi.CData, length: int, dtype: np.dtype, device: torch.device, reshape: bool = False) -> torch.Tensor:
118+
element_size = np.dtype(dtype).itemsize
119+
expected_size = length * element_size
120+
data_buffer = self.lib_wrapper.ffi.buffer(buffer, expected_size)
121+
data = np.frombuffer(data_buffer, dtype=dtype)
122+
if reshape:
123+
data = data.reshape((length, 1))
124+
return torch.from_numpy(data).to(device, non_blocking=True)
125+
126+
def _load_next_file(self) -> None:
127+
self.lib_wrapper.close_file(self.current_reader)
128+
self.file_index = (self.file_index + 1) % len(self.files)
129+
self.current_reader = self.lib_wrapper.file_reader_new(self.files[self.file_index])
86130

87131
def __del__(self) -> None:
88-
self.parse_lib.close_file(self.current_reader)
89-
self.parse_lib.batch_drop(self.batch)
132+
if hasattr(self, 'current_reader') and self.current_reader:
133+
self.lib_wrapper.close_file(self.current_reader)
134+
if hasattr(self, 'batch') and self.batch:
135+
self.lib_wrapper.batch_drop(self.batch)

0 commit comments

Comments
 (0)