PyTorch中的數據集Torchvision和Torchtext
對于PyTorch加載和處理不同類型數據,官方提供了torchvision和torchtext。
之前使用 torchDataLoader類直接加載圖像并將其轉換為張量。現在結合torchvision和torchtext介紹torch中的內置數據集
Torchvision 中的數據集
MNIST
MNIST是一個由標準化和中心裁剪的手寫圖像組成的數據集。它有超過 60,000 張訓練圖像和 10,000 張測試圖像。這是用于學習和實驗目的最常用的數據集之一。要加載和使用數據集,使用以下語法導入:torchvision.datasets.MNIST()。
Fashion MNIST
Fashion MNIST數據集類似于MNIST,但該數據集包含T恤、褲子、包包等服裝項目,而不是手寫數字,訓練和測試樣本數分別為60,000和10,000。要加載和使用數據集,使用以下語法導入:torchvision.datasets.FashionMNIST()
CIFAR
CIFAR數據集有兩個版本,CIFAR10和CIFAR100。CIFAR10 由 10 個不同標簽的圖像組成,而 CIFAR100 有 100 個不同的類。這些包括常見的圖像,如卡車、青蛙、船、汽車、鹿等。
- torchvision.datasets.CIFAR10()
- torchvision.datasets.CIFAR100()
COCO
COCO數據集包含超過 100,000 個日常對象,如人、瓶子、文具、書籍等。這個圖像數據集廣泛用于對象檢測和圖像字幕應用。下面是可以加載 COCO 的位置:torchvision.datasets.CocoCaptions()
EMNIST
EMNIST數據集是 MNIST 數據集的高級版本。它由包括數字和字母的圖像組成。如果您正在處理基于從圖像中識別文本的問題,EMNIST是一個不錯的選擇。下面是可以加載 EMNIST的位置::torchvision.datasets.EMNIST()
IMAGE-NET
ImageNet 是用于訓練高端神經網絡的旗艦數據集之一。它由分布在 10,000 個類別中的超過 120 萬張圖像組成。通常,這個數據集加載在高端硬件系統上,因為單獨的 CPU 無法處理這么大的數據集。下面是加載 ImageNet 數據集的類:torchvision.datasets.ImageNet()
Torchtext 中的數據集
IMDB
IMDB是一個用于情感分類的數據集,其中包含一組 25,000 條高度極端的電影評論用于訓練,另外 25,000 條用于測試。使用以下類加載這些數據torchtext:torchtext.datasets.IMDB()
WikiText2
WikiText2語言建模數據集是一個超過 1 億個標記的集合。它是從維基百科中提取的,并保留了標點符號和實際的字母大小寫。它廣泛用于涉及長期依賴的應用程序。可以從torchtext以下位置加載此數據:torchtext.datasets.WikiText2()
除了上述兩個流行的數據集,torchtext庫中還有更多可用的數據集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。
深入查看 MNIST 數據集
MNIST 是最受歡迎的數據集之一。現在我們將看到 PyTorch 如何從 pytorch/vision 存儲庫加載 MNIST 數據集。讓我們首先下載數據集并將其加載到名為 的變量中data_train
- from torchvision.datasets import MNIST
- # Download MNIST
- data_train = MNIST('~/mnist_data', train=True, download=True)
- import matplotlib.pyplot as plt
- random_image = data_train[0][0]
- random_image_label = data_train[0][1]
- # Print the Image using Matplotlib
- plt.imshow(random_image)
- print("The label of the image is:", random_image_label)
DataLoader加載MNIST
下面我們使用DataLoader該類加載數據集,如下所示。
- import torch
- from torchvision import transforms
- data_train = torch.utils.data.DataLoader(
- MNIST(
- '~/mnist_data', train=True, download=True,
- transform = transforms.Compose([
- transforms.ToTensor()
- ])),
- batch_size=64,
- shuffle=True
- )
- for batch_idx, samples in enumerate(data_train):
- print(batch_idx, samples)
CUDA加載
我們可以啟用 GPU 來更快地訓練我們的模型。現在讓我們使用CUDA加載數據時可以使用的(GPU 支持 PyTorch)的配置。
- device = "cuda" if torch.cuda.is_available() else "cpu"
- kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}
- train_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MNIST('/files/', train=True, download=True),
- batch_size=batch_size_train, **kwargs)
- test_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MNIST('files/', train=False, download=True),
- batch_size=batch_size, **kwargs)
ImageFolder
ImageFolder是一個通用數據加載器類torchvision,可幫助加載自己的圖像數據集。處理一個分類問題并構建一個神經網絡來識別給定的圖像是apple還是orange。要在 PyTorch 中執行此操作,第一步是在默認文件夾結構中排列圖像,如下所示:
- root
- ├── orange
- │ ├── orange_image1.png
- │ └── orange_image1.png
- ├── apple
- │ └── apple_image1.png
- │ └── apple_image2.png
- │ └── apple_image3.png
可以使用ImageLoader該類加載所有這些圖像。
- torchvision.datasets.ImageFolder(root, transform)
transforms
PyTorch 轉換定義了簡單的圖像轉換技術,可將整個數據集轉換為獨特的格式。
如果是一個包含不同分辨率的不同汽車圖片的數據集,在訓練時,我們訓練數據集中的所有圖像都應該具有相同的分辨率大小。如果我們手動將所有圖像轉換為所需的輸入大小,則很耗時,因此我們可以使用transforms;使用幾行 PyTorch 代碼,我們數據集中的所有圖像都可以轉換為所需的輸入大小和分辨率。
現在讓我們加載 CIFAR10torchvision.datasets并應用以下轉換:
- 將所有圖像調整為 32×32
- 對圖像應用中心裁剪變換
- 將裁剪后的圖像轉換為張量
- 標準化圖像
- import torch
- import torchvision
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plt
- import numpy as np
- transform = transforms.Compose([
- # resize 32×32
- transforms.Resize(32),
- # center-crop裁剪變換
- transforms.CenterCrop(32),
- # to-tensor
- transforms.ToTensor(),
- # normalize 標準化
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
- ])
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
- download=True, transform=transform)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
- shuffle=False)
在 PyTorch 中創建自定義數據集
下面將創建一個由數字和文本組成的簡單自定義數據集。需要封裝Dataset 類中的__getitem__()和__len__()方法。
- __getitem__()方法通過索引返回數據集中的選定樣本。
- __len__()方法返回數據集的總大小。
下面是曾經封裝FruitImagesDataset數據集的代碼,基本是比較好的 PyTorch 中創建自定義數據集的模板。
- import os
- import numpy as np
- import cv2
- import torch
- import matplotlib.patches as patches
- import albumentations as A
- from albumentations.pytorch.transforms import ToTensorV2
- from matplotlib import pyplot as plt
- from torch.utils.data import Dataset
- from xml.etree import ElementTree as et
- from torchvision import transforms as torchtrans
- class FruitImagesDataset(torch.utils.data.Dataset):
- def __init__(self, files_dir, width, height, transforms=None):
- self.transforms = transforms
- self.files_dir = files_dir
- self.height = height
- self.width = width
- self.imgs = [image for image in sorted(os.listdir(files_dir))
- if image[-4:] == '.jpg']
- self.classes = ['_','apple', 'banana', 'orange']
- def __getitem__(self, idx):
- img_name = self.imgs[idx]
- image_path = os.path.join(self.files_dir, img_name)
- # reading the images and converting them to correct size and color
- img = cv2.imread(image_path)
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
- img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
- # diving by 255
- img_res /= 255.0
- # annotation file
- annot_filename = img_name[:-4] + '.xml'
- annot_file_path = os.path.join(self.files_dir, annot_filename)
- boxes = []
- labels = []
- tree = et.parse(annot_file_path)
- root = tree.getroot()
- # cv2 image gives size as height x width
- wt = img.shape[1]
- ht = img.shape[0]
- # box coordinates for xml files are extracted and corrected for image size given
- for member in root.findall('object'):
- labels.append(self.classes.index(member.find('name').text))
- # bounding box
- xmin = int(member.find('bndbox').find('xmin').text)
- xmax = int(member.find('bndbox').find('xmax').text)
- ymin = int(member.find('bndbox').find('ymin').text)
- ymax = int(member.find('bndbox').find('ymax').text)
- xmin_corr = (xmin / wt) * self.width
- xmax_corr = (xmax / wt) * self.width
- ymin_corr = (ymin / ht) * self.height
- ymax_corr = (ymax / ht) * self.height
- boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])
- # convert boxes into a torch.Tensor
- boxes = torch.as_tensor(boxes, dtype=torch.float32)
- # getting the areas of the boxes
- area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
- # suppose all instances are not crowd
- iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
- labels = torch.as_tensor(labels, dtype=torch.int64)
- target = {}
- target["boxes"] = boxes
- target["labels"] = labels
- target["area"] = area
- target["iscrowd"] = iscrowd
- # image_id
- image_id = torch.tensor([idx])
- target["image_id"] = image_id
- if self.transforms:
- sample = self.transforms(image=img_res,
- bboxes=target['boxes'],
- labels=labels)
- img_res = sample['image']
- target['boxes'] = torch.Tensor(sample['bboxes'])
- return img_res, target
- def __len__(self):
- return len(self.imgs)
- def get_transform(train):
- if train:
- return A.Compose([
- A.HorizontalFlip(0.5),
- ToTensorV2(p=1.0)
- ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
- else:
- return A.Compose([
- ToTensorV2(p=1.0)
- ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
- files_dir = '../input/fruit-images-for-object-detection/train_zip/train'
- test_dir = '../input/fruit-images-for-object-detection/test_zip/test'
- dataset = FruitImagesDataset(train_dir, 480, 480)