PyTorch 1.8發(fā)布,支持AMD GPU和Python函數(shù)轉(zhuǎn)換
近日,PyTorch 團隊發(fā)布了 PyTorch 1.8 版本。該版本整合了自去年 10 月 1.7 版本發(fā)布以來的 3000 多次 commit,提供了編譯、代碼優(yōu)化、科學計算前端 API 方面的主要更新和新特性。值得一提的是,該版本新增了對 AMD ROCm 的支持。
此外,PyTorch 1.8 版本還為大規(guī)模訓練 pipeline 和模型并行化、梯度壓縮提供了特性改進。該版本的主要亮點如下:
- 支持 Python 函數(shù)轉(zhuǎn)換;
- 添加或穩(wěn)定化 API,以支持 FFT (torch.fft)、線性代數(shù)函數(shù) (torch.linalg);
- 添加對復(fù)雜張量 autograd 的支持;
- 多項更新用于提升 Hessian 與 Jacobian 矩陣計算的性能;
- 改進分布式訓練,包括提升 NCCL 可靠性、支持 pipeline 并行化、RPC profiling,以及通過添加梯度壓縮來支持通信鉤子(hook)。
(注:從 PyTorch 1.6 版本開始,PyTorch 特性分為 Stable(穩(wěn)定版)、Beta(測試版)和 Prototype(原型版)。
此外,PyTorch 團隊還對多個 PyTorch 庫進行了主要更新,包括 TorchCSPRNG、TorchVision、TorchText 和 TorchAudio。
新增和更新 API
(1) 通過 torch.fft 支持 NumPy 兼容的 FFT 操作
PyTorch 1.7 版本提出了這一特性的 Beta 版,而在 1.8 版本中該特性更新為穩(wěn)定版。FFT 支持旨在完成 PyTorch 支持科學計算的目的。torch.fft 模塊和 NumPy 的 np.fft 模塊實現(xiàn)了同樣的功能,并且支持硬件加速和 autograd。
(2) 通過 torch.linalg 支持 NumPy 式的線性代數(shù)函數(shù)
torch.linalg 模塊類似于 NumPy 中的 np.linalg 模塊,支持 NumPy 式的線性代數(shù)操作,包括 Cholesky 分解、行列式、特征值等。
使用 torch.fx 進行 Python 代碼轉(zhuǎn)換
這一 Beta 特性支持 Python 代碼轉(zhuǎn)換,開發(fā)者可以利用它做 Conv/BN 融合、圖模式量化、實現(xiàn) vmap 等。鑒于 torch.fx 提供 PyTorch 代碼的圖表示,開發(fā)者可以用 Python 寫任意變換或分析。
分布式訓練
(1) pipeline 并行化
這一新增的 Beta 特性提供了一個易用的 PyTorch API,可將 pipeline 并行化作為訓練 loop 的一部分。
(2) DDP 通信鉤子
DDP 通信鉤子是一個通用接口,用于控制 workers 間的梯度通信。
此外,PyTorch 1.8 還增加了一些 prototype 特性:
- ZeroRedundancyOptimizer:有助于減少每個線程的內(nèi)存占用;
- Process Group NCCL Send/Recv:該特性允許用戶在 Python 層(而非 C++ 層)實現(xiàn)集合操作;
- CUDA-support in RPC using TensorPipe:該特性為使用 PyTorch RPC 和多 GPU 機器的用戶帶來速度提升;
- Remote Module:該特性允許用戶像操作本地模塊那樣操作遠程 worker 上的模塊。
支持 PyTorch Mobile
此次版本更新發(fā)布了一組新的移動端教程,包括在 iOS 端和安卓端實現(xiàn)圖像分割 DeepLabV3 模型。PyTorch 還發(fā)布了新的 demo app,包括圖像分割、目標檢測、神經(jīng)機器翻譯、問答和視覺 transformer。
此外,這次發(fā)布還包括 PyTorch Mobile Lite Interpreter,該解釋器可降低運行時二進制文件大小。
性能優(yōu)化
為了幫助用戶更好地監(jiān)控性能變化,PyTorch 1.8 版本支持 benchmark utils,并開放了新的自動量化 API——FX Graph Mode Quantization。
硬件支持
在硬件支持方面,PyTorch 1.8 版本新增了兩個 Beta 特性:
- 擴展 PyTorch Dispatcher,使之適應(yīng)新型 C++ 后端;支持 AMD ROCm。
- 需要注意的是,PyTorch 1.8 僅在 Linux 系統(tǒng)中支持 AMD ROCm。
參考鏈接:
- https://pytorch.org/blog/pytorch-1.8-released/
- https://github.com/pytorch/pytorch
- https://pytorch.org/
- https://twitter.com/cHHillee/status/1367621538791317504
【本文是51CTO專欄機構(gòu)“機器之心”的原創(chuàng)譯文,微信公眾號“機器之心( id: almosthuman2014)”】