谷歌狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇
谷歌力推的JAX在最近的基準測試中性能已經超過Pytorch和TensorFlow,7項指標排名第一。
而且測試并不是在JAX性能表現最好的TPU上完成的。
雖然現在在開發者中,Pytorch依然比Tensorflow更受歡迎。
但未來,也許有更多的大模型會基于JAX平臺進行訓練和運行。
模型
最近,Keras團隊為三個后端(TensorFlow、JAX、PyTorch)與原生PyTorch實現以及搭配TensorFlow的Keras 2進行了基準測試。
首先,他們為生成式和非生成式人工智能任務選擇了一組主流的計算機視覺和自然語言處理模型:
對于模型的Keras版本,其采用了KerasCV和KerasNLP中已有的實現進行構建。而對于原生的PyTorch版本,則選擇了網絡上最流行的幾個選項:
- 來自HuggingFace Transformers的BERT、Gemma、Mistral
- 來自HuggingFace Diffusers的StableDiffusion
- 來自Meta的SegmentAnything
他們將這組模型稱作「Native PyTorch」,以便與使用PyTorch后端的Keras 3版本進行區分。
他們對所有基準測試都使用了合成數據,并在所有LLM訓練和推理中使用了bfloat16精度,同時在所有LLM訓練中使用了LoRA(微調)。
根據PyTorch團隊的建議,他們在原生PyTorch實現中使用了torch.compile(model, mode="reduce-overhead")(由于不兼容,Gemma和Mistral訓練除外)。
為了衡量開箱即用的性能,他們使用高級API(例如HuggingFace的Trainer()、標準PyTorch訓練循環和Keras model.fit()),并盡可能減少配置。
硬件配置
所有基準測試均使用Google Cloud Compute Engine進行,配置為:一塊擁有40GB顯存的NVIDIA A100 GPU、12個虛擬CPU和85GB的主機內存。
基準測試結果
表2顯示了基準測試結果(以步/毫秒為單位)。每步都涉及對單個數據批次進行訓練或預測。
結果是100步的平均值,但排除了第一個步,因為第一步包括了模型創建和編譯,這會額外花費時間。
為了確保比較的公平性,對于相同的模型和任務(不論是訓練還是推理)都使用相同的批大小。
然而,對于不同的模型和任務,由于它們的規模和架構有所不同,可根據需要調整數據批大小,從而避免因過大而導致內存溢出,或是批過小而導致GPU使用不足。
過小的批大小也會使PyTorch看起來較慢,因為會增加Python的開銷。
對于大型語言模型(Gemma和Mistral),測試時也使用了相同的批處理大小,因為它們是相同類型的模型,具有類似數量的參數(7B)。
考慮到用戶對單批文本生成的需求,也對批大小為1的文本生成情況進行了基準測試。
關鍵發現
發現1
不存在「最優」后端。
Keras的三種后端各展所長,重要的是,就性能而言,并沒有哪一個后端能夠始終勝出。
選擇哪個后端最快,往往取決于模型的架構。
這一點突出了選擇不同框架以追求最佳性能的重要性。Keras 3可以幫助輕松切換后端,以便為模型找到最合適的選擇。
發現2
Keras 3的性能普遍超過PyTorch的標準實現。
相對于原生PyTorch,Keras 3在吞吐量(步/毫秒)上有明顯的提升。
特別是,在10個測試任務中,有5個的速度提升超過了50%。其中,最高更是達到了290%。
如果是100%,意味著Keras 3的速度是PyTorch的2倍;如果是0%,則表示兩者性能相當
發現3
Keras 3提供一流的「開箱即用」性能。
也就是,所有參與測試的Keras模型都未進行過任何優化。相比之下,使用原生PyTorch實現時,通常需要用戶自行進行更多性能優化。
除了上面分享的數據,測試中還注意到在HuggingFace Diffusers的StableDiffusion推理功能上,從版本0.25.0升級到0.3.0時,性能提升超過了100%。
同樣,在HuggingFace Transformers中,Gemma從4.38.1版本升級至4.38.2版本也顯著提高了性能。
這些性能的提升凸顯了HuggingFace在性能優化方面的專注和努力。
對于一些手動優化較少的模型,如SegmentAnything,則使用了研究作者提供的實現。在這種情況下,與Keras相比,性能差距比大多數其他模型更大。
這表明,Keras能夠提供卓越的開箱即用性能,用戶無需深入了解所有優化技巧即可享受到快速的模型運行速度。
發現4
Keras 3的表現始終優于Keras 2。
例如,SegmentAnything的推理速度提升了驚人的380%,StableDiffusion的訓練處理速度提升了150%以上,BERT的訓練處理速度也提升了100%以上。
這主要是因為Keras 2在某些情況下直接使用了更多的TensorFlow融合操作,而這可能對于XLA的編譯并不是最佳選擇。
值得注意的是,即使僅升級到Keras 3并繼續使用TensorFlow后端,也能顯著提升性能。
結論
框架的性能在很大程度上取決于具體使用的模型。
Keras 3能夠幫助為任務選擇最快的框架,這種選擇幾乎總能超越Keras 2和PyTorch實現。
更為重要的是,Keras 3模型無需進行復雜的底層優化,即可提供卓越的開箱即用性能。