PyTorch Foundationは、オープンソースのPython向け機械学習ライブラリPyTorchの最新バージョンとなる「PyTorch 2.0」を、3月15日(現地時間)にリリースした。
「PyTorch 2.0」には、PyTorch Transformer APIのより高性能な実装となるAccelerated PT2 Transformerが含まれ、以前はBetter Transformerと呼ばれていたfastpath推論アーキテクチャを拡張した、scaled dot product attention(SPDA)用のカスタムカーネルアーキテクチャを使用したトレーニングと推論がサポートされている。
さらにメインAPIとして、モデルをラップしてコンパイル済みのモデルを返すtorch.compile(ベータ版)が実装された。torch.compileは、Pythonフレーム評価フックを使用してPyTorchプログラムを安全にキャプチャするTorchDynamo、PyTorchのautogradエンジンを事前の後方トレースを生成するためのトレースautodiffとしてオーバーロードするAOTAutograd、PyTorchオペレータをPyTorchバックエンドを構築するためのプリミティブオペレータにおけるクローズドセットに正規化するPrimTorch、複数のアクセラレータとバックエンド用の高速コードを生成するディープラーニングコンパイラTorchInductorの、4つのテクノロジーが基盤となっている。
ほかにも、MacプラットフォームにおいてGPUで高速化されたPyTorchトレーニングを提供するPyTorch MPSバックエンド(ベータ)、入力と使用中のハードウェアに応じてシームレスに適用できる複数の実装が含まれるscaled dot product attention関数のバージョン2.0(ベータ)、コンポーザブルなvmap(ベクトル化)とautodiff変換を提供するライブラリfunctorch(ベータ)、バックエンドをオプションの引数に変更するinit_process_group() APIの改良版となるDispatchable Collective(ベータ)など、数多くの機能追加が行われた。
- 関連リンク
この記事は参考になりましたか?
- この記事の著者
-
CodeZine編集部(コードジンヘンシュウブ)
CodeZineは、株式会社翔泳社が運営するソフトウェア開発者向けのWebメディアです。「デベロッパーの成長と課題解決に貢献するメディア」をコンセプトに、現場で役立つ最新情報を日々お届けします。
※プロフィールは、執筆時点、または直近の記事の寄稿時点での内容です