Skip to content

Commit db817bf

Browse files
committed
Fix validation loss computation with collapsed nets
1 parent d70dc6a commit db817bf

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

trainer.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ int main(int argc, char **argv) {
8888

8989
current_iteration++;
9090

91-
collapse_network(nncpy, nn);
91+
collapse_network(nncpy, nn); // evaluate_network() expects collapsed
9292

9393
#pragma omp parallel for schedule(static) num_threads(NTHREADS) reduction(+:loss)
9494
for (int i = batch * BATCHSIZE; i < (batch+1) * BATCHSIZE; i++) {
@@ -115,10 +115,12 @@ int main(int argc, char **argv) {
115115

116116
/// Verify by iterating over each of the Validation Samples
117117

118+
collapse_network(nncpy, nn); // evaluate_network() expects collapsed
119+
118120
#pragma omp parallel for schedule(static) num_threads(NTHREADS) reduction(+:vloss)
119121
for (uint64_t i = 0; i < NVALIDATE; i++) {
120122
const int tidx = omp_get_thread_num();
121-
evaluate_network(nn, evals[tidx], &validate[i]);
123+
evaluate_network(nncpy, evals[tidx], &validate[i]);
122124
vloss += LOSS_FUNC(&validate[i], evals[tidx]->activated[nn->layers-1]);
123125
}
124126

0 commit comments

Comments
 (0)