在实现 softmax 回归前,先做一下数据集的工作。关于 FashionMNIST
以下是本节用到的模块
import torchimport torchvisionimport torchvision.transforms as transimport matplotlib.pyplot as pltimport timeimport d2lzh_pytorch
torchvision包,是服务于 PyTorch 深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;torchvision.utils: 其他的一些有用的方法。3.5.1 获取数据集
我们用
torchvision.datasets来获取数据集,包括两部分:训练集(training set):用于训练模型。
- 测试集(testing set):用于测试模型的学习效果。
torchvision.datasets 已经提供了获取 MNIST 、FashionMNIST 等常用各种数据集的接口。root 用于指定下载路径, trai=True对应训练集,False 则对应测试集,download=True 表示从互联网上下载,已经下载好的不会再重复下载。
另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor。
关于教程作者的提醒,没有遇到,先mark一下。
# 获取数据集DATA_SETS_PATH = "~/My-Project/Python学习/PyTorch学习/知乎马卡斯扬-动手学深度学习PyTorch版/Data-Sets"# 训练集training_set = torchvision.datasets.FashionMNIST(root=DATA_SETS_PATH, train=True, download=True, transform=trans.ToTensor())# 测试集testing_set = torchvision.datasets.FashionMNIST(root=DATA_SETS_PATH, train=False, download=True, transform=trans.ToTensor())print(type(training_set), type(testing_set))print(len(training_set), len(testing_set))
下载链接是国外的,所以下载速度实在堪忧。下面提供一下这两个数据集。
Data-Sets.zip
看一下数据集。
# 看一下数据集print(type(training_set), type(testing_set))print(len(training_set), len(testing_set))feature, label = training_set[0]label = torch.tensor(label)print(feature.size(), label)
运行结果
<class 'torchvision.datasets.mnist.FashionMNIST'> <class 'torchvision.datasets.mnist.FashionMNIST'>60000 10000torch.Size([1, 28, 28]) tensor(9)<class 'torch.Tensor'> <class 'torch.Tensor'>
特征为 28*28 的 8 位灰度图, 像素值范围已映射到 [0, 1]。
FashionMNIST 一共包含 10 个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
接下来我们对数据集做一个简单的可视化。
# 已在 d2lzh_pytorch 中的 get_fashion_mnist_labels 实现def labels_number2txt(src_labels):"""将数值标签转换为文本标签, 便于阅读Args:src_labels: 数据集的数值标签Returns:文本标签Raises:无"""TEXT_LABELS = ["t-shirt", "trouser", "pullover", "dress", "coat","sandal", "shirt", "sneaker", "bag", "ankle boot"]return [TEXT_LABELS[int(i)] for i in src_labels]# 已在 d2lzh_pytorch 中的 show_fashion_mnist 实现def show_fashion_mnist(features, labels):"""画出原始图像和对应标签Args:features: 数据集的特征labels: 数据集的标签Returns:无Raises:无"""# d2lzh_pytorch.use_svg_display# 按样本数量建立子图figs = plt.subplots(1, len(features), figsize=(12, 12))[1]for f, img, label in zip(figs, features, labels):# 显示原始图像f.imshow(img.view(28, 28).numpy())# 显示标签f.set_title(label)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()# 数据集的可视化x, y = [], []for i in range(10):x.append(training_set[i][0])y.append(training_set[i][1])show_fashion_mnist(x, labels_number2txt(y))
运行结果
3.5.2 读取小批次数据
我们上述获得的 training_set 和 testing_set 是 torch.utils.data.Dataset的子类,因此因此可以用 torch.utils.data.DataLoader() 来创建一个用于读取小批次数据的迭代器 DataLoader 实例。
其中 num_workers 用于指定 进程数量 来加速数据的读取。
# 读取小批次数据batch_size = 256# 创建读取小批次数据的迭代器training_set_iter = torch_data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=10)testing_set_iter = torch_data.DataLoader(testing_set, batch_size=batch_size, shuffle=True, num_workers=10)
看一下我们读取整个训练集共 60000 个样本需要的时间。
# 完整地读取一次数据start = time.time()for x, y in training_set_iter:passprint("读取全部数据需要的时间: {0} s".format(time.time() - start))
运行结果
读取全部数据需要的时间: 0.9038164615631104 s
小结
FashionMNIST 和 MNIST 这两个数据集是完全兼容的,用的时候主要不要搞混。相比已经被用烂了的MNIST,FashionMNIST 更加合理。
