Skip to content

BERT GPU 保持 & バッチ推論

GPU 経路では2つが主な最適化ポイントです。

  • オリジナル SBV2 の不要な CPU 転送を除去 して BERT 出力を GPU テンソルのまま維持すること
  • 多文をバッチにまとめて BERT を1回だけ呼び出し て kernel launch overhead(GPU に演算を要求するたびに発生する固定コスト)を削減すること

なぜ問題か

オリジナル SBV2 は基本的に テキストをまるごと一度に合成 します(line_split=False)。

BERT も1回だけ呼び出されるためバッチ化の必要がない構造でした。

HayaKoe は prosody(韻律)の安定性のために 句読点基準の文分割 を導入し、それに伴い BERT が文の数だけ繰り返し呼び出される問題が新たに発生しました。

BERT 出力の不要な CPU 転送

オリジナル SBV2 の BERT feature 抽出コードには以下のような部分があります。

python
# オリジナル SBV2 (style_bert_vits2/nlp/japanese/bert_feature.py)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()

BERT を GPU で forward した後、出力テンソルに .cpu() を呼び出して 毎回 CPU に下ろしていました。

この出力はその後 Synthesizer に渡されますが、Synthesizer は GPU で動作するため再び GPU にアップロードする必要があります。

結果的に文ごとに GPU → CPU → GPU の往復 が発生し、この不要な往復自体がボトルネックになります。

文別の個別 BERT 呼び出し

多文を分割した後、各文に対して BERT を別々に呼び出すと GPU kernel launch が文の数だけ繰り返されます。

kernel launch は GPU に演算を要求するたびに発生する固定コストです。

文が短ければ実際の計算時間より launch overhead の比重が大きくなり、文の数に比例して非効率が累積します。

実装

.cpu() の除去 — GPU テンソル維持

オリジナルの .cpu() 呼び出しを除去し、BERT 出力が GPU テンソルのまま Synthesizer に渡されるよう修正しました。

python
# オリジナル SBV2
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()    # GPU → CPU

# HayaKoe
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float()  # GPU 維持

BERT モデル自体も prepare() 時点で GPU にロードされ、推論が終わるまで維持されます。

BERT モデルは グローバルシングルトン として管理されるため、話者を複数ロードしても BERT は一度だけロードされ全話者が共有します。

多文 BERT バッチ化

HayaKoe が使用する BERT(DeBERTa)は HuggingFace Transformer モデルなので基本的に batch 入力をサポートします。

これを活用して、多文合成時に各文の BERT を個別呼び出しする代わりに すべての文をひとつのバッチにまとめて1回で処理 します。

複数の文を tokenizer に一括で入れて padding されたバッチ入力を作り、BERT を 1回だけ forward します。

ONNX 経路でも同一のバッチロジックが実装されています。

改善効果

GPU バッチ推論速度

同一ハードウェアでの シーケンシャル (sequential) vs バッチ (batched) 比較です(5回平均)。

文数シーケンシャルバッチ速度向上
20.447 s0.364 s1.23x
40.812 s0.566 s1.43x
81.598 s1.121 s1.43x
162.972 s2.264 s1.31x

kernel launch オーバーヘッドが1回に統合された効果で、+23% ~ +43% の速度向上が見られます。

GPU メモリ

バッチ化がメモリを追加消費しないか確認しました。

文数シーケンシャル peakバッチ peak差分
21,662.2 MB1,661.9 MB-0.3 MB
41,661.8 MB1,662.2 MB+0.4 MB
81,697.7 MB1,699.0 MB+1.3 MB
161,934.3 MB1,934.3 MB0 MB

シーケンシャルとバッチの差は 1.3 MB 以内 で、事実上同一です。

CPU では効果なし

同じ実験を CPU(ONNX)で繰り返すとバッチ化の効果がほぼ現れません。

文数シーケンシャルバッチ速度差
22.566 s2.564 s1.00x
45.464 s4.855 s1.13x
810.647 s11.783 s0.90x
1624.559 s24.195 s1.01x

ONNX Runtime のグラフ最適化が既に十分に強く Python レベルの dispatch overhead がボトルネックではなく、バッチ時の padding オーバーヘッドが利得を相殺します。

GPU ではバッチ化を維持し、CPU では利得・損失ともに大きくないためバックエンド間のコード単一性のために同一経路を維持しています。