Shong

Shong

AMP Mixed Precision Training

What is AMP#

By default, most deep learning frameworks use single precision (32-bit floating point) for training.

In 2017, NVIDIA combined single precision and half precision (16-bit floating point) during network training, achieving nearly the same accuracy as single precision with the same hyperparameters.

Half precision: 16bit, 1 bit sign bit, 5 bit exponent bit, 10 bit fraction bit

Single precision: 32bit, 1 bit sign bit, 8 bit exponent bit, 23 bit fraction bit

In PyTorch, there are a total of 10 types of tensors.

The default tensor is torch.FloatTensor (32bit floating point)

torch.FloatTensor(32bit floating point)
torch.DoubleTensor(64bit floating point)
torch.HalfTensor(16bit floating point)
torch.BFloat16Tensor(16bit floating point)
torch.ByteTensor(8bit integer(unsigned)
torch.CharTensor(8bit integer(signed))
torch.ShortTensor(16bit integer(signed))
torch.IntTensor(32bit integer(signed))
torch.LongTensor(64bit integer(signed))
torch.BoolTensor(Boolean)

Automatic mixed precision has two key points:

  • Automatic: The dtype of the tensor will change automatically, and the framework adjusts the dtype of the tensor automatically, although sometimes manual intervention is needed.
  • Mixed precision: Using tensors of more than one precision, torch.FloatTensor and torch.HalfTensor.

Why Use AMP#

Core: In some situations, FP16 is better, while in others, FP32 is better.

There are three advantages of FP16:

  • Reduced memory usage.
  • Faster training and inference (significantly reduced communication volume accelerates data flow).
  • The prevalence of tensor cores; low precision computation is an important trend.

The two major issues with FP16:

  • Overflow errors: FP16 has a very narrow dynamic range, making it prone to overflow and underflow. Once overflow occurs, it can easily lead to "NAN" issues. In deep learning, the gradient of the activation function is often smaller than the weight gradient, making underflow more likely. The smallest number representable by FP16 is 2242^{-24}, which can prevent weight updates.
  • Rounding errors: When the gradient is too small, less than the minimum interval within the current range, the gradient update may fail.
    For example 🌰: Under FP16, if the weight is 232^{-3} and the gradient is 2142^{-14}, the updated weight would be 23+214=232^{-3}+2^{-14}=2^{-3}, because the fixed interval for FP16 is 2132^{-13}, small gradients will be seen as no update.
If you don't understand this part, please refer to the last section: Data Representation.

😤 Therefore, to eliminate the issues with FP16, there are two solutions.

Mixed Precision Training#

Use FP16 for storage and multiplication in memory to accelerate computation, while using FP32 for accumulation to avoid rounding errors.

The strategy of mixed precision training effectively alleviates the problem of rounding errors.

Loss Scaling#

Even with mixed precision training, there can still be cases where convergence fails due to the activation gradient being too small, causing underflow.

This can be prevented by using torch.cuda.amp.GradScaler to amplify the loss value to prevent gradient underflow. ( 🤓 Note: The amplification of the loss here is only used when passing gradient information during backpropagation; when actually updating weights, the amplified gradients need to be scaled back down.)

How to Use AMP#

from torch.cuda.amp import autocast as autocast

model=Net().cuda()
optimizer=optim.SGD(model.parameters(),...)

scaler = GradScaler() # Instantiate a GradScaler object before training.

for epoch in epochs:
  for input,target in data:
    optimizer.zero_grad()

    with autocast():  # Enable autocast before and after
      output=model(input)
      loss = loss_fn(output,target)

    scaler.scale(loss).backward()  # For gradient amplification
    # scaler.step() First unscale the gradient value; if the gradient value is not inf or NaN, call optimizer.step() to update the weights; otherwise, ignore the step call to ensure weights are not updated.
    scaler.step(optimizer)
    scaler.update()  # Prepare to see if the scaler needs to be increased.

The size of the scaler is dynamically estimated in each iteration. To minimize gradient underflow, the scaler should gradually increase.

However, if it becomes too large, half-precision floating point can easily overflow (becoming inf or NaN).

Thus, the principle of dynamic estimation is to increase the scaler value as much as possible without encountering inf or NaN gradients.

In each scaler.step(optimizer), it checks for the presence of inf or NaN gradients:

  • If inf or NaN appears, scaler.step(optimizer) will ignore the weight update (optimizer.step()) and reduce the size of the scaler (by multiplying with backoff_factor).
  • If there are no inf or NaN, the weights are updated normally, and when there are multiple consecutive iterations (as specified by growth_interval) without inf or NaN, scaler.update() will increase the size of the scaler (by multiplying with growth_factor).

For distributed training, since autocast is thread-local (meaning the behavior and state of autocast are specific to the current independent thread),

For torch.nn.DataParallel and torch.nn.DistributedDataParallel,

it cannot be used as follows:

model = MyModel()
dp_model = nn.DataParallel(model)

with autocast():
    output=dp_model(input)
loss=loss_fn(output)

Instead, it should be used as follows:

class MyModel(nn.Module):
    @autocast()
    def forward(self, input):
        ...
        
# alternatively
class MyModel(nn.Module):
    def forward(self, input):
        with autocast():
            ...

model = MyModel()
dp_model=nn.DataParallel(model)

with autocast():
    output=dp_model(input)
    loss = loss_fn(output)
  • Ensure that autocast is present in every forward to ensure that each thread operates under autocast.
  • The loss also needs to be used under autocast.

Important Notes#

  • Check if the GPU supports FP16.
  • Constant range: To ensure calculations do not overflow, first ensure that manually set epsilon and INF do not overflow.
  • Dimensions are best as multiples of 8 for optimal performance ( 🤯).
  • Operations involving sum are prone to overflow; for softmax operations, it is recommended to use the official API and define it as a layer in the model initialization.
  • Some less commonly used functions need to be registered before use: 🌰 amp.register_float_function(torch, ‘sigmoid’).
  • Layers should be defined in the model's init function, and the graph should be defined in the forward function.
  • Certain functions do not support FP16 acceleration; it is recommended not to use them.
  • Modules that need to operate on gradients must be included in the optimizer's step; otherwise, AMP cannot determine if the grad is NaN.

Data Representation#

Von Neumann architecture: binary conception + five components (memory, controller, arithmetic unit, input, output).

Harvard architecture: The biggest difference is simultaneous access to data and instructions; the ARM architecture is a Harvard architecture.

Overflow Issues#

(lldb) print (233333 + 1) * (233333 + 1)
(int) $0 = -1389819292

(x+1)20(x+1)^2 \ge 0 is not always true because integers can overflow; int is only 32 bits.

The representation of floating-point numbers is different from integers and does not become negative due to overflow, but it has its own issues.

(lldb) print (1e20 + -1e20) + 3.14
(double) $0 = 3.1400000000000001
(lldb) print 1e20 + (-1e20 + 3.14)
(double) $1 = 0

This is due to the differences in floating-point addition and subtraction, which will be explained in detail below.

Bit Representation#

Everything seen in a computer is bits; each bit is either 0 or 1. Computers perform different tasks by encoding and describing bits in various ways.

From the perspective of analog circuits, this method of describing bits is well-stored and can maintain a high level of reliability even in the presence of noise or less accurate transmission.

Integer Types#

signed and unsigned

  • Unsigned: B2U(X)=i=0w1xi2iB2U(X)= \sum ^{w-1}_{i=0}x_i*2^i
  • Signed: B2T(X)=xw12w1+i=0w2xi2iB2T(X)= -x_{w-1}*2^{w-1}+\sum ^{w-2}_{i=0}x_i*2^i

The main difference between signed and unsigned numbers lies in the presence of the highest bit as the sign bit.

🤓 When converting between signed and unsigned numbers:

  • The value of each byte does not change; what changes is how the computer interprets the current value.
  • If an expression includes both signed and unsigned numbers, it will be implicitly converted to an unsigned number for comparison.

Type Extension and Truncation#

  • Extension: For example, from short int to int
    • Unsigned: Add 0.
    • Signed: Add the sign bit.
  • Truncation: For example, from unsigned to unsigned short, for small numbers, the expected result can be obtained.
    • Unsigned: mod operation.
    • Signed: Approximate mod operation.
short int x = 15213;
int ix = (int) x;
short int y = -15213;
int iy = (int) y;
DecimalHexadecimalBinary
x=152133B 6D00111011 01101101
ix=1521300 00 3B 6D00000000 00000000 00111011 01101101
y=-15213C4 9311000100 10010011
iy=-15213FF FF C4 9311111111 11111111 11000100 10010011

Integer operations and overflow:

  • Signed: Overflow changes the sign bit, positive to negative, negative to positive.
  • Unsigned: Overflow causes the high bit to become 0, meaning that adding results in a smaller value.

Floating Point Numbers#

Floating point numbers can be expressed using a unified formula:

k=jibk2k\sum_{k=-j}^{i}b_k*2^k

It can be seen that only the fractional part in the form of x2k\frac{x}{2^k} can be accurately represented.

IEEE floating point standard:

(1)sM2E(-1)^sM2^E

where s is the sign bit, determining the sign; M is usually a value in [1.0, 2.0), and E is the exponent.

Floating Point Numbers

Normalized values: When exp00,..,0and111,...,1exp \neq 00,..,0 \enspace and \enspace 111,...,1, they represent normalized values.

E is an offset value E=ExpBiasE=Exp-Bias.

  • ExpExp: is the unsigned value of the exp encoding area.
  • BiasBias: is the offset value of 2k112^{k-1}-1, where k is the number of bits in the exp encoding, meaning:
    • Single precision: 127
    • Double precision: 1023

For M, it must start with 1: that is, M=1.xxxx...x2M=1.xxxx...x_2, where xxxxxx is the encoded part of frac.

For example 🌰:

float F = 15213.0;

1521310=111011011011012=1.1101101101101221315213_{10}=11101101101101_{2}=1.1101101101101_{2} *2^{13}

The value of the frac part is the digits after the decimal point: 1101101101101.

Exp=E+Bias=13+127=140=100011002Exp = E + Bias = 13 + 127 = 140 =10001100_2

sexpfrac
01000110011011011011010000000000

🧐 Remember the non-normalized values mentioned earlier?

When exp=000,..,000exp = 000,..,000, the value is non-normalized, meaning that the originally continuous values on the real number axis are normalized to a finite set of values, and the spacing of these values is also uniform.

Unlike before, M=0.xxxx...x2M = 0.xxxx...x_2.

  • When exp=000...0exp=000...0, E=1BiasE=1-Bias.
    • frac=000...0frac=000...0 represents 0.
    • frac000...0frac\neq 000...0 represents a value close to 0.
  • When exp=111...1exp=111...1, E=n/aE = n/a.
    • frac=000...0frac=000...0 represents \infin.
    • frac000...0frac\neq 000...0 is not considered a numerical value, used to represent an indeterminate value (NaN).

Now I will use the number line to illustrate this issue 😤.

Number Line

Using the following 🌰 to illustrate this issue:

    s exp  frac   E   Value
------------------------------------------------------------------
    0 0000 000   -6   0   # This part is non-normalized; the next part is normalized values.
    0 0000 001   -6   1/8 * 1/64 = 1/512 # The closest value to zero that can be represented.
    0 0000 010   -6   2/8 * 1/64 = 2/512 
    ...
    0 0000 110   -6   6/8 * 1/64 = 6/512
    0 0000 111   -6   7/8 * 1/64 = 7/512 # The largest non-normalized value that can be represented.
------------------------------------------------------------------
    0 0001 000   -6   8/8 * 1/64 = 8/512 # The smallest normalized value that can be represented.
    0 0001 001   -6   9/8 * 1/64 = 9/512
    ...
    0 0110 110   -1   14/8 * 1/2 = 14/16
    0 0110 111   -1   15/8 * 1/2 = 15/16 # The closest value less than 1.
    0 0111 000    0   8/8 * 1 = 1
    0 0111 001    0   9/8 * 1 = 9/8      # The closest value greater than 1.
    0 0111 010    0   10/8 * 1 = 10/8
    ...
    0 1110 110    7   14/8 * 128 = 224
    0 1110 111    7   15/8 * 128 = 240   # The largest normalized value that can be represented.
------------------------------------------------------------------
    0 1111 000   n/a  Infinity               # Special value.

Floating Point Rounding#

For floating point addition and multiplication, we can first calculate the exact value and then convert it to the appropriate precision.

  Decimal    Binary     Rounding Result  Decimal    Reason
2 and 3/32  10.00011   10.00     2      Less than half, normal rounding.
2 and 3/16  10.00110   10.01  2 and 1/4   Exceeds half, normal rounding.
2 and 7/8   10.11100   11.00     3      Exactly at half, ensure the last digit is even, so round up.
2 and 5/8   10.10100   10.10  2 and 1/2   Exactly at half, ensure the last digit is even, so round down.

Floating Point Addition#

(1)s1M12E1+(1)s2M22E2(-1)^{s_1}M_12^{E_1}+(-1)^{s_2}M_22^{E_2}

Assuming E1>E2E_1>E_2, the result is (1)sM2E(-1)^{s}M2^{E}, where s=s1s2,M=M1+M2,E=E1s=s_1 \oplus s_2, \enspace M=M_1+M_2,\enspace E=E_1.

  • If M2M\ge 2, then shift M right and increase the value of E.
  • If M<1M< 1, shift M left k bits, and decrease E by k.
  • If E exceeds the representable range, overflow occurs.
  • Round M to the precision of frac.

Basic properties:

  • Addition may produce infinity or NaN.
  • Satisfies the commutative property.
  • Does not satisfy the associative property.
  • Adding 0 equals the original number.
  • Except for infinity or NaN, every element has a corresponding reciprocal.
  • Except for infinity or NaN, monotonicity is satisfied.

Floating Point Multiplication#

(1)s1M12E1(1)s2M22E2(-1)^{s_1}M_12^{E_1}*(-1)^{s_2}M_22^{E_2}

The result is (1)sM2E(-1)^{s}M2^{E}, where s=s1s2,M=M1M2,E=E1+E2s=s_1 \oplus s_2, \enspace M=M_1*M_2,\enspace E=E_1+E_2.

  • If M2M\ge 2, then shift M right and increase the value of E.
  • If E exceeds the representable range, overflow occurs.
  • Round M to the precision of frac.

Basic properties:

  • Addition may produce infinity or NaN.
  • Satisfies the commutative property.
  • Does not satisfy the associative property.
  • Multiplying by 1 equals the original number.
  • Except for infinity or NaN, monotonicity is satisfied.
Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.