What is AMP#
By default, most deep learning frameworks use single precision (32-bit floating point) for training.
In 2017, NVIDIA combined single precision and half precision (16-bit floating point) during network training, achieving nearly the same accuracy as single precision with the same hyperparameters.
Half precision: 16bit, 1 bit sign bit, 5 bit exponent bit, 10 bit fraction bit
Single precision: 32bit, 1 bit sign bit, 8 bit exponent bit, 23 bit fraction bit
In PyTorch, there are a total of 10 types of tensors.
The default tensor is torch.FloatTensor (32bit floating point)
torch.FloatTensor(32bit floating point)
torch.DoubleTensor(64bit floating point)
torch.HalfTensor(16bit floating point)
torch.BFloat16Tensor(16bit floating point)
torch.ByteTensor(8bit integer(unsigned)
torch.CharTensor(8bit integer(signed))
torch.ShortTensor(16bit integer(signed))
torch.IntTensor(32bit integer(signed))
torch.LongTensor(64bit integer(signed))
torch.BoolTensor(Boolean)
Automatic mixed precision has two key points:
- Automatic: The dtype of the tensor will change automatically, and the framework adjusts the dtype of the tensor automatically, although sometimes manual intervention is needed.
- Mixed precision: Using tensors of more than one precision, torch.FloatTensor and torch.HalfTensor.
Why Use AMP#
Core: In some situations, FP16 is better, while in others, FP32 is better.
There are three advantages of FP16:
- Reduced memory usage.
- Faster training and inference (significantly reduced communication volume accelerates data flow).
- The prevalence of tensor cores; low precision computation is an important trend.
The two major issues with FP16:
- Overflow errors: FP16 has a very narrow dynamic range, making it prone to overflow and underflow. Once overflow occurs, it can easily lead to "NAN" issues. In deep learning, the gradient of the activation function is often smaller than the weight gradient, making underflow more likely. The smallest number representable by FP16 is , which can prevent weight updates.
- Rounding errors: When the gradient is too small, less than the minimum interval within the current range, the gradient update may fail.
For example 🌰: Under FP16, if the weight is and the gradient is , the updated weight would be , because the fixed interval for FP16 is , small gradients will be seen as no update.
If you don't understand this part, please refer to the last section: Data Representation.
😤 Therefore, to eliminate the issues with FP16, there are two solutions.
Mixed Precision Training#
Use FP16 for storage and multiplication in memory to accelerate computation, while using FP32 for accumulation to avoid rounding errors.
The strategy of mixed precision training effectively alleviates the problem of rounding errors.
Loss Scaling#
Even with mixed precision training, there can still be cases where convergence fails due to the activation gradient being too small, causing underflow.
This can be prevented by using torch.cuda.amp.GradScaler to amplify the loss value to prevent gradient underflow. ( 🤓 Note: The amplification of the loss here is only used when passing gradient information during backpropagation; when actually updating weights, the amplified gradients need to be scaled back down.)
How to Use AMP#
from torch.cuda.amp import autocast as autocast
model=Net().cuda()
optimizer=optim.SGD(model.parameters(),...)
scaler = GradScaler() # Instantiate a GradScaler object before training.
for epoch in epochs:
for input,target in data:
optimizer.zero_grad()
with autocast(): # Enable autocast before and after
output=model(input)
loss = loss_fn(output,target)
scaler.scale(loss).backward() # For gradient amplification
# scaler.step() First unscale the gradient value; if the gradient value is not inf or NaN, call optimizer.step() to update the weights; otherwise, ignore the step call to ensure weights are not updated.
scaler.step(optimizer)
scaler.update() # Prepare to see if the scaler needs to be increased.
The size of the scaler is dynamically estimated in each iteration. To minimize gradient underflow, the scaler should gradually increase.
However, if it becomes too large, half-precision floating point can easily overflow (becoming inf or NaN).
Thus, the principle of dynamic estimation is to increase the scaler value as much as possible without encountering inf or NaN gradients.
In each scaler.step(optimizer), it checks for the presence of inf or NaN gradients:
- If inf or NaN appears, scaler.step(optimizer) will ignore the weight update (optimizer.step()) and reduce the size of the scaler (by multiplying with backoff_factor).
- If there are no inf or NaN, the weights are updated normally, and when there are multiple consecutive iterations (as specified by growth_interval) without inf or NaN, scaler.update() will increase the size of the scaler (by multiplying with growth_factor).
For distributed training, since autocast is thread-local (meaning the behavior and state of autocast are specific to the current independent thread),
For torch.nn.DataParallel and torch.nn.DistributedDataParallel,
it cannot be used as follows:
model = MyModel()
dp_model = nn.DataParallel(model)
with autocast():
output=dp_model(input)
loss=loss_fn(output)
Instead, it should be used as follows:
class MyModel(nn.Module):
@autocast()
def forward(self, input):
...
# alternatively
class MyModel(nn.Module):
def forward(self, input):
with autocast():
...
model = MyModel()
dp_model=nn.DataParallel(model)
with autocast():
output=dp_model(input)
loss = loss_fn(output)
- Ensure that autocast is present in every forward to ensure that each thread operates under autocast.
- The loss also needs to be used under autocast.
Important Notes#
- Check if the GPU supports FP16.
- Constant range: To ensure calculations do not overflow, first ensure that manually set epsilon and INF do not overflow.
- Dimensions are best as multiples of 8 for optimal performance ( 🤯).
- Operations involving sum are prone to overflow; for softmax operations, it is recommended to use the official API and define it as a layer in the model initialization.
- Some less commonly used functions need to be registered before use: 🌰 amp.register_float_function(torch, ‘sigmoid’).
- Layers should be defined in the model's init function, and the graph should be defined in the forward function.
- Certain functions do not support FP16 acceleration; it is recommended not to use them.
- Modules that need to operate on gradients must be included in the optimizer's step; otherwise, AMP cannot determine if the grad is NaN.
Data Representation#
Von Neumann architecture: binary conception + five components (memory, controller, arithmetic unit, input, output).
Harvard architecture: The biggest difference is simultaneous access to data and instructions; the ARM architecture is a Harvard architecture.
Overflow Issues#
(lldb) print (233333 + 1) * (233333 + 1)
(int) $0 = -1389819292
is not always true because integers can overflow; int is only 32 bits.
The representation of floating-point numbers is different from integers and does not become negative due to overflow, but it has its own issues.
(lldb) print (1e20 + -1e20) + 3.14
(double) $0 = 3.1400000000000001
(lldb) print 1e20 + (-1e20 + 3.14)
(double) $1 = 0
This is due to the differences in floating-point addition and subtraction, which will be explained in detail below.
Bit Representation#
Everything seen in a computer is bits; each bit is either 0 or 1. Computers perform different tasks by encoding and describing bits in various ways.
From the perspective of analog circuits, this method of describing bits is well-stored and can maintain a high level of reliability even in the presence of noise or less accurate transmission.
Integer Types#
signed
and unsigned
- Unsigned:
- Signed:
The main difference between signed and unsigned numbers lies in the presence of the highest bit as the sign bit.
🤓 When converting between signed and unsigned numbers:
- The value of each byte does not change; what changes is how the computer interprets the current value.
- If an expression includes both signed and unsigned numbers, it will be implicitly converted to an unsigned number for comparison.
Type Extension and Truncation#
- Extension: For example, from
short int
toint
- Unsigned: Add 0.
- Signed: Add the sign bit.
- Truncation: For example, from
unsigned
tounsigned short
, for small numbers, the expected result can be obtained.- Unsigned: mod operation.
- Signed: Approximate mod operation.
short int x = 15213;
int ix = (int) x;
short int y = -15213;
int iy = (int) y;
Decimal | Hexadecimal | Binary |
---|---|---|
x=15213 | 3B 6D | 00111011 01101101 |
ix=15213 | 00 00 3B 6D | 00000000 00000000 00111011 01101101 |
y=-15213 | C4 93 | 11000100 10010011 |
iy=-15213 | FF FF C4 93 | 11111111 11111111 11000100 10010011 |
Integer operations and overflow:
- Signed: Overflow changes the sign bit, positive to negative, negative to positive.
- Unsigned: Overflow causes the high bit to become 0, meaning that adding results in a smaller value.
Floating Point Numbers#
Floating point numbers can be expressed using a unified formula:
It can be seen that only the fractional part in the form of can be accurately represented.
IEEE floating point standard:
where s is the sign bit, determining the sign; M is usually a value in [1.0, 2.0), and E is the exponent.
Normalized values: When , they represent normalized values.
E is an offset value .
- : is the unsigned value of the exp encoding area.
- : is the offset value of , where k is the number of bits in the exp encoding, meaning:
- Single precision: 127
- Double precision: 1023
For M, it must start with 1: that is, , where is the encoded part of frac.
For example 🌰:
float F = 15213.0;
The value of the frac part is the digits after the decimal point: 1101101101101.
s | exp | frac |
---|---|---|
0 | 10001100 | 11011011011010000000000 |
🧐 Remember the non-normalized values mentioned earlier?
When , the value is non-normalized, meaning that the originally continuous values on the real number axis are normalized to a finite set of values, and the spacing of these values is also uniform.
Unlike before, .
- When , .
- represents 0.
- represents a value close to 0.
- When , .
- represents .
- is not considered a numerical value, used to represent an indeterminate value (NaN).
Now I will use the number line to illustrate this issue 😤.
Using the following 🌰 to illustrate this issue:
s exp frac E Value
------------------------------------------------------------------
0 0000 000 -6 0 # This part is non-normalized; the next part is normalized values.
0 0000 001 -6 1/8 * 1/64 = 1/512 # The closest value to zero that can be represented.
0 0000 010 -6 2/8 * 1/64 = 2/512
...
0 0000 110 -6 6/8 * 1/64 = 6/512
0 0000 111 -6 7/8 * 1/64 = 7/512 # The largest non-normalized value that can be represented.
------------------------------------------------------------------
0 0001 000 -6 8/8 * 1/64 = 8/512 # The smallest normalized value that can be represented.
0 0001 001 -6 9/8 * 1/64 = 9/512
...
0 0110 110 -1 14/8 * 1/2 = 14/16
0 0110 111 -1 15/8 * 1/2 = 15/16 # The closest value less than 1.
0 0111 000 0 8/8 * 1 = 1
0 0111 001 0 9/8 * 1 = 9/8 # The closest value greater than 1.
0 0111 010 0 10/8 * 1 = 10/8
...
0 1110 110 7 14/8 * 128 = 224
0 1110 111 7 15/8 * 128 = 240 # The largest normalized value that can be represented.
------------------------------------------------------------------
0 1111 000 n/a Infinity # Special value.
Floating Point Rounding#
For floating point addition and multiplication, we can first calculate the exact value and then convert it to the appropriate precision.
Decimal Binary Rounding Result Decimal Reason
2 and 3/32 10.00011 10.00 2 Less than half, normal rounding.
2 and 3/16 10.00110 10.01 2 and 1/4 Exceeds half, normal rounding.
2 and 7/8 10.11100 11.00 3 Exactly at half, ensure the last digit is even, so round up.
2 and 5/8 10.10100 10.10 2 and 1/2 Exactly at half, ensure the last digit is even, so round down.
Floating Point Addition#
Assuming , the result is , where .
- If , then shift M right and increase the value of E.
- If , shift M left k bits, and decrease E by k.
- If E exceeds the representable range, overflow occurs.
- Round M to the precision of frac.
Basic properties:
- Addition may produce infinity or NaN.
- Satisfies the commutative property.
- Does not satisfy the associative property.
- Adding 0 equals the original number.
- Except for infinity or NaN, every element has a corresponding reciprocal.
- Except for infinity or NaN, monotonicity is satisfied.
Floating Point Multiplication#
The result is , where .
- If , then shift M right and increase the value of E.
- If E exceeds the representable range, overflow occurs.
- Round M to the precision of frac.
Basic properties:
- Addition may produce infinity or NaN.
- Satisfies the commutative property.
- Does not satisfy the associative property.
- Multiplying by 1 equals the original number.
- Except for infinity or NaN, monotonicity is satisfied.