Shong

Shong

AMP混合精度訓練

什麼是 AMP#

默認情況下,大多數深度學習框架都採用單精度(32 位浮點數)進行訓練。

2017 年,nvidia 在訓練網絡時將單精度和半精度(16 位浮點數)結合在一起,使用相同的超參數實現了與單精度幾乎相同的精度。

半精度:16bit,1 bit 符號位,5 bit 指數位,10 bit 分數位

單精度:32bit,1 bit 符號位,8 bit 指數位,23 bit 分數位

在 pytorch 中,一共有 10 種 tensor

默認 tensor 是 torch.FloatTensor (32bit floating point)

torch.FloatTensor(32bit floating point)
torch.DoubleTensor(64bit floating point)
torch.HalfTensor(16bit floating piont1)
torch.BFloat16Tensor(16bit floating piont2)
torch.ByteTensor(8bit integer(unsigned)
torch.CharTensor(8bit integer(signed))
torch.ShortTensor(16bit integer(signed))
torch.IntTensor(32bit integer(signed))
torch.LongTensor(64bit integer(signed))
torch.BoolTensor(Boolean)

自動混合精度有兩個關鍵

  • 自動:tensor 的 dtype 類型會自動變化,框架自動調整 tensor 的 dtype,有些時候需要手動干預
  • 混合精度:採用不止一種精度的 tensor,torch.FloatTenso 和 torch.HalfTensor

為什麼要使用 AMP#

core:某些情況 FP16 更好,某些情況 FP32 更好

FP16 優勢有三個:

  • 減少內存使用
  • 加快訓練和推斷的使用 (通信量大幅減少加快數據流通)
  • 張量核心的普及,低精度計算是一個重要趨勢

FP16 的兩大問題:

  • 溢出錯誤:FP16 的動態範圍很狹窄,容易出現上溢出和下溢出,溢出後就容易出現 “NAN” 的問題,在深度學習中,由於激活函數的梯度往往要比權重梯度小,更容易出現下溢出。FP16 所能表示的最小的數為 2242^{-24} ,會導致權重無法更新。
  • 舍入誤差:當梯度過小時,小於當前區間內的最小間隔時,該次梯度更新可能會失敗
    舉個 🌰:FP16 下,權重為 232^{-3} ,梯度為 2142^{-14} ,更新權重為 23+214=232^{-3}+2^{-14}=2^{-3} ,因為 FP16 的固定間隔為 2132^{-13} ,因此小梯度會視作沒更新。
這塊如果看不懂指路,請看最後一節:數據表示

😤因此為了消除 FP16 的問題:有兩種解決辦法

混合精度訓練#

在內存中用 FP16 做儲存和乘法從而加速計算,而用 FP32 做累加避免舍入誤差。

混合精度訓練的策略有效地緩解了舍入誤差的問題。

損失放大#

即使了混合精度訓練,還是存在無法收斂的情況,原因是激活梯度的值太小,造成了下溢出。

可以通過使用 torch.cuda.amp.GradScaler,通過放大 loss 的值來防止梯度的下溢出。( 🤓注意:這個地方的 los 的放大只在 BP 時傳遞梯度信息才使用,真正更新權重時還是要把放大的梯度縮小回去)

如何使用 AMP#

from torch.cuda.amp import autocast as autocast

model=Net().cuda()
optimizer=optim.SGD(model.parameters(),...)

scaler = GradScaler() #訓練前實例化一個GradScaler對象

for epoch in epochs:
  for input,target in data:
    optimizer.zero_grad()

    with autocast(): #前後開啟autocast
      output=model(input)
      loss = loss_fn(output,targt)

    scaler.scale(loss).backward()  #為了梯度放大
    #scaler.step() 首先把梯度值unscale回來,如果梯度值不是inf或NaN,則調用optimizer.step()來更新權重,否則,忽略step調用,從而保證權重不更新。
   scaler.step(optimizer)
    scaler.update()  #準備著,看是否要增大scaler

scaler 的大小在每次迭代中動態估計,為了盡可能減少梯度下溢出,scaler 應該逐漸變大

但太大,半精度浮點型又容易 overflow(變成 inf 或 NaN)

所以,動態估計原理就是在不出現 inf 或 NaN 梯度的情況下,盡可能的增大 scaler 值

在每次 scaler.step (optimizer) 中,都会检查是否有 inf 或 NaN 的梯度出现:

  • 如果出现 inf 或 NaN ,scaler.step (optimizer) 会忽略此次权重更新 (optimizer.step ()) ,并将 scaler 的大小缩小(乘上 backoff_factor)
  • 如果没有出现 inf 或 NaN , 那么权重正常更新,并且当连续多次 (growth_interval 指定) 没有出现 inf 或 NaN ,则 scaler.update () 会将 scaler 的大小增加 (乘上 growth_factor)

对于分布式训练,由于 autocast 是 thread local 的(也就是说 autocast 的行为和状态都是特定于当前独立线程的)

对于 torch.nn.DataParallel 和 torch.nn.DistributedDataParallel

就不能使用

model = MyModel()
dp_model = nn.DataParallel(model)

with autocast():
    output=dp_model(input)
loss=loss_fn(output)

而应该使用

MyModel(nn.Module):
    @autocast()
    def forward(self, input):
        ...
        
#alternatively
MyModel(nn.Module):
    def forward(self, input):
        with autocast():
            ...

model = MyModel()
dp_model=nn.DataParallel(model)

with autocast():
    output=dp_model(input)
    loss = loss_fn(output)
  • 保證在每個 forward 裡面都有 autocast 才能確保每個線程都在 autocast 下
  • loss 也需要在 autocast 下使用

注意事例#

  • 判斷 GPU 是否支持 FP16
  • 常數範圍:為了保證計算不溢出,首先確保人工設定的 epsilon 和 INF 不溢出
  • Dimension 最好是 8 的倍數,性能最好( 🤯)
  • 涉及 sum 的操作容易溢出;softmax 操作建議用官方 API ,並定義成 layer 寫在模型初始化中
  • 一些不常用的函數,使用前要註冊:🌰 amp.register_float_function (torch, ‘sogmoid’)
  • Layer 寫在模型 init 函數中,graph 寫在 forward 中
  • 某些函數不支持 FP16 加速,建議不要使用
  • 需要操作梯度的模塊必須在 optimizer 的 step 裡,不然 AMP 不能判斷 grad 是否為 NaN

數據表示#

馮諾依曼架構:二進制構想 + 五大組件(存儲器,控制器,運算器,輸入,輸出)

哈佛架構:最大的區別在於同時訪問數據和指令,ARM 架構就是哈佛架構

溢出問題#

(lldb) print (233333 + 1) * (233333 + 1)
(int) $0 = -1389819292

(x+1)20(x+1)^2 \ge 0 不是一定的,因為整數會溢出,int 只有 32 位

而浮點數的表示方法和整數不同,並不會因為出現溢出而變成負數,但是也有自己的問題

(lldb) print (1e20 + -1e20) + 3.14
(double) $0 = 3.1400000000000001
(lldb) print 1e20 + (-1e20 + 3.14)
(double) $1 = 0

這是因為浮點數的加減法的區別,下面會詳細解釋

比特心生#

計算機中看到的一切都是比特,每個比特不是 0 就是 1,計算機通過對比特進行不同方式的編碼和描述來實現不同的任務

從模擬電路的角度來看,比特這種描述方法很好存儲,並且在有噪聲或者傳輸不那麼準確的情況下也能保持比較高的可靠度。

整型 Integer#

sighedunsighed

  • 無符號數:B2U(X)=i=0w1xi2iB2U(X)= \sum ^{w-1}_{i=0}x_i*2^i
  • 有符號數:B2T(X)=xw12w1i=0w2xi2iB2T(X)= -x_{w-1}*2^{w-1}\sum ^{w-2}_{i=0}x_i*2^i

有符號數和無符號數的區別主要在於有沒有最高位的符號位

🤓在進行有符號和無符號數的相互轉換時:

  • 具體每一個字節的值不會改變,改變的是計算機解釋當前值的方式
  • 如果一個表達式既包括有符號數也包含無符號數,那麼會被隱式轉換成無符號數進行比較

類型擴展和截取#

  • 擴展:例如從 short intint
    • 無符號數:加 0
    • 有符號數:加符號位
  • 截取:例如從unsighedunsighed short ,對於小的數字可以得到預期的結果
    • 無符號數:mod 操作
    • 有符號數:近似 mod 操作
short int x = 15213;
int ix = (int) x;
short int y = -15213;
int iy = (int) y;
十進制十六進制二進制
x=152133B 6D00111011 01101101
ix=1521300 00 3B 6D00000000 00000000 00111011 01101101
y=-15213C4 9311000100 10010011
iy=-15213FF FF C4 9311111111 11111111 11000100 10010011

整數運算與溢出

  • 有符號數:溢出是符號位變號,正變為負,負變為正
  • 無符號數:溢出是高位變 0,也就是想加反而變小

浮點數#

浮點數可以用一個統一的公式表達

k=jibk2k\sum_{k=-j}^{i}b_k*2^k

可以看見只有形為x2k\frac{x}{2^k}的小數部分可以被精確表示

IEEE 浮點數標準:

(1)sM2E(-1)^sM2^E

其中 s 是符號位,決定正負;M 通常是一個值在 [1.0, 2.0) 的小數,E 是次方數

浮點數

規範化值:exp00,..,0and111,...,1exp \neq 00,..,0 \enspace and \enspace 111,...,1 時,表示的都是規範化的值

E 是一個偏移的值 E=ExpBiasE=Exp-Bias

  • ExpExp :是 exp 編碼區域的無符號數值
  • BiasBias :值為 2k112^{k-1}-1 的偏移量,其中 k 是 exp 編碼的位數,也就是
    • 單精度:127
    • 雙精度:1023

而對於 M,一定是 1 開頭的:也就是M=1.xxxx...x2M=1.xxxx...x_2 ,其中 xxxxxx 的部分就是 frac 的編碼部分

舉個🌰:

float F = 15213.0;

1521310=111011011011012=1.1101101101101221315213_{10}=11101101101101_{2}=1.1101101101101_{2} *2^{13}

frac 部分的值就是小數點後面的數值:1101101101101

Exp=E+Bias=13+127=140=100011002Exp = E + Bias = 13 + 127 = 140 =10001100_2

sexpfrac
01000110011011011011010000000000

🧐記得之前提到過的非規範化值嗎(?

exp=000,..,000exp = 000,..,000 時,值是非規範化的,意思是實數軸上原來連續的值會被規範到有限的定值上,並且這些定值的間距也是一樣的

和前面不同的是,M=0.xxxx...x2M = 0.xxxx...x_2

  • exp=000...0exp=000...0E=1BiasE=1-Bias
    • frac=000...0frac=000...0 時,表示 0
    • frac000...0frac\neq 000...0 時,數值是接近 0 的
  • exp=111...1exp=111...1E=n/aE = n/a
    • frac=000...0frac=000...0 時,表示 \infin
    • frac000...0frac\neq 000...0 時,不認為此時是一個數值,用來表示無法確定的值(NaN)

現在我用數軸來表示這個問題 😤

數軸

用下面一個🌰來說明這個問題

    s exp  frac   E   值
------------------------------------------------------------------
    0 0000 000   -6   0   # 這部分是非規範化數值,下一部分是規範化值
    0 0000 001   -6   1/8 * 1/64 = 1/512 # 能表示的最接近零的值
    0 0000 010   -6   2/8 * 1/64 = 2/512 
    ...
    0 0000 110   -6   6/8 * 1/64 = 6/512
    0 0000 111   -6   7/8 * 1/64 = 7/512 # 能表示的最大非規範化值
------------------------------------------------------------------
    0 0001 000   -6   8/8 * 1/64 = 8/512 # 能表示的最小規範化值
    0 0001 001   -6   9/8 * 1/64 = 9/512
    ...
    0 0110 110   -1   14/8 * 1/2 = 14/16
    0 0110 111   -1   15/8 * 1/2 = 15/16 # 最接近且小於 1 的值
    0 0111 000    0   8/8 * 1 = 1
    0 0111 001    0   9/8 * 1 = 9/8      # 最接近且大於 1 的值
    0 0111 010    0   10/8 * 1 = 10/8
    ...
    0 1110 110    7   14/8 * 128 = 224
    0 1110 111    7   15/8 * 128 = 240   # 能表示的最大規範化值
------------------------------------------------------------------
    0 1111 000   n/a  無窮               # 特殊值

浮點數舍入#

對於浮點數的加法和乘法而言,我們可以先計算出準確值,然後轉換到合適的精度

  十進制    二進制     舍入結果  十進制    原因
23/32  10.00011   10.00     2      不到一半,正常四舍五入
23/16  10.00110   10.01  21/4   超過一般,正常四舍五入
27/8   10.11100   11.00     3      剛好在一半時,保證最後一位是偶數,所以向上舍入
25/8   10.10100   10.10  21/2   剛好在一半時,保證最後一位是偶數,所以向下舍入

浮點數加法#

(1)s1M12E1+(1)s2M22E2(-1)^{s_1}M_12^{E_1}+(-1)^{s_2}M_22^{E_2}

假設 E1>E2E_1>E_2 ,結果是(1)sM2E(-1)^{s}M2^{E} ,其中s=s1s2,M=M1+M2,E=E1s=s_1 \wedge s_2, \enspace M=M_1+M_2,\enspace E=E_1

  • 如果 M2M\ge 2 ,那麼把 M 右移,並增加 E 的值
  • 如果 M<1M< 1 ,把 M 左移 k 位, E 減少 k
  • 如果 E 超出了可以表示的範圍,溢出
  • 把 M 舍入到 frac 的精度

基本性質:

  • 相加可能產生 infinity 或者 NaN
  • 滿足交換律
  • 不滿足結合律
  • 加上 0 等於原來的數
  • 除了 infinity 或者 NaN ,每個元素都有對應的倒數
  • 除了 infinity 或者 NaN ,滿足單調性

浮點數乘法#

(1)s1M12E1(1)s2M22E2(-1)^{s_1}M_12^{E_1}*(-1)^{s_2}M_22^{E_2}

結果是(1)sM2E(-1)^{s}M2^{E},其中s=s1s2,M=M1M2,E=E1+E2s=s_1 \wedge s_2, \enspace M=M_1*M_2,\enspace E=E_1+E_2

  • 如果 M2M\ge 2 ,那麼把 M 右移,並增加 E 的值
  • 如果 E 超出了可以表示的範圍,溢出
  • 把 M 舍入到 frac 的精度

基本性質:

  • 相加可能產生 infinity 或者 NaN
  • 滿足交換律
  • 不滿足結合律
  • 乘以 1 等於原來的數
  • 除了 infinity 或者 NaN ,滿足單調性
載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。