A high-performance pipeline with Scaling Logical Inference for Post Training Gemma 3 1B into a specialized reasoning engine using Tunix, JAX, and Cloud TPUs v3–8/v5e-1
This project focuses on the post-training phase of Large Language Models to enhance logical deduction and multi-step reasoning. By leveraging the JAX/Flax ecosystem, we achieve massive throughput on TPU v3-8 hardware, overcoming traditional bottlenecks associated with dynamic shapes in transformer training.
- Reasoning Alignment: Transforming general-purpose knowledge into structured "Chain of Thought" (CoT) logic.
- JAX Optimization: Solving the XLA "compilation hang" and memory fragmentation on TPUs.
- Hardware Efficiency: Utilizing a TPU v3-8 mesh with LoRA (Low-Rank Adaptation) via Tunix.
- Model: Google Gemma 3 (1B Variant)
- Framework: Tunix (Post-training framework for JAX/Flax)
- Compute: Google Cloud TPU v3-8
- Optimization: Optax (AdamW with Cosine Schedule)
- Data: Hugging Face Datasets (Multi-domain rebalanced)
TPUs require fixed-size buffers. Standard tokenization creates "jagged" arrays that crash the XLA compiler.
- Solution: Implemented a custom data collator that enforces a strict
MAX_TARGET_LENGTH. This ensures every batch is a perfect rectangle, allowing JAX to compile the computation graph exactly once.
Gemma 3's Multi-Head Attention expects specific 4D tensor alignments.
- Solution: Manually expanded 2D attention masks into a 4D broadcastable format
[Batch, 1, 1, Sequence]. This aligned the dimensions for the Einstein Summation (einsum) kernels in the attention layers.
JAX cannot trace Python strings. Residual metadata in datasets often causes _str_abstractify errors.
- Solution: Developed a pre-processing pipeline that strips all non-numeric columns, leaving only the raw integer
input_idsandattention_maskfor the TPU.
- Compilation Speed: After the initial Step 0 JAX trace, training stabilized at millisecond execution speeds per step.
- Reasoning Delta: Post-trained models showed a marked increase in using "Chain of Thought" markers compared to the base model.
---
We utilized Stratified Post-Training, rebalancing the model across six critical reasoning domains:
- Mathematics: Step-by-step problem solving.
- Coding: Logic-heavy algorithm generation.
- Science: Deductive reasoning.
- Creative: Instruction following.
- Summarization: Contextual logic.
- General: General-knowledge reasoning.

