GPT時代學算法,Pytorch框架實現線性模型
今天我們繼續來實現線性回歸模型,不過這一次我們不再所有功能都自己實現,而是使用Pytorch框架來完成。
整個代碼會發生多大變化呢?
首先是數據生成的部分,這個部分和之前類似:
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
但是從數據讀取開始,就變得不同了。
在之前的代碼中,我們是自己實現了迭代器,從訓練數據中隨機抽取數據。但我們沒有做無放回的采樣設計,也沒有做數據的打亂操作。
然而這些內容Pytorch框架都有現成的工具可以使用,我們不需要再自己實現了。
這里需要用到TensorDataset和DataLoader兩個類:
def load_array(data_arrays, batch_size, is_train=True): #@save
"""構造一個PyTorch數據迭代器"""
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
關于這兩個類的用法,我們可以直接詢問ChatGPT。
圖片
簡而言之TensorDataset是用來封裝tensor數據的,它的主要功能就是和DataLoader配合。
圖片
DataLoader是一個迭代器,除了基本的數據讀取之外,還提供亂序、采樣、多線程讀取等功能。
我們調用load_array獲得訓練數據的迭代器。
batch_size = 10
data_iter = load_array((features, labels), batch_size)
模型部分
在之前的實現當中,我們是自己創建了兩個tensor來作為線性回歸模型的參數。
然而其實不必這么麻煩,我們可以把線性回歸看做是單層的神經網絡,在原理和效果上,它們都是完全一樣的。因此我們可以通過調用對應的API來很方便地實現模型:
from torch import nn
net = nn.Sequential(nn.Linear(2, 1))
這里的nn是神經網絡的英文縮寫,nn.Linear(2, 1)定義了一個輸入維度是2,輸出維度是1的單層線性網絡,等同于線性模型。
nn.Sequential模塊容器,它能夠將輸入的多個網絡結構按照順序拼裝成一個完整的模型。這是一種非常常用和方便地構建模型的方法,除了這種方法之外,還有其他的方法創建模型,我們在之后遇到的時候再詳細展開。
圖片
一般來說模型創建好了之后,并不需要特別去初始化,但如果你想要對模型的參數進行調整的話,可以使用weight.data和weight.bias來訪問參數:
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
接著我們來定義損失函數,Pytorch當中同樣封裝了損失函數的實現,我們直接調用即可。
loss = nn.MSELoss()
nn.MSELoss即均方差,MSE即mean square error的縮寫。
最后是優化算法,Pytorch當中也封裝了更新模型中參數的方法,我們不需要手動來使用tensor里的梯度去更新模型了。只需要定義優化方法,讓優化方法自動完成即可:
optim = torch.optim.SGD(net.parameters(), lr=0.03)
訓練
最后就是把上述這些實現全部串聯起來的模型訓練了。
整個過程代碼量很少,只有幾行。
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X) ,y)
optim.zero_grad()
l.backward()
optim.step()
l = loss(net(features), labels)
print(f'epoch {epoch + 1}, loss {l:f}')
我們之前自己實現的模型參數更新部分,被一行optim.step()代替了。
不論多么復雜的模型,都可以通過optim.step()來進行參數更新,非常方便!
同樣我們可以來檢查一下訓練完成之后模型的參數值,同樣和我們設置的非常接近。
圖片
到這里,整個線性回歸模型的實現就結束了。
這個模型是所有模型里最簡單的了,正因為簡單,所以最適合初學者。后面當接觸了更多更復雜的模型之后,會發現雖然代碼變復雜了,但遵循的仍然是現在這個框架。