加载数据集
这一章我们加载数据集。FashionMNIST 是一个简单而经典的基准数据集,适合用于入门机器学习和深度学习的分类任务。其中的图像是各种时尚服装图片。
我们可以自己下载数据、处理并加载数据,但 datasets.FashionMNIST 函数把这些工作都封装好了,其原型为:
- torchvision.datasets.FashionMNIST(
- root: str,
- train: bool = True,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- download: bool = False,
- )
其中,root 参数指定数据集的存储路径。
train 参数指定下载训练集(True)或测试集(False)。
transform 参数指定对数据样本进行变换操作。常用 torchvision.transforms.ToTensor() 将图片转换为 PyTorch 张量。
target_transform 参数指定对目标标签进行变换操作。
download 指定如果数据集文件未在指定路径中,是否自动下载。
在了解了 datasets.FashionMNIST 的使用之后,代码清单 1 是加载训练集和测试集的示例。
- # 加载 FashionMNIST 数据集
- # 训练集
- train_data = datasets.FashionMNIST(
- root="data", # 数据存储路径
- train=True, # 是否是训练集
- download=True, # 如果数据未下载,则下载数据
- transform=torchvision.transforms.ToTensor(), # 转换为张量格式
- target_transform=None # 不对目标标签进行额外转换
- )
- # 测试集
- test_data = datasets.FashionMNIST(
- root="data", # 数据存储路径
- train=False, # 是否是测试集
- download=True, # 如果数据未下载,则下载数据
- transform=torchvision.transforms.ToTensor(), # 转换为张量格式
- target_transform=None # 不对目标标签进行额外转换
- )
接着我们可以对数据集的信息进行了解。比如可以查看一下数据集大小。
- # 打印训练集和测试集的大小
- print(f"训练集大小: {len(train_data)} 张图片")
- print(f"测试集大小: {len(test_data)} 张图片")
打印内容为:
- 训练集大小: 60000 张图片
- 测试集大小: 10000 张图片
查看数据集内容。我们查看训练集中的第一张图片的内容和对应的标签。
- # 查看训练集中的第一张图片及其标签
- image, label = train_data[0]
- print(f"第一张图片的张量: {image}")
- print(f"第一张图片的标签: {label}")
从打印内容中可以看到,图片已经被预处理成了张量。
- 第一张图片的张量: tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000],
- ……
- [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000]]])
- 第一张图片的标签: 9
查看标签信息,即具体的标签值对应的实际含义。
- # 打印类别名称列表
- class_name = train_data.classes
- print(f"数据集的类别名称: {class_name}")
- # 打印类别到索引的映射
- class_to_idx = train_data.class_to_idx
- print(f"类别到索引的映射: {class_to_idx}")
- # 打印训练集的所有标签
- print(f"训练集的所有标签: {train_data.targets}")
打印内容为:
- 数据集的类别名称: ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
- 类别到索引的映射: {'T-shirt/top': 0, 'Trouser': 1, 'Pullover': 2, 'Dress': 3, 'Coat': 4, 'Sandal': 5, 'Shirt': 6, 'Sneaker': 7, 'Bag': 8, 'Ankle boot': 9}
- 训练集的所有标签: tensor([9, 0, 0, ..., 3, 0, 5])
继续查看数据集信息。
- # 打印第一张图片的形状和对应的类别名称
- print(f"第一张图片的形状: {image.shape}")
- print(f"第一张图片的类别名称: {class_name[label]}")
从打印信息中可以看到,第一张训练图片的大小是 28x28,因为是灰度图像,所以只有一个颜色通道。对应的具体标签含义是短靴。
- 第一张图片的形状: torch.Size([1, 28, 28])
- 第一张图片的类别名称: Ankle boot