Skip to content

Commit db824e2

Browse files
xu-shawnvondele
authored andcommitted
Pass accumulator caches by reference
closes #6416 No functional change
1 parent a191791 commit db824e2

File tree

5 files changed

+13
-13
lines changed

5 files changed

+13
-13
lines changed

src/evaluate.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ Value Eval::evaluate(const Eval::NNUE::Networks& networks,
5959
assert(!pos.checkers());
6060

6161
bool smallNet = use_smallnet(pos);
62-
auto [psqt, positional] = smallNet ? networks.small.evaluate(pos, accumulators, &caches.small)
63-
: networks.big.evaluate(pos, accumulators, &caches.big);
62+
auto [psqt, positional] = smallNet ? networks.small.evaluate(pos, accumulators, caches.small)
63+
: networks.big.evaluate(pos, accumulators, caches.big);
6464

6565
Value nnue = (125 * psqt + 131 * positional) / 128;
6666

6767
// Re-evaluate the position when higher eval accuracy is worth the time spent
6868
if (smallNet && (std::abs(nnue) < 236))
6969
{
70-
std::tie(psqt, positional) = networks.big.evaluate(pos, accumulators, &caches.big);
70+
std::tie(psqt, positional) = networks.big.evaluate(pos, accumulators, caches.big);
7171
nnue = (125 * psqt + 131 * positional) / 128;
7272
smallNet = false;
7373
}
@@ -107,7 +107,7 @@ std::string Eval::trace(Position& pos, const Eval::NNUE::Networks& networks) {
107107

108108
ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15);
109109

110-
auto [psqt, positional] = networks.big.evaluate(pos, *accumulators, &caches->big);
110+
auto [psqt, positional] = networks.big.evaluate(pos, *accumulators, caches->big);
111111
Value v = psqt + positional;
112112
v = pos.side_to_move() == WHITE ? v : -v;
113113
ss << "NNUE evaluation " << 0.01 * UCIEngine::to_cp(v, pos) << " (white side)\n";

src/nnue/network.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ template<typename Arch, typename Transformer>
172172
NetworkOutput
173173
Network<Arch, Transformer>::evaluate(const Position& pos,
174174
AccumulatorStack& accumulatorStack,
175-
AccumulatorCaches::Cache<FTDimensions>* cache) const {
175+
AccumulatorCaches::Cache<FTDimensions>& cache) const {
176176

177177
constexpr uint64_t alignment = CacheLineSize;
178178

@@ -234,7 +234,7 @@ template<typename Arch, typename Transformer>
234234
NnueEvalTrace
235235
Network<Arch, Transformer>::trace_evaluate(const Position& pos,
236236
AccumulatorStack& accumulatorStack,
237-
AccumulatorCaches::Cache<FTDimensions>* cache) const {
237+
AccumulatorCaches::Cache<FTDimensions>& cache) const {
238238

239239
constexpr uint64_t alignment = CacheLineSize;
240240

src/nnue/network.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ class Network {
7676

7777
NetworkOutput evaluate(const Position& pos,
7878
AccumulatorStack& accumulatorStack,
79-
AccumulatorCaches::Cache<FTDimensions>* cache) const;
79+
AccumulatorCaches::Cache<FTDimensions>& cache) const;
8080

8181

8282
void verify(std::string evalfilePath, const std::function<void(std::string_view)>&) const;
8383
NnueEvalTrace trace_evaluate(const Position& pos,
8484
AccumulatorStack& accumulatorStack,
85-
AccumulatorCaches::Cache<FTDimensions>* cache) const;
85+
AccumulatorCaches::Cache<FTDimensions>& cache) const;
8686

8787
private:
8888
void load_user_net(const std::string&, const std::string&);

src/nnue/nnue_feature_transformer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,12 @@ class FeatureTransformer {
264264
// Convert input features
265265
std::int32_t transform(const Position& pos,
266266
AccumulatorStack& accumulatorStack,
267-
AccumulatorCaches::Cache<HalfDimensions>* cache,
267+
AccumulatorCaches::Cache<HalfDimensions>& cache,
268268
OutputType* output,
269269
int bucket) const {
270270

271271
using namespace SIMD;
272-
accumulatorStack.evaluate(pos, *this, *cache);
272+
accumulatorStack.evaluate(pos, *this, cache);
273273
const auto& accumulatorState = accumulatorStack.latest<PSQFeatureSet>();
274274
const auto& threatAccumulatorState = accumulatorStack.latest<ThreatFeatureSet>();
275275

src/nnue/nnue_misc.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat
124124

125125
// We estimate the value of each piece by doing a differential evaluation from
126126
// the current base eval, simulating the removal of the piece from its square.
127-
auto [psqt, positional] = networks.big.evaluate(pos, *accumulators, &caches.big);
127+
auto [psqt, positional] = networks.big.evaluate(pos, *accumulators, caches.big);
128128
Value base = psqt + positional;
129129
base = pos.side_to_move() == WHITE ? base : -base;
130130

@@ -140,7 +140,7 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat
140140
pos.remove_piece(sq);
141141

142142
accumulators->reset();
143-
std::tie(psqt, positional) = networks.big.evaluate(pos, *accumulators, &caches.big);
143+
std::tie(psqt, positional) = networks.big.evaluate(pos, *accumulators, caches.big);
144144
Value eval = psqt + positional;
145145
eval = pos.side_to_move() == WHITE ? eval : -eval;
146146
v = base - eval;
@@ -157,7 +157,7 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat
157157
ss << '\n';
158158

159159
accumulators->reset();
160-
auto t = networks.big.trace_evaluate(pos, *accumulators, &caches.big);
160+
auto t = networks.big.trace_evaluate(pos, *accumulators, caches.big);
161161

162162
ss << " NNUE network contributions "
163163
<< (pos.side_to_move() == WHITE ? "(White to move)" : "(Black to move)") << std::endl

0 commit comments

Comments
 (0)