深度學習框架Flash如何僅用幾行代碼構建圖像分類器?
譯文【51CTO.com快譯】一、簡介
圖像分類是我們想要預測哪個類別屬于圖像的任務。由于圖像表示,這項任務很困難。如果我們將圖像鋪平,它會創建一個長長的一維向量。此外,該表示將丟失相鄰信息。因此,我們需要深度學習來提取特征并預測結果。
有時,構建深度學習模型會成為一項艱巨的任務。雖然我們為圖像分類創建了一個基礎模型,但需要花大量時間來創建代碼。我們必須準備好用于準備數據、訓練模型并測試模型的代碼,并將模型部署到服務器上。這時Flash就有了用武之地!
Flash是一種高級深度學習框架,用于快速構建、訓練和測試深度學習模型。Flash基于PyTorch框架。所以如果您了解PyTorch,就會很熟悉Flash。
與PyTorch和Lighting相比,Flash易于使用,但不如以前的庫靈活。如果您想構建更復雜的模型,可以使用Lightning或直接使用PyTorch。
借助Flash,您可以用幾行代碼構建深度學習模型!因此,如果您剛接觸深度學習,別害怕。Flash可以幫助您構建深度學習模型,不會因代碼而感到困惑。
本文將介紹如何使用Flash構建圖像分類器。
二、實施
安裝庫
想安裝庫,您可以使用pip命令,如下所示:
- pip install lightning-flash
如果該命令不起作用,可以使用其GitHub存儲庫安裝該庫。命令如下所示:
- pip install git+https://github.com/PyTorchLightning/lightning-flash.git
在我們可以成功下載軟件包之后,現在可以加載庫。我們還將種子設為編號42。這是執行此操作的代碼:
- from pytorch_lightning import seed_everything
- import flash
- from flash.core.classification import Labels
- from flash.core.data.utils import download_data
- from flash.image import ImageClassificationData, ImageClassifier
- # set the random seeds.
- seed_everything(42)
- Global seed set to 42
- 42
下載數據
安裝完庫后,現在不妨獲取數據。出于演示需要,我們將使用名為Cat和Dog數據集的數據集。
該數據集含有兩個類別:貓和狗的圖像。想訪問數據集,您可以在Kaggle找到該數據集??梢栽?a >此處訪問數據集。
加載數據
下載數據后,不妨將數據集加載到一個對象中。我們將使用from_folders方法將數據放入到ImageClassification對象中。這是執行此操作的代碼:
- datamodule = ImageClassificationData.from_folders(
- train_folder="cat_and_dog/training_set",
- val_folder="cat_and_dog/validation_set",
- )
加載模型
我們加載數據后,下一步就是加載模型。由于我們不會從頭開始構建自己的架構,將使用基于現有卷積神經網絡架構的預訓練模型。
我們將使用已經過預訓練的ResNet-50模型。此外,我們基于數據集設置類別的數量。這是執行此操作的代碼:
- model = ImageClassifier(backbone="resnet50", num_classes=datamodule.num_classes)
訓練模型
加載模型后,現在不妨訓練模型。我們需要先初始化Trainer對象。我們將用3個輪次(epoch)訓練模型。此外,我們啟用GPU以訓練模型。這是執行此操作的代碼:
- trainer = flash.Trainer(max_epochs=3, gpus=1)
- GPU available: True, used: True TPU available: False, using: 0 TPU cores
初始化對象后,不妨訓練模型。為訓練模型,我們可以使用一個名為finetune的函數。在函數里面,我們設置模型和數據。此外,我們將訓練策略設置為freeze(凍結),這表明我們不想訓練特征提取器。換句話說,我們只訓練分類器部分。
這是執行此操作的代碼:
- trainer.finetune(model, datamodule=datamodule, strategy="freeze")
- LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------- 0 | metrics | ModuleDict | 0 1 | backbone | Sequential | 23.5 M 2 | head | Sequential | 4.1 K ---------------------------------------- 57.2 K Trainable params 23.5 M Non-trainable params 23.5 M Total params 94.049 Total estimated model params size (MB)
- Validation sanity check: 0it [00:00, ?it/s]
- Global seed set to 42
- Training: 0it [00:00, ?it/s]
- Validating: 0it [00:00, ?it/s]
- Validating: 0it [00:00, ?it/s]
- Validating: 0it [00:00, ?it/s]
這是評估結果:
從結果中可以看出,我們的模型其準確率達到了約97%。不賴!現在不妨拿幾個新數據測試模型。
測試模型
我們將使用針對該模型沒有訓練過的樣本數據。以下是我們將測試模型的樣本:
- import matplotlib.pyplot as plt
- from PIL import Image
- fig, ax = plt.subplots(1, 5, figsize=(40,8))
- for i in range(5):
- ax[i].imshow(Image.open(f'cat_and_dog/testing/{i+1}.jpg'))
- plt.show()
為了測試模型,我們可以使用flash庫中的predict方法。這是執行此操作的代碼:
- model.serializer = Labels()
- predictions = model.predict(["cat_and_dog/testing/1.jpg",
- "cat_and_dog/testing/2.jpg",
- "cat_and_dog/testing/3.jpg",
- "cat_and_dog/testing/4.jpg",
- "cat_and_dog/testing/5.jpg"])
- print(predictions)
- ['dogs', 'dogs', 'cats', 'cats', 'dogs']
從上面的結果可以看出,模型預測了帶有正確標簽的樣本。很好!不妨保存模型以備后用。
保存模型
我們已訓練并測試了模型。不妨使用save_checkpoint方法保存模型。這是執行此操作的代碼:
- trainer.save_checkpoint("cat_dog_classifier.pt")
如果您想針對其他代碼加載模型,可以使用load_from_checkpoint方法。這是執行此操作的代碼:
- model = ImageClassifier.load_from_checkpoint("cat_dog_classifier.pt")
三、結語
做得好!您已學習了如何使用Flash構建圖像分類器。正如文章開頭所說,它只需要幾行代碼!是不是很酷?
但愿本文可以幫助您根據自己的情況構建自己的深度學習模型。如果您想實施一個更復雜的模型,但愿能開始學習 PyTorch。
原文標題:How to Build An Image Classifier in Few Lines of Code with Flash,作者:Irfan Alghani Khalid
【51CTO譯稿,合作站點轉載請注明原文譯者和出處為51CTO.com】