Skip to content

Commit b3becd3

Browse files
authored
ftperm type hints (official-stockfish#358)
* ftperm type hints * add typing to indices
1 parent d4a5bb2 commit b3becd3

File tree

1 file changed

+78
-53
lines changed

1 file changed

+78
-53
lines changed

ftperm.py

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,17 @@
3131
3232
"""
3333

34-
import time
3534
import argparse
36-
import chess
37-
import torch
3835
import copy
36+
from dataclasses import dataclass
37+
import time
38+
from typing import Callable, Generator, TypeAlias
39+
40+
import chess
3941
import cupy as cp
4042
import numpy as np
43+
import numpy.typing as npt
44+
import torch
4145

4246
import data_loader
4347
import model as M
@@ -61,7 +65,7 @@
6165
VERBOSE = False
6266

6367

64-
def batched(arr, batch_size):
68+
def batched(arr: npt.NDArray, batch_size: int) -> Generator[npt.NDArray, None, None]:
6569
"""
6670
Utility generator that yields chunks of array `arr` of size `batch_size`
6771
Expects arr to be a numpy-like array
@@ -73,14 +77,14 @@ def batched(arr, batch_size):
7377
idx += batch_size
7478

7579

76-
def apply_swap(perm, i, j):
80+
def apply_swap(perm: npt.NDArray, i: int, j: int) -> None:
7781
"""
7882
Swap `i`-th and `j`-th elements in the array `perm`.
7983
"""
8084
perm[i], perm[j] = perm[j], perm[i]
8185

8286

83-
def apply_rotate_right(perm, indices):
87+
def apply_rotate_right(perm: npt.NDArray, indices: tuple[int, ...]) -> None:
8488
"""
8589
Rotates right the values in `perm` at selected indices `indices`.
8690
The rotation is performed as-if the selected indices were layed out in the order
@@ -92,7 +96,9 @@ def apply_rotate_right(perm, indices):
9296
perm[i] = j
9397

9498

95-
def get_swapped_zero_positive_count(actmat_flat, use_cupy=True):
99+
def get_swapped_zero_positive_count(
100+
actmat_flat: npt.NDArray[np.bool_], use_cupy: bool = True
101+
) -> int:
96102
if use_cupy:
97103
actmat_flat = cp.asarray(actmat_flat, dtype=cp.int8)
98104

@@ -148,7 +154,9 @@ def get_swapped_zero_positive_count(actmat_flat, use_cupy=True):
148154
return swapped_zero_count
149155

150156

151-
def get_swapped_zero_increase(actmat, use_cupy=True):
157+
def get_swapped_zero_increase(
158+
actmat: npt.NDArray[np.bool_], use_cupy: bool = True
159+
) -> npt.NDArray[np.int_]:
152160
n_neurons = actmat.shape[1]
153161
swapped_zero_count = 0
154162

@@ -179,7 +187,9 @@ def get_swapped_zero_increase(actmat, use_cupy=True):
179187
return swapped_zero_increase
180188

181189

182-
def get_score_change(actmat, use_cupy=True):
190+
def get_score_change(
191+
actmat: npt.NDArray[np.bool_], use_cupy: bool = True
192+
) -> npt.NDArray[np.int_]:
183193
# actmat is a boolean matrix of shape (N, L1) with "True" meaning 0
184194

185195
n_neurons = actmat.shape[1]
@@ -193,7 +203,16 @@ def get_score_change(actmat, use_cupy=True):
193203
return score_change
194204

195205

196-
def make_swaps_2(actmat, use_cupy=True):
206+
@dataclass
207+
class SwapResult:
208+
swaps: list[tuple[int, ...]]
209+
score_change: float
210+
211+
212+
SwapFucntion: TypeAlias = Callable[[npt.NDArray[np.bool_], bool], SwapResult]
213+
214+
215+
def make_swaps_2(actmat: npt.NDArray[np.bool_], use_cupy: bool = True) -> SwapResult:
197216
"""
198217
Returns a series of independent 2-swap operations that collectively improve the objective function.
199218
"""
@@ -210,7 +229,7 @@ def make_swaps_2(actmat, use_cupy=True):
210229
# Sum score_change[i, j] + score_change[j, i] to get the cumulative impact of the swap.
211230
score_change = score_change + score_change.T
212231

213-
def all_indices_in_same_block(i):
232+
def all_indices_in_same_block(i: int) -> list[int]:
214233
"""Returns a list of indices of all neurons in the same block as the i-th neuron."""
215234
# Floor to the start of the block.
216235
base = i // ZERO_BLOCK_SIZE * ZERO_BLOCK_SIZE
@@ -248,10 +267,10 @@ def all_indices_in_same_block(i):
248267
print(f"Time elapsed: {time.time() - start_time:0.3f}")
249268
print(f"Improvement this iteration: {total_improvement:0.3f}")
250269

251-
return swaps, total_improvement
270+
return SwapResult(swaps, total_improvement)
252271

253272

254-
def make_swaps_3(actmat, use_cupy=True):
273+
def make_swaps_3(actmat: npt.NDArray[np.bool_], use_cupy: bool = True) -> SwapResult:
255274
"""
256275
Returns a series of independent left-rotates operations that collectively improve the objective function.
257276
"""
@@ -340,10 +359,12 @@ def make_swaps_3(actmat, use_cupy=True):
340359
total_improvement = total_score_change / n_samples / (n_neurons // 4) * 100
341360
print(f"Time elapsed: {time.time() - start_time:0.3f}")
342361
print(f"Improvement this iteration: {total_improvement:0.3f}")
343-
return cycles, total_improvement
362+
return SwapResult(cycles, total_improvement)
344363

345364

346-
def find_perm_impl(actmat, use_cupy, L1: int):
365+
def find_perm_impl(
366+
actmat: npt.NDArray[np.bool_], use_cupy: bool, L1: int
367+
) -> npt.NDArray[np.int_]:
347368
actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1] // 2))
348369
if use_cupy:
349370
actmat = cp.asarray(actmat, dtype=cp.int8)
@@ -352,7 +373,7 @@ def find_perm_impl(actmat, use_cupy, L1: int):
352373
total_score_change = 0
353374
perm = np.arange(L1 // 2)
354375

355-
stages = [make_swaps_2, make_swaps_3]
376+
stages: list[SwapFucntion] = [make_swaps_2, make_swaps_3]
356377
# The optimization routines are deterministic, so no need to retry.
357378
stages_max_fails = [0, 0]
358379
stage_id = 0
@@ -370,15 +391,15 @@ def find_perm_impl(actmat, use_cupy, L1: int):
370391

371392
# Calculate a set of independent right rotates (so swaps for 2 element case)
372393
# that when applied improve the objective function
373-
swaps, score_change = swap_fn(actmat, use_cupy)
374-
for cycle in swaps:
394+
swap_result = swap_fn(actmat, use_cupy)
395+
for cycle in swap_result.swaps:
375396
# Update the current best permutation with the newly found adjustments.
376397
apply_rotate_right(perm, cycle)
377398

378-
total_score_change += score_change
399+
total_score_change += swap_result.score_change
379400
print(f"Total improvement: {total_score_change}\n")
380401

381-
if score_change == 0:
402+
if swap_result.score_change == 0:
382403
num_fails += 1
383404
if num_fails > stages_max_fails[stage_id]:
384405
num_fails = 0
@@ -399,17 +420,19 @@ def find_perm_impl(actmat, use_cupy, L1: int):
399420

400421

401422
def read_model(
402-
nnue_path,
423+
nnue_path: str,
403424
feature_set: FeatureSet,
404425
config: ModelConfig,
405426
quantize_config: QuantizationConfig,
406-
):
427+
) -> NNUEModel:
407428
with open(nnue_path, "rb") as f:
408429
reader = NNUEReader(f, feature_set, config, quantize_config)
409430
return reader.model
410431

411432

412-
def make_fen_batch_provider(data_path, batch_size):
433+
def make_fen_batch_provider(
434+
data_path: str, batch_size: int
435+
) -> data_loader.FenBatchProvider:
413436
return data_loader.FenBatchProvider(
414437
data_path,
415438
True,
@@ -421,7 +444,7 @@ def make_fen_batch_provider(data_path, batch_size):
421444
)
422445

423446

424-
def filter_fens(fens):
447+
def filter_fens(fens: list[str]) -> list[str]:
425448
# We don't want fens where a king is in check, as these cannot be evaluated by the engine.
426449
filtered_fens = []
427450
for fen in fens:
@@ -431,7 +454,7 @@ def filter_fens(fens):
431454
return filtered_fens
432455

433456

434-
def quantize_ft(model: NNUEModel):
457+
def quantize_ft(model: NNUEModel) -> None:
435458
model.input.weight.data = model.input.weight.data.mul(
436459
model.quantization.quantized_one
437460
).round()
@@ -441,16 +464,16 @@ def quantize_ft(model: NNUEModel):
441464

442465

443466
def forward_ft(
444-
model,
445-
us,
446-
them,
447-
white_indices,
448-
white_values,
449-
black_indices,
450-
black_values,
451-
psqt_indices,
452-
layer_stack_indices,
453-
):
467+
model: NNUEModel,
468+
us: torch.Tensor,
469+
them: torch.Tensor,
470+
white_indices: torch.Tensor,
471+
white_values: torch.Tensor,
472+
black_indices: torch.Tensor,
473+
black_values: torch.Tensor,
474+
psqt_indices: torch.Tensor,
475+
layer_stack_indices: torch.Tensor,
476+
) -> torch.Tensor:
454477
wp, bp = model.input(white_indices, white_values, black_indices, black_values)
455478
w, _ = torch.split(wp, model.L1, dim=1)
456479
b, _ = torch.split(bp, model.L1, dim=1)
@@ -466,7 +489,7 @@ def forward_ft(
466489
return l0_.round()
467490

468491

469-
def eval_ft(model, batch: data_loader.SparseBatchPtr):
492+
def eval_ft(model: NNUEModel, batch: data_loader.SparseBatchPtr) -> torch.Tensor:
470493
with torch.no_grad():
471494
(
472495
us,
@@ -494,8 +517,8 @@ def eval_ft(model, batch: data_loader.SparseBatchPtr):
494517
return res
495518

496519

497-
def ft_permute_impl(model, permutation):
498-
permutation = list(permutation)
520+
def ft_permute_impl(model: NNUEModel, perm: npt.NDArray[np.int_]) -> None:
521+
permutation = list(perm)
499522

500523
l1_size = model.layer_stacks.l1.in_features
501524
if l1_size != len(permutation) * 2:
@@ -517,14 +540,14 @@ def ft_permute_impl(model, permutation):
517540
]
518541

519542

520-
def ft_permute(model, ft_perm_path):
543+
def ft_permute(model: NNUEModel, ft_perm_path: str) -> None:
521544
with open(ft_perm_path, "rb") as f:
522545
permutation = np.load(f)
523546

524547
ft_permute_impl(model, permutation)
525548

526549

527-
def gather_impl(model: NNUEModel, dataset, count):
550+
def gather_impl(model: NNUEModel, dataset: str, count: int) -> npt.NDArray[np.bool_]:
528551
ZERO_POINT = 0.0 # Vary this to check hypothetical forced larger truncation to zero
529552
BATCH_SIZE = 1000
530553

@@ -559,7 +582,7 @@ def gather_impl(model: NNUEModel, dataset, count):
559582
return np.concatenate(actmats, axis=0)
560583

561584

562-
def command_gather(args):
585+
def command_gather(args: argparse.Namespace) -> None:
563586
feature_set = M.get_feature_set_from_name(args.features)
564587
if args.checkpoint:
565588
nnue = NNUE.load_from_checkpoint(
@@ -582,13 +605,15 @@ def command_gather(args):
582605
np.save(file, actmat)
583606

584607

585-
def eval_act_mat(actmat):
608+
def eval_act_mat(actmat: npt.NDArray[np.bool_]) -> float:
586609
actmat = actmat.reshape((actmat.shape[0], actmat.shape[1] // 4, 4))
587610
r = np.all(actmat, axis=2)
588611
return np.count_nonzero(r) / r.shape[0] / r.shape[1]
589612

590613

591-
def eval_perm_impl(actmat, perm=None):
614+
def eval_perm_impl(
615+
actmat: npt.NDArray[np.bool_], perm: npt.NDArray[np.int_] | None = None
616+
) -> None:
592617
actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1] // 2))
593618

594619
actmat_eval = eval_act_mat(actmat)
@@ -600,7 +625,7 @@ def eval_perm_impl(actmat, perm=None):
600625
print(f"Combined zeros in perm matrix: {perm_act_mat_eval * 100:0.6f}")
601626

602627

603-
def command_eval_perm(args):
628+
def command_eval_perm(args: argparse.Namespace) -> None:
604629
with open(args.data, "rb") as file:
605630
actmat = np.load(file)
606631

@@ -613,7 +638,7 @@ def command_eval_perm(args):
613638
eval_perm_impl(actmat, perm)
614639

615640

616-
def command_find_perm(args):
641+
def command_find_perm(args: argparse.Namespace) -> None:
617642
with open(args.data, "rb") as file:
618643
actmat = np.load(file)
619644

@@ -626,12 +651,12 @@ def command_find_perm(args):
626651

627652
def ft_optimize(
628653
model: NNUEModel,
629-
dataset_path,
630-
count,
631-
actmat_save_path=None,
632-
perm_save_path=None,
633-
use_cupy=True,
634-
):
654+
dataset_path: str,
655+
count: int,
656+
actmat_save_path: str | None = None,
657+
perm_save_path: str | None = None,
658+
use_cupy: bool = True,
659+
) -> None:
635660
print("Gathering activation data...")
636661
actmat = gather_impl(model, dataset_path, count)
637662
if actmat_save_path is not None:
@@ -651,12 +676,12 @@ def ft_optimize(
651676
ft_permute_impl(model, perm)
652677

653678

654-
def set_cupy_device(device):
679+
def set_cupy_device(device: int) -> None:
655680
if device is not None:
656681
cp.cuda.runtime.setDevice(device)
657682

658683

659-
def main():
684+
def main() -> None:
660685
parser = argparse.ArgumentParser(description="")
661686
parser.add_argument(
662687
"--no-cupy",

0 commit comments

Comments
 (0)