Training Tricks
Mixed Precision
Below is a table highlighting the difference of data types with different precisions.
Todo
This table is generated by Gemini. I have not checked its correctness.
Feature |
FP32 (Single) |
TF32 (TensorFloat) |
BF16 (BFloat16) |
FP16 (Half) |
|---|---|---|---|---|
Total Bits |
32 |
19 (internal) |
16 |
16 |
Sign Bit |
1 |
1 |
1 |
1 |
Exponent Bits |
8 |
8 |
8 |
5 |
Mantissa Bits |
23 |
10 |
7 |
10 |
Dynamic Range |
High |
High |
High |
Low |
Precision |
High |
Medium |
Low |
Medium |
Epsilon |
1.19e-07 |
9.77e-04 |
7.81e-03 |
9.77e-04 |
Max Value |
3.40e+38 |
3.40e+38 |
3.40e+38 |
6.55e+04 |
Below is a simple example to demonstrate the difference of these types in terms of precision.
Demonstration of precison
import torch
# large number reduction
x = 100 * torch.ones(1024).double()
y = 100 * torch.ones(1024).double()
out_double = (x @ y)
out_fp32 = (x.float() @ y.float())
out_bf16 = (x.bfloat16() @ y.bfloat16())
out_fp16 = (x.half() @ y.half())
print(out_double.item(), out_fp32.item(), out_bf16.item(), out_fp16.item())
# 10240000.0 10240000.0 10223616.0 inf
# small number reduction
x = torch.ones(1024).double()
y = torch.ones(1024).double() - 0.001
out_double = (x - y).abs().mean()
out_fp32 = (x.float() - y.float()).abs().mean()
out_bf16 = (x.bfloat16() - y.bfloat16()).abs().mean()
out_fp16 = (x.half() - y.half()).abs().mean()
print(out_double.item(), out_fp32.item(), out_bf16.item(), out_fp16.item())
# 0.0010000000000000009 0.0009999871253967285 0.0 0.0009765625