蚂蚁-蜜蜂分类数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip
Dataset
Dataset 提供一种方式获取数据及其 label,包含两个功能:
- 如何获取每一个数据及其 label?——
__getitem__ - 总共有多少条数据?——
__len__
Dataset 的子类都必须重写__getitem__和__len__这两个方法
from torch.utils.data import Datasetfrom PIL import Imageimport osclass MyDataset(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(root_dir, label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_path[idx]img_item_path = os.path.join(self.path, img_name)img = Image.open(img_item_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_path)root_dir = "./data/hymenoptera_data/train"ants_label_dir = "ants"ants_dataset = MyDataset(root_dir, ants_label_dir)bees_label_dir = "bees"bees_dataset = MyDataset(root_dir, bees_label_dir)# 合并两个类别的 dataset 组成训练集train_dataset = ants_dataset + bees_datasetlen(ants_dataset), len(bees_dataset), len(train_dataset) # (124, 121, 245)img2, label2 = train_dataset[123]img2.show()label2 # 'ants'
Dataloader
Dataloader 为后面的网络提供不同的数学形式。用于控制如何从 Dataset 取数据
import torchvisionfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform = torchvision.transforms.ToTensor())test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)writer = SummaryWriter("dataloader_logs")for epoch in range(2):step = 0for data in test_loader:imgs, targets = datawriter.add_images("Epoch: {}".format(epoch), imgs, step)step = step + 1writer.close()

