使用自己的數據集訓練DETR模型
眾所周知,Transformer已經席卷深度學習領域。Transformer架構最初在NLP領域取得了突破性成果,尤其是在機器翻譯和語言模型中,其自注意力機制允許模型處理序列數據的全局依賴性。隨之,研究者開始探索如何將這種架構應用于計算機視覺任務,特別是目標檢測,這是計算機視覺中的核心問題之一。
在目標識別方面,Facebook提出的DETR(Detection Transformer)是第一個將Transformer的核心思想引入到目標檢測的模型,它拋棄了傳統檢測框架中的錨框和區域提案步驟,實現了端到端的檢測。
本文將使用四個預訓練的DETR模型(DETR ResNet50、DETR ResNet50 DC5、DETR ResNet101和DETR ResNet101 DC5)在自定義數據集上對其進行微調,通過比較它們在自定義數據集上的mAP,來比較評估每個模型的檢測精度。
DETR模型結構
如圖所示,DETR模型通過將卷積神經網絡CNN與Transformer架構相結合,來確定最終的一組邊界框。
在目標檢測中,預測的Bounding box經過非極大值抑制NMS處理,獲得最終的預測。但是,DETR默認總是預測100個Bounding box(可以配置)。因此,我們需要一種方法將真實Bounding box與預測的Bounding box進行匹配。為此,DETR使用了二分圖匹配法。
DETR的架構如下圖所示。
DETR使用CNN模型作為Backbone,在官方代碼中,選用的是ResNet架構。CNN學習二維表示,并將輸出展平,再進入位置編碼(positional encoding)階段。位置編碼后的特征進入Transformer編碼器,編碼器學習位置嵌入(positional embeddings)。這些位置嵌入隨后傳遞給解碼器。解碼器的輸出嵌入會進一步傳遞給前饋網絡(FFN)。FFN負責識別是物體類別的邊界框還是'no object'類別。它會對每個解碼器輸出進行分類,以確定是否檢測到對象以及對應的類別。
DETR模型的詳細架構如下:
數據集
本文將使用一個包含多種海洋生物的水族館數據集(https://www.kaggle.com/datasets/sovitrath/aquarium-data)訓練DETR模型。數據集目錄結構如下:
Aquarium Combined.v2-raw-1024.voc
├── test [126 entries exceeds filelimit, not opening dir]
├── train [894 entries exceeds filelimit, not opening dir]
├── valid [254 entries exceeds filelimit, not opening dir]
├── README.dataset.txt
└── README.roboflow.txt
其中,數據集包含三個子目錄,分別存儲圖像和注釋。注釋是以XML(Pascal VOC)格式提供的。訓練目錄包含了894個圖像和注釋的組合,訓練集447張圖像。同理,測試集63張圖像,驗證集127張圖像。
數據集中共有7個類別:
- fish
- jellyfish
- penguin
- shark
- puffin
- stingray
- starfish
準備vision_transformers庫
vision_transformers庫是一個專注于基于Transformer的視覺模型的新庫。盡管Facebook提供了DETR模型的官方倉庫,但使用它來進行模型的微調可能較為復雜。vision_transformers庫中包含了預訓練模型,支持圖像分類和對象檢測。在這篇文章中,我們將主要關注目標檢測模型,庫中已經集成了四種DETR模型。
首先,在終端或命令行中使用以下命令克隆vision_transformers庫。克隆完成后,使用cd命令進入新克隆的目錄。
git clone https://github.com/sovit-123/vision_transformers.git
cd vision_transformers
接下來,我們需要安裝PyTorch。最好從官方網站上按照適當的CUDA版本安裝PyTorch。例如,以下命令安裝了支持CUDA 11.7的PyTorch 2.0.0:
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
安裝其它依賴庫。
pip install -r requirements.txt
在克隆了vision_transformers倉庫后,可以再執行以下命令獲取庫中的所有訓練和推理代碼。
pip install vision_transformers
搭建DETR訓練目錄
在開始訓練DETR模型之前,需要創建一個項目目錄結構,以組織代碼、數據、日志和模型檢查點。
├── input
│ ├── Aquarium Combined.v2-raw-1024.voc
│ └── inference_data
└── vision_transformers
├── data
├── examples
├── example_test_data
├── readme_images
├── runs
├── tools
├── vision_transformers
├── README.md
├── requirements.txt
└── setup.py
其中:
- input目錄:包含水族館數據集,inference_data目錄存放后續用于推理的圖像或視頻文件。
- vision_transformers目錄:這是前面克隆的庫。
- tools目錄:包含訓練和推理所需的腳本,例如train_detector.py(用于訓練檢測器的腳本)、inference_image_detect.py(用于圖像推理的腳本)和inference_video_detect.py(用于視頻推理的腳本)
- data目錄:包含一些YAML文件,用于模型訓練。
訓練DETR模型
由于要在自定義數據集上訓練4種不同的檢測變換器模型,如若對每個模型訓練相同的輪數,再挑選最佳模型可能會浪費計算資源。
這里首先對每個模型進行20個訓練周期。然后,對在初步訓練中表現最佳的模型進行更多輪的訓練,以進一步提升模型的性能。
開始訓練之前,需要先創建數據集的YAML配置文件。
1.創建數據集YAML配置文件
數據集的YAML文件將存儲在vision_transformers/data目錄下。它包含了數據集的所有信息。包括圖像路徑、注釋路徑、所有類別名稱、類別數量等。
vision_transformers庫中已經包含了水族館數據集的YAML文件,但是需要根據當前目錄結構修改,
將以下數據復制并粘貼到 data/aquarium.yaml 文件中。
# 圖像和標簽目錄相對于train.py腳本的相對路徑
TRAIN_DIR_IMAGES: '../input/Aquarium Combined.v2-raw-1024.voc/train'
TRAIN_DIR_LABELS: '../input/Aquarium Combined.v2-raw-1024.voc/train'
VALID_DIR_IMAGES: '../input/Aquarium Combined.v2-raw-1024.voc/valid'
VALID_DIR_LABELS: '../input/Aquarium Combined.v2-raw-1024.voc/valid'
# 類名
CLASSES: [
'__background__',
'fish', 'jellyfish', 'penguin',
'shark', 'puffin', 'stingray',
'starfish'
]
# 類別數
NC: 8
# 是否在訓練期間保存驗證集的預測結果
SAVE_VALID_PREDICTION_IMAGES: True
2.訓練模型
訓練環境:
- 10GB RTX 3080 GPU
- 10代i7 CPU
- 32GB RAM
(1) 訓練DETR ResNet50
執行以下命令:
python tools/train_detector.py --epochs 20 --batch 2 --data data/aquarium.yaml --model detr_resnet50 --name detr_resnet50
其中:
- --epochs:模型訓練的輪數。
- --batch:數據加載器的批次大小。
- --data:指向數據集YAML文件的路徑。
- --model:模型名稱。
- --name:保存所有訓練結果的目錄名,包括訓練好的權重。
通過在驗證集上計算mAP(Mean Average Precision)來評估目標檢測性能。
以下是最佳epoch的檢測性能結果。
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.172
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.383
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.126
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.094
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.107
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.247
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.088
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.250
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.337
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.235
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.330
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.344
BEST VALIDATION mAP: 0.17192136022687962
SAVING BEST MODEL FOR EPOCH: 20
由此可以看到模型在不同IoU閾值和目標尺寸條件的表現。
模型在最后一個epoch,IoU閾值0.50到0.95之間對目標檢測的平均精度mAP達到了17.2%。
在水族館數據集上訓練DETR ResNet50模型20個epoch后的mAP結果如下圖所示。
顯然,mAP值在逐步提高。但在得出任何結論之前,我們需要對其他模型進行訓練。
(2) 訓練DETR ResNet50 DC5
執行以下命令:
python tools/train_detector.py --epochs 20 --batch 2 --data data/aquarium.yaml --model detr_resnet50_dc5 --name detr_resnet50_dc5
最佳epoch的檢測性能結果如下。
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.161
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.360
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.123
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.141
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.155
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.233
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.096
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.248
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.345
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.379
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.373
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.340
BEST VALIDATION mAP: 0.16066837142161672
SAVING BEST MODEL FOR EPOCH: 20
DETR ResNet50 DC5模型在第20個epoch也達到了最高mAP值,為0.16%,相比于DETR ResNet50模型,這個值較低。
(3) 訓練DETR ResNet101
DETR ResNet101模型擁有超過6000萬個參數,相較于前兩個模型(DETR ResNet50及其DC5變體),網絡容量更大。理論上,理論上能夠學習到更復雜的特征表示,從而在性能上有所提升。
python tools/train_detector.py --epochs 20 --batch 2 --data data/aquarium.yaml --model detr_resnet101 --name detr_resnet101
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.175
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.381
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.132
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.089
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.113
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.260
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.095
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.269
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.362
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.298
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.451
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.351
BEST VALIDATION mAP: 0.17489964894400944
SAVING BEST MODEL FOR EPOCH: 17
DETR ResNet101模型在第17個epoch達到了17.5%的mAP,相比之前的DETR ResNet50和DETR ResNet50 DC5模型稍有提升,但提升幅度不大。
(4) 訓練DETR ResNet101 DC5
DETR ResNet101 DC5模型設計上特別考慮了對小物體檢測的優化。本文所用數據集中包含大量小尺寸對象,理論上,DETR ResNet101 DC5模型應該能展現出優于前幾個模型的性能。
python tools/train_detector.py --epochs 20 --batch 2 --data data/aquarium.yaml --model detr_resnet101_dc5 --name detr_resnet101_dc5
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.206
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.438
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.178
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.110
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.093
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.303
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.099
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.287
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.391
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.317
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.394
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.394
BEST VALIDATION mAP: 0.20588343074278573
SAVING BEST MODEL FOR EPOCH: 20
DETR ResNet101 DC5模型在第20個epoch達到了20%的mAP,這是目前為止的最佳表現。這證實了我們的預期——由于該模型在設計上對小尺寸目標檢測進行了優化,因此在含有大量小對象的數據集上,它的性能確實更勝一籌。
接下來,延長訓練至60個epochs。由如下結果可以看出,DETR ResNet101 DC5模型在第48個epoch達到了最佳性能,這表明模型在這個階段找到了更優的權重組合。
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.239
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.501
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.186
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.119
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.143
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.328
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.109
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.290
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.394
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.349
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.369
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.398
BEST VALIDATION mAP: 0.23894132553612263
SAVING BEST MODEL FOR EPOCH: 48
DETR ResNet101 DC5模型在447個訓練樣本上達到了24%的mAP,對于IoU=0.50:0.95,這樣的結果相當不錯。
3.推理
(1) 視頻推理
使用inference_video_detect.py腳本進行視頻推理。將視頻文件路徑作為輸入,腳本就會處理視頻中的每一幀,并在每個幀上運行目標檢測。
python tools/inference_video_detect.py --weights runs/training/detr_resnet101_dc5_60e/best_model.pth --input ../input/inference_data/video_1.mp4 --show
這里多了一個--show標志,它允許在推理過程中實時顯示結果,在RTX 3080 GPU上,模型平均可以達到38 FPS的速度。
「inference_video_detect.py」
import torch
import cv2
import numpy as np
import argparse
import yaml
import os
import time
import torchinfo
from vision_transformers.detection.detr.model import DETRModel
from utils.detection.detr.general import (
set_infer_dir,
load_weights
)
from utils.detection.detr.transforms import infer_transforms, resize
from utils.detection.detr.annotations import (
convert_detections,
inference_annotations,
annotate_fps,
convert_pre_track,
convert_post_track
)
from deep_sort_realtime.deepsort_tracker import DeepSort
from utils.detection.detr.viz_attention import visualize_attention
# NumPy隨機數生成器的種子值為2023
np.random.seed(2023)
# 命令行參數配置選項
def parse_opt():
parser = argparse.ArgumentParser()
# 模型權重文件的路徑
parser.add_argument(
'-w',
'--weights',
)
# 輸入圖像或圖像文件夾的路徑
parser.add_argument(
'-i', '--input',
help='folder path to input input image (one image or a folder path)',
)
# 數據配置文件的路徑
parser.add_argument(
'--data',
default=None,
help='(optional) path to the data config file'
)
# 模型名稱,默認為'detr_resnet50'
parser.add_argument(
'--model',
default='detr_resnet50',
help='name of the model'
)
# 計算和訓練設備,默認使用GPU(如果可用)否則使用CPU
parser.add_argument(
'--device',
default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
help='computation/training device, default is GPU if GPU present'
)
# 圖像的尺寸,默認為640
parser.add_argument(
'--imgsz',
'--img-size',
default=640,
dest='imgsz',
type=int,
help='resize image to, by default use the original frame/image size'
)
# 可視化時的置信度閾值,默認為0.5
parser.add_argument(
'-t',
'--threshold',
type=float,
default=0.5,
help='confidence threshold for visualization'
)
# 訓練結果存放目錄
parser.add_argument(
'--name',
default=None,
type=str,
help='training result dir name in outputs/training/, (default res_#)'
)
# 不顯示邊界框上的標簽
parser.add_argument(
'--hide-labels',
dest='hide_labels',
action='store_true',
help='do not show labels during on top of bounding boxes'
)
# 只有傳遞該選項時才會顯示輸出
parser.add_argument(
'--show',
dest='show',
action='store_true',
help='visualize output only if this argument is passed'
)
# 開啟跟蹤功能
parser.add_argument(
'--track',
action='store_true'
)
# 過濾要可視化的類別,如--classes 1 2 3
parser.add_argument(
'--classes',
nargs='+',
type=int,
default=None,
help='filter classes by visualization, --classes 1 2 3'
)
# 可視化檢測框的注意力圖
parser.add_argument(
'--viz-atten',
dest='vis_atten',
action='store_true',
help='visualize attention map of detected boxes'
)
args = parser.parse_args()
return args
# 讀取并處理視頻文件相關信息
def read_return_video_data(video_path):
# 打開指定路徑的視頻文件
cap = cv2.VideoCapture(video_path)
# 獲取視頻幀的寬度和高度
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
# 獲取視頻的幀率
fps = int(cap.get(5))
# 檢查視頻的寬度和高度是否不為零。如果它們都是零,那么會拋出一個錯誤消息,提示用戶檢查視頻路徑是否正確
assert (frame_width != 0 and frame_height !=0), 'Please check video path...'
# 函數返回一個元組,包含VideoCapture對象cap以及視頻的寬度、高度和幀率fps
return cap, frame_width, frame_height, fps
def main(args):
# 如果args.track為真,初始化DeepSORT追蹤器
if args.track:
tracker = DeepSort(max_age=30)
# 根據args.data加載數據配置(如果存在)以獲取類別數量和類別列表
NUM_CLASSES = None
CLASSES = None
data_configs = None
if args.data is not None:
with open(args.data) as file:
data_configs = yaml.safe_load(file)
NUM_CLASSES = data_configs['NC']
CLASSES = data_configs['CLASSES']
# 獲取設備類型
DEVICE = args.device
# 設置輸出目錄
OUT_DIR = set_infer_dir(args.name)
# 加載模型權重
model, CLASSES, data_path = load_weights(
args,
# 設備類型
DEVICE,
# 模型類
DETRModel,
# 數據配置
data_configs,
# 類別數量
NUM_CLASSES,
# 類別列表
CLASSES,
video=True
)
# 將模型移動到指定的設備(如GPU或CPU)并將其設置為評估模式(.eval())
_ = model.to(DEVICE).eval()
# 使用torchinfo.summary來打印模型的詳細結構和參數統計
try:
torchinfo.summary(
model,
device=DEVICE,
input_size=(1, 3, args.imgsz, args.imgsz),
row_settings=["var_names"]
)
# 如果此過程出現異常,代碼會打印模型的完整結構,并計算模型的總參數數和可訓練參數數
except:
print(model)
# 計算模型的所有參數總數
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
# 只計算那些需要在訓練過程中更新的參數(即requires_grad屬性為True的參數)
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# 生成一個隨機分布的顏色數組,每個元素的值在0到255之間,這是標準的8位RGB色彩空間中的每個通道的取值范圍
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
# 獲取視頻的路徑
VIDEO_PATH = args.input
# 如果用戶沒有通過命令行參數--input指定視頻路徑,則將VIDEO_PATH設置為data_path
if VIDEO_PATH == None:
VIDEO_PATH = data_path
# cap: 一個cv2.VideoCapture對象,用于讀取和處理視頻文件
# frame_width: 視頻的幀寬度(寬度像素數)
# frame_height: 視頻的幀高度(高度像素數)
# fps: 視頻的幀率(每秒幀數)
cap, frame_width, frame_height, fps = read_return_video_data(VIDEO_PATH)
# 生成輸出文件的名稱
# [-1]:選取列表中的最后一個元素,即文件名(包括擴展名)
# .split('.')[0]:再次分割文件名,這次是基于點號(.)來分隔,然后選取第一個元素,即文件的基本名稱,不包括擴展名
save_name = VIDEO_PATH.split(os.path.sep)[-1].split('.')[0]
# 將處理后的幀寫入輸出視頻文件
# 輸出文件路徑:f"{OUT_DIR}/{save_name}.mp4"
# 編碼器(codec):cv2.VideoWriter_fourcc(*'mp4v')
# 幀率(fps)
# 視頻尺寸:(frame_width, frame_height)
out = cv2.VideoWriter(f"{OUT_DIR}/{save_name}.mp4",
cv2.VideoWriter_fourcc(*'mp4v'), fps,
(frame_width, frame_height))
# 檢查args.imgsz是否已設置(即用戶是否通過命令行參數指定了圖像大小)
# 如果args.imgsz有值,說明用戶想要將輸入圖像(或視頻幀)縮放到指定的大小,那么RESIZE_TO將被設置為這個值
if args.imgsz != None:
RESIZE_TO = args.imgsz
# 如果args.imgsz沒有設置或者為None,則默認使用視頻幀的原始寬度frame_width作為縮放尺寸
else:
RESIZE_TO = frame_width
# 記錄總的幀數
frame_count = 0
# 計算最終的幀率
total_fps = 0
# 檢查視頻是否已經結束
while(cap.isOpened()):
# 讀取下一幀,并返回一個布爾值ret表示是否成功讀取
ret, frame = cap.read()
if ret:
# 復制原始幀以保留未處理的版本
orig_frame = frame.copy()
# 使用resize函數將幀調整到指定的大小(如果args.imgsz已設置,否則保持原大小)
frame = resize(frame, RESIZE_TO, square=True)
image = frame.copy()
# 將BGR圖像轉換為RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 將圖像歸一化到0-1范圍
image = image / 255.0
# 預處理
image = infer_transforms(image)
# 將圖像轉換為PyTorch張量,設置數據類型為torch.float32
image = torch.tensor(image, dtype=torch.float32)
# 調整張量維度,使通道維度成為第一個維度,以便于模型輸入(模型通常期望輸入張量的形狀為(batch_size, channels, height, width))
image = torch.permute(image, (2, 0, 1))
# 在張量前面添加一個維度以表示批次大小(batch_size=1)
image = image.unsqueeze(0)
# 計算模型前向傳播的時間(start_time和forward_end_time)以衡量處理單幀的速度
start_time = time.time()
with torch.no_grad():
outputs = model(image.to(args.device))
forward_end_time = time.time()
forward_pass_time = forward_end_time - start_time
# 計算當前幀的處理速度
fps = 1 / (forward_pass_time)
# Add `fps` to `total_fps`.
total_fps += fps
# Increment frame count.
frame_count += 1
# 如果啟用了注意力可視化(args.vis_atten),則將注意力圖保存為圖像文件
if args.vis_atten:
visualize_attention(
model,
image,
args.threshold,
orig_frame,
f"{OUT_DIR}/frame_{str(frame_count)}.png",
DEVICE
)
# 如果模型檢測到了物體(outputs['pred_boxes'][0]非空)
if len(outputs['pred_boxes'][0]) != 0:
# 轉換預測結果
draw_boxes, pred_classes, scores = convert_detections(
outputs,
args.threshold,
CLASSES,
orig_frame,
args
)
# 使用tracker更新跟蹤狀態,并將結果轉換回檢測框(convert_pre_track和convert_post_track)
if args.track:
tracker_inputs = convert_pre_track(
draw_boxes, pred_classes, scores
)
# Update tracker with detections.
tracks = tracker.update_tracks(
tracker_inputs, frame=frame
)
draw_boxes, pred_classes, scores = convert_post_track(tracks)
# 將預測結果應用到原始幀上(inference_annotations),包括繪制邊界框、類別標簽和置信度
orig_frame = inference_annotations(
draw_boxes,
pred_classes,
scores,
CLASSES,
COLORS,
orig_frame,
args
)
# 在幀上添加實時FPS信息
orig_frame = annotate_fps(orig_frame, fps)
# 將處理后的幀寫入輸出視頻文件
out.write(orig_frame)
if args.show:
cv2.imshow('Prediction', orig_frame)
# Press `q` to exit
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
if args.show:
# Release VideoCapture().
cap.release()
# Close all frames and video windows.
cv2.destroyAllWindows()
# Calculate and print the average FPS.
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")
if __name__ == '__main__':
args = parse_opt()
main(args)
視頻1推理結果如下。盡管模型在大部分情況下表現良好,但是誤將corals識別為fish了。通過提高閾值,可以減少假陽性,即模型錯誤識別為fish的corals。
視頻2推理結果如下。考慮到模型在未知環境中表現出的性能,這些結果是相當不錯的。誤將stingrays識別為fish類的情況可能是由于它們在形狀和外觀上與某些魚類相似,這導致模型在分類時出現混淆。不過,總體來說,模型的檢測效果還是令人滿意的。
(2) 圖片推理
有了最佳訓練權重,現在可以進行推理測試了。
python tools/inference_image_detect.py --weights runs/training/detr_resnet101_dc5_60e/best_model.pth --input "../input/Aquarium Combined.v2-raw-1024.voc/test"
其中:
- --weights:表示用于推理的權重文件路徑。這里即指訓練60個epoch后得到的最佳模型權重的路徑。
- --input:推理測試圖像所在目錄。
「inference_image_detect.py」
import torch
import cv2
import numpy as np
import argparse
import yaml
import glob
import os
import time
import torchinfo
from vision_transformers.detection.detr.model import DETRModel
from utils.detection.detr.general import (
set_infer_dir,
load_weights
)
from utils.detection.detr.transforms import infer_transforms, resize
from utils.detection.detr.annotations import (
convert_detections,
inference_annotations,
)
from utils.detection.detr.viz_attention import visualize_attention
np.random.seed(2023)
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument(
'-w',
'--weights',
)
parser.add_argument(
'-i', '--input',
help='folder path to input input image (one image or a folder path)',
)
parser.add_argument(
'--data',
default=None,
help='(optional) path to the data config file'
)
parser.add_argument(
'--model',
default='detr_resnet50',
help='name of the model'
)
parser.add_argument(
'--device',
default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
help='computation/training device, default is GPU if GPU present'
)
parser.add_argument(
'--imgsz',
'--img-size',
default=640,
dest='imgsz',
type=int,
help='resize image to, by default use the original frame/image size'
)
parser.add_argument(
'-t',
'--threshold',
type=float,
default=0.5,
help='confidence threshold for visualization'
)
parser.add_argument(
'--name',
default=None,
type=str,
help='training result dir name in outputs/training/, (default res_#)'
)
parser.add_argument(
'--hide-labels',
dest='hide_labels',
action='store_true',
help='do not show labels during on top of bounding boxes'
)
parser.add_argument(
'--show',
dest='show',
action='store_true',
help='visualize output only if this argument is passed'
)
parser.add_argument(
'--track',
action='store_true'
)
parser.add_argument(
'--classes',
nargs='+',
type=int,
default=None,
help='filter classes by visualization, --classes 1 2 3'
)
parser.add_argument(
'--viz-atten',
dest='vis_atten',
action='store_true',
help='visualize attention map of detected boxes'
)
args = parser.parse_args()
return args
def collect_all_images(dir_test):
"""
Function to return a list of image paths.
:param dir_test: Directory containing images or single image path.
Returns:
test_images: List containing all image paths.
"""
test_images = []
if os.path.isdir(dir_test):
image_file_types = ['*.jpg', '*.jpeg', '*.png', '*.ppm']
for file_type in image_file_types:
test_images.extend(glob.glob(f"{dir_test}/{file_type}"))
else:
test_images.append(dir_test)
return test_images
def main(args):
NUM_CLASSES = None
CLASSES = None
data_configs = None
if args.data is not None:
with open(args.data) as file:
data_configs = yaml.safe_load(file)
NUM_CLASSES = data_configs['NC']
CLASSES = data_configs['CLASSES']
DEVICE = args.device
OUT_DIR = set_infer_dir(args.name)
model, CLASSES, data_path = load_weights(
args, DEVICE, DETRModel, data_configs, NUM_CLASSES, CLASSES
)
_ = model.to(DEVICE).eval()
try:
torchinfo.summary(
model,
device=DEVICE,
input_size=(1, 3, args.imgsz, args.imgsz),
row_settings=["var_names"]
)
except:
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# Colors for visualization.
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
DIR_TEST = args.input
if DIR_TEST == None:
DIR_TEST = data_path
test_images = collect_all_images(DIR_TEST)
print(f"Test instances: {len(test_images)}")
# To count the total number of frames iterated through.
frame_count = 0
# To keep adding the frames' FPS.
total_fps = 0
for image_num in range(len(test_images)):
image_name = test_images[image_num].split(os.path.sep)[-1].split('.')[0]
orig_image = cv2.imread(test_images[image_num])
frame_height, frame_width, _ = orig_image.shape
if args.imgsz != None:
RESIZE_TO = args.imgsz
else:
RESIZE_TO = frame_width
image_resized = resize(orig_image, RESIZE_TO, square=True)
image = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
image = image / 255.0
image = infer_transforms(image)
input_tensor = torch.tensor(image, dtype=torch.float32)
input_tensor = torch.permute(input_tensor, (2, 0, 1))
input_tensor = input_tensor.unsqueeze(0)
h, w, _ = orig_image.shape
start_time = time.time()
with torch.no_grad():
outputs = model(input_tensor.to(DEVICE))
end_time = time.time()
# Get the current fps.
fps = 1 / (end_time - start_time)
# Add `fps` to `total_fps`.
total_fps += fps
# Increment frame count.
frame_count += 1
if args.vis_atten:
visualize_attention(
model,
input_tensor,
args.threshold,
orig_image,
f"{OUT_DIR}/{image_name}.png",
DEVICE
)
if len(outputs['pred_boxes'][0]) != 0:
draw_boxes, pred_classes, scores = convert_detections(
outputs,
args.threshold,
CLASSES,
orig_image,
args
)
orig_image = inference_annotations(
draw_boxes,
pred_classes,
scores,
CLASSES,
COLORS,
orig_image,
args
)
if args.show:
cv2.imshow('Prediction', orig_image)
cv2.waitKey(1)
cv2.imwrite(f"{OUT_DIR}/{image_name}.jpg", orig_image)
print(f"Image {image_num+1} done...")
print('-'*50)
print('TEST PREDICTIONS COMPLETE')
if args.show:
cv2.destroyAllWindows()
# Calculate and print the average FPS.
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")
if __name__ == '__main__':
args = parse_opt()
main(args)
默認情況下,腳本使用0.5的得分閾值,我們也可以使用--threshold標志來修改這個值。
python tools/inference_image_detect.py \
--weights /path/to/best/weights.pth \
--input /path/to/test/images/directory \
--threshold 0.5
運行這個命令后,腳本會加載模型權重,處理測試圖像,并將結果保存在指定的輸出目錄中,查看生成的圖像或結果文件,以評估模型在實際測試集上的表現。
從目前的結果來看,模型在檢測sharks、fish和stingrays方面表現得較為高效,但對puffins的檢測效果不佳。這很可能是因為訓練數據集中這些類別的實例數量較少,導致模型在學習這些特定類別特征時不夠充分。