Skip to content

Commit 1d8e435

Browse files
Update
1 parent 43e4ac0 commit 1d8e435

1 file changed

Lines changed: 19 additions & 12 deletions

File tree

experimental/kernels/gpt2_webgpu_aot.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ typedef struct {
271271
Tensor inputs; // the input tokens for the current forward pass
272272
Tensor targets; // the target tokens for the current forward pass
273273
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
274+
float* mean_loss_buffer;
275+
276+
Tensor nullTensor;
274277

275278
// kernels
276279
Kernels kernels;
@@ -372,6 +375,8 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin
372375
model->batch_size = 0;
373376
model->seq_len = 0;
374377
model->mean_loss = -1.0f; // -1.0f will designate no loss
378+
// Allocate B * C buffer for mean loss
379+
model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len);
375380

376381
printf("Model build complete\n");
377382

@@ -474,6 +479,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si
474479
/*input=*/ model->acts.residual3[L-1], /*weight=*/ model->params.lnfw, /*bias=*/ model->params.lnfb,
475480
B, T, C);
476481
Tensor nullTensor = createTensor(ctx, Shape{1}, kf32);
482+
model->nullTensor = nullTensor;
477483
kernels.matmul_final_forward = matmul_forward(ctx, model->acts.logits, model->acts.lnf, model->params.wte, nullTensor, B, T, C, Vp);
478484
kernels.softmax_final_forward = softmax_forward(ctx, model->acts.probs, model->acts.logits, B, T, V, Vp);
479485
kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp);
@@ -829,8 +835,9 @@ void gpt2_free(GPT2 *model) {
829835
free(model->v_memory);
830836
free(model->acts_memory);
831837
free(model->grads_acts_memory);
832-
free(model->inputs);
833-
free(model->targets);
838+
// free(model->inputs);
839+
// free(model->targets);
840+
free(model->mean_loss_buffer);
834841
}
835842

836843
#ifndef TESTING
@@ -874,9 +881,6 @@ int main() {
874881
.requiredLimits = &requiredLimits
875882
});
876883

877-
Continue!
878-
879-
```cpp
880884
// build the GPT-2 model from a checkpoint
881885
GPT2 model;
882886
gpt2_build_from_checkpoint(ctx, &model, "gpt2_124M.bin");
@@ -903,11 +907,14 @@ Continue!
903907

904908
// some memory for generating samples from the model
905909
uint64_t rng_state = 1337;
906-
int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
910+
// int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
907911
const int genT = 64; // number of steps of inference we will do
908912

909913
// train
910914
struct timespec start, end;
915+
Tensor inputs = createTensor(ctx, Shape{B, T}, ki32);
916+
Tensor targets = createTensor(ctx, Shape{B, T}, ki32);
917+
Tensor gen_tokens = createTensor(ctx, Shape{B, T}, ki32);
911918
printf("Starting training\n");
912919
for (int step = 0; step <= 40; step++) {
913920
printf("Step %d\n", step);
@@ -918,7 +925,9 @@ Continue!
918925
dataloader_reset(&val_loader);
919926
for (int i = 0; i < val_num_batches; i++) {
920927
dataloader_next_batch(&val_loader);
921-
gpt2_forward(ctx, &model, val_loader.inputs, val_loader.targets, B, T);
928+
toGPU(ctx, val_loader.inputs, inputs);
929+
toGPU(ctx, val_loader.targets, targets);
930+
gpt2_forward(ctx, &model, inputs, targets, B, T);
922931
val_loss += model.mean_loss;
923932
}
924933
val_loss /= val_num_batches;
@@ -928,17 +937,15 @@ Continue!
928937
// once in a while do model inference to print generated text
929938
if (step > 0 && step % 20 == 0) {
930939
// fill up gen_tokens with the GPT2_EOT, which kicks off the generation
931-
for(int i = 0; i < B * T; ++i) {
932-
gen_tokens[i] = tokenizer.eot_token;
933-
}
940+
toGPU(ctx, tokenizer.eot_token, gen_tokens);
934941
// now sample from the model autoregressively
935942
printf("generating:\n---\n");
936943
for (int t = 1; t < genT; t++) {
937944
// note that inference is very wasteful here because for each token
938945
// we re-calculate the forward pass for all of (B,T) positions from scratch
939946
// but the inference here is just for sanity checking anyway
940947
// and we can maybe optimize a bit more later, with careful tests
941-
gpt2_forward(ctx, &model, gen_tokens, NULL, B, T);
948+
gpt2_forward(ctx, &model, gen_tokens, model.nullTensor, B, T);
942949
// furthermore, below we're only using b=0 (i.e. the first row) of all B rows
943950
// we're in principle running B "inference streams" in parallel here
944951
// but only using position 0
@@ -981,7 +988,7 @@ Continue!
981988
dataloader_free(&val_loader);
982989
tokenizer_free(&tokenizer);
983990
gpt2_free(&model);
984-
free(gen_tokens);
991+
// free(gen_tokens);
985992
return 0;
986993
}
987994
#endif

0 commit comments

Comments
 (0)