Towards Fully FP8 GEMM LLM Training at Scale
Alejandro Hernández Cano*
EPFL
Introduction
Large Language Models (LLMs) performance has been increased rapidly recently.
Current trend: Larger models, trained with more data.
Compute demand grows exponentially.
- 8-bit floating-point (FP8) formats promise speed-up computation by leveraging compute with lower precision numbers on matrix multiplications (GEMMs).
- In practice, FP8 pre-training deployment is challenging:
- Some FP8 recipes have been shown to diverge for longer training regimes.
- Some FP8 recipes involve fine-grained FP8-casting, greatly diminishing speed-ups.
- Current FP8 training approaches leave some GEMMs at higher precision.
- Our research seeks to diminish these deficiencies by providing general guidelines for robust transformer architectures.
How are FP8 GEMMs carried out?
- Under this regime, most activations remain BF16, and only GEMMs are computed with FP8.
- Therefore, casting tensors from BF16 to FP8 happen many times during a forward pass.
- Due to the limited FP8 dynamic range, higher-precision tensors are scaled to a more favourable range before casting \[\mathbf{X}_{\text{FP8}} = \rho \mathbf{X}_{\text{BF16}}.\]
- This scaling can be done tensorwise (i.e. one scaling factor per tensor), or more granular (many scaling factors per tensor).
- Delayed scaling: Computes a single tensorwise scalar based on statistics observed during earlier iterations.
- Blockwise: Tensors are scaled in a blockwise fashion, resulting in many scaling factors per tensor. This greatly diminishes potential speed-ups.
What makes FP8 training unstable?
- Due to this scaling, activations with very large outlier values become harder to represent accurately.
- It is therefore crucial to control outliers to ensure stable FP8 training.
- Some architectural choices result in long-term large activation outliers1.
- We can track outlier dynamics during training using kurtosis of activations \(\mathbf{X} \in \mathbb{R}^{N \times C \times D}\) \[\mathrm{kurt}(\mathbf{X}) := \frac{1}{NC} \sum_{n=1}^N \sum_{c=1}^C \frac{\mu[\mathbf{x}_{nc}^4]}{\sigma^2[\mathbf{x}_{nc}^2]}.\]
- Then, \(\mathrm{kurt}(\mathbf{X})\) is maximized when few elements of x reach extremely large values, relative to the variance across the entire vector.
FOG
Motivated by this observation, we designed FOG: the Fast and Outlier-Guarded set of transformer architectures.
Key components: Post-normalisation and QK entropy regularisation with non-trainable gain vector.
Main architectures proposed:
| FOG-max |
RMSNorm |
xIELU1 |
RMSNorm |
| FOG-opt |
RMSNorm |
GeLU |
RMSNorm |
| FOG-flash |
Tanh2 |
GeLU |
RMSNorm |
Main Results
- Tested FOG on general-purpose pre-training data, trained from random initialization.
- Model size: 390M – 8B parameters.
- Data budget: 50B – 450B tokens.
- Baselines: Llama3, OLMo2, OP1, and Llama3+SmoothSwiGLU2.
- Precision:
- BF16: GEMMs using BF16 precision.
- FP8: Linear projections computed with FP8 precision.
- FP8DPA: Linear projections and dot product attention GEMMs computed with FP8 precision.
Long-term outlier dynamics
![]()
Loss and kurtosis training dynamics of 1.5B FOG-max and Llama3 models trained for over 100B tokens with BF16 precision.
Kurtosis of unstable FP8 runs
- Tracking tensor-level metrics such as kurtosis to potentially predict later divergences, before common global metrics like the loss and gradient norms show any symptoms of divergence
![]()
Training dynamics of a failed and a successful FP8DPA run. Kurtosis exhibits atypical behaviour much earlier than when the loss diverged.
End-to-end FP8 pre-training
- No other tested architecture was able to surpass the 20B token mark without diverging at any scale.
- FOG architectures ensure robust and efficient FP8DPA training in all tested settings.
![]()
Cross-entropy loss plots of different architectures with FP8DPA training.
Long-data regime
- We stress-test FOG-max 1B to train with over 450B tokens, 15x chinchilla-optimal.
- We observe comparable downstream performance.
Long-context efficiency
- Throughput benefits increase even more with longer context lengths!
Other FP8 recipes
- Due to the small overhead of the tensorwise delayed scaling recipe and enabling FP8DPA, our approach provides better throughputs.