Introduction

  • Large Language Models (LLMs) performance has been increased rapidly recently.

  • Current trend: Larger models, trained with more data.

  • Compute demand grows exponentially.

    Compute trends in AI.1
  • 8-bit floating-point (FP8) formats promise speed-up computation by leveraging compute with lower precision numbers on matrix multiplications (GEMMs).
  • In practice, FP8 pre-training deployment is challenging:
    • Some FP8 recipes have been shown to diverge for longer training regimes.
    • Some FP8 recipes involve fine-grained FP8-casting, greatly diminishing speed-ups.
    • Current FP8 training approaches leave some GEMMs at higher precision.
  • Our research seeks to diminish these deficiencies by providing general guidelines for robust transformer architectures.

How are FP8 GEMMs carried out?

  • Under this regime, most activations remain BF16, and only GEMMs are computed with FP8.
  • Therefore, casting tensors from BF16 to FP8 happen many times during a forward pass.
  • Due to the limited FP8 dynamic range, higher-precision tensors are scaled to a more favourable range before casting \[\mathbf{X}_{\text{FP8}} = \rho \mathbf{X}_{\text{BF16}}.\]
  • This scaling can be done tensorwise (i.e. one scaling factor per tensor), or more granular (many scaling factors per tensor).
    • Delayed scaling: Computes a single tensorwise scalar based on statistics observed during earlier iterations.
    • Blockwise: Tensors are scaled in a blockwise fashion, resulting in many scaling factors per tensor. This greatly diminishes potential speed-ups.

What makes FP8 training unstable?

  • Due to this scaling, activations with very large outlier values become harder to represent accurately.
  • It is therefore crucial to control outliers to ensure stable FP8 training.

  • Some architectural choices result in long-term large activation outliers1.
  • We can track outlier dynamics during training using kurtosis of activations \(\mathbf{X} \in \mathbb{R}^{N \times C \times D}\) \[\mathrm{kurt}(\mathbf{X}) := \frac{1}{NC} \sum_{n=1}^N \sum_{c=1}^C \frac{\mu[\mathbf{x}_{nc}^4]}{\sigma^2[\mathbf{x}_{nc}^2]}.\]
  • Then, \(\mathrm{kurt}(\mathbf{X})\) is maximized when few elements of x reach extremely large values, relative to the variance across the entire vector.

FOG

FOG Transformer Block.
  • Motivated by this observation, we designed FOG: the Fast and Outlier-Guarded set of transformer architectures.

  • Key components: Post-normalisation and QK entropy regularisation with non-trainable gain vector.

  • Main architectures proposed:

    Model QK-Reg Activation Post-norm
    FOG-max RMSNorm xIELU1 RMSNorm
    FOG-opt RMSNorm GeLU RMSNorm
    FOG-flash Tanh2 GeLU RMSNorm

Main Results

  • Tested FOG on general-purpose pre-training data, trained from random initialization.
  • Model size: 390M – 8B parameters.
  • Data budget: 50B – 450B tokens.
  • Baselines: Llama3, OLMo2, OP1, and Llama3+SmoothSwiGLU2.
  • Precision:
    • BF16: GEMMs using BF16 precision.
    • FP8: Linear projections computed with FP8 precision.
    • FP8DPA: Linear projections and dot product attention GEMMs computed with FP8 precision.

Long-term outlier dynamics

Loss and kurtosis training dynamics of 1.5B FOG-max and Llama3 models trained for over 100B tokens with BF16 precision.

Kurtosis of unstable FP8 runs

  • Tracking tensor-level metrics such as kurtosis to potentially predict later divergences, before common global metrics like the loss and gradient norms show any symptoms of divergence

Training dynamics of a failed and a successful FP8DPA run. Kurtosis exhibits atypical behaviour much earlier than when the loss diverged.

End-to-end FP8 pre-training

  • No other tested architecture was able to surpass the 20B token mark without diverging at any scale.
  • FOG architectures ensure robust and efficient FP8DPA training in all tested settings.

Cross-entropy loss plots of different architectures with FP8DPA training.

Long-data regime

  • We stress-test FOG-max 1B to train with over 450B tokens, 15x chinchilla-optimal.

Downstream performance

  • We observe comparable downstream performance.
Size Model BF16 (%) FP8DPA (%)
390M Llama3 39.8
 -
FOG-max 41.2 40.8
FOG-opt 40.9 40.4
FOG-flash 40.5 40.3
1B Llama3 46.1
 -
FOG-max 46.0 47.1
FOG-opt 45.7 46.0
FOG-flash 45.7 44.9
Table 1: Average downstream performance of models trained with BF16 and FP8DPA precision.

Efficiency

Size Model Precision Throughput vs BF16 baseline
8B Llama3 BF16 +0.0%
Llama3+SmoothSwiGLU FP8 +34.3%
FOG-max FP8DPA +35.5%
FOG-opt FP8DPA +36.3%
FOG-flash FP8DPA +40.2%
390M Llama3 BF16 +0.0%
Llama3+SmoothSwiGLU FP8 +18.1%
FOG-max FP8DPA +15.2%
FOG-opt FP8DPA +15.9%
FOG-flash FP8DPA +18.0%
Table 2: Throughput gains.

Long-context efficiency

  • Throughput benefits increase even more with longer context lengths!
Context TP FOG-flash Llama3+SmoothSwiGLU
4096 1 +42.6% +38.5%
8192 1 +43.5% OOM
8192 2 +39.1% +34.2%
16384 2 +38.8% +31.1%
Table 3: Long-context throughput gains compared to the BF16 Llama3 8B baseline. TP is the tensor parallel size.

Other FP8 recipes

  • Due to the small overhead of the tensorwise delayed scaling recipe and enabling FP8DPA, our approach provides better throughputs.
Model Precision FP8 recipe Throughput vs BF16 baseline
Llama3 BF16 N/A +0.0%
Llama3 FP8 Blockwise +17.9%
OP FP8 Delayed +28.1%
Llama3+SmoothSwiGLU FP8 Delayed +38.2%
FOG-flash FP8DPA Delayed +42.6%
Table 4: Throughput gains when using different FP8 recipes on 8B models.