成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

2022年,我該用JAX嗎?GitHub 1.6萬星,這個年輕的工具并不完美

人工智能 深度學習 新聞
近年來,谷歌于 2018 年推出的 JAX 迎來了迅猛發展,很多研究者對其寄予厚望,希望它可以取代 TensorFlow 等眾多深度學習框架。但 JAX 是否真的適合所有人使用呢?這篇文章對 JAX 的方方面面展開了深入探討,希望可以給研究者選擇深度學習框架時提供有益的參考。

?自 2018 年底推出以來,JAX 的受歡迎程度一直在穩步提升。2020 年,DeepMind 宣布使用 JAX 來加速其研究。越來越多來自谷歌大腦(Google Brain)和其他機構的項目也都在使用 JAX。 

目前,在 JAX 的 GitHub 項目主頁,Star 量已經達到了 16.3k。?

項目地址:https://github.com/google/jaxJAX 是一個非常有前途的項目,并且用戶一直在穩步增長。JAX 已經在深度學習、機器人 / 控制系統、貝葉斯方法和科學模擬等諸多領域得到了廣泛應用。

如此,是否意味著 JAX 也將成為下一個大型深度學習框架?近日,發表在 AssemblyAI 博客上的文章《Why You Should (or Shouldn't) Be Using JAX in 2022》中,作者 Ryan O'Connor 為我們深入解讀了 JAX 的概念、使用 JAX 的理由以及是否應該使用 JAX 等。

JAX 簡介

JAX 不是一個深度學習框架或庫,其設計初衷也不是成為一個深度學習框架或庫。簡而言之,JAX 是一個包含可組合函數轉換的數值計算庫。正如我們所看到的,深度學習只是 JAX 功能的一小部分:

?JAX 的定位科學計算(Scientific Computing)和函數轉換(Function Transformations)的交叉融合,具有除訓練深度學習模型以外的一系列能力,包括如下:

  • 即時編譯(Just-in-Time Compilation)
  • 自動并行化(Automatic Parallelization)
  • 自動向量化(Automatic Vectorization)
  • 自動微分(Automatic Differentiation)

使用 JAX 的原因有哪些?

簡而言之,是速度。這是 JAX 與任何用例相關的一種通用能力。讓我們使用 NumPy 和 JAX 對矩陣的前三個冪求和(按元素)。 

首先是 NumPy 實現。我們發現,該計算大約需要 851 毫秒。?

 

然后使用 JAX 實現該計算:JAX 僅在 5.54 毫秒內執行完成該計算,速度是 NumPy 的 150 倍以上。

?JAX 的速度比 NumPy 快了 N 個數量級。需要注意,JAX 使用的是 TPU,NumPy 使用了 CPU,以此強調 JAX 的速度上限遠高于 NumPy。

作者列出了以下六條可能想要使用 JAX 的理由:

  • NumPy 加速器。NumPy 是使用 Python 進行科學計算的基礎包之一,但它僅與 CPU 兼容。JAX 提供了 NumPy 的實現(具有幾乎相同的 API),可以非常輕松地在 GPU 和 TPU 上運行。對于許多用戶而言,僅此一項功能就足以證明使用 JAX 的合理性;
  • XLA。XLA(Accelerated Linear Algebra)是專為線性代數設計的全程序優化編譯器。JAX 建立在 XLA 之上,顯著提高了計算速度上限;
  • JIT。JAX 允許用戶使用 XLA 將自己的函數轉換為即時編譯(JIT)版本。這意味著可以通過在計算函數中添加一個簡單的函數裝飾器(decorator)來將計算速度提高幾個數量級;
  • Auto-differentiation。JAX 將 Autograd(自動區分原生 Python 代碼和 NumPy 代碼)和 XLA 結合在一起,它的自動微分能力在科學計算的許多領域都至關重要。JAX 提供了幾個強大的自動微分工具;
  • 深度學習。雖然 JAX 本身不是深度學習框架,但它的確為深度學習提供了一個很好的基礎。很多構建在 JAX 之上的庫旨在提供深度學習功能,包括 Flax、Haiku 和 Elegy。甚至在最近的一些 PyTorch 與 TensorFlow 文章中強調了 JAX 作為一個值得關注的「框架」,并推薦其用于基于 TPU 的深度學習研究。JAX 對 Hessians 的高效計算也與深度學習相關,因為它們使高階優化技術更加可行;
  • 通用可微分編程范式(General Differentiable Programming Paradigm )。雖然我們可以使用 JAX 來構建和訓練深度學習模型,但它也為通用可微編程提供了一個框架。這意味著 JAX 可以通過使用基于模型的機器學習方法來解決問題,從而可以利用數十年研究建立起的給定領域的先驗知識。?

JAX 轉換

?到目前為止,我們已經討論了 XLA 以及它如何允許 JAX 在加速器上實現 NumPy;但請記住,這只是 JAX 定義的一半。JAX 不僅為強大的科學計算提供了工具,而且還為可組合的函數轉換提供了工具。

舉例來說如果我們對標量值函數 f(x) 使用梯度函數轉換,那么我們將得到一個向量值函數 f'(x),它給出了函數在 f(x) 域中任意點的梯度。?

在函數上使用 grad() 可以讓我們得到域中任意點的梯度

?JAX 包含了一個可擴展系統來實現這樣的函數轉換,有四種典型方式:

  • Grad() 進行自動微分;
  • Vmap() 自動向量化;
  • Pmap() 并行化計算;
  • Jit() 將函數轉換為即時編譯版本。

使用 grad() 進行自動微分

訓練機器學習模型需要反向傳播。在 JAX 中,就像在 Autograd 中一樣,用戶可以使用 grad() 函數來計算梯度。

舉例來說,如下是對函數 f(x) = abs(x^3) 求導。我們可以看到,當求 x=2 和 x=-3 處的函數及其導數時,我們得到了預期的結果。?

那么 grad() 能微分到什么程度?JAX 通過重復應用 grad() 使得微分變得很容易,如下程序我們可以看到,輸出函數的三階導數給出了 f'''(x)=6 的恒定預期輸出。

可能有人會問,grad() 可以用在哪些方面?標量值函數:grad() 采用標量值函數的梯度,將標量 / 向量映射到標量函數。此外還有向量值函數:對于將向量映射到向量的向量值函數,梯度的類似物是雅可比矩陣。使用 jacfwd() 和 jacrev(),JAX 返回一個函數,該函數在域中的某個點求值時產生雅可比矩陣。

?從深度學習角度來看,JAX 使得計算 Hessians 變得非常簡單和高效。由于 XLA,JAX 可以比 PyTorch 更快地計算 Hessians,這使得實現諸如 AdaHessian 這樣的高階優化更加快速。

下面代碼是在 PyTorch 中對一個簡單的輸入總和進行 Hessian:?

正如我們所看到的,上述計算大約需要 16.3 ms,在 JAX 中嘗試相同的計算:

使用 JAX,計算僅需 1.55 毫秒,比 PyTorch 快 10 倍以上:JAX 可以非常快速?地計算 Hessians,使得高階優化更加可行。

使用 vmap() 自動向量化

JAX 在其 API 中還有另一種變換:vmap() 自動向量化。以下是矢量化向量加法展示:?

使用 pmap() 實現自動并行化

分布式計算變得越來越重要,在深度學習中尤其如此,如下圖所示,SOTA 模型已經發展到超大規模。

?得益于 XLA,JAX 可以輕松地在加速器上進行計算,但 JAX 也可以輕松地使用多個加速器進行計算,即使用單個命令 - pmap() 執行 SPMD 程序的分布式訓練。

我們以向量矩陣乘法為例,如下為非并行向量矩陣乘法:?


使用 JAX,我們可以輕松地將這些計算分布在 4 個 TPU 上,只需將操作包裝在 pmap() 中即可。這允許用戶在每個 TPU 上同時執行一個點積,顯著提高了計算速度(對于大型計算而言)。


使用 jit() 加快功能

JIT 編譯是一種執行代碼的方法,介于解釋(interpretation)和 AoT(ahead-of-time)編譯之間。重要的是,JIT 編譯器在運行時將代碼編譯成快速的可執行文件,但代價是首次運行速度較慢。

JIT 不是一次將一個操作分配給 GPU 內核,而是使用 XLA 將一系列操作編譯成一個內核,從而為函數提供端到端編譯的高效 XLA 實現。

以下圖為例,代碼定義了一個函數:用三種方式計算 5000 x 5000 矩陣——一次使用 NumPy,一次使用 JAX,還有一次在 JIT 編譯的函數版本上使用 JAX。我們首先在 CPU 上進行實驗:?

JAX 對于逐元素計算明顯更快,尤其是在使用 jit 時。

我們看到 JAX 比 NumPy 快 2.3 倍以上,當我們 JIT 函數時,JAX 比 NumPy 快 30 倍。這些結果已經令人印象深刻,但讓我們繼續看,讓 JAX 在 TPU 上進行計算:

?當 JAX 在 TPU 上執行相同的計算時,它的相對性能會進一步提升(NumPy 計算仍在 CPU 上執行,因為它不支持 TPU 計算)在這種情況下,我們可以看到 JAX 比 NumPy 快了驚人的 13 倍,如果我們同時在 TPU 上 JIT 函數和計算,我們會發現 JAX 比 NumPy 快 80 倍。

當然,這種速度的大幅提升是有代價的。JAX 對 JIT 允許的函數進行了限制,盡管通常允許僅涉及上述 NumPy 操作的函數。此外,通過 Python 控制流進行 JIT 處理存在一些限制,因此在編寫函數時須牢記這一點。

2022 年了,我該用 JAX 嗎?

很遺憾,這個問題的答案還是「視情況而定」。是否遷移到 JAX 取決于你的情況和目標。為具體分析是否應該(或不應該)在 2022 年使用 JAX,這里將建議匯總到下面的流程圖中,并針對不同的興趣領域提供不同的圖表。?

科學計算

?如果你對 JAX 在通用計算感興趣,首先要問的問題就是——是否只嘗試在加速器上運行 NumPy?如果答案是肯定的,那么你顯然應該開始遷移到 JAX。

如果你不只處理數字而是參與動態計算建模,那么是否應該使用 JAX 將取決于具體用例。如果大部分工作是在 Python 中使用大量自定義代碼完成的,那么開始學習 JAX 以增強工作流程是值得的。

如果大部分工作不在 Python 中,但你想構建的是某種基于模型 / 神經網絡的混合系統,那么使用 JAX 可能是值得的。

如果大部分工作不使用 Python,或者你正在使用一些專門的軟件進行研究(熱力學、半導體等),那么 JAX 可能是不合適的工具,除非你想從這些程序中導出數據,用來做自定義計算。如果你感興趣的領域更接近物理 / 數學并包含計算方法(動力系統、微分幾何、統計物理)并且大部分工作都在例如 Mathematica 上,那么堅持使用目前的工具才是值得的,特別是在已有大型自定義代碼庫的情形下。?

深度學習

?雖然我們已經強調過,JAX 不是專為深度學習構建的通用框架,但 JAX 速度很快且具有自動微分功能,你肯定想知道使用 JAX 進行深度學習是什么樣的。

若想在 TPU 上進行訓練,那么你應該開始使用 JAX,尤其是如果當前正在使用的是 PyTorch。雖然有 PyTorch-XLA 存在,但使用 JAX 進行 TPU 訓練絕對是更好的體驗。如果你正在研究的是「非標準」架構 / 建模,例如 SDE-Nets,那么也絕對應該嘗試一下 JAX。此外,如果你想利用高階優化技術,JAX 也是要嘗試的東西。

如果你不是在構建特殊的架構,只是在 GPU 上訓練常見的架構,那么你現在可能應該堅持使用 PyTorch 或 TensorFlow。然而,這個建議可能會在未來一兩年內快速發生變化。雖然 PyTorch 仍然在研究領域占據主導地位,但使用 JAX 的論文數量一直在穩步增長。隨著 DeepMind 和谷歌重量級玩家不斷開發用于 JAX 的高級深度學習 API,在幾年內 JAX 可能會出現爆炸性的增長率。

這意味著你至少應該稍微熟悉一下 JAX,如果你是研究人員的話更應如此。?

深度學習初學者


但如果我只是個初學者呢?情況會有些不一樣。

如果你有興趣了解深度學習并實現一些想法,你應該使用 JAX 或 PyTorch。如果你想自上而下學習深度學習,或有一些 Python 軟件的經驗,則應該從 PyTorch 入手。如果你想自下而上地學習深度學習,或具有數學背景,你可能會發現 JAX 很直觀。在這種情況下,在進行任何大型項目之前,請確保了解如何使用 JAX。

如果你對深度學習感興趣,又想轉行相關的職位,那么你需要使用 PyTorch 或 TensorFlow。盡管最好是同時熟悉兩個框架,但你必須知道 TensorFlow 被普遍認為是「行業」框架,不同框架的職位發布數量證明了這一點:

如果你是一個沒有數學或軟件背景但想學習深度學習的初學者,那么你不會想使用 JAX。相反,Keras 是更好的選擇。

不該使用 JAX 的四條理由

?雖然上文已經討論了很多 JAX 的正面反饋,它有潛力極大地提升用戶程序的性能。但作者同時列舉了以下四條不該使用 JAX 的理由:

  • ?JAX 仍然被官方認為是一個實驗性框架。JAX 是一個相對「年輕」的項目。目前,JAX 仍被視為一個研究項目,而不是成熟的谷歌產品,因此如果用戶正在考慮遷移到 JAX,請記住這一點;
  • 使用 JAX 一定要勤勉。調試的時間成本,或者更嚴重的是,未跟蹤副作用(untracked side effects)的風險可能導致那些沒有扎實掌握函數式編程的用戶不適用 JAX。在開始將它用于正式項目之前,請確保自己了解使用 JAX 的常見缺陷;
  • JAX 沒有針對 CPU 計算進行優化。鑒于 JAX 是以「加速器優先」的方式開發的,因此每個操作的分派并未針對 JAX 進行完全優化。在某些情況下,NumPy 實際上可能比 JAX 更快,尤其是對于小型程序而言,這是因為 JAX 引入了開銷;
  • JAX 與 Windows 不兼容。目前在 Windows 上不支持 JAX。如果用戶使用 Windows 系統但仍想嘗試 JAX,可以使用 Colab 或將其安裝在虛擬機(VM)上。?
責任編輯:張燕妮 來源: 機器之心
相關推薦

2021-08-09 15:56:43

機器學習人工智能計算機

2010-08-16 10:39:59

虛擬化

2012-01-13 08:46:15

云計算災備負載均衡

2015-09-23 10:12:00

2012-12-04 10:10:30

求職程序員

2023-03-16 08:13:56

人工智能?OpenAI

2012-01-13 10:37:04

負載均衡災備

2011-11-28 10:18:20

2018-06-21 15:00:34

2011-11-24 14:49:16

JavaJDKWebService

2009-09-11 09:36:53

李開復

2023-04-07 12:58:15

數據中心可持續發展清潔能源

2023-03-17 07:25:16

李彥宏百度文心一言

2021-03-02 16:27:32

大數據程序員IT

2019-06-21 11:06:15

Python 開發編程語言

2019-08-09 18:08:13

程序員技能開發者

2019-04-17 13:34:30

Galaxy Fold三星折疊屏

2019-07-05 15:42:58

GitHub代碼開發者

2009-09-23 11:59:48

Office 2010Web程序漏洞

2019-12-12 09:43:46

GitHub代碼開發者
點贊
收藏

51CTO技術棧公眾號

主站蜘蛛池模板: 亚洲人成一区二区三区性色 | 伊人网综合在线观看 | 91免费视频| 国产区在线免费观看 | 成人av色| 美女视频黄的 | 国产精品亚洲综合 | 精品久久精品 | 国产精品日韩欧美一区二区三区 | 国产高清在线精品一区二区三区 | 91在线电影| 女同av亚洲女人天堂 | 欧美综合一区二区三区 | 特级做a爰片毛片免费看108 | 一级a性色生活片久久毛片波多野 | 国产精品久久久久久婷婷天堂 | 两性午夜视频 | 国产成人网 | 国产日韩视频在线 | 国产一级片 | 久久综合一区二区 | 欧美精品久久久久 | 久草视频在线播放 | 自拍中文字幕 | 国产激情91久久精品导航 | 日韩欧美视频 | 美女久久 | 日韩一区二区三区在线观看 | 国产精品久久久99 | 国产三级大片 | 一区二区视频在线 | 性高湖久久久久久久久aaaaa | ww亚洲ww亚在线观看 | 精品久久久久久久久亚洲 | 国产视频在线一区二区 | 一区二区在线免费播放 | 中文字幕在线免费观看 | 91麻豆精品国产91久久久更新资源速度超快 | 亚洲精品www. | 午夜精品久久久久久久久久久久久 | 免费骚视频 |