创建一个基础线性模型

在这篇文章之前,我们已经准备好数据。大致的流程为:使用 torchvision 获取了 FashionMNIST 的数据集,不需要我们手动处理;后续再使用 DataLoader 将数据批次化。

在这篇文章中,我们预构建一个简单的线性训练模型,先跑通数据流程。

1. 扁平化层

在编写训练模型之前,先介绍一下扁平化层 nn.Flatten(),因为我们第一次接触,后续需要用到。

目前我们的图片输入是多维的,nn.Flatten 可以将多维张量压平成一维张量。我们直接看到代码清单 1 的例子进行理解,初始张量形状是 torch.Size([1, 28, 28]),扁平化后是 torch.Size([1, 784])。

第一个维度没有被展平有点混淆。nn.Flatten 应该默认把第一个维度当成 batch_size,会扁平化所有 batch_size 以外的维度。

代码清单 1 nn.Flatten
  1. # 初始化扁平化层
  2. flatten_model = nn.Flatten()
  3.  
  4. x = train_features_batch[0]
  5. print(f"输入张量的原始形状: {x.shape}")
  6.  
  7. # 应用扁平化层
  8. output = flatten_model(x)
  9. print(f"扁平化后的张量形状: {output.shape}")

2. 训练模型

如代码清单 2 所示,我们先定义训练模型:一个扁平化层,两个线性层。

代码清单 2 模型定义
  1. from torch import nn
  2.  
  3. class FashionMNISTModelV0(nn.Module):
  4.     def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
  5.         super().__init__()
  6.         self.layer_stack = nn.Sequential(
  7.             nn.Flatten(),  # 将输入张量展平成一维
  8.             nn.Linear(in_features=input_shape, out_features=hidden_units),  # 全连接层
  9.             nn.Linear(in_features=hidden_units, out_features=output_shape)  # 输出层
  10.         )
  11.  
  12.     def forward(self, x):
  13.         return self.layer_stack(x)

接着如代码清单 3 所示,我们实例化模型。我们的原始图片形状是 [1, 28, 28],因为后续需要展平使用,所以输入大小设置为 28 * 28;隐藏单元先设置为 10 个;输出大小需要设置为标签个数。

代码清单 3 模型实例化
  1. # 设置随机种子以确保结果的可复现性
  2. torch.manual_seed(42)
  3.  
  4. # 实例化模型
  5. model_0 = FashionMNISTModelV0(
  6.     input_shape=28 * 28,
  7.     hidden_units=10,
  8.     output_shape=len(class_name)
  9. ).to("cpu")
  10.  
  11. # 打印模型结构
  12. print(model_0)

最后,如代码清单 4 所示,我们构造一下测试输入,看一下网络能否跑通。

代码清单 4 测试
  1. dummy_x = torch.rand([1, 1, 28, 28])
  2. # 将输入张量传入模型进行前向传播
  3. output = model_0(dummy_x)
  4.  
  5. # 打印模型输出的张量形状和具体值
  6. print("\n模型输出:")
  7. print(f"输出张量的形状:{output.shape}")
  8. print(f"输出张量的值:{output}")

从打印信息中可以看到,是我们想要的结果。

  • 模型输出:
  • 输出张量的形状:torch.Size([1, 10])
  • 输出张量的值:tensor([[-0.0315,  0.3171,  0.0531, -0.2525,  0.5959,  0.2112,  0.3233,  0.2694,
  •          -0.1004,  0.0157]], grad_fn=<AddmmBackward0>)