📺B站:博主个人介绍
📘博主书籍-京东购买链接*:Yocto项目实战教程
📘加博主微信,进技术交流群:jerrydev
MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)
目标:用 MNIST 把深度学习最核心的一条主线走通——数据 → 模型 → 损失 → 反向传播 → 更新参数 → 评估 → 保存/加载 → 推理。读完你会对“训练到底在做什么”“评估到底评估了什么”“.pth 到底是什么”形成清晰、稳定的概念框架,并具备继续突破到更复杂任务(CIFAR、检测、分割、YOLO、Transformer)的基础。
1. 你已经做到了什么?为什么这条路径很重要
你已经跑通了一个完整的 CNN 分类项目:
- 数据:MNIST
- 模型:两层卷积 + 两层全连接(带 Dropout)
- 训练:CrossEntropyLoss + SGD
- 评估:test loss + test acc
- 保存:生成
mnist_cnn.pth
并且得到了非常合理的结果:
- test acc ≈ 97%+
- test loss 逐步下降
这说明:
- 数据流(DataLoader)没问题;
- 模型结构(forward)没问题;
- 训练闭环(loss → backward → step)成立;
- 评估逻辑(eval + no_grad + 统计)正确;
- 参数存档(.pth)可复用。
这条路径是所有 AI 任务的“地基”。以后你换数据集、换模型、换任务,本质上还是这条主线,只是每一段更复杂。
2. 全流程总览:一张逻辑图把它串起来
下面这张“主线流程图”非常建议你记住:
┌──────────────┐ │ 数据集 Dataset │ MNIST(train/test) └──────┬───────┘ │ transform(ToTensor/Normalize) ▼ ┌──────────────┐ │ DataLoader │ batch化、shuffle、并行加载 └──────┬───────┘ │ (data:[B,1,28,28], target:[B]) ▼ ┌──────────────┐ │ 模型 Net │ forward: [B,1,28,28] → logits[B,10] └──────┬───────┘ │ ▼ ┌──────────────┐n│ Loss 函数 │ CrossEntropyLoss(logits, target) └──────┬───────┘ │ ▼ ┌──────────────┐ │ backward() │ 自动求梯度: p.grad └──────┬───────┘ │ ▼ ┌──────────────┐ │ optimizer.step│ 更新参数: θ ← θ - lr * grad └──────┬───────┘ │ ▼ ┌──────────────┐ │ 评估 Evaluate │ net.eval + no_grad + 指标统计 └──────┬───────┘ │ ▼ ┌──────────────┐ │ 保存/加载 .pth │ state_dict() / load_state_dict() └──────────────┘你可以把它理解为:
- Net是“会计算的结构”;
- Loss把“对不对”变成“可优化的数字”;
- Backward给每个参数算出“该往哪改”;
- Optimizer真正“改参数”;
- Evaluate用测试集验证“改得值不值”。
3. 数据部分:Dataset / Transform / DataLoader 到底做了什么
3.1 MNIST 是什么数据?
MNIST 是手写数字识别数据集:
- 图片:28×28 灰度图(单通道)
- 标签:0~9 十个类别
- 训练集:60000
- 测试集:10000
在 PyTorch 里,一条样本通常表现为:
- image:
torch.Tensor,形状[1,28,28] - label:
int,范围 0…9
3.2 为什么 DataLoader 出来是[B,1,28,28]?
你训练时拿到的是一个 batch:
data.shape == [B, 1, 28, 28]target.shape == [B]
四维的含义非常固定:
| 维度 | 含义 | MNIST 示例 |
|---|---|---|
| B | batch_size | 32 |
| C | 通道数 | 1(灰度) |
| H | 高度 | 28 |
| W | 宽度 | 28 |
只要你看到[B,C,H,W],你就知道它是“图片批次输入”。
3.3 transform:ToTensor + Normalize 为什么必做?
你的 transform 典型是:
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])ToTensor():把图像从0~255的像素转换为0~1的 float 张量Normalize(mean, std):做标准化:
[
x’ = \frac{x - mean}{std}
]
为什么要标准化?
- 让输入分布更稳定;
- 梯度更容易优化;
- 训练更快、收敛更稳定。
你看到min/max出现负数,就是 Normalize 生效的直接证据。
3.4 你本地的 data/ 目录意味着什么?
你看到的:
data/MNIST/raw/ train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte说明:
- 数据源来自
torchvision.datasets.MNIST(root='data/', download=True, ...); - 评估时使用
train=False对应t10k-*(测试集); - 训练时使用
train=True对应train-*(训练集)。
4. 模型部分:Net 是“结构”,参数是“记忆”
4.1 你写的 Net 到底是什么?
在 PyTorch 里:
class Net(nn.Module)定义了网络结构(哪些层、如何连接);forward(self, x)定义了前向计算(输入如何流过这些层);- 模型参数(权重、偏置)被
nn.Conv2d / nn.Linear等层自动创建。
一句话:
模型 = 结构(Net) + 参数(weights/bias)。
4.2 CNN 为什么有效?卷积、池化、全连接各自负责什么
可以把 CNN 分为两段:
- 特征提取器:Conv/Pool
- 分类器:Linear
直觉理解:
- Conv:在局部区域找“笔画、边缘、拐角”
- Pool:让特征更稳、更省计算
- FC:把多个特征组合,输出 10 类得分
4.3 形状追踪:你必须会“追输出形状”
以你常用结构为例:
输入:[B,1,28,28]
- conv1(1→8, k=3):
[B,8,26,26] - maxpool(2):
[B,8,13,13] - conv2(8→16, k=3):
[B,16,11,11] - maxpool(2):
[B,16,5,5] - flatten:
[B,400](因为 1655=400) - fc1:
[B,40] - fc2:
[B,10](logits)
只要你追得出 400 的来源,你对 CNN 的理解就非常稳。
4.4 logits 是什么?为什么不是直接输出概率?
模型输出[B,10],这不是概率,而是 logits(打分)。
预测类别通常是:
pred=logits.argmax(dim=1)# [B]训练时用CrossEntropyLoss(logits, target),它内部会做 softmax 相关处理,所以你不需要手写 softmax。
5. 训练部分:三件套把“计算结构”变成“会学习的系统”
5.1 训练三件套:loss / backward / optimizer
训练的核心闭环:
- forward:模型输出 logits
- loss:把“错得程度”变成一个数字
- backward:求每个参数的梯度(该往哪改)
- step:更新参数(真的改)
最简骨架:
optimizer.zero_grad()logits=net(data)loss=criterion(logits,target)loss.backward()optimizer.step()5.2 为什么一定要 zero_grad()
PyTorch 默认会把梯度累加;如果不清零,上一个 batch 的梯度会污染当前 batch。
你可以把它理解为:
- 每个 batch 都要单独算一次“该怎么改”;
- 不能把历史梯度一直加下去。
5.3 net.train() 与 net.eval() 的区别(非常关键)
这不是“可有可无”。尤其你用了 Dropout。
net.train():训练模式,Dropout 生效net.eval():评估模式,Dropout 关闭(输出更稳定)
你出现 test acc 比 train acc 高一点,在带 Dropout 的训练里不罕见:训练时更难,评估时更容易。
5.4 为什么 loss 会从 2.3 降下来?
在 10 类分类里,模型完全随机时:
- 每类概率大约 1/10
- 交叉熵大约 (\ln(10) ≈ 2.3026)
所以你最开始看到 loss ≈ 2.3 非常合理。
训练后,模型对正确类别给更高分,loss 就下降。
6. 评估部分:你到底评估了什么?关键函数和关键变量
你目前做的基础评估非常正确:
- 平均 loss(test avg loss)
- 准确率(test accuracy)
评估的关键点:
6.1 评估数据从哪里来?
评估使用的是MNIST 测试集:
test_ds=torchvision.datasets.MNIST('data/',train=False,...)对应你目录里的:
t10k-images-idx3-ubytet10k-labels-idx1-ubyte
6.2 评估为什么要 no_grad()?
评估只需要 forward,不需要 backward。
torch.no_grad()的价值:
- 更快
- 更省显存
- 更不容易内存爆
6.3 评估最核心的统计逻辑是什么?
- logits:
net(data) - loss:
criterion(logits, target) - pred:
logits.argmax(dim=1) - correct:
(pred == target).sum()
最终:
- avg_loss = total_loss / total_samples
- acc = total_correct / total_samples
7. 训练日志该怎么读?你这份结果说明了什么
你的输出:
Epoch 1: train loss=0.6006, train acc=80.70% test loss=0.1068, test acc=96.75% Epoch 2: train loss=0.3271, train acc=89.95% test loss=0.0856, test acc=97.34% Epoch 3: train loss=0.2842, train acc=91.26% test loss=0.0718, test acc=97.60%可以简单做三点结论:
- loss 在下降:说明优化器在持续把模型往正确方向推;
- acc 在上升:说明预测正确率提升;
- test acc 达到 97%+:说明模型已经学到 MNIST 的主特征;
你的模型在 3 个 epoch 已经接近“够用”。如果继续训练可能还会涨,但收益会变小。
8.mnist_cnn.pth到底是什么?怎么理解它的“类型”
8.1.pth不是模型结构,而是模型参数
你保存的是:
torch.save(net.state_dict(),'mnist_cnn.pth')state_dict()是一个“参数字典”:键是参数名,值是张量.pth只是常用后缀名,表示 PyTorch 的保存文件
一句话:
mnist_cnn.pth保存的是你训练得到的“记忆”(权重/偏置)。
8.2 为什么必须同时有 Net 才能用.pth
.pth本身不包含forward的代码逻辑。
可用模型 =Net()+load_state_dict(mnist_cnn.pth)。
8.3 你可以用一行看它内部是什么
importtorch sd=torch.load('mnist_cnn.pth',map_location='cpu')print(type(sd))print(list(sd.keys())[:10])你会看到:它是 dict,里面有conv1.weight、fc2.bias等键。
9. 一份“初学者最推荐”的工程组织方式
你现在的结构已经很健康:
jerry_mnist/ create_model.py # 模型结构 get_data.py # 数据准备(可选) train_model.py # 训练 + 评估 + 保存 eval_basic.py # (建议新增)单独评估 mnist_cnn.pth # 训练结果参数 data/ # MNIST 数据缓存推荐原则:
- 模型结构独立文件(便于复用)
- 训练脚本只管训练
- 评估脚本只管评估
- 推理脚本只管推理
这样你后面做 CIFAR10、做猫狗分类、做 YOLO,都可以复用思路。
10. 代码模板:最简训练脚本(带注释)
下面是一份“看一眼就懂”的训练框架(与你现在的逻辑一致):
# train_model.py(结构示意)# 1) 准备 DataLoader(train_loader / test_loader)# 2) 创建模型 net = Net().to(device)# 3) 定义 loss + optimizer# 4) 循环 epoch:# 4.1 net.train()# 4.2 对 train_loader:zero_grad → forward → loss → backward → step# 4.3 net.eval() + no_grad()# 4.4 对 test_loader:forward → 统计 loss/acc# 5) 保存参数 torch.save(net.state_dict(), 'mnist_cnn.pth')你要掌握的是“骨架”,以后换任何任务都能套进去。
11. 评估进阶:从两个指标走向“知道错在哪”
你现在评估了 loss/acc,这已经是基础合格。
接下来想提高理解和实战能力,建议加三种评估:
- 混淆矩阵:看哪些数字互相混淆
- 每类准确率:哪个类最弱
- 错误样本可视化:错的样本长什么样
这三件事会让你从“知道结果不错”升级到“知道怎么继续提升”。
12. 从 MNIST 走向更大突破:下一步练什么最有效
当你把 MNIST 这一套跑稳后,推荐你按这个顺序升级:
12.1 升级数据:CIFAR-10
- 32×32 彩色图(3 通道)
- 更贴近真实视觉任务
- 你会遇到:数据增强、过拟合、模型更深
12.2 升级模型:更规范的 CNN(BatchNorm、更多层)
- 加 BatchNorm 稳定训练
- 更深的网络、残差结构(ResNet)
12.3 升级任务:目标检测(YOLO)
- 输入输出不再是 10 类得分
- 变为:框位置 + 类别 + 置信度
但注意:无论怎么升级,“主线流程图”依旧成立。
13. 关键知识点清单(建议你定期扫一遍)
数据
- Dataset / DataLoader
[B,C,H,W]含义- transform:ToTensor / Normalize
模型
nn.Module/forward- Conv / Pool / Flatten / Linear
- logits 与 argmax
训练
criterion(CrossEntropyLoss)optimizer(SGD/Adam)zero_grad → forward → loss → backward → steptrain()vseval()
评估
no_grad()- avg loss / accuracy
- 错误分析(混淆矩阵、错例)
保存
state_dict()/load_state_dict().pth是参数存档,不是结构
14. 你目前这份项目的一句话“专业总结”
你已经完成了一个可复用的 PyTorch 图像分类最小工程:
- 使用
torchvision.datasets.MNIST构建训练/测试数据流; - 使用自定义 CNN(两层卷积 + 两层全连接 + Dropout)进行分类;
- 使用
CrossEntropyLoss + SGD完成训练闭环并达到97%+测试准确率; - 使用
state_dict()将训练得到的参数保存为mnist_cnn.pth,可在任意环境中通过load_state_dict()复现推理效果。
15. 附:你可以直接复制的“推理脚本”骨架(可选)
如果你想快速验证.pth的价值:
# predict_one_batch.pyimporttorchimporttorchvisionfromtorch.utils.dataimportDataLoaderfromtorchvisionimporttransformsfromcreate_modelimportNet device='cuda'iftorch.cuda.is_available()else'cpu'# 1) load modelnet=Net().to(device)net.load_state_dict(torch.load('mnist_cnn.pth',map_location=device))net.eval()# 2) test loadertransform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])test_ds=torchvision.datasets.MNIST('data/',train=False,download=True,transform=transform)test_loader=DataLoader(test_ds,batch_size=32,shuffle=True)data,target=next(iter(test_loader))data=data.to(device)withtorch.no_grad():logits=net(data)pred=logits.argmax(dim=1).cpu()print('pred[:10] =',pred[:10].tolist())print('target[:10]=',target[:10].tolist())结尾:你下一步应该做什么(最实用)
如果你想把基础打得更牢,建议你立刻做两个小任务:
- 写一个
eval_basic.py:只输出 test loss / test acc(你已经理解清楚) - 写一个
eval_confusion.py:输出混淆矩阵 + 错例图(知道错在哪)
做完这两步,你会对“评估”不再停留在“跑出一个数字”,而是能解释为什么这样、如何继续提升。
如果你希望我继续按“最小步骤”推进:下一步我会在不增加太多代码复杂度的前提下,带你实现混淆矩阵 + 错例可视化(这一步对形成实战直觉特别有帮助)。