⏱️ Estimated Reading Time: 12 minutes

Introduction

The release of gpt-oss marks a significant milestone as the first open-source model family from OpenAI since GPT-2. This groundbreaking model features a mixture of experts (MoE) architecture, 128K context length, and adjustable deep reasoning capabilities. However, its native MXFP4 precision introduces unique challenges for fine-tuning that require innovative solutions.

In this comprehensive guide, we’ll explore NVIDIA’s proven workflow for fine-tuning gpt-oss models that maintains accuracy while preserving the performance benefits of FP4 precision. This approach combines supervised fine-tuning (SFT) with quantization-aware training (QAT) using NVIDIA TensorRT Model Optimizer.

Understanding the gpt-oss Challenge

The MXFP4 Precision Dilemma

OpenAI’s decision to release gpt-oss at native MXFP4 precision was an industry first, but it created significant challenges for practitioners:

  • Stability Issues: Native MXFP4 precision hasn’t proven stable for gradient accumulation during fine-tuning
  • Training Difficulties: Direct fine-tuning in FP4 can lead to convergence problems and accuracy degradation
  • Production Requirements: Most foundational models require post-training techniques for effective deployment, especially in low-fault-tolerance industries like healthcare and finance

Why Traditional Fine-Tuning Falls Short

The gpt-oss-120B model achieves performance comparable to OpenAI’s closed-source o3 and o4 models on open benchmarks. However, out-of-the-box performance often shows room for improvement on specific downstream tasks:

  • Non-English Reasoning: Initial scores around 16% on multilingual datasets
  • Safe Prompt Handling: 30% pass rate on reducing unnecessary refusals of safe user prompts
  • Task-Specific Alignment: Generic models require specialized training for domain-specific applications

The SFT + QAT Workflow Solution

Overview of the Approach

NVIDIA’s solution involves a two-stage process that addresses the stability issues while maintaining efficiency:

  1. Supervised Fine-Tuning (SFT): Performed on an upcasted BF16 version for stable gradient accumulation
  2. Quantization-Aware Training (QAT): Applied using NVIDIA TensorRT Model Optimizer to return to FP4 precision

This workflow enables SFT to reinforce task-specific behavior while QAT adapts the weights to the target low-precision format, delivering both alignment and performance for deployment.

Step-by-Step Implementation

Step 1: Upcast Original MXFP4 Checkpoint to BF16/FP16

The first crucial step involves converting the native MXFP4 model to higher precision:

# Using Hugging Face Transformers for upcasting
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the original MXFP4 model
model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    torch_dtype=torch.bfloat16,  # Upcast to BF16
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")

Benefits of Upcasting:

  • Provides more stable gradients during training
  • Enables QAT to effectively recover accuracy when re-quantizing back to FP4
  • Acceptable trade-off since fine-tuning uses far fewer tokens than pre-training

Step 2: Perform Supervised Fine-Tuning

With the upcasted model, perform standard supervised fine-tuning:

import torch
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer

# Configure training arguments
training_args = TrainingArguments(
    output_dir="./sft_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    fp16=False,  # Use BF16 for stability
    bf16=True,
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

# Execute fine-tuning
trainer.train()

Step 3: Quantize Using TensorRT Model Optimizer

Apply quantization to prepare the model for QAT:

import modelopt.torch.quantization as mtq

# Configure quantization settings
config = mtq.MXFP4_MLP_WEIGHT_ONLY_CFG

# Define forward loop for calibration
def forward_loop(model):
    for data in calib_set:
        model(data)

# Quantize the model and prepare for QAT
model = mtq.quantize(model, config, forward_loop)

Step 4: Fine-tune the FP4 Quantized Model

The final QAT step involves fine-tuning at a smaller learning rate:

# QAT configuration
qat_optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
qat_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    qat_optimizer, T_max=1000
)

# QAT training loop
for epoch in range(qat_epochs):
    for batch in train_loader:
        qat_optimizer.zero_grad()
        
        outputs = model(**batch)
        loss = outputs.loss
        
        loss.backward()
        qat_optimizer.step()
        qat_scheduler.step()

Performance Impact Analysis

Dramatic Improvements on Downstream Tasks

The SFT + QAT workflow demonstrates remarkable effectiveness across various evaluation metrics:

Before Fine-tuning:

  • Non-English reasoning: 16% pass rate
  • Safe prompt handling: 30% pass rate

After SFT + QAT:

  • Non-English reasoning: 98% pass rate
  • Safe prompt handling: 98% pass rate

This represents a 6x improvement in non-English reasoning and 3.3x improvement in safe prompt handling.

Comparison of Quantization Methods

Method Non-English Reasoning Safe Prompt Handling Deployment Efficiency
Original 16% 30% High (FP4)
SFT Only 85% 82% Low (BF16)
PTQ 45% 52% High (FP4)
SFT + QAT 98% 98% High (FP4)

NVFP4: The Next Generation

Introducing NVFP4 Format

With NVIDIA Blackwell, NVFP4 introduces a new FP4 format purpose-built for both training and inference efficiency:

# Switching to NVFP4 is as simple as changing one line
# From MXFP4
config = mtq.MXFP4_MLP_WEIGHT_ONLY_CFG

