AMP とは何か#
デフォルトでは、ほとんどの深層学習フレームワークは単精度(32 ビット浮動小数点数)でトレーニングを行います。
2017 年、nvidia はネットワークのトレーニング時に単精度と半精度(16 ビット浮動小数点数)を組み合わせ、同じハイパーパラメータを使用して単精度とほぼ同じ精度を実現しました。
半精度:16 ビット、1 ビット符号ビット、5 ビット指数ビット、10 ビット仮数ビット
単精度:32 ビット、1 ビット符号ビット、8 ビット指数ビット、23 ビット仮数ビット
pytorch には、合計 10 種類のテンソルがあります。
デフォルトのテンソルは torch.FloatTensor(32 ビット浮動小数点)です。
torch.FloatTensor(32bit floating point)
torch.DoubleTensor(64bit floating point)
torch.HalfTensor(16bit floating piont1)
torch.BFloat16Tensor(16bit floating piont2)
torch.ByteTensor(8bit integer(unsigned)
torch.CharTensor(8bit integer(signed))
torch.ShortTensor(16bit integer(signed))
torch.IntTensor(32bit integer(signed))
torch.LongTensor(64bit integer(signed))
torch.BoolTensor(Boolean)
自動混合精度には 2 つの重要なポイントがあります。
- 自動:テンソルの dtype タイプが自動的に変化し、フレームワークがテンソルの dtype を自動的に調整しますが、場合によっては手動で介入が必要です。
- 混合精度:複数の精度のテンソルを使用し、torch.FloatTensor と torch.HalfTensor を使用します。
なぜ AMP を使用するのか#
core:特定の状況では FP16 が優れており、他の状況では FP32 が優れています。
FP16 の利点は 3 つあります。
- メモリ使用量の削減
- トレーニングと推論の速度向上(通信量が大幅に減少し、データの流通が加速されます)
- テンソルコアの普及、低精度計算は重要なトレンドです。
FP16 の 2 つの大きな問題:
- オーバーフローエラー:FP16 の動的範囲は非常に狭く、オーバーフローやアンダーフローが発生しやすく、オーバーフロー後には「NAN」の問題が発生しやすいです。深層学習では、活性化関数の勾配が重みの勾配よりも小さいことが多く、アンダーフローが発生しやすくなります。FP16 が表現できる最小の数はで、重みが更新できなくなります。
- 丸め誤差:勾配が小さすぎる場合、現在の区間内の最小間隔よりも小さい場合、その勾配更新が失敗する可能性があります。
例を挙げると🌰:FP16 では、重みが、勾配がの場合、重みの更新はとなります。FP16 の固定間隔はであるため、小さな勾配は更新されていないと見なされます。
この部分が理解できない場合は、最後のセクション「データ表現」を参照してください。
😤したがって、FP16 の問題を解消するためには、2 つの解決策があります。
混合精度トレーニング#
メモリ内で FP16 を使用して保存と乗算を行い、計算を加速し、FP32 を使用して加算を行い、丸め誤差を回避します。
混合精度トレーニングの戦略は、丸め誤差の問題を効果的に緩和します。
損失の拡大#
混合精度トレーニングを使用しても、収束しない場合があります。その理由は、活性化勾配の値が小さすぎてアンダーフローを引き起こすためです。
torch.cuda.amp.GradScaler を使用して、損失の値を拡大することで勾配のアンダーフローを防ぐことができます。(🤓注意:この場所での損失の拡大は、BP 時に勾配情報を伝達するためにのみ使用され、実際に重みを更新する際には拡大された勾配を元に戻す必要があります。)
AMP の使用方法#
from torch.cuda.amp import autocast as autocast
model=Net().cuda()
optimizer=optim.SGD(model.parameters(),...)
scaler = GradScaler() # トレーニング前にGradScalerオブジェクトをインスタンス化します。
for epoch in epochs:
for input,target in data:
optimizer.zero_grad()
with autocast(): #前後でautocastを有効にします。
output=model(input)
loss = loss_fn(output,target)
scaler.scale(loss).backward() # 勾配を拡大するために
# scaler.step() まず勾配値を元に戻し、勾配値がinfまたはNaNでない場合、optimizer.step()を呼び出して重みを更新します。そうでない場合は、step呼び出しを無視し、重みが更新されないようにします。
scaler.step(optimizer)
scaler.update() # 準備をして、scalerを増やす必要があるかどうかを確認します。
scaler のサイズは各イテレーションで動的に推定され、勾配のアンダーフローをできるだけ減らすために、scaler は徐々に大きくなるべきです。
しかし、大きすぎると、半精度浮動小数点型はオーバーフロー(inf または NaN になる)しやすくなります。
したがって、動的推定の原理は、inf または NaN の勾配が発生しない範囲で、scaler の値をできるだけ大きくすることです。
各 scaler.step (optimizer) の中で、inf または NaN の勾配が発生しているかどうかを確認します。
- inf または NaN が発生した場合、scaler.step (optimizer) は今回の重みの更新(optimizer.step ())を無視し、scaler のサイズを縮小します(backoff_factor を掛けます)。
- inf または NaN が発生しなかった場合、重みは正常に更新され、連続して多くの回数(growth_interval で指定)inf または NaN が発生しなかった場合、scaler.update () は scaler のサイズを増加させます(growth_factor を掛けます)。
分散トレーニングの場合、autocast はスレッドローカルであるため(つまり、autocast の動作と状態は現在の独立したスレッドに特有のものです)。
torch.nn.DataParallel および torch.nn.DistributedDataParallel では使用できません。
model = MyModel()
dp_model = nn.DataParallel(model)
with autocast():
output=dp_model(input)
loss=loss_fn(output)
代わりに次のようにする必要があります。
MyModel(nn.Module):
@autocast()
def forward(self, input):
...
# あるいは
MyModel(nn.Module):
def forward(self, input):
with autocast():
...
model = MyModel()
dp_model=nn.DataParallel(model)
with autocast():
output=dp_model(input)
loss = loss_fn(output)
- 各 forward 内で autocast があることを保証することで、各スレッドが autocast の下で動作することを確保します。
- 損失も autocast の下で使用する必要があります。
注意事項#
- GPU が FP16 をサポートしているかどうかを確認します。
- 定数範囲:計算がオーバーフローしないように、まず手動で設定した epsilon と INF がオーバーフローしないことを確認します。
- 次元は 8 の倍数が望ましく、パフォーマンスが最も良いです(🤯)。
- sum を含む操作はオーバーフローしやすいです;softmax 操作は公式 API を使用し、レイヤーとしてモデルの初期化時に定義することをお勧めします。
- あまり使用されない関数は、使用前に登録する必要があります:🌰 amp.register_float_function (torch, ‘sogmoid’)
- レイヤーはモデルの init 関数内に書き、グラフは forward 内に書きます。
- 一部の関数は FP16 加速をサポートしていないため、使用しないことをお勧めします。
- 勾配を操作するモジュールは必ず optimizer の step 内に置く必要があり、そうでないと AMP は勾配が NaN かどうかを判断できません。
データ表現#
フォン・ノイマンアーキテクチャ:バイナリ構想 + 5 つのコンポーネント(メモリ、コントローラ、演算器、入力、出力)
ハーバードアーキテクチャ:最大の違いは、データと命令を同時にアクセスできることです。ARM アーキテクチャはハーバードアーキテクチャです。
オーバーフローの問題#
(lldb) print (233333 + 1) * (233333 + 1)
(int) $0 = -1389819292
は必ずしも成り立つわけではありません。整数はオーバーフローするため、int は 32 ビットしかありません。
浮動小数点数の表現方法は整数とは異なり、オーバーフローが発生しても負の数にはなりませんが、独自の問題があります。
(lldb) print (1e20 + -1e20) + 3.14
(double) $0 = 3.1400000000000001
(lldb) print 1e20 + (-1e20 + 3.14)
(double) $1 = 0
これは浮動小数点数の加減算の違いによるもので、以下で詳しく説明します。
ビットの心生#
コンピュータで見えるすべてはビットであり、各ビットは 0 または 1 です。コンピュータはビットをさまざまな方法でエンコードおよび記述することによって、さまざまなタスクを実行します。
アナログ回路の観点から見ると、ビットという記述方法は非常に保存しやすく、ノイズや伝送がそれほど正確でない場合でも比較的高い信頼性を保つことができます。
整数#
signed
とunsigned
- 符号なし数:
- 符号あり数:
符号あり数と符号なし数の違いは、主に最上位ビットの符号ビットの有無です。
🤓符号あり数と符号なし数の相互変換を行う際:
- 各バイトの値は変わりませんが、コンピュータが現在の値を解釈する方法が変わります。
- 式に符号あり数と符号なし数の両方が含まれている場合、暗黙的に符号なし数に変換されて比較されます。
型の拡張と切り捨て#
- 拡張:例えば
short int
からint
へ- 符号なし数:0 を加えます。
- 符号あり数:符号ビットを加えます。
- 切り捨て:例えば
unsigned
からunsigned short
へ。小さな数字の場合、期待される結果が得られます。- 符号なし数:mod 操作
- 符号あり数:近似 mod 操作
short int x = 15213;
int ix = (int) x;
short int y = -15213;
int iy = (int) y;
十進法 | 十六進法 | 二進法 |
---|---|---|
x=15213 | 3B 6D | 00111011 01101101 |
ix=15213 | 00 00 3B 6D | 00000000 00000000 00111011 01101101 |
y=-15213 | C4 93 | 11000100 10010011 |
iy=-15213 | FF FF C4 93 | 11111111 11111111 11000100 10010011 |
整数演算とオーバーフロー
- 符号あり数:オーバーフローは符号ビットが変わること、正から負、負から正に変わります。
- 符号なし数:オーバーフローは高位ビットが 0 に変わること、つまり加算しようとすると逆に小さくなります。
浮動小数点数#
浮動小数点数は、統一された公式で表現できます。
形がの小数部分のみが正確に表現できます。
IEEE 浮動小数点数標準:
ここで、s は符号ビットで、正負を決定します。M は通常 [1.0, 2.0) の範囲の小数で、E は指数です。
正規化された値: の場合、表現されるのはすべて正規化された値です。
E はオフセット値です。
- :exp エンコーディング領域の符号なし数値
- :値はのオフセットで、k は exp エンコーディングのビット数です。つまり、
- 単精度:127
- 倍精度:1023
M は必ず 1 で始まります。つまりで、の部分が frac のエンコーディング部分です。
例を挙げると🌰:
float F = 15213.0;
frac 部分の値は小数点以下の数値:1101101101101 です。
s | exp | frac |
---|---|---|
0 | 10001100 | 11011011011010000000000 |
🧐以前に言及した非正規化値を覚えていますか?
の場合、値は非正規化されており、実数軸上の連続的な値が有限の定値に正規化され、これらの定値の間隔も同じです。
前述のものとは異なり、です。
- の場合、
- の場合、0 を表します。
- の場合、値は 0 に近いです。
- の場合、
- の場合、を表します。
- の場合、この時点では数値とは見なされず、未定義の値(NaN)を表します。
今、数軸を使ってこの問題を示します😤
以下の🌰を使ってこの問題を説明します。
s exp frac E 値
------------------------------------------------------------------
0 0000 000 -6 0 # この部分は非正規化数値で、次の部分は正規化値です。
0 0000 001 -6 1/8 * 1/64 = 1/512 # 表現できる最も近いゼロの値
0 0000 010 -6 2/8 * 1/64 = 2/512
...
0 0000 110 -6 6/8 * 1/64 = 6/512
0 0000 111 -6 7/8 * 1/64 = 7/512 # 表現できる最大の非正規化値
------------------------------------------------------------------
0 0001 000 -6 8/8 * 1/64 = 8/512 # 表現できる最小の正規化値
0 0001 001 -6 9/8 * 1/64 = 9/512
...
0 0110 110 -1 14/8 * 1/2 = 14/16
0 0110 111 -1 15/8 * 1/2 = 15/16 # 1未満で最も近い値
0 0111 000 0 8/8 * 1 = 1
0 0111 001 0 9/8 * 1 = 9/8 # 1より大きい最も近い値
0 0111 010 0 10/8 * 1 = 10/8
...
0 1110 110 7 14/8 * 128 = 224
0 1110 111 7 15/8 * 128 = 240 # 表現できる最大の正規化値
------------------------------------------------------------------
0 1111 000 n/a 無限大 # 特殊値
浮動小数点数の丸め#
浮動小数点数の加算と乗算に関しては、正確な値を計算し、適切な精度に変換することができます。
十進法 二進法 丸め結果 十進法 原因
2 または 3/32 10.00011 10.00 2 半分に満たないため、通常の四捨五入
2 または 3/16 10.00110 10.01 2 または 1/4 一般的に超えているため、通常の四捨五入
2 または 7/8 10.11100 11.00 3 ちょうど半分の時、最後のビットが偶数になるように上に丸めます。
2 または 5/8 10.10100 10.10 2 または 1/2 ちょうど半分の時、最後のビットが偶数になるように下に丸めます。
浮動小数点数の加算#
と仮定すると、結果はとなり、ここでです。
- の場合、M を右にシフトし、E の値を増加させます。
- の場合、M を k ビット左にシフトし、E を k 減少させます。
- E が表現できる範囲を超えた場合、オーバーフローします。
- M を frac の精度に丸めます。
基本的な性質:
- 加算により infinity または NaN が発生する可能性があります。
- 交換法則が成り立ちます。
- 結合法則は成り立ちません。
- 0 を加えると元の数になります。
- infinity または NaN を除いて、各要素には対応する逆数があります。
- infinity または NaN を除いて、単調性が成り立ちます。
浮動小数点数の乗算#
結果はとなり、ここでです。
- の場合、M を右にシフトし、E の値を増加させます。
- E が表現できる範囲を超えた場合、オーバーフローします。
- M を frac の精度に丸めます。
基本的な性質:
- 加算により infinity または NaN が発生する可能性があります。
- 交換法則が成り立ちます。
- 結合法則は成り立ちません。
- 1 を掛けると元の数になります。
- infinity または NaN を除いて、単調性が成り立ちます。