一文讀懂全新深度學習庫Rust Burn
一、什么是Rust Burn?
Rust Burn是一個全新的深度學習框架,完全使用Rust編程語言編寫。創建這個新框架而不是使用現有框架(如PyTorch或TensorFlow)的動機是為了構建一個適應多種用戶需求的通用框架,包括研究人員、機器學習工程師和底層軟件工程師。
Rust Burn的關鍵設計原則包括靈活性、高性能和易用性。
靈活性:能夠快速實現前沿研究想法,并進行實驗。
高性能:通過優化措施,例如利用特定硬件功能,如Nvidia GPU上的張量內核(Tensor Cores)。
易用性:簡化訓練、部署和運行模型的工作流程。
Rust Burn的主要特點:
- 靈活而動態的計算圖。
- 線程安全的數據結構。
- 直觀的抽象,簡化開發過程。
- 在訓練和推理過程中實現極快的性能。
- 支持CPU和GPU的多種后端實現。
- 完全支持訓練過程中的日志記錄、度量和檢查點功能。
- 小型但活躍的開發者社區。
二、快速入門
2.1、安裝Rust
Burn是基于Rust編程語言的、強大的深度學習框架,需要對Rust有基本的了解,但一旦掌握了這些知識,用戶將能夠充分利用Burn提供的所有功能。
按照官方指南進行安裝。也可以查看GeeksforGeeks在Windows和Linux上安裝Rust的指南和截圖。
【官方指南】:https://www.rust-lang.org/tools/install
圖片來自Install Rust
【安裝指南和截圖】:https://www.geeksforgeeks.org/how-to-install-rust-on-windows-and-linux-operating-system/
2.2、安裝Burn
要使用Rust Burn,首先需要在系統上安裝Rust。一旦正確設置了Rust,就可以使用cargo(Rust的軟件包管理器)創建新的Rust應用程序。
在當前目錄中運行以下命令:
cargo new new_burn_app
導航到這個新目錄:
cd new_burn_app
接下來,添加Burn作為依賴項,并添加啟用GPU操作的WGPU后端功能:
cargo add burn --features wgpu
最后,編譯項目以安裝Burn:
cargo build
這將安裝Burn框架以及WGPU后端。WGPU允許Burn執行底層的GPU操作。
三、代碼示例
3.1、逐元素相加
要運行以下代碼,用戶需要打開并替換src/main.rs中的內容:
use burn::tensor::Tensor;
use burn::backend::WgpuBackend;
// Type alias for the backend to use.
type Backend = WgpuBackend;
fn main() {
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
let tensor_1 = Tensor::::from_data([[2., 3.], [4., 5.]]);
let tensor_2 = Tensor::::ones_like(&tensor_1);
// Print the element-wise addition (done with the WGPU backend) of the two tensors.
println!("{}", tensor_1 + tensor_2);
}
main函數使用WGPU后端創建了兩個張量,并進行了相加運算。
在終端中運行cargo run,執行該代碼。
輸出:
查看相加的結果:
Tensor {
data: [[3.0, 4.0], [5.0, 6.0]],
shape: [2, 2],
device: BestAvailable,
backend: "wgpu",
kind: "Float",
dtype: "f32",
}
3.2、位置智能前饋模塊
以下是使用Burn框架的一個簡單示例。示例創建了一個前饋模塊,并使用以下代碼片段定義了它的前向傳播。
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: GELU,
}
impl PositionWiseFeedForward<B> {
pub fn forward(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}
3.3、項目示例
要了解更多示例并運行它們,請復制https://github.com/burn-rs/burn存儲庫,并運行以下項目:
- MNIST:使用各種后端在CPU或GPU上訓練模型。
【MNIST】:https://github.com/burn-rs/burn/tree/main/examples/mnist
- MNIST網絡推理:在瀏覽器中進行模型推理。
【MNIST網絡推理】:https://github.com/burn-rs/burn/tree/main/examples/mnist-inference-web
- 文本分類:在GPU上從頭開始訓練一個Transformer編碼器。
【文本分類】:https://github.com/burn-rs/burn/tree/main/examples/text-classification
- 文本生成:在GPU上從頭開始構建和訓練自回歸Transformer。
【文本生成】:https://github.com/burn-rs/burn/tree/main/examples/text-generation
3.4、預訓練模型
要構建AI應用程序,可以使用以下預訓練模型,并根據數據集對其進行微調。
- SqueezeNet:squeezenet-burn
【鏈接】:https://github.com/burn-rs/models/blob/main/squeezenet-burn/README.md
- Llama 2:Gadersd/llama2-burn
【鏈接】:https://github.com/Gadersd/llama2-burn
- Whisper:Gadersd/whisper-burn
【鏈接】:https://github.com/Gadersd/whisper-burn
- Stable Diffusion v1.4:Gadersd/stable-diffusion-burn
【鏈接】:https://github.com/Gadersd/stable-diffusion-burn
四、結論
Rust Burn在深度學習框架領域提供了一個令人興奮的新選擇。如果你已經是一名Rust開發者,就可以利用Rust的速度、安全性和并發性來推動深度學習研究和生產的發展。Burn致力于在靈活性、性能和可用性方面找到合適的折衷方案,從而創建一個適用于各種用例的、獨特的多功能框架。
盡管Burn還處于早期階段,但它在解決現有框架的痛點并滿足該領域內各種從業者的需求方面已顯示出前景。隨著該框架的成熟和社區的發展,它有可能成為與現有框架相媲美的生產就緒框架。其新穎的設計和語言選擇為深度學習社區帶來了新的可能性。