@@ -271,6 +271,9 @@ typedef struct {
271271 Tensor inputs; // the input tokens for the current forward pass
272272 Tensor targets; // the target tokens for the current forward pass
273273 float mean_loss; // after a forward pass with targets, will be populated with the mean loss
274+ float * mean_loss_buffer;
275+
276+ Tensor nullTensor;
274277
275278 // kernels
276279 Kernels kernels;
@@ -372,6 +375,8 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin
372375 model->batch_size = 0 ;
373376 model->seq_len = 0 ;
374377 model->mean_loss = -1 .0f ; // -1.0f will designate no loss
378+ // Allocate B * C buffer for mean loss
379+ model->mean_loss_buffer = (float *)mallocCheck (sizeof (float ) * model->batch_size * model->seq_len );
375380
376381 printf (" Model build complete\n " );
377382
@@ -474,6 +479,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
474479 /* input=*/ model->acts .residual3 [L-1 ], /* weight=*/ model->params .lnfw , /* bias=*/ model->params .lnfb ,
475480 B, T, C);
476481 Tensor nullTensor = createTensor (ctx, Shape{1 }, kf32);
482+ model->nullTensor = nullTensor;
477483 kernels.matmul_final_forward = matmul_forward (ctx, model->acts .logits , model->acts .lnf , model->params .wte , nullTensor, B, T, C, Vp);
478484 kernels.softmax_final_forward = softmax_forward (ctx, model->acts .probs , model->acts .logits , B, T, V, Vp);
479485 kernels.crossentropy_softmax_backward = crossentropy_softmax_backward (ctx, model->acts .logits , model->acts .losses , model->acts .probs , targets, B, T, V, Vp);
@@ -829,8 +835,9 @@ void gpt2_free(GPT2 *model) {
829835 free (model->v_memory );
830836 free (model->acts_memory );
831837 free (model->grads_acts_memory );
832- free (model->inputs );
833- free (model->targets );
838+ // free(model->inputs);
839+ // free(model->targets);
840+ free (model->mean_loss_buffer );
834841}
835842
836843#ifndef TESTING
@@ -874,9 +881,6 @@ int main() {
874881 .requiredLimits = &requiredLimits
875882 });
876883
877- Continue!
878-
879- ```cpp
880884 // build the GPT-2 model from a checkpoint
881885 GPT2 model;
882886 gpt2_build_from_checkpoint (ctx, &model, " gpt2_124M.bin" );
@@ -903,11 +907,14 @@ Continue!
903907
904908 // some memory for generating samples from the model
905909 uint64_t rng_state = 1337 ;
906- int * gen_tokens = (int *)mallocCheck (B * T * sizeof (int ));
910+ // int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
907911 const int genT = 64 ; // number of steps of inference we will do
908912
909913 // train
910914 struct timespec start, end;
915+ Tensor inputs = createTensor (ctx, Shape{B, T}, ki32);
916+ Tensor targets = createTensor (ctx, Shape{B, T}, ki32);
917+ Tensor gen_tokens = createTensor (ctx, Shape{B, T}, ki32);
911918 printf (" Starting training\n " );
912919 for (int step = 0 ; step <= 40 ; step++) {
913920 printf (" Step %d\n " , step);
@@ -918,7 +925,9 @@ Continue!
918925 dataloader_reset (&val_loader);
919926 for (int i = 0 ; i < val_num_batches; i++) {
920927 dataloader_next_batch (&val_loader);
921- gpt2_forward (ctx, &model, val_loader.inputs , val_loader.targets , B, T);
928+ toGPU (ctx, val_loader.inputs , inputs);
929+ toGPU (ctx, val_loader.targets , targets);
930+ gpt2_forward (ctx, &model, inputs, targets, B, T);
922931 val_loss += model.mean_loss ;
923932 }
924933 val_loss /= val_num_batches;
@@ -928,17 +937,15 @@ Continue!
928937 // once in a while do model inference to print generated text
929938 if (step > 0 && step % 20 == 0 ) {
930939 // fill up gen_tokens with the GPT2_EOT, which kicks off the generation
931- for (int i = 0 ; i < B * T; ++i) {
932- gen_tokens[i] = tokenizer.eot_token ;
933- }
940+ toGPU (ctx, tokenizer.eot_token , gen_tokens);
934941 // now sample from the model autoregressively
935942 printf (" generating:\n ---\n " );
936943 for (int t = 1 ; t < genT; t++) {
937944 // note that inference is very wasteful here because for each token
938945 // we re-calculate the forward pass for all of (B,T) positions from scratch
939946 // but the inference here is just for sanity checking anyway
940947 // and we can maybe optimize a bit more later, with careful tests
941- gpt2_forward (ctx, &model, gen_tokens, NULL , B, T);
948+ gpt2_forward (ctx, &model, gen_tokens, model. nullTensor , B, T);
942949 // furthermore, below we're only using b=0 (i.e. the first row) of all B rows
943950 // we're in principle running B "inference streams" in parallel here
944951 // but only using position 0
@@ -981,7 +988,7 @@ Continue!
981988 dataloader_free (&val_loader);
982989 tokenizer_free (&tokenizer);
983990 gpt2_free (&model);
984- free (gen_tokens);
991+ // free(gen_tokens);
985992 return 0 ;
986993}
987994#endif
0 commit comments