22import torch
33
44
5- def _find_nearest_divisor (value , target ) :
5+ def _find_nearest_divisor (value : int , target : int ) -> int :
66 divisors = []
77 for i in range (1 , value + 1 ):
88 if value % i == 0 :
@@ -11,10 +11,10 @@ def _find_nearest_divisor(value, target):
1111 return divisors [0 ][0 ]
1212
1313
14- _num_threads_forward_cache = dict ()
14+ _num_threads_forward_cache : dict [ int , int ] = dict ()
1515
1616
17- def _get_num_threads_for_forward (output_size ) :
17+ def _get_num_threads_for_forward (output_size : int ) -> int :
1818 optimal_num_threads = 512
1919 if output_size not in _num_threads_forward_cache :
2020 _num_threads_forward_cache [output_size ] = _find_nearest_divisor (
@@ -24,10 +24,10 @@ def _get_num_threads_for_forward(output_size):
2424 return _num_threads_forward_cache [output_size ]
2525
2626
27- _num_threads_backward_cache = dict ()
27+ _num_threads_backward_cache : dict [ int , int ] = dict ()
2828
2929
30- def _get_num_threads_for_backward (output_size ) :
30+ def _get_num_threads_for_backward (output_size : int ) -> int :
3131 optimal_num_threads = 512
3232 if output_size not in _num_threads_backward_cache :
3333 _num_threads_backward_cache [output_size ] = _find_nearest_divisor (
@@ -44,15 +44,15 @@ def f(grid, args):
4444 return f
4545
4646
47- _feature_transformer_slice_forward_kernel_cache = dict ()
47+ _sparse_input_linear_forward_kernel_cache = dict ()
4848
4949
5050@torch .compiler .disable (recursive = False )
51- def make_feature_transformer_slice_forward_kernel ( max_active_features , output_size ):
51+ def make_sparse_input_linear_forward_kernel ( max_active_indices : int , output_size : int ):
5252 """
53- @param: max_active_features
54- The maximum number of features that are active
55- (non-zero) for a single position. This value determines
53+ @param: max_active_indices
54+ The maximum number of indices that are non-zero
55+ for a single position. This value determines
5656 the shape of the inputs.
5757 This value is of type uint32_t.
5858
@@ -63,8 +63,8 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
6363 """
6464 num_threads = _get_num_threads_for_forward (output_size )
6565 output_thread_slice_size = output_size // num_threads
66- key = (max_active_features , output_size , num_threads )
67- if key not in _feature_transformer_slice_forward_kernel_cache :
66+ key = (max_active_indices , output_size , num_threads )
67+ if key not in _sparse_input_linear_forward_kernel_cache :
6868 kernel = cp .RawKernel (
6969 r"""
7070
@@ -79,23 +79,23 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
7979 The threads must have dimensionality (N,), where
8080 N * output_thread_slice_size == output_size.
8181
82- @param: feature_indices
83- A matrix of shape (BATCH_SIZE, max_active_features )
84- containing indices of active features for each position
85- in a batch. Feature index of -1 means that the slot is empty
82+ @param: input_indices
83+ A matrix of shape (BATCH_SIZE, max_active_indices )
84+ containing indices of active indices for each position
85+ in a batch. Input index of -1 means that the slot is empty
8686 and the weights will not be accumulated for it. Moreover
8787 no further indices from this block will be considered.
8888 The indices form an implicit matrix of shape
8989 (BATCH_SIZE, NUM_INPUTS), where the first dimension index is
9090 inferred from the memory location (BATCH_SIZE), and the
91- second dimension index is stored in the feature_indices matrix.
92- The type for feature indices is int32_t.
91+ second dimension index is stored in the input_indices matrix.
92+ The type for input indices is int32_t.
9393
94- @param: feature_values
95- A matrix of shape (BATCH_SIZE, max_active_features )
94+ @param: input_values
95+ A matrix of shape (BATCH_SIZE, max_active_indices )
9696 containing the values (arity) of the corresponding
97- feature index in feature_indices .
98- The type for the feature value (arity) is float32.
97+ input index in input_indices .
98+ The type for the input value (arity) is float32.
9999
100100 @param: weight
101101 The weight matrix of shape (NUM_INPUTS, output_size).
@@ -111,9 +111,9 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
111111 to the output first.
112112 Output values must have type float32.
113113*/
114- void feature_transformer_slice_forward (
115- const int32_t* const feature_indices ,
116- const float* const feature_values ,
114+ void sparse_input_linear_forward (
115+ const int32_t* const input_indices ,
116+ const float* const input_values ,
117117 const float* const weight,
118118 const float* const bias,
119119 float* const output
@@ -128,26 +128,26 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
128128 const float* const bias_slice = bias + slice_offset;
129129 float* shared_output_slice = shared_output + slice_offset;
130130
131- const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features };
132- const float* const feature_value_row = feature_values + block_idx * {max_active_features };
131+ const int32_t* const input_index_row = input_indices + block_idx * {max_active_indices };
132+ const float* const input_value_row = input_values + block_idx * {max_active_indices };
133133
134134 #pragma unroll
135135 for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
136136 {{
137137 shared_output_slice[s] = bias_slice[s];
138138 }}
139139
140- for (uint32_t k = 0; k < {max_active_features }; ++k)
140+ for (uint32_t k = 0; k < {max_active_indices }; ++k)
141141 {{
142- const int32_t feature_index = feature_index_row [k];
143- const float feature_value = feature_value_row [k];
144- if (feature_index != -1)
142+ const int32_t input_index = input_index_row [k];
143+ const float input_value = input_value_row [k];
144+ if (input_index != -1)
145145 {{
146- const float* const weight_slice = weight + feature_index * {output_size} + slice_offset;
146+ const float* const weight_slice = weight + input_index * {output_size} + slice_offset;
147147 #pragma unroll
148148 for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
149149 {{
150- shared_output_slice[s] += weight_slice[s] * feature_value ;
150+ shared_output_slice[s] += weight_slice[s] * input_value ;
151151 }}
152152 }} else break;
153153 }}
@@ -160,29 +160,29 @@ def make_feature_transformer_slice_forward_kernel(max_active_features, output_si
160160}}
161161
162162""" .format (
163- max_active_features = max_active_features ,
163+ max_active_indices = max_active_indices ,
164164 output_thread_slice_size = output_thread_slice_size ,
165165 output_size = output_size ,
166166 ),
167- "feature_transformer_slice_forward " ,
167+ "sparse_input_linear_forward " ,
168168 )
169169 kernel .compile ()
170- _feature_transformer_slice_forward_kernel_cache [key ] = _kernel_with_threads (
170+ _sparse_input_linear_forward_kernel_cache [key ] = _kernel_with_threads (
171171 kernel , (num_threads ,)
172172 )
173- return _feature_transformer_slice_forward_kernel_cache [key ]
173+ return _sparse_input_linear_forward_kernel_cache [key ]
174174
175175
176- _feature_transformer_slice_backward_kernel_cache = dict ()
176+ _sparse_input_linear_backward_kernel_cache = dict ()
177177
178178
179179@torch .compiler .disable (recursive = False )
180- def make_feature_transformer_slice_backward_kernel ( max_active_features , output_size ):
180+ def make_sparse_input_linear_backward_kernel ( max_active_indices : int , output_size : int ):
181181 """
182- @param: max_active_features
183- The maximum number of features that are active
184- (non-zero) for a single position. This value determines
185- the shape of the inputs.
182+ @param: max_active_indices
183+ The maximum number of indices that are non-zero for
184+ a single position. This value determines the shape
185+ of the inputs.
186186 This value is of type uint32_t.
187187
188188 @param: output_size
@@ -192,8 +192,8 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
192192 """
193193 num_threads = _get_num_threads_for_backward (output_size )
194194 output_thread_slice_size = output_size // num_threads
195- key = (max_active_features , output_size , num_threads )
196- if key not in _feature_transformer_slice_backward_kernel_cache :
195+ key = (max_active_indices , output_size , num_threads )
196+ if key not in _sparse_input_linear_backward_kernel_cache :
197197 kernel = cp .RawKernel (
198198 r"""
199199
@@ -207,23 +207,23 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
207207 The threads must have dimensionality (N,), where
208208 N * output_thread_slice_size == output_size.
209209
210- @param: feature_indices
211- A matrix of shape (BATCH_SIZE, max_active_features )
212- containing indices of active features for each position
213- in a batch. Feature index of -1 means that the slot is empty
210+ @param: input_indices
211+ A matrix of shape (BATCH_SIZE, max_active_indices )
212+ containing indices of active indices for each position
213+ in a batch. Input index of -1 means that the slot is empty
214214 and the weights will not be accumulated for it. Moreover
215215 no further indices from this block will be considered.
216216 The indices form an implicit matrix of shape
217217 (BATCH_SIZE, NUM_INPUTS), where the first dimension index is
218218 inferred from the memory location (BATCH_SIZE), and the
219- second dimension index is stored in the feature_indices matrix.
220- The type for feature indices is int32_t.
219+ second dimension index is stored in the input_indices matrix.
220+ The type for input indices is int32_t.
221221
222- @param: feature_values
223- A matrix of shape (BATCH_SIZE, max_active_features )
222+ @param: input_values
223+ A matrix of shape (BATCH_SIZE, max_active_indices )
224224 containing the values (arity) of the corresponding
225- feature index in feature_indices .
226- The type for the feature value (arity) is float32.
225+ input index in input_indices .
226+ The type for the input value (arity) is float32.
227227
228228 @param: weight_grad
229229 The weight gradient matrix of shape (NUM_INPUTS, output_size).
@@ -241,9 +241,9 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
241241 An output gradient matrix of shape (BATCH_SIZE, output_size).
242242 Output values must have type float32.
243243*/
244- void feature_transformer_slice_backward (
245- const int32_t* const feature_indices ,
246- const float* const feature_values ,
244+ void sparse_input_linear_backward (
245+ const int32_t* const input_indices ,
246+ const float* const input_values ,
247247 float* const weight_grad,
248248 float* const bias_grad,
249249 const float* const output_grad
@@ -258,8 +258,8 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
258258 float* const bias_grad_slice = bias_grad + slice_offset;
259259 float* shared_output_grad_slice = shared_output_grad + slice_offset;
260260
261- const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features };
262- const float* const feature_value_row = feature_values + block_idx * {max_active_features };
261+ const int32_t* const input_index_row = input_indices + block_idx * {max_active_indices };
262+ const float* const input_value_row = input_values + block_idx * {max_active_indices };
263263
264264 #pragma unroll
265265 for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
@@ -277,35 +277,35 @@ def make_feature_transformer_slice_backward_kernel(max_active_features, output_s
277277 }}
278278 }}
279279
280- for (uint32_t k = 0; k < {max_active_features }; ++k)
280+ for (uint32_t k = 0; k < {max_active_indices }; ++k)
281281 {{
282- const int32_t feature_index = feature_index_row [k];
283- const float feature_value = feature_value_row [k];
284- if (feature_index != -1)
282+ const int32_t input_index = input_index_row [k];
283+ const float input_value = input_value_row [k];
284+ if (input_index != -1)
285285 {{
286- float* const weight_grad_slice = weight_grad + feature_index * {output_size} + slice_offset;
286+ float* const weight_grad_slice = weight_grad + input_index * {output_size} + slice_offset;
287287 #pragma unroll
288288 for (int s = 0; s < {output_thread_slice_size}; ++s)
289289 {{
290290 const float sog = shared_output_grad_slice[s];
291291 if (sog != 0.0f)
292292 {{
293- atomicAdd(&weight_grad_slice[s], sog * feature_value );
293+ atomicAdd(&weight_grad_slice[s], sog * input_value );
294294 }}
295295 }}
296296 }} else break;
297297 }}
298298}}
299299
300300""" .format (
301- max_active_features = max_active_features ,
301+ max_active_indices = max_active_indices ,
302302 output_thread_slice_size = output_thread_slice_size ,
303303 output_size = output_size ,
304304 ),
305- "feature_transformer_slice_backward " ,
305+ "sparse_input_linear_backward " ,
306306 )
307307 kernel .compile ()
308- _feature_transformer_slice_backward_kernel_cache [key ] = _kernel_with_threads (
308+ _sparse_input_linear_backward_kernel_cache [key ] = _kernel_with_threads (
309309 kernel , (num_threads ,)
310310 )
311- return _feature_transformer_slice_backward_kernel_cache [key ]
311+ return _sparse_input_linear_backward_kernel_cache [key ]
0 commit comments