Skip to content

Commit dc89b24

Browse files
authored
fix perf_sigmoid_fitter.py (official-stockfish#346)
`SparseBatchDataset` expects a list of strings but a string was given
1 parent e6b9f54 commit dc89b24

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

data_loader/dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from . import stream
88
from .config import DataloaderSkipConfig
99

10+
from typing import List
11+
1012

1113
class FenBatchProvider:
1214
def __init__(
@@ -62,7 +64,7 @@ def __init__(
6264
destroy_stream,
6365
fetch_next,
6466
destroy_part,
65-
filenames,
67+
filenames: List[str],
6668
cyclic,
6769
num_workers,
6870
batch_size=None,
@@ -114,7 +116,7 @@ class SparseBatchProvider(TrainingDataProvider):
114116
def __init__(
115117
self,
116118
feature_set: str,
117-
filenames,
119+
filenames: List[str],
118120
batch_size,
119121
cyclic=True,
120122
num_workers=1,
@@ -138,7 +140,7 @@ class SparseBatchDataset(torch.utils.data.IterableDataset):
138140
def __init__(
139141
self,
140142
feature_set: str,
141-
filenames,
143+
filenames: List[str],
142144
batch_size,
143145
cyclic=True,
144146
num_workers=1,

data_loader/stream.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._native import c_lib
44
from .config import CDataloaderSkipConfig, DataloaderSkipConfig
55
from features.feature_set import FeatureSet
6+
from typing import List
67

78

89
def _to_c_str_array(str_list):
@@ -13,7 +14,7 @@ def _to_c_str_array(str_list):
1314

1415
def create_fen_batch_stream(
1516
concurrency,
16-
filenames,
17+
filenames: List[str],
1718
batch_size,
1819
cyclic,
1920
config: DataloaderSkipConfig,
@@ -43,7 +44,7 @@ def destroy_fen_batch(fen_batch):
4344
def create_sparse_batch_stream(
4445
feature_set: str,
4546
concurrency,
46-
filenames,
47+
filenames: List[str],
4748
batch_size,
4849
cyclic,
4950
config: DataloaderSkipConfig,

perf_sigmoid_fitter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def gather_statistics_from_batches(batches, bucket_size):
9292
return data
9393

9494

95-
def gather_statistics_from_data(filename, count, bucket_size):
95+
def gather_statistics_from_data(filename: str, count: int, bucket_size: int):
9696
"""
9797
Takes a .bin or .binpack file and produces perf% statistics
9898
The result is a dictionary of the form { eval : (perf%, count) }
@@ -105,7 +105,7 @@ def gather_statistics_from_data(filename, count, bucket_size):
105105
# this is just the easiest way to do it
106106
dataset = data_loader.SparseBatchDataset(
107107
"HalfKP",
108-
filename,
108+
[filename],
109109
batch_size,
110110
cyclic,
111111
1,

0 commit comments

Comments
 (0)