Skip to content

Addyk-24/Tunix-RPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TUNIX: Post-Training for Reasoning

A high-performance pipeline with Scaling Logical Inference for Post Training Gemma 3 1B into a specialized reasoning engine using Tunix, JAX, and Cloud TPUs v3–8/v5e-1

Check out My blog related to This 👉 Medium

📌 Project Overview

This project focuses on the post-training phase of Large Language Models to enhance logical deduction and multi-step reasoning. By leveraging the JAX/Flax ecosystem, we achieve massive throughput on TPU v3-8 hardware, overcoming traditional bottlenecks associated with dynamic shapes in transformer training.

Key Objectives:

  • Reasoning Alignment: Transforming general-purpose knowledge into structured "Chain of Thought" (CoT) logic.
  • JAX Optimization: Solving the XLA "compilation hang" and memory fragmentation on TPUs.
  • Hardware Efficiency: Utilizing a TPU v3-8 mesh with LoRA (Low-Rank Adaptation) via Tunix.

🛠️ Technical Stack

  • Model: Google Gemma 3 (1B Variant)
  • Framework: Tunix (Post-training framework for JAX/Flax)
  • Compute: Google Cloud TPU v3-8
  • Optimization: Optax (AdamW with Cosine Schedule)
  • Data: Hugging Face Datasets (Multi-domain rebalanced)

🚀 Key Technical Hurdles & Solutions

1. The "Inhomogeneous Shape" Fix (Static Padding)

TPUs require fixed-size buffers. Standard tokenization creates "jagged" arrays that crash the XLA compiler.

  • Solution: Implemented a custom data collator that enforces a strict MAX_TARGET_LENGTH. This ensures every batch is a perfect rectangle, allowing JAX to compile the computation graph exactly once.

2. Solving the "BTNS" Einsum Error

Gemma 3's Multi-Head Attention expects specific 4D tensor alignments.

  • Solution: Manually expanded 2D attention masks into a 4D broadcastable format [Batch, 1, 1, Sequence]. This aligned the dimensions for the Einstein Summation (einsum) kernels in the attention layers.

3. JAX Data Tracing

JAX cannot trace Python strings. Residual metadata in datasets often causes _str_abstractify errors.

  • Solution: Developed a pre-processing pipeline that strips all non-numeric columns, leaving only the raw integer input_ids and attention_mask for the TPU.

📈 Performance & Results

  • Compilation Speed: After the initial Step 0 JAX trace, training stabilized at millisecond execution speeds per step.
  • Reasoning Delta: Post-trained models showed a marked increase in using "Chain of Thought" markers compared to the base model.

Tunix Reasoning Architecture & Hyperparameters Configs:

balance_ds_image
Final Rebalancing Dataset


enhanced_hardware_check Hardware Check


model_config
Model Configs


---

📊 Data Strategy

We utilized Stratified Post-Training, rebalancing the model across six critical reasoning domains:

  1. Mathematics: Step-by-step problem solving.
  2. Coding: Logic-heavy algorithm generation.
  3. Science: Deductive reasoning.
  4. Creative: Instruction following.
  5. Summarization: Contextual logic.
  6. General: General-knowledge reasoning.

👩‍💻 Author

Aditya Katkar
GitHub
LinkedIn

About

Post-Training-Using-Tunix

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published