利用 PyTorch Lightning 搭建一個文本分類模型,你學會了嗎?
引言
在這篇博文[1]中,將逐步介紹如何使用 PyTorch Lightning 來構建和部署一個基礎的文本分類模型。該項目借助了 PyTorch 生態中的多個強大工具,例如 torch、pytorch_lightning 以及 Hugging Face 提供的 transformers,從而構建了一個強大且可擴展的機器學習流程。
圖片
代碼庫包含四個核心的 Python 腳本:
- data.py:負責數據的加載和預處理工作。
- model.py:構建模型的結構。
- train.py:包含了訓練循環和訓練的配置。
- inference.py:支持使用訓練好的模型進行推斷。
下面詳細解析每個部分,以便理解它們是如何協同作用,以實現文本分類的高效工作流程。
1. 數據加載與預處理
在 data.py 文件中,DataModule 類被設計用來處理數據加載和預處理的所有環節。它利用了 PyTorch Lightning 的 LightningDataModule,這有助于保持數據處理任務的模塊化和可復用性。
class DataModule(pl.LightningDataModule):
def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", batch_size=32):
super().__init__()
self.batch_size = batch_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
這個類在初始化時需要指定模型名稱和批量大小,并從 Hugging Face 的 Transformers 庫加載一個分詞器。prepare_data() 函數會從 GLUE 基準測試套件中下載 CoLA 數據集,這個數據集經常用來評估自然語言理解(NLU)模型的性能。
setup() 函數負責對文本數據進行分詞處理,并創建用于訓練和驗證的 PyTorch DataLoader 對象:
def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_data = self.train_data.map(self.tokenize_data, batched=True)
self.train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
self.val_data = self.val_data.map(self.tokenize_data, batched=True)
self.val_data.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
2. 模型架構
在 model.py 文件中定義的 ColaModel 類繼承自 PyTorch Lightning 的 LightningModule。該模型采用 BERT(一種雙向編碼器表示,源自 Transformers)的簡化版本作為文本表示的核心模型。
class ColaModel(pl.LightningModule):
def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=1e-2):
super(ColaModel, self).__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.W = nn.Linear(self.bert.config.hidden_size, 2)
模型在前向傳播過程中提取 BERT 的最終隱藏狀態,并通過一個線性層來生成用于二分類的對數幾率(logits):
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
h_cls = outputs.last_hidden_state[:, 0]
logits = self.W(h_cls)
return logits
另外,training_step() 和 validation_step() 函數分別負責處理訓練和驗證的邏輯,并記錄諸如損失和準確率等關鍵指標。
3. Training Loop
train.py 腳本利用 PyTorch Lightning 的 Trainer 類來控制訓練過程。它還包含了模型檢查點和提前停止的回調機制,以防止模型過擬合。
checkpoint_callback = ModelCheckpoint(dirpath="./models", mnotallow="val_loss", mode="min")
early_stopping_callback = EarlyStopping(mnotallow="val_loss", patience=3, verbose=True, mode="min")
訓練過程設定了最大周期數,并在可能的情況下利用 GPU 進行加速:
trainer = pl.Trainer(
default_root_dir="logs",
gpus=(1 if torch.cuda.is_available() else 0),
max_epochs=5,
fast_dev_run=False,
logger=pl.loggers.TensorBoardLogger("logs/", name="cola", versinotallow=1),
callbacks=[checkpoint_callback, early_stopping_callback],
)
trainer.fit(cola_model, cola_data)
這樣的配置不僅讓訓練變得更加簡便,還保證了模型能夠定期保存并對其性能進行監控。
4. 推理
訓練結束后,將利用模型來進行預測。inference.py 腳本中定義了一個名為 ColaPredictor 的類,該類負責加載經過訓練的模型檢查點,并提供了一個用于生成預測的方法:
class ColaPredictor:
def __init__(self, model_path):
self.model_path = model_path
self.model = ColaModel.load_from_checkpoint(model_path)
self.model.eval()
self.model.freeze()
Predict() 方法接受文本輸入,使用分詞器對其進行處理,并返回模型的預測:
def predict(self, text):
inference_sample = {"sentence": text}
processed = self.processor.tokenize_data(inference_sample)
logits = self.model(
torch.tensor([processed["input_ids"]]),
torch.tensor([processed["attention_mask"]]),
)
scores = self.softmax(logits[0]).tolist()
predictions = [{"label": label, "score": score} for score, label in zip(scores, self.labels)]
return predictions
總結
本項目展示了如何采用 PyTorch Lightning 進行構建、訓練和部署文本分類模型的系統化方法。盡情地嘗試代碼,調整參數,并試用不同的數據集或模型吧。