Skip to content

Commit a904a5d

Browse files
committed
Add multiple nets support to the server
Required by the new SF architecture with multiple nets, see: official-stockfish/Stockfish#4915 official-stockfish/Stockfish#5068 Also: - increment the download of the net only for a API request - improve the net code to deal with a devel0pment server built from scratch
1 parent 4bee31b commit a904a5d

File tree

5 files changed

+78
-58
lines changed

5 files changed

+78
-58
lines changed

server/fishtest/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,9 @@ def download_nn(self):
608608
nn = self.request.rundb.get_nn(self.request.matchdict["id"])
609609
if nn is None:
610610
raise exception_response(404)
611+
else:
612+
self.request.rundb.increment_nn_downloads(self.request.matchdict["id"])
613+
611614
return HTTPFound(
612615
"https://data.stockfishchess.org/nn/" + self.request.matchdict["id"]
613616
)

server/fishtest/rundb.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
# To make this practical we will eventually put all schemas
5151
# in a separate module "schemas.py".
5252

53-
net_name = regex("nn-[a-z0-9]{12}.nnue", name="net_name")
53+
net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name")
5454
tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc")
5555
str_int = regex(r"[1-9]\d*", name="str_int")
5656
sha = regex(r"[a-f0-9]{40}", name="sha")
@@ -111,8 +111,8 @@
111111
"args": {
112112
"base_tag": str,
113113
"new_tag": str,
114-
"base_net": net_name,
115-
"new_net": net_name,
114+
"base_nets": [net_name, ...],
115+
"new_nets": [net_name, ...],
116116
"num_games": int,
117117
"tc": tc,
118118
"new_tc": tc,
@@ -322,8 +322,8 @@ def new_run(
322322
msg_new="",
323323
base_signature="",
324324
new_signature="",
325-
base_net=None,
326-
new_net=None,
325+
base_nets=None,
326+
new_nets=None,
327327
rescheduled_from=None,
328328
base_same_as_master=None,
329329
start_time=None,
@@ -344,8 +344,8 @@ def new_run(
344344
run_args = {
345345
"base_tag": base_tag,
346346
"new_tag": new_tag,
347-
"base_net": base_net,
348-
"new_net": new_net,
347+
"base_nets": base_nets,
348+
"new_nets": new_nets,
349349
"num_games": num_games,
350350
"tc": tc,
351351
"new_tc": new_tc,
@@ -477,11 +477,10 @@ def update_nn(self, net):
477477
self.nndb.update_one({"name": net["name"]}, {"$set": net})
478478

479479
def get_nn(self, name):
480-
nn = self.nndb.find_one({"name": name}, {"nn": 0})
481-
if nn:
482-
self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}})
483-
return nn
484-
return None
480+
return self.nndb.find_one({"name": name}, {"nn": 0})
481+
482+
def increment_nn_downloads(self, name):
483+
self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}})
485484

486485
def get_nns(
487486
self, user_id, user="", network_name="", master_only=False, limit=0, skip=0

server/fishtest/views.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -729,35 +729,34 @@ def get_sha(branch, repo_url):
729729
return "", ""
730730

731731

732-
def get_net(commit_sha, repo_url):
733-
"""Get the net from evaluate.h or ucioption.cpp in the repo"""
732+
def get_nets(commit_sha, repo_url):
733+
"""Get the nets from evaluate.h or ucioption.cpp in the repo"""
734734
api_url = repo_url.replace(
735735
"https://github.com", "https://raw.githubusercontent.com"
736736
)
737737
try:
738-
net = None
738+
nets = []
739+
pattern = re.compile("nn-[a-f0-9]{12}.nnue")
739740

740741
url1 = api_url + "/" + commit_sha + "/src/evaluate.h"
741742
options = requests.get(url1).content.decode("utf-8")
742743
for line in options.splitlines():
743744
if "EvalFileDefaultName" in line and "define" in line:
744-
p = re.compile("nn-[a-z0-9]{12}.nnue")
745-
m = p.search(line)
745+
m = pattern.search(line)
746746
if m:
747-
net = m.group(0)
747+
nets.append(m.group(0))
748748

749-
if net:
750-
return net
749+
if nets:
750+
return nets
751751

752752
url2 = api_url + "/" + commit_sha + "/src/ucioption.cpp"
753753
options = requests.get(url2).content.decode("utf-8")
754754
for line in options.splitlines():
755755
if "EvalFile" in line and "Option" in line:
756-
p = re.compile("nn-[a-z0-9]{12}.nnue")
757-
m = p.search(line)
756+
m = pattern.search(line)
758757
if m:
759-
net = m.group(0)
760-
return net
758+
nets.append(m.group(0))
759+
return nets
761760
except:
762761
raise Exception("Unable to access developer repository: " + api_url)
763762

@@ -919,19 +918,23 @@ def strip_message(m):
919918
)
920919
data["base_same_as_master"] = master_diff.text == ""
921920

