加载数据集

这一章我们加载数据集。FashionMNIST 是一个简单而经典的基准数据集,适合用于入门机器学习和深度学习的分类任务。其中的图像是各种时尚服装图片。

我们可以自己下载数据、处理并加载数据,但 datasets.FashionMNIST 函数把这些工作都封装好了,其原型为:

  • torchvision.datasets.FashionMNIST(
  •     root: str,
  •     train: bool = True,
  •     transform: Optional[Callable] = None,
  •     target_transform: Optional[Callable] = None,
  •     download: bool = False,
  • )

其中,root 参数指定数据集的存储路径。

train 参数指定下载训练集(True)或测试集(False)。

transform 参数指定对数据样本进行变换操作。常用 torchvision.transforms.ToTensor() 将图片转换为 PyTorch 张量。

target_transform 参数指定对目标标签进行变换操作。

download 指定如果数据集文件未在指定路径中,是否自动下载。

在了解了 datasets.FashionMNIST 的使用之后,代码清单 1 是加载训练集和测试集的示例。

代码清单 1 加载数据集
  1. # 加载 FashionMNIST 数据集
  2. # 训练集
  3. train_data = datasets.FashionMNIST(
  4.     root="data"# 数据存储路径
  5.     train=True# 是否是训练集
  6.     download=True# 如果数据未下载,则下载数据
  7.     transform=torchvision.transforms.ToTensor(),  # 转换为张量格式
  8.     target_transform=None  # 不对目标标签进行额外转换
  9. )
  10.  
  11. # 测试集
  12. test_data = datasets.FashionMNIST(
  13.     root="data"# 数据存储路径
  14.     train=False# 是否是测试集
  15.     download=True# 如果数据未下载,则下载数据
  16.     transform=torchvision.transforms.ToTensor(),  # 转换为张量格式
  17.     target_transform=None  # 不对目标标签进行额外转换
  18. )

接着我们可以对数据集的信息进行了解。比如可以查看一下数据集大小。

  • # 打印训练集和测试集的大小
  • print(f"训练集大小: {len(train_data)} 张图片")
  • print(f"测试集大小: {len(test_data)} 张图片")

打印内容为:

  • 训练集大小: 60000 张图片
  • 测试集大小: 10000 张图片

查看数据集内容。我们查看训练集中的第一张图片的内容和对应的标签。

  • # 查看训练集中的第一张图片及其标签
  • image, label = train_data[0]
  • print(f"第一张图片的张量: {image}")
  • print(f"第一张图片的标签: {label}")

从打印内容中可以看到,图片已经被预处理成了张量。

  • 第一张图片的张量: tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
  •           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
  •           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
  •           0.0000, 0.0000, 0.0000, 0.0000],
  •           ……
  •          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
  •           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
  •           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
  •           0.0000, 0.0000, 0.0000, 0.0000]]])
  • 第一张图片的标签: 9

查看标签信息,即具体的标签值对应的实际含义。

  • # 打印类别名称列表
  • class_name = train_data.classes
  • print(f"数据集的类别名称: {class_name}")
  •  
  • # 打印类别到索引的映射
  • class_to_idx = train_data.class_to_idx
  • print(f"类别到索引的映射: {class_to_idx}")
  •  
  • # 打印训练集的所有标签
  • print(f"训练集的所有标签: {train_data.targets}")

打印内容为:

  • 数据集的类别名称: ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
  • 类别到索引的映射: {'T-shirt/top': 0, 'Trouser': 1, 'Pullover': 2, 'Dress': 3, 'Coat': 4, 'Sandal': 5, 'Shirt': 6, 'Sneaker': 7, 'Bag': 8, 'Ankle boot': 9}
  • 训练集的所有标签: tensor([9, 0, 0,  ..., 3, 0, 5])

继续查看数据集信息。

  • # 打印第一张图片的形状和对应的类别名称
  • print(f"第一张图片的形状: {image.shape}")
  • print(f"第一张图片的类别名称: {class_name[label]}")

从打印信息中可以看到,第一张训练图片的大小是 28x28,因为是灰度图像,所以只有一个颜色通道。对应的具体标签含义是短靴。

  • 第一张图片的形状: torch.Size([1, 28, 28])
  • 第一张图片的类别名称: Ankle boot