Training Tricks =============== .. highlight:: python Mixed Precision --------------- Below is a table highlighting the difference of data types with different precisions. .. todo:: This table is generated by Gemini. I have not checked its correctness. +---------------------+----------------+--------------------+-----------------+-------------+ | Feature | FP32 (Single) | TF32 (TensorFloat) | BF16 (BFloat16) | FP16 (Half) | +=====================+================+====================+=================+=============+ | **Total Bits** | 32 | 19 (internal) | 16 | 16 | +---------------------+----------------+--------------------+-----------------+-------------+ | **Sign Bit** | 1 | 1 | 1 | 1 | +---------------------+----------------+--------------------+-----------------+-------------+ | **Exponent Bits** | 8 | 8 | 8 | 5 | +---------------------+----------------+--------------------+-----------------+-------------+ | **Mantissa Bits** | 23 | 10 | 7 | 10 | +---------------------+----------------+--------------------+-----------------+-------------+ | **Dynamic Range** | High | High | High | Low | +---------------------+----------------+--------------------+-----------------+-------------+ | **Precision** | High | Medium | Low | Medium | +---------------------+----------------+--------------------+-----------------+-------------+ | **Epsilon** | 1.19e-07 | 9.77e-04 | 7.81e-03 | 9.77e-04 | +---------------------+----------------+--------------------+-----------------+-------------+ | **Max Value** | 3.40e+38 | 3.40e+38 | 3.40e+38 | 6.55e+04 | +---------------------+----------------+--------------------+-----------------+-------------+ Below is a simple example to demonstrate the difference of these types in terms of precision. .. code-block:: python :caption: Demonstration of precison import torch # large number reduction x = 100 * torch.ones(1024).double() y = 100 * torch.ones(1024).double() out_double = (x @ y) out_fp32 = (x.float() @ y.float()) out_bf16 = (x.bfloat16() @ y.bfloat16()) out_fp16 = (x.half() @ y.half()) print(out_double.item(), out_fp32.item(), out_bf16.item(), out_fp16.item()) # 10240000.0 10240000.0 10223616.0 inf # small number reduction x = torch.ones(1024).double() y = torch.ones(1024).double() - 0.001 out_double = (x - y).abs().mean() out_fp32 = (x.float() - y.float()).abs().mean() out_bf16 = (x.bfloat16() - y.bfloat16()).abs().mean() out_fp16 = (x.half() - y.half()).abs().mean() print(out_double.item(), out_fp32.item(), out_bf16.item(), out_fp16.item()) # 0.0010000000000000009 0.0009999871253967285 0.0 0.0009765625 EMA ---