922-
# Test existence of net
923-
new_net = get_net(data["resolved_new"], data["tests_repo"])
924-
if new_net:
925-
if not request.rundb.get_nn(new_net):
926-
raise Exception(
927-
"The net {}, used by {}, is not "
928-
"known to Fishtest. Please upload it to: "
929-
"{}/upload.".format(new_net, data["new_tag"], request.host_url)
921+
# Store nets info
922+
data["base_nets"] = get_nets(data["resolved_base"], data["tests_repo"])
923+
data["new_nets"] = get_nets(data["resolved_new"], data["tests_repo"])
924+
925+
# Test existence of nets
926+
missing_nets = []
927+
for net_name in set(data["base_nets"]) | set(data["new_nets"]):
928+
net = request.rundb.get_nn(net_name)
929+
if net is None:
930+
missing_nets.append(net_name)
931+
if missing_nets:
932+
raise Exception(
933+
"Missing net(s). Please upload to: {} the following net(s): {}".format(
934+
request.host_url,
935+
", ".join(missing_nets),
930936
)
931-
932-
# Store net info
933-
data["new_net"] = new_net
934-
data["base_net"] = get_net(data["resolved_base"], data["tests_repo"])
937+
)
935938

936939
# Integer parameters
937940
data["threads"] = int(request.POST["threads"])
@@ -1003,25 +1006,32 @@ def del_tasks(run):
10031006
def update_nets(request, run):
10041007
run_id = str(run["_id"])
10051008
data = run["args"]
1009+
base_nets, new_nets, missing_nets = [], [], []
1010+
for net_name in set(data["base_nets"]) | set(data["new_nets"]):
1011+
net = request.rundb.get_nn(net_name)
1012+
if net is None:
1013+
# This should never happen
1014+
missing_nets.append(net_name)
1015+
else:
1016+
if net_name in data["base_nets"]:
1017+
base_nets.append(net)
1018+
if net_name in data["new_nets"]:
1019+
new_nets.append(net)
1020+
if missing_nets:
1021+
raise Exception(
1022+
"Missing net(s). Please upload to {} the following net(s): {}".format(
1023+
request.host_url,
1024+
", ".join(missing_nets),
1025+
)
1026+
)
1027+
10061028
if run["base_same_as_master"]:
1007-
base_net = data["base_net"]
1008-
if base_net:
1009-
net = request.rundb.get_nn(base_net)
1010-
if not net:
1011-
# Should never happen:
1012-
raise Exception(
1013-
"The net {}, used by {}, is not "
1014-
"known to Fishtest. Please upload it to: "
1015-
"{}/upload.".format(base_net, data["base_tag"], request.host_url)
1016-
)
1029+
for net in base_nets:
10171030
if "is_master" not in net:
10181031
net["is_master"] = True
10191032
request.rundb.update_nn(net)
1020-
new_net = data["new_net"]
1021-
if new_net:
1022-
net = request.rundb.get_nn(new_net)
1023-
if not net:
1024-
return
1033+
1034+
for net in new_nets:
10251035
if "first_test" not in net:
10261036
net["first_test"] = {"id": run_id, "date": datetime.now(timezone.utc)}
10271037
net["last_test"] = {"id": run_id, "date": datetime.now(timezone.utc)}
@@ -1228,7 +1238,10 @@ def tests_approve(request):
12281238
if run is None:
12291239
request.session.flash(message, "error")
12301240
else:
1231-
update_nets(request, run)
1241+
try:
1242+
update_nets(request, run)
1243+
except Exception as e:
1244+
request.session.flash(str(e), "error")
12321245
request.actiondb.approve_run(username=username, run=run)
12331246
cached_flash(request, message)
12341247
return home(request)
@@ -1367,11 +1380,13 @@ def tests_view(request):
13671380
"new_options",
13681381
"resolved_new",
13691382
"new_net",
1383+
"new_nets",
13701384
"base_tag",
13711385
"base_signature",
13721386
"base_options",
13731387
"resolved_base",
13741388
"base_net",
1389+
"base_nets",
13751390
"sprt",
13761391
"num_games",
13771392
"spsa",
@@ -1401,6 +1416,9 @@ def tests_view(request):
14011416
if name == "base_tag" and "msg_base" in run["args"]:
14021417
value += " (" + run["args"]["msg_base"][:50] + ")"
14031418

1419+
if name in ("new_nets", "base_nets"):
1420+
value = ", ".join(value)
1421+
14041422
if name == "sprt" and value != "-":
14051423
value = "elo0: {:.2f} alpha: {:.2f} elo1: {:.2f} beta: {:.2f} state: {} ({})".format(
14061424
value["elo0"],

server/tests/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def new_run(self, add_tasks=0):
3939
msg_new="Super stuff",
4040
base_signature="123456",
4141
new_signature="654321",
42-
base_net="nn-0000000000a0.nnue",
43-
new_net="nn-0000000000a0.nnue",
42+
base_nets=["nn-0000000000a0.nnue"],
43+
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
4444
rescheduled_from="653db116cc309ae839563103",
4545
base_same_as_master=False,
4646
tests_repo="https://google.com",

server/tests/test_rundb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def test_10_create_run(self):
7272
msg_new="Super stuff",
7373
base_signature="123456",
7474
new_signature="654321",
75-
base_net="nn-0000000000a0.nnue",
76-
new_net="nn-0000000000a0.nnue",
75+
base_nets=["nn-0000000000a0.nnue"],
76+
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
7777
rescheduled_from="653db116cc309ae839563103",
7878
base_same_as_master=False,
7979
tests_repo="https://google.com",
@@ -113,8 +113,8 @@ def test_10_create_run(self):
113113
msg_new="Super stuff",
114114
base_signature="123456",
115115
new_signature="654321",
116-
base_net="nn-0000000000a0.nnue",
117-
new_net="nn-0000000000a0.nnue",
116+
base_nets=["nn-0000000000a0.nnue"],
117+
new_nets=["nn-0000000000a0.nnue", "nn-0000000000a1.nnue"],
118118
rescheduled_from="653db116cc309ae839563103",
119119
base_same_as_master=False,
120120
tests_repo="https://google.com",

0 commit comments

Comments
 (0)