PyTorch數據集處理

2022-12-31 20:24:03 來源:51CTO博客

數據樣本處理的代碼可能會變得雜亂且難以維護,因此理想狀態下我們應該將模型訓練的代碼和數據集代碼分開封裝,以獲得更好的代碼可讀性和模塊化代碼。


(資料圖片)

PyTorch 提供了兩個基本方法 ??torch.utils.data.DataLoader??和??torch.utils.data.Dataset??可以讓你預加載數據集或者你的數據。

??Dataset??存儲樣本及其相關的標簽, ??DataLoader??封裝了關于 ??Dataset??的迭代器,讓我們可以方便地讀取樣本。

PyTorch庫中也提供了一些常用的數據集可以方便用戶做預加載可以通過??torch.utils.data.Dataset??調用,還提供了一些對應數據集的方法。它們可以用于模型的原型和基準測試。

詳細可以戳這里:

??Image Datasets??,??Text Datasets??,??Audio Datasets??。

加載數據集

接下來我們看一下怎么從TorchVision加載??Fashion-MNIST??數據集。

Fashion-MNIST是Zalando的一個數據集,包含6萬個訓練樣例和1萬個測試樣例。

每個樣例由兩部分組成,一個28×28灰度圖像和一個十分類標簽中的某一個標簽。

我們要加載 ??FashionMNIST Dataset??需要用到以下幾個參數:

??root?? 數據集的存儲地址??train?? 指定你要取訓練集還是測試集??download=True?? 如果你指定的 ??root??中沒有數據集,會自動從網上下載數據集??transform?? 、 ??target_transform?? 指定特征和標簽轉換

下邊這段代碼是取FashionMNIST的訓練集和測試集,root設置了一個data文件,運行下邊這段代碼以后你可以看到當前目錄下邊應該多了一個data文件夾,里邊就是FashionMNIST數據集文件了。

import torchfrom torch.utils.data import Datasetfrom torchvision import datasetsfrom torchvision.transforms import ToTensorimport matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(    root="data",    train=True,    download=True,    transform=ToTensor())test_data = datasets.FashionMNIST(    root="data",    train=False,    download=True,    transform=ToTensor())復制代碼

迭代和可視化數據集

我們可以像列表索引一樣查看??Datasets??。可以使用??matplotlib??可視化我們的數據集。

其他代碼解析看注釋。

至于畫子圖有兩個方法,二者的區別僅在于一個面向方法,一個面向對象,別的完全一樣。

subplot
figure = plt.figure() cols, rows = 3, 3 for i in range(1, cols * rows + 1):     plt.subplot(rows, cols, i) plt.show()復制代碼
add_subplot
figure = plt.figure()cols, rows = 3, 3for i in range(1, cols * rows + 1):    figure.subplot(rows, cols, i)plt.show()復制代碼
labels_map = {    0: "T-Shirt",    1: "Trouser",    2: "Pullover",    3: "Dress",    4: "Coat",    5: "Sandal",    6: "Shirt",    7: "Sneaker",    8: "Bag",    9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3for i in range(1, cols * rows + 1):    sample_idx = torch.randint(len(training_data), size=(1,)).item()   # 從數據集中隨機采樣    img, label = training_data[sample_idx]      # 取得數據集的圖和標簽    figure.add_subplot(rows, cols, i)           # 畫子圖,也可以plt.subplot(rows, cols, i)    plt.title(labels_map[label])                plt.axis("off")    plt.imshow(img.squeeze(), cmap="gray")      # 是黑白圖,這里做一個維度壓縮,把1通道的1壓縮掉plt.show()復制代碼

最后隨機采樣的結果大概是這樣的:


使用DataLoader

??Dataset??可以檢索我們數據集中一個樣本的特征和標簽。但是在訓練模型的時候,我們通常希望數據以小批量(minibatch)的方式作為輸入,在每個epoch中重新調整數據以防止過擬合,并且還能使用Python的??multiprocessing??加速數據檢索。

??DataLoader??是一個迭代器,將剛才提到的復雜方法抽象成簡單的API。

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)復制代碼

通過DataLoader迭代獲取數據

我們已經將數據集加載到??DataLoader??中,并可以根據需要迭代數據集。

下面的每次迭代返回一個批量數據的??train_features??和??train_labels??(分別包含??batch_size=64??個特征和標簽)。

因為我們指定了??shuffle=True??,在遍歷所有批量之后,數據會被打亂(要對數據加載順序進行更細粒度的控制,戳這里??pytorch.org/docs/stable…?? 。

# Display image and label.train_features, train_labels = next(iter(train_dataloader))print(f"Feature batch shape: {train_features.size()}")print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"Label: {label}")復制代碼

為你的數據創建自定義數據集

自定義Dataset類必須實現三個函數:??__init__??, ??__len__??和??__getitem__??。看看這個FashionMNIST圖像存儲在img_dir目錄中,它們的標簽單獨存儲在CSV文件annotations_file中。在下一節我們詳細分析一下每個函數中發生的事情。

import osimport pandas as pdfrom torchvision.io import read_imageclass CustomImageDataset(Dataset):    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):        self.img_labels = pd.read_csv(annotations_file)        self.img_dir = img_dir        self.transform = transform        self.target_transform = target_transform    def __len__(self):        return len(self.img_labels)    def __getitem__(self, idx):        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])        image = read_image(img_path)        label = self.img_labels.iloc[idx, 1]        if self.transform:            image = self.transform(image)        if self.target_transform:            label = self.target_transform(label)        return image, label復制代碼

init

??__init__??函數在實例化Dataset對象時運行一次,幫我們初始化一個目錄,其中包含圖像、注釋文件和兩個變換(下一節將詳細介紹)。

The labels.csv file looks like:

tshirt1.jpg, 0

tshirt2.jpg, 0

......

ankleboot999.jpg, 9

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):    self.img_labels = pd.read_csv(annotations_file)    self.img_dir = img_dir    self.transform = transform    self.target_transform = target_transform復制代碼

len

??__len__??方法返回我們數據集中的樣本數量。

def __len__(self):    return len(self.img_labels)復制代碼

getitem

??__getitem__??函數當你給定一個索引??idx??的時候,用于加載并返回樣本。

基于索引,該函數去尋找圖像在磁盤上的位置,使用??read_image?? 將其轉換為一個張量,從??self??中的csv數據中檢索相應的標簽??img_labels??,調用它們上的變換函數(如果適用),并返回一個元組,元組中是圖像的張量和對應的標簽。

def __getitem__(self, idx):    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])    image = read_image(img_path)    label = self.img_labels.iloc[idx, 1]    if self.transform:        image = self.transform(image)    if self.target_transform:        label = self.target_transform(label)    return image, label

標簽: 加載數據 隨機采樣

上一篇:第八章《Java高級語法》第7節:枚舉
下一篇:每日播報!一文了解 Go fmt 標準庫輸入函數的使用