可视化训练数据
在上一篇文章中,我们已经了解了训练数据的输入,是 28x28 的灰度图片,输出是标签值。在这一篇文章中,我们把输入的灰色图进行可视化,进一步感受理解。
如代码清单 1 所示,我们显示训练数据中的第一张图片。原来的图片形状是 1x28x28,我们可以直接使用 squeeze() 把 1 维度“去掉”。因为是灰度图,所以 imshow 的 cmap 参数设置为 gray。
- image, label = train_data[0]
- print(f"图片的形状为: {image.shape}") # 打印图片的形状信息,便于调试
- plt.imshow(image.squeeze(), cmap="gray")
- plt.title(f"类别: {class_name[label]}") # 图片标题显示对应的类别名称
- plt.axis("off") # 关闭坐标轴显示
- plt.show()
图 1 显示的就是训练数据中的第一张图片,虽然像素很低,但结合标签值,还是可以分辨出这是一只短靴。

为了进一步了解训练集的内容,如代码清单 2 所示,我们可以随机显示训练集中的多个内容。
- # 设置随机数种子,保证结果可重复
- torch.manual_seed(42)
- # 定义用于显示多张图片的网格尺寸
- fig = plt.figure(figsize=(9, 9)) # 设置画布大小
- rows, cols = 4, 4 # 定义网格的行和列数
- num_images = rows * cols # 总共显示的图片数量
- # 从训练数据集中随机选择图片并显示在网格中
- for i in range(1, num_images + 1):
- # 随机生成图片的索引
- random_idx = torch.randint(0, len(train_data), size=[1]).item()
- img, label = train_data[random_idx]
- # 将图片添加到对应的子图位置
- fig.add_subplot(rows, cols, i)
- plt.imshow(img.squeeze(), cmap="gray")
- plt.title(f"类别: {class_name[label]}") # 显示类别名称
- plt.axis("off") # 关闭坐标轴显示
- plt.show()
torch.randint 函数第一次接触,它用于生成随机整数。参数 low 指定生成整数的下界;参数 high 指定生成整数的上界;参数 size 指定生成整数张量的形状,类型是元组。此处只需要生成一个整数,所以 size 是 [1],并使用 item() 将其转换为标量。
图 2 显示了训练集中的多张图片。
