Skip to content

Commit 886925d

Browse files
authored
Rename feature transformer and kernels (official-stockfish#376)
1 parent 299770c commit 886925d

File tree

9 files changed

+119
-116
lines changed

9 files changed

+119
-116
lines changed

model/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from .config import ModelConfig
55
from .features import FeatureSet
6-
from .modules import DoubleFeatureTransformerSlice, LayerStacks
6+
from .modules import DoubleFeatureTransformer, LayerStacks
77
from .quantize import QuantizationConfig, QuantizationManager
88

99

@@ -25,7 +25,7 @@ def __init__(
2525
self.num_psqt_buckets = num_psqt_buckets
2626
self.num_ls_buckets = num_ls_buckets
2727

28-
self.input = DoubleFeatureTransformerSlice(
28+
self.input = DoubleFeatureTransformer(
2929
feature_set.num_features, self.L1 + self.num_psqt_buckets
3030
)
3131
self.feature_set = feature_set

model/modules/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from .feature_transformer import (
2-
BaseFeatureTransformerSlice,
3-
DoubleFeatureTransformerSlice,
4-
FeatureTransformerSlice,
2+
BaseFeatureTransformer,
3+
DoubleFeatureTransformer,
4+
FeatureTransformer,
55
)
66
from .layer_stacks import LayerStacks
77

88
__all__ = [
9-
"BaseFeatureTransformerSlice",
10-
"DoubleFeatureTransformerSlice",
11-
"FeatureTransformerSlice",
9+
"BaseFeatureTransformer",
10+
"DoubleFeatureTransformer",
11+
"FeatureTransformer",
1212
"LayerStacks",
1313
]
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from .module import (
2-
BaseFeatureTransformerSlice,
3-
DoubleFeatureTransformerSlice,
4-
FeatureTransformerSlice,
2+
BaseFeatureTransformer,
3+
DoubleFeatureTransformer,
4+
FeatureTransformer,
55
)
66

77
__all__ = [
8-
"BaseFeatureTransformerSlice",
9-
"DoubleFeatureTransformerSlice",
10-
"FeatureTransformerSlice",
8+
"BaseFeatureTransformer",
9+
"DoubleFeatureTransformer",
10+
"FeatureTransformer",
1111
]

model/modules/feature_transformer/functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from torch import autograd
33

44
from .kernel import (
5-
make_feature_transformer_slice_forward_kernel,
6-
make_feature_transformer_slice_backward_kernel,
5+
make_sparse_input_linear_forward_kernel,
6+
make_sparse_input_linear_backward_kernel,
77
)
88

99

10-
class FeatureTransformerSliceFunction(autograd.Function):
10+
class SparseLinearFunction(autograd.Function):
1111
@staticmethod
1212
def forward(ctx, feature_indices, feature_values, weight, bias):
1313
ctx.save_for_backward(feature_indices, feature_values, weight, bias)
@@ -52,7 +52,7 @@ def forward(ctx, feature_indices, feature_values, weight, bias):
5252
requires_grad=True,
5353
)
5454

55-
kernel = make_feature_transformer_slice_forward_kernel(
55+
kernel = make_sparse_input_linear_forward_kernel(
5656
max_active_features, output_size
5757
)
5858
kernel(
@@ -87,7 +87,7 @@ def backward(ctx, grad_output):
8787
)
8888
bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device)
8989

90-
kernel = make_feature_transformer_slice_backward_kernel(
90+
kernel = make_sparse_input_linear_backward_kernel(
9191
max_active_features, output_size
9292
)
9393
kernel(

model/modules/feature_transformer/kernel.py

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44

5-
def _find_nearest_divisor(value, target):
5+
def _find_nearest_divisor(value: int, target: int) -> int:
66
divisors = []
77
for i in range(1, value + 1):
88
if value % i == 0:
@@ -11,10 +11,10 @@ def _find_nearest_divisor(value, target):
1111
return divisors[0][0]
1212

1313

14-
_num_threads_forward_cache = dict()
14+
_num_threads_forward_cache: dict[int, int] = dict()
1515

1616

17-
def _get_num_threads_for_forward(output_size):
17+
def _get_num_threads_for_forward(output_size: int) -> int:
1818
optimal_num_threads = 512
1919
if output_size not in _num_threads_forward_cache:
2020
_num_threads_forward_cache[output_size] = _find_nearest_divisor(
@@ -24,10 +24,10 @@ def _get_num_threads_for_forward(output_size):
2424
return _num_threads_forward_cache[output_size]
2525

2626

27-
_num_threads_backward_cache = dict()
27+
_num_threads_backward_cache: dict[int, int] = dict()
2828

2929

30-
def _get_num_threads_for_backward(output_size):
30+
def _get_num_threads_for_backward(output_size: int) -> int:
3131
optimal_num_threads = 512
3232
if output_size not in _num_threads_backward_cache:
3333
_num_threads_backward_cache[output_size] = _find_nearest_divisor(
@@ -44,15 +44,15 @@ def f(grid, args):
4444
return f
4545

4646

47-
_feature_transformer_slice_forward_kernel_cache = dict()
47+
_sparse_input_linear_forward_kernel_cache = dict()
4848

4949

5050
@torch.compiler.disable(recursive=False)
51-
def make_feature_transformer_slice_forward_kernel(max_active_features, output_size):
51+
def make_sparse_input_linear_forward_kernel(max_active_indices: int, output_size: int):
5252
"""
53-
@param: max_active_features
54-
The maximum number of features that are active
55-
(non-zero) for a single position. This value determines
53+
@param: max_active_indices
54+
The maximum number of indices that are non-zero
55+
for a single position. This value determines
5656
the shape of the inputs.
5757
This value is of type uint32_t.
5858
@@ -63,8 +63,8 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
6363
"""
6464
num_threads = _get_num_threads_for_forward(output_size)
6565
output_thread_slice_size = output_size // num_threads
66-
key = (max_active_features, output_size, num_threads)
67-
if key not in _feature_transformer_slice_forward_kernel_cache:
66+
key = (max_active_indices, output_size, num_threads)
67+
if key not in _sparse_input_linear_forward_kernel_cache:
6868
kernel = cp.RawKernel(
6969
r"""
7070
@@ -79,23 +79,23 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
7979
The threads must have dimensionality (N,), where
8080
N * output_thread_slice_size == output_size.
8181
82-
@param: feature_indices
83-
A matrix of shape (BATCH_SIZE, max_active_features)
84-
containing indices of active features for each position
85-
in a batch. Feature index of -1 means that the slot is empty
82+
@param: input_indices
83+
A matrix of shape (BATCH_SIZE, max_active_indices)
84+
containing indices of active indices for each position
85+
in a batch. Input index of -1 means that the slot is empty
8686
and the weights will not be accumulated for it. Moreover
8787
no further indices from this block will be considered.
8888
The indices form an implicit matrix of shape
8989
(BATCH_SIZE, NUM_INPUTS), where the first dimension index is
9090
inferred from the memory location (BATCH_SIZE), and the
91-
second dimension index is stored in the feature_indices matrix.
92-
The type for feature indices is int32_t.
91+
second dimension index is stored in the input_indices matrix.
92+
The type for input indices is int32_t.
9393
94-
@param: feature_values
95-
A matrix of shape (BATCH_SIZE, max_active_features)
94+
@param: input_values
95+
A matrix of shape (BATCH_SIZE, max_active_indices)
9696
containing the values (arity) of the corresponding
97-
feature index in feature_indices.
98-
The type for the feature value (arity) is float32.
97+
input index in input_indices.
98+
The type for the input value (arity) is float32.
9999
100100
@param: weight
101101
The weight matrix of shape (NUM_INPUTS, output_size).
@@ -111,9 +111,9 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
111111
to the output first.
112112
Output values must have type float32.
113113
*/
114-
void feature_transformer_slice_forward(
115-
const int32_t* const feature_indices,
116-
const float* const feature_values,
114+
void sparse_input_linear_forward(
115+
const int32_t* const input_indices,
116+
const float* const input_values,
117117
const float* const weight,
118118
const float* const bias,
119119
float* const output
@@ -128,26 +128,26 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
128128
const float* const bias_slice = bias + slice_offset;
129129
float* shared_output_slice = shared_output + slice_offset;
130130
131-
const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features};
132-
const float* const feature_value_row = feature_values + block_idx * {max_active_features};
131+
const int32_t* const input_index_row = input_indices + block_idx * {max_active_indices};
132+
const float* const input_value_row = input_values + block_idx * {max_active_indices};
133133
134134
#pragma unroll
135135
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
136136
{{
137137
shared_output_slice[s] = bias_slice[s];
138138
}}
139139
140-
for (uint32_t k = 0; k < {max_active_features}; ++k)
140+
for (uint32_t k = 0; k < {max_active_indices}; ++k)
141141
{{
142-
const int32_t feature_index = feature_index_row[k];
143-
const float feature_value = feature_value_row[k];
144-
if (feature_index != -1)
142+
const int32_t input_index = input_index_row[k];
143+
const float input_value = input_value_row[k];
144+
if (input_index != -1)
145145
{{
146-
const float* const weight_slice = weight + feature_index * {output_size} + slice_offset;
146+
const float* const weight_slice = weight + input_index * {output_size} + slice_offset;
147147
#pragma unroll
148148
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
149149
{{
150-
shared_output_slice[s] += weight_slice[s] * feature_value;
150+
shared_output_slice[s] += weight_slice[s] * input_value;
151151
}}
152152
}} else break;
153153
}}
@@ -160,29 +160,29 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
160160
}}
161161
162162
""".format(
163-
max_active_features=max_active_features,
163+
max_active_indices=max_active_indices,
164164
output_thread_slice_size=output_thread_slice_size,
165165
output_size=output_size,
166166
),
167-
"feature_transformer_slice_forward",
167+
"sparse_input_linear_forward",
168168
)
169169
kernel.compile()
170-
_feature_transformer_slice_forward_kernel_cache[key] = _kernel_with_threads(
170+
_sparse_input_linear_forward_kernel_cache[key] = _kernel_with_threads(
171171
kernel, (num_threads,)
172172
)
173-
return _feature_transformer_slice_forward_kernel_cache[key]
173+
return _sparse_input_linear_forward_kernel_cache[key]
174174

175175

176-
_feature_transformer_slice_backward_kernel_cache = dict()
176+
_sparse_input_linear_backward_kernel_cache = dict()
177177

178178

179179
@torch.compiler.disable(recursive=False)
180-
def make_feature_transformer_slice_backward_kernel(max_active_features, output_size):
180+
def make_sparse_input_linear_backward_kernel(max_active_indices: int, output_size: int):
181181
"""
182-
@param: max_active_features
183-
The maximum number of features that are active
184-
(non-zero) for a single position. This value determines
185-
the shape of the inputs.
182+
@param: max_active_indices
183+
The maximum number of indices that are non-zero for
184+
a single position. This value determines the shape
185+
of the inputs.
186186
This value is of type uint32_t.
187187
188188
@param: output_size
@@ -192,8 +192,8 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
192192
"""
193193
num_threads = _get_num_threads_for_backward(output_size)
194194
output_thread_slice_size = output_size // num_threads
195-
key = (max_active_features, output_size, num_threads)
196-
if key not in _feature_transformer_slice_backward_kernel_cache:
195+
key = (max_active_indices, output_size, num_threads)
196+
if key not in _sparse_input_linear_backward_kernel_cache:
197197
kernel = cp.RawKernel(
198198
r"""
199199
@@ -207,23 +207,23 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
207207
The threads must have dimensionality (N,), where
208208
N * output_thread_slice_size == output_size.
209209
210-
@param: feature_indices
211-
A matrix of shape (BATCH_SIZE, max_active_features)
212-
containing indices of active features for each position
213-
in a batch. Feature index of -1 means that the slot is empty
210+
@param: input_indices
211+
A matrix of shape (BATCH_SIZE, max_active_indices)
212+
containing indices of active indices for each position
213+
in a batch. Input index of -1 means that the slot is empty
214214
and the weights will not be accumulated for it. Moreover
215215
no further indices from this block will be considered.
216216
The indices form an implicit matrix of shape
217217
(BATCH_SIZE, NUM_INPUTS), where the first dimension index is
218218
inferred from the memory location (BATCH_SIZE), and the
219-
second dimension index is stored in the feature_indices matrix.
220-
The type for feature indices is int32_t.
219+
second dimension index is stored in the input_indices matrix.
220+
The type for input indices is int32_t.
221221
222-
@param: feature_values
223-
A matrix of shape (BATCH_SIZE, max_active_features)
222+
@param: input_values
223+
A matrix of shape (BATCH_SIZE, max_active_indices)
224224
containing the values (arity) of the corresponding
225-
feature index in feature_indices.
226-
The type for the feature value (arity) is float32.
225+
input index in input_indices.
226+
The type for the input value (arity) is float32.
227227
228228
@param: weight_grad
229229
The weight gradient matrix of shape (NUM_INPUTS, output_size).
@@ -241,9 +241,9 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
241241
An output gradient matrix of shape (BATCH_SIZE, output_size).
242242
Output values must have type float32.
243243
*/
244-
void feature_transformer_slice_backward(
245-
const int32_t* const feature_indices,
246-
const float* const feature_values,
244+
void sparse_input_linear_backward(
245+
const int32_t* const input_indices,
246+
const float* const input_values,
247247
float* const weight_grad,
248248
float* const bias_grad,
249249
const float* const output_grad
@@ -258,8 +258,8 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
258258
float* const bias_grad_slice = bias_grad + slice_offset;
259259
float* shared_output_grad_slice = shared_output_grad + slice_offset;
260260
261-
const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features};
262-
const float* const feature_value_row = feature_values + block_idx * {max_active_features};
261+
const int32_t* const input_index_row = input_indices + block_idx * {max_active_indices};
262+
const float* const input_value_row = input_values + block_idx * {max_active_indices};
263263
264264
#pragma unroll
265265
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
@@ -277,35 +277,35 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
277277
}}
278278
}}
279279
280-
for (uint32_t k = 0; k < {max_active_features}; ++k)
280+
for (uint32_t k = 0; k < {max_active_indices}; ++k)
281281
{{
282-
const int32_t feature_index = feature_index_row[k];
283-
const float feature_value = feature_value_row[k];
284-
if (feature_index != -1)
282+
const int32_t input_index = input_index_row[k];
283+
const float input_value = input_value_row[k];
284+
if (input_index != -1)
285285
{{
286-
float* const weight_grad_slice = weight_grad + feature_index * {output_size} + slice_offset;
286+
float* const weight_grad_slice = weight_grad + input_index * {output_size} + slice_offset;
287287
#pragma unroll
288288
for (int s = 0; s < {output_thread_slice_size}; ++s)
289289
{{
290290
const float sog = shared_output_grad_slice[s];
291291
if (sog != 0.0f)
292292
{{
293-
atomicAdd(&weight_grad_slice[s], sog * feature_value);
293+
atomicAdd(&weight_grad_slice[s], sog * input_value);
294294
}}
295295
}}
296296
}} else break;
297297
}}
298298
}}
299299
300300
""".format(
301-
max_active_features=max_active_features,
301+
max_active_indices=max_active_indices,
302302
output_thread_slice_size=output_thread_slice_size,
303303
output_size=output_size,
304304
),
305-
"feature_transformer_slice_backward",
305+
"sparse_input_linear_backward",
306306
)
307307
kernel.compile()
308-
_feature_transformer_slice_backward_kernel_cache[key] = _kernel_with_threads(
308+
_sparse_input_linear_backward_kernel_cache[key] = _kernel_with_threads(
309309
kernel, (num_threads,)
310310
)
311-
return _feature_transformer_slice_backward_kernel_cache[key]
311+
return _sparse_input_linear_backward_kernel_cache[key]

0 commit comments

Comments
 (0)