Shong

Shong

AMP混合精度训练

什么是 AMP#

默认情况下,大多数深度学习框架都采用单精度(32 位浮点数)进行训练。

2017 年,nvidia 在训练网络时将单精度和半精度(16 位浮点数)结合在一起,使用相同的超参数实现了与单精度几乎相同的精度。

半精度:16bit,1 bit 符号位,5 bit 指数位,10 bit 分数位

单精度:32bit,1 bit 符号位,8 bit 指数位,23 bit 分数位

在 pytorch 中,一共有 10 种 tensor

默认 tensor 是 torch.FloatTensor (32bit floating point)

torch.FloatTensor(32bit floating point)
torch.DoubleTensor(64bit floating point)
torch.HalfTensor(16bit floating piont1)
torch.BFloat16Tensor(16bit floating piont2)
torch.ByteTensor(8bit integer(unsigned)
torch.CharTensor(8bit integer(signed))
torch.ShortTensor(16bit integer(signed))
torch.IntTensor(32bit integer(signed))
torch.LongTensor(64bit integer(signed))
torch.BoolTensor(Boolean)

自动混合精度有两个关键

  • 自动:tensor 的 dtype 类型会自动变化,框架自动调整 tensor 的 dtype,有些时候需要手动干预
  • 混合精度:采用不止一种精度的 tensor,torch.FloatTenso 和 torch.HalfTensor

为什么要使用 AMP#

core:某些情况 FP16 更好,某些情况 FP32 更好

FP16 优势有三个:

  • 减少内存使用
  • 加快训练和推断的使用 (通信量大幅减少加快数据流通)
  • 张量核心的普及,低精度计算是一个重要趋势

FP16 的两大问题:

  • 溢出错误:FP16 的动态范围很狭窄,容易出现上溢出和下溢出,溢出后就容易出现 “NAN” 的问题,在深度学习中,由于激活函数的梯度往往要比权重梯度小,更容易出现下溢出。FP16 所能表示的最小的数为 2242^{-24} ,会导致权重无法更新。
  • 舍入误差:当梯度过小时,小于当前区间内的最小间隔时,该次梯度更新可能会失败
    举个 🌰:FP16 下,权重为 232^{-3} ,梯度为 2142^{-14} ,更新权重为 23+214=232^{-3}+2^{-14}=2^{-3} ,因为 FP16 的固定间隔为 2132^{-13} ,因此小梯度会视作没更新。
这块如果看不懂指路,请看最后一节:数据表示

😤因此为了消除 FP16 的问题:有两种解决办法

混合精度训练#

在内存中用 FP16 做储存和乘法从而加速计算,而用 FP32 做累加避免舍入误差。

混合精度训练的策略有效地缓解了舍入误差的问题。

损失放大#

即使了混合精度训练,还是存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出。

可以通过使用 torch.cuda.amp.GradScaler,通过放大 loss 的值来防止梯度的下溢出。( 🤓注意:这个地方的 los 的放大只在 BP 时传递梯度信息才使用,真正更新权重时还是要把放大的梯度缩小回去)

如何使用 AMP#

from torch.cuda.amp import autocast as autocast

model=Net().cuda()
optimizer=optim.SGD(model.parameters(),...)

scaler = GradScaler() #训练前实例化一个GradScaler对象

for epoch in epochs:
  for input,target in data:
    optimizer.zero_grad()

    with autocast(): #前后开启autocast
      output=model(input)
      loss = loss_fn(output,targt)

    scaler.scale(loss).backward()  #为了梯度放大
    #scaler.step() 首先把梯度值unscale回来,如果梯度值不是inf或NaN,则调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新。
   scaler.step(optimizer)
    scaler.update()  #准备着,看是否要增大scaler

scaler 的大小在每次迭代中动态估计,为了尽可能减少梯度下溢出,scaler 应该逐渐变大

但太大,半精度浮点型又容易 overflow(变成 inf 或 NaN)

所以,动态估计原理就是在不出现 inf 或 NaN 梯度的情况下,尽可能的增大 scaler 值

在每次 scaler.step (optimizer) 中,都会检查是否有 inf 或 NaN 的梯度出现:

  • 如果出现 inf 或 NaN ,scaler.step (optimizer) 会忽略此次权重更新 (optimizer.step ()) ,并将 scaler 的大小缩小(乘上 backoff_factor)
  • 如果没有出现 inf 或 NaN , 那么权重正常更新,并且当连续多次 (growth_interval 指定) 没有出现 inf 或 NaN ,则 scaler.update () 会将 scaler 的大小增加 (乘上 growth_factor)

对于分布式训练,由于 autocast 是 thread local 的(也就是说 autocast 的行为和状态都是特定于当前独立线程的)

对于 torch.nn.DataParallel 和 torch.nn.DistributedDataParallel

就不能使用

model = MyModel()
dp_model = nn.DataParallel(model)

with autocast():
    output=dp_model(input)
loss=loss_fn(output)

而应该使用

MyModel(nn.Module):
    @autocast()
    def forward(self, input):
        ...
        
#alternatively
MyModel(nn.Module):
    def forward(self, input):
        with autocast():
            ...

model = MyModel()
dp_model=nn.DataParallel(model)

with autocast():
    output=dp_model(input)
    loss = loss_fn(output)
  • 保证在每个 forward 里面都有 autocast 才能确保每个线程都在 autocast 下
  • loss 也需要在 autocast 下使用

注意事例#

  • 判断 GPU 是否支持 FP16
  • 常数范围:为了保证计算不溢出,首先确保人工设定的 epsilon 和 INF 不溢出
  • Dimension 最好是 8 的倍数,性能最好( 🤯)
  • 涉及 sum 的操作容易溢出;softmax 操作建议用官方 API ,并定义成 layer 写在模型初始化中
  • 一些不常用的函数,使用前要注册:🌰 amp.register_float_function (torch, ‘sogmoid’)
  • Layer 写在模型 init 函数中,graph 写在 forward 中
  • 某些函数不支持 FP16 加速,建议不要使用
  • 需要操作梯度的模块必须在 optimizer 的 step 里,不然 AMP 不能判断 grad 是否为 NaN

数据表示#

冯诺依曼架构:二进制构想 + 五大组件(存储器,控制器,运算器,输入,输出)

哈佛架构:最大的区别在于同时访问数据和指令,ARM 架构就是哈佛架构

溢出问题#

(lldb) print (233333 + 1) * (233333 + 1)
(int) $0 = -1389819292

(x+1)20(x+1)^2 \ge 0 不是一定的,因为整数会溢出,int 只有 32 位

而浮点数的表示方法和整数不同,并不会因为出现溢出而变成负数,但是也有自己的问题

(lldb) print (1e20 + -1e20) + 3.14
(double) $0 = 3.1400000000000001
(lldb) print 1e20 + (-1e20 + 3.14)
(double) $1 = 0

这是因为浮点数的加减法的区别,下面会详细解释

比特心生#

计算机中看到的一切都是比特,每个比特不是 0 就是 1,计算机通过对比特进行不同方式的编码和描述来实现不同的任务

从模拟电路的角度来看,比特这种描述方法很好存储,并且在有噪声或者传输不那么准确的情况下也能保持比较高的可靠度。

整型 Integer#

sighedunsighed

  • 无符号数:B2U(X)=i=0w1xi2iB2U(X)= \sum ^{w-1}_{i=0}x_i*2^i
  • 有符号数:B2T(X)=xw12w1i=0w2xi2iB2T(X)= -x_{w-1}*2^{w-1}\sum ^{w-2}_{i=0}x_i*2^i

有符号数和无符号数的区别主要在于有没有最高位的符号位

🤓在进行有符号和无符号数的相互转换时:

  • 具体每一个字节的值不会改变,改变的事计算机解释当前值的方式
  • 如果一个表达式既包括有符号数也包含无符号数,那么会被隐式转换成无符号数进行比较

类型扩展和截取#

  • 扩展:例如从 short intint
    • 无符号数:加 0
    • 有符号数:加符号位
  • 截取:例如从unsighedunsighed short ,对于小的数字可以得到预期的结果
    • 无符号数:mod 操作
    • 有符号数:近似 mod 操作
short int x = 15213;
int ix = (int) x;
short int y = -15213;
int iy = (int) y;
十进制十六进制二进制
x=152133B 6D00111011 01101101
ix=1521300 00 3B 6D00000000 00000000 00111011 01101101
y=-15213C4 9311000100 10010011
iy=-15213FF FF C4 9311111111 11111111 11000100 10010011

整数运算与溢出

  • 有符号数:溢出是符号位变号,正变为负,负变为正
  • 无符号数:溢出是高位变 0,也就是想加反而变小

浮点数#

浮点数可以用一个统一的公式表达

k=jibk2k\sum_{k=-j}^{i}b_k*2^k

可以看见只有形为x2k\frac{x}{2^k}的小数部分可以被精确表示

IEEE 浮点数标准:

(1)sM2E(-1)^sM2^E

其中 s 是符号位,决定正负;M 通常是一个值在 [1.0, 2.0) 的小数,E 是次方数

浮点数

规范化值:exp00,..,0and111,...,1exp \neq 00,..,0 \enspace and \enspace 111,...,1 时,表示的都是规范化的值

E 是一个偏移的值 E=ExpBiasE=Exp-Bias

  • ExpExp :是 exp 编码区域的无符号数值
  • BiasBias :值为 2k112^{k-1}-1 的偏移量,其中 k 是 exp 编码的位数,也就是说
    • 单精度:127
    • 双精度:1023

而对于 M,一定是 1 开头的:也就是M=1.xxxx...x2M=1.xxxx...x_2 ,其中 xxxxxx 的部分就是 frac 的编码部分

举个🌰:

float F = 15213.0;

1521310=111011011011012=1.1101101101101221315213_{10}=11101101101101_{2}=1.1101101101101_{2} *2^{13}

frac 部分的值就是小数点后面的数值:1101101101101

Exp=E+Bias=13+127=140=100011002Exp = E + Bias = 13 + 127 = 140 =10001100_2

sexpfrac
01000110011011011011010000000000

🧐记得之前提到过的非规范化值吗(?

exp=000,..,000exp = 000,..,000 时,值是非规范化的,意思是实数轴上原来连续的值会被规范到有限的定值上,并且这些定值的间距也是一样的

和前面不同的是,M=0.xxxx...x2M = 0.xxxx...x_2

  • exp=000...0exp=000...0E=1BiasE=1-Bias
    • frac=000...0frac=000...0 时,表示 0
    • frac000...0frac\neq 000...0 时,数值是接近 0 的
  • exp=111...1exp=111...1E=n/aE = n/a
    • frac=000...0frac=000...0 时,表示 \infin
    • frac000...0frac\neq 000...0 时,不认为此时是一个数值,用来表示无法确定的值(NaN)

现在我用数轴来表示这个问题 😤

数轴

用下面一个🌰来说明这个问题

    s exp  frac   E   值
------------------------------------------------------------------
    0 0000 000   -6   0   # 这部分是非规范化数值,下一部分是规范化值
    0 0000 001   -6   1/8 * 1/64 = 1/512 # 能表示的最接近零的值
    0 0000 010   -6   2/8 * 1/64 = 2/512 
    ...
    0 0000 110   -6   6/8 * 1/64 = 6/512
    0 0000 111   -6   7/8 * 1/64 = 7/512 # 能表示的最大非规范化值
------------------------------------------------------------------
    0 0001 000   -6   8/8 * 1/64 = 8/512 # 能表示的最小规范化值
    0 0001 001   -6   9/8 * 1/64 = 9/512
    ...
    0 0110 110   -1   14/8 * 1/2 = 14/16
    0 0110 111   -1   15/8 * 1/2 = 15/16 # 最接近且小于 1 的值
    0 0111 000    0   8/8 * 1 = 1
    0 0111 001    0   9/8 * 1 = 9/8      # 最接近且大于 1 的值
    0 0111 010    0   10/8 * 1 = 10/8
    ...
    0 1110 110    7   14/8 * 128 = 224
    0 1110 111    7   15/8 * 128 = 240   # 能表示的最大规范化值
------------------------------------------------------------------
    0 1111 000   n/a  无穷               # 特殊值

浮点数舍入#

对于浮点数的加法和乘法而言,我们可以先计算出准确值,然后转换到合适的精度

  十进制    二进制     舍入结果  十进制    原因
23/32  10.00011   10.00     2      不到一半,正常四舍五入
23/16  10.00110   10.01  21/4   超过一般,正常四舍五入
27/8   10.11100   11.00     3      刚好在一半时,保证最后一位是偶数,所以向上舍入
25/8   10.10100   10.10  21/2   刚好在一半时,保证最后一位是偶数,所以向下舍入

浮点数加法#

(1)s1M12E1+(1)s2M22E2(-1)^{s_1}M_12^{E_1}+(-1)^{s_2}M_22^{E_2}

假设 E1>E2E_1>E_2 ,结果是(1)sM2E(-1)^{s}M2^{E} ,其中s=s1s2,M=M1+M2,E=E1s=s_1 \wedge s_2, \enspace M=M_1+M_2,\enspace E=E_1

  • 如果 M2M\ge 2 ,那么把 M 右移,并增加 E 的值
  • 如果 M<1M< 1 ,把 M 左移 k 位, E 减少 k
  • 如果 E 超出了可以表示的范围,溢出
  • 把 M 舍入到 frac 的精度

基本性质:

  • 相加可能产生 infinity 或者 NaN
  • 满足交换律
  • 不满足结合律
  • 加上 0 等于原来的数
  • 除了 infinity 或者 NaN ,每个元素都有对应的倒数
  • 除了 infinity 或者 NaN ,满足单调性

浮点数乘法#

(1)s1M12E1(1)s2M22E2(-1)^{s_1}M_12^{E_1}*(-1)^{s_2}M_22^{E_2}

结果是(1)sM2E(-1)^{s}M2^{E},其中s=s1s2,M=M1M2,E=E1+E2s=s_1 \wedge s_2, \enspace M=M_1*M_2,\enspace E=E_1+E_2

  • 如果 M2M\ge 2 ,那么把 M 右移,并增加 E 的值
  • 如果 E 超出了可以表示的范围,溢出
  • 把 M 舍入到 frac 的精度

基本性质:

  • 相加可能产生 infinity 或者 NaN
  • 满足交换律
  • 不满足结合律
  • 乘以 1 等于原来的数
  • 除了 infinity 或者 NaN ,满足单调性
加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。