可视化训练数据

在上一篇文章中,我们已经了解了训练数据的输入,是 28x28 的灰度图片,输出是标签值。在这一篇文章中,我们把输入的灰色图进行可视化,进一步感受理解。

如代码清单 1 所示,我们显示训练数据中的第一张图片。原来的图片形状是 1x28x28,我们可以直接使用 squeeze() 把 1 维度“去掉”。因为是灰度图,所以 imshowcmap 参数设置为 gray。

代码清单 1 显示单张图片
  1. image, label = train_data[0]
  2. print(f"图片的形状为: {image.shape}"# 打印图片的形状信息,便于调试
  3. plt.imshow(image.squeeze(), cmap="gray")
  4. plt.title(f"类别: {class_name[label]}"# 图片标题显示对应的类别名称
  5. plt.axis("off"# 关闭坐标轴显示
  6. plt.show()

图 1 显示的就是训练数据中的第一张图片,虽然像素很低,但结合标签值,还是可以分辨出这是一只短靴。

图1 单张图片

为了进一步了解训练集的内容,如代码清单 2 所示,我们可以随机显示训练集中的多个内容。

代码清单 2 显示单张图片
  1. # 设置随机数种子,保证结果可重复
  2. torch.manual_seed(42)
  3.  
  4. # 定义用于显示多张图片的网格尺寸
  5. fig = plt.figure(figsize=(9, 9))  # 设置画布大小
  6. rows, cols = 4, 4  # 定义网格的行和列数
  7. num_images = rows * cols  # 总共显示的图片数量
  8.  
  9. # 从训练数据集中随机选择图片并显示在网格中
  10. for i in range(1, num_images + 1):
  11.     # 随机生成图片的索引
  12.     random_idx = torch.randint(0, len(train_data), size=[1]).item()
  13.     img, label = train_data[random_idx]
  14.     # 将图片添加到对应的子图位置
  15.     fig.add_subplot(rows, cols, i)
  16.     plt.imshow(img.squeeze(), cmap="gray")
  17.     plt.title(f"类别: {class_name[label]}"# 显示类别名称
  18.     plt.axis("off"# 关闭坐标轴显示
  19.  
  20. plt.show()

torch.randint 函数第一次接触,它用于生成随机整数。参数 low 指定生成整数的下界;参数 high 指定生成整数的上界;参数 size 指定生成整数张量的形状,类型是元组。此处只需要生成一个整数,所以 size 是 [1],并使用 item() 将其转换为标量。

图 2 显示了训练集中的多张图片。

图2 多张图片