使用 DataLoader 加载数据

在之前的文章中,我们已经了解了 FashionMNIST 数据集的内容。这篇文章介绍如何使用 DataLoader 加载数据。使用 DataLoader 加载的原因是,有一个批次处理的概念。

为什么要批次处理,我还不清楚,我现在是这么说服自己的:

批次应该能并行化,加速整体的训练时间;

一个批次的内容能继续利用,比如“均值化”之类的,减少波动。

以上只是猜测。具体怎么使用,是否和不采用批次的训练流程一致,还需要后面学习。

尽管概念还不太清楚,我们还是先学习涉及的方法。DataLoader 函数用于将数据集分为小批次。参数 dataset 指定需要批次化的数据集;参数 batch_size 指定每个批次包含的样本数;每个批次的内容可以不按照数据集本身的顺序,指定参数 shuffle=True 可以打乱顺序。

如代码清单 1 所示,我们将训练集和测试集进行了批次化,每个批次 32 个样本。

代码清单 1 DataLoader
  1. from torch.utils.data import DataLoader
  2.  
  3. BATCH_SIZE = 32
  4.  
  5. # 创建训练和测试的 DataLoader
  6. train_dataloader = DataLoader(dataset=train_data,
  7.                               batch_size=BATCH_SIZE,
  8.                               shuffle=True)
  9.  
  10. test_dataloader = DataLoader(dataset=test_data,
  11.                               batch_size=BATCH_SIZE,
  12.                               shuffle=False)
  13.  
  14. # 打印 DataLoader 对象及其长度
  15. print("训练数据加载器:", train_dataloader)
  16. print("测试数据加载器:", test_dataloader)
  17.  
  18. print(f"训练数据加载器中的批次数量: {len(train_dataloader)}")
  19. print(f"测试数据加载器中的批次数量: {len(test_dataloader)}")

我们继续了解一个批次的数据结构。如代码清单 2 所示,DataLoader 对象可以转化成一个迭代器,我们可以获取第一个批次的内容。

一个批次的内容包括图片的批次数据和标签的批次数据。从打印的信息中可以看到,图片批次的形状是 torch.Size([32, 1, 28, 28]),标签批次的形状是 torch.Size([32])。

再后续的操作是可视化,在之前的文章中已经了解过:我们在一个批次中随机选择一个样本(图片和标签),然后再可视化出来。

代码清单 2 获取一个批次
  1. # 获取一个训练批次的数据
  2. train_features_batch, train_labels_batch = next(iter(train_dataloader))
  3. print(f"训练特征批次的形状: {train_features_batch.shape}")
  4. print(f"训练标签批次的形状: {train_labels_batch.shape}")
  5.  
  6. # 设置随机种子以确保可复现性
  7. torch.manual_seed(42)
  8.  
  9. # 从批次中随机选择一个索引
  10. random_idx = torch.randint(0, len(train_features_batch), size=[1]).item()
  11. img, label = train_features_batch[random_idx], train_labels_batch[random_idx]
  12.  
  13. # 打印选中样本的详细信息
  14. print(f"选中的随机索引: {random_idx}")
  15. print(f"选中图片的形状: {img.shape}")
  16. print(f"选中图片的标签: {label} ({class_name[label]})")
  17.  
  18. # 可视化选中的图片
  19. plt.imshow(img.squeeze(), cmap="gray")
  20. plt.title(f"标签: {class_name[label]}")
  21. plt.axis("off")
  22. plt.show()