# To NVFP4
config = mtq.NVFP4_MLP_WEIGHT_ONLY_CFG

# For even better performance with weight-activation quantization
config = mtq.NVFP4_MLP_ONLY_CFG

NVFP4 Advantages

Technical Benefits:

  • E4M3 FP8 scaling precision reduces quantization errors
  • Better convergence during fake quantization process
  • 2-3% better validation loss compared to MXFP4
  • Enhanced accuracy recovery capabilities

Performance Benefits:

  • Up to 15 PFLOPs of FP4 compute on NVIDIA Blackwell Ultra
  • Specialized instructions in second-generation NVIDIA Transformer Engine
  • Better model accuracy performance retention

Validation Loss Comparison

Empirical results show consistent improvements with NVFP4:

  • Multi-lingual tasks: 2-3% better validation loss
  • False rejection tasks: Improved convergence stability
  • Deep reasoning scenarios: Better performance under strict thresholds

Production Deployment Guide

Model Conversion and Export

After completing the SFT + QAT workflow, convert the model for deployment:

# Convert BF16-trained checkpoint to MXFP4
python examples/gpt-oss/convert_oai_mxfp4_weight_only.py \
    --model_path qat_model_dir/ \
    --output_path qat_model_mxfp4/

Deployment with TensorRT-LLM

Deploy the optimized model using TensorRT-LLM:

# Host endpoint with TensorRT-LLM 1.1.0rc1
trtllm-serve qat_model_mxfp4/ \
    --tokenizer <tokenizer_path> \
    --max_batch_size 32 \
    --max_num_tokens 4096 \
    --max_seq_len 128000 \
    --tp_size 4 \
    --pp_size 2 \
    --host 0.0.0.0 \
    --kv_cache_free_gpu_memory_fraction 0.95

Compatibility and Framework Support

The resulting MXFP4 checkpoints have been tested with:

  • SGLang: Full compatibility for serving
  • TensorRT-LLM: Optimized inference performance
  • vLLM: Production-ready deployment
  • Upcoming NVFP4 Support: Enhanced performance with Blackwell architecture

Best Practices and Optimization Tips

Hyperparameter Optimization

SFT Stage:

  • Learning rate: 2e-5 to 5e-5
  • Batch size: Adjust based on GPU memory
  • Epochs: 2-5 depending on dataset size
  • Gradient accumulation: 4-16 steps

QAT Stage:

  • Learning rate: 1e-5 to 5e-6 (10x smaller than SFT)
  • Training duration: 500-2000 steps
  • Optimizer: Adam with cosine annealing
  • Calibration dataset: Representative of target distribution

Common Pitfalls to Avoid

  1. Skipping SFT: Direct QAT results in significantly lower accuracy
  2. Incorrect Learning Rates: Too high learning rates in QAT can destabilize quantized weights
  3. Insufficient Calibration: Poor calibration data leads to suboptimal quantization
  4. Premature Convergence: Monitor validation metrics throughout QAT

Monitoring and Validation

# Essential metrics to track during training
metrics_to_monitor = {
    'training_loss': 'Primary optimization target',
    'validation_loss': 'Generalization indicator', 
    'perplexity': 'Language modeling quality',
    'task_specific_accuracy': 'Downstream performance',
    'quantization_error': 'Precision preservation'
}

Future Directions and Roadmap

Upcoming Enhancements

NVFP4 Ecosystem Expansion:

  • TensorRT-LLM native support for gpt-oss NVFP4
  • Integration with additional open-source inference frameworks
  • Enhanced tooling for NVFP4 model optimization

Training Innovations:

  • Native FP4 training techniques for improved efficiency
  • Advanced QAT algorithms for better accuracy recovery
  • Multi-objective optimization for accuracy-efficiency trade-offs

Industry Impact

The SFT + QAT workflow for gpt-oss represents a significant advancement in production AI deployment:

  • Cost Efficiency: Maintains FP4 inference benefits while ensuring accuracy
  • Scalability: Enables deployment in resource-constrained environments
  • Reliability: Provides stable performance for mission-critical applications
  • Accessibility: Makes advanced AI capabilities available to broader audiences

Conclusion

The fine-tuning workflow for gpt-oss using supervised fine-tuning followed by quantization-aware training successfully addresses the unique challenges posed by native MXFP4 precision. This approach delivers:

  • Dramatic Performance Improvements: Up to 6x better accuracy on downstream tasks
  • Production Efficiency: Maintains FP4 inference benefits for cost-effective deployment
  • Future-Ready Architecture: Seamless transition to NVFP4 for even better performance

The combination of NVIDIA’s TensorRT Model Optimizer with proven training methodologies provides a robust foundation for deploying gpt-oss models in production environments. As NVFP4 support expands across inference frameworks, this workflow will unlock even greater potential for accuracy and efficiency optimization.

For practitioners looking to leverage gpt-oss in production applications, the SFT + QAT workflow offers a proven path to achieving both high accuracy and efficient deployment. The complete implementation is available through the NVIDIA Model Optimizer repository, providing immediate access to these advanced optimization techniques.


References: