Training Tricks

Mixed Precision

Below is a table highlighting the difference of data types with different precisions.

Todo

This table is generated by Gemini. I have not checked its correctness.

Feature

FP32 (Single)

TF32 (TensorFloat)

BF16 (BFloat16)

FP16 (Half)

Total Bits

32

19 (internal)

16

16

Sign Bit

1

1

1

1

Exponent Bits

8

8

8

5

Mantissa Bits

23

10

7

10

Dynamic Range

High

High

High

Low

Precision

High

Medium

Low

Medium

Epsilon

1.19e-07

9.77e-04

7.81e-03

9.77e-04

Max Value

3.40e+38

3.40e+38

3.40e+38

6.55e+04

Below is a simple example to demonstrate the difference of these types in terms of precision.

Demonstration of precison
import torch

# large number reduction
x = 100 * torch.ones(1024).double()
y = 100 * torch.ones(1024).double()
out_double = (x @ y)
out_fp32 = (x.float() @ y.float())
out_bf16 = (x.bfloat16() @ y.bfloat16())
out_fp16 = (x.half() @ y.half())
print(out_double.item(), out_fp32.item(), out_bf16.item(), out_fp16.item())
# 10240000.0 10240000.0 10223616.0 inf

# small number reduction
x = torch.ones(1024).double()
y = torch.ones(1024).double() - 0.001
out_double = (x - y).abs().mean()
out_fp32 = (x.float() - y.float()).abs().mean()
out_bf16 = (x.bfloat16() - y.bfloat16()).abs().mean()
out_fp16 = (x.half() - y.half()).abs().mean()
print(out_double.item(), out_fp32.item(), out_bf16.item(), out_fp16.item())
# 0.0010000000000000009 0.0009999871253967285 0.0 0.0009765625

EMA