news 2026/4/3 8:47:50

MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)

📺B站:博主个人介绍

📘博主书籍-京东购买链接*:Yocto项目实战教程

📘加博主微信,进技术交流群jerrydev


MNIST 入门实战:从数据流到模型训练与评估(含完整代码与流程图)

目标:用 MNIST 把深度学习最核心的一条主线走通——数据 → 模型 → 损失 → 反向传播 → 更新参数 → 评估 → 保存/加载 → 推理。读完你会对“训练到底在做什么”“评估到底评估了什么”“.pth 到底是什么”形成清晰、稳定的概念框架,并具备继续突破到更复杂任务(CIFAR、检测、分割、YOLO、Transformer)的基础。


1. 你已经做到了什么?为什么这条路径很重要

你已经跑通了一个完整的 CNN 分类项目:

  • 数据:MNIST
  • 模型:两层卷积 + 两层全连接(带 Dropout)
  • 训练:CrossEntropyLoss + SGD
  • 评估:test loss + test acc
  • 保存:生成mnist_cnn.pth

并且得到了非常合理的结果:

  • test acc ≈ 97%+
  • test loss 逐步下降

这说明:

  1. 数据流(DataLoader)没问题;
  2. 模型结构(forward)没问题;
  3. 训练闭环(loss → backward → step)成立;
  4. 评估逻辑(eval + no_grad + 统计)正确;
  5. 参数存档(.pth)可复用。

这条路径是所有 AI 任务的“地基”。以后你换数据集、换模型、换任务,本质上还是这条主线,只是每一段更复杂。


2. 全流程总览:一张逻辑图把它串起来

下面这张“主线流程图”非常建议你记住:

┌──────────────┐ │ 数据集 Dataset │ MNIST(train/test) └──────┬───────┘ │ transform(ToTensor/Normalize) ▼ ┌──────────────┐ │ DataLoader │ batch化、shuffle、并行加载 └──────┬───────┘ │ (data:[B,1,28,28], target:[B]) ▼ ┌──────────────┐ │ 模型 Net │ forward: [B,1,28,28] → logits[B,10] └──────┬───────┘ │ ▼ ┌──────────────┐n│ Loss 函数 │ CrossEntropyLoss(logits, target) └──────┬───────┘ │ ▼ ┌──────────────┐ │ backward() │ 自动求梯度: p.grad └──────┬───────┘ │ ▼ ┌──────────────┐ │ optimizer.step│ 更新参数: θ ← θ - lr * grad └──────┬───────┘ │ ▼ ┌──────────────┐ │ 评估 Evaluate │ net.eval + no_grad + 指标统计 └──────┬───────┘ │ ▼ ┌──────────────┐ │ 保存/加载 .pth │ state_dict() / load_state_dict() └──────────────┘

你可以把它理解为:

  • Net是“会计算的结构”;
  • Loss把“对不对”变成“可优化的数字”;
  • Backward给每个参数算出“该往哪改”;
  • Optimizer真正“改参数”;
  • Evaluate用测试集验证“改得值不值”。

3. 数据部分:Dataset / Transform / DataLoader 到底做了什么

3.1 MNIST 是什么数据?

MNIST 是手写数字识别数据集:

  • 图片:28×28 灰度图(单通道)
  • 标签:0~9 十个类别
  • 训练集:60000
  • 测试集:10000

在 PyTorch 里,一条样本通常表现为:

  • image:torch.Tensor,形状[1,28,28]
  • label:int,范围 0…9

3.2 为什么 DataLoader 出来是[B,1,28,28]

你训练时拿到的是一个 batch:

  • data.shape == [B, 1, 28, 28]
  • target.shape == [B]

四维的含义非常固定:

维度含义MNIST 示例
Bbatch_size32
C通道数1(灰度)
H高度28
W宽度28

只要你看到[B,C,H,W],你就知道它是“图片批次输入”。

3.3 transform:ToTensor + Normalize 为什么必做?

你的 transform 典型是:

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
  • ToTensor():把图像从0~255的像素转换为0~1的 float 张量
  • Normalize(mean, std):做标准化:

[
x’ = \frac{x - mean}{std}
]

为什么要标准化?

  • 让输入分布更稳定;
  • 梯度更容易优化;
  • 训练更快、收敛更稳定。

你看到min/max出现负数,就是 Normalize 生效的直接证据。

3.4 你本地的 data/ 目录意味着什么?

你看到的:

data/MNIST/raw/ train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte

说明:

  • 数据源来自torchvision.datasets.MNIST(root='data/', download=True, ...)
  • 评估时使用train=False对应t10k-*(测试集);
  • 训练时使用train=True对应train-*(训练集)。

4. 模型部分:Net 是“结构”,参数是“记忆”

4.1 你写的 Net 到底是什么?

在 PyTorch 里:

  • class Net(nn.Module)定义了网络结构(哪些层、如何连接);
  • forward(self, x)定义了前向计算(输入如何流过这些层);
  • 模型参数(权重、偏置)被nn.Conv2d / nn.Linear等层自动创建。

一句话:

模型 = 结构(Net) + 参数(weights/bias)

4.2 CNN 为什么有效?卷积、池化、全连接各自负责什么

可以把 CNN 分为两段:

  • 特征提取器:Conv/Pool
  • 分类器:Linear

直觉理解:

  • Conv:在局部区域找“笔画、边缘、拐角”
  • Pool:让特征更稳、更省计算
  • FC:把多个特征组合,输出 10 类得分

4.3 形状追踪:你必须会“追输出形状”

以你常用结构为例:

输入:[B,1,28,28]

  • conv1(1→8, k=3):[B,8,26,26]
  • maxpool(2):[B,8,13,13]
  • conv2(8→16, k=3):[B,16,11,11]
  • maxpool(2):[B,16,5,5]
  • flatten:[B,400](因为 1655=400)
  • fc1:[B,40]
  • fc2:[B,10](logits)

只要你追得出 400 的来源,你对 CNN 的理解就非常稳。

4.4 logits 是什么?为什么不是直接输出概率?

模型输出[B,10],这不是概率,而是 logits(打分)。

预测类别通常是:

pred=logits.argmax(dim=1)# [B]

训练时用CrossEntropyLoss(logits, target),它内部会做 softmax 相关处理,所以你不需要手写 softmax。


5. 训练部分:三件套把“计算结构”变成“会学习的系统”

5.1 训练三件套:loss / backward / optimizer

训练的核心闭环:

  1. forward:模型输出 logits
  2. loss:把“错得程度”变成一个数字
  3. backward:求每个参数的梯度(该往哪改)
  4. step:更新参数(真的改)

最简骨架:

optimizer.zero_grad()logits=net(data)loss=criterion(logits,target)loss.backward()optimizer.step()

5.2 为什么一定要 zero_grad()

PyTorch 默认会把梯度累加;如果不清零,上一个 batch 的梯度会污染当前 batch。

你可以把它理解为:

  • 每个 batch 都要单独算一次“该怎么改”;
  • 不能把历史梯度一直加下去。

5.3 net.train() 与 net.eval() 的区别(非常关键)

这不是“可有可无”。尤其你用了 Dropout。

  • net.train():训练模式,Dropout 生效
  • net.eval():评估模式,Dropout 关闭(输出更稳定)

你出现 test acc 比 train acc 高一点,在带 Dropout 的训练里不罕见:训练时更难,评估时更容易。

5.4 为什么 loss 会从 2.3 降下来?

在 10 类分类里,模型完全随机时:

  • 每类概率大约 1/10
  • 交叉熵大约 (\ln(10) ≈ 2.3026)

所以你最开始看到 loss ≈ 2.3 非常合理。

训练后,模型对正确类别给更高分,loss 就下降。


6. 评估部分:你到底评估了什么?关键函数和关键变量

你目前做的基础评估非常正确:

  • 平均 loss(test avg loss)
  • 准确率(test accuracy)

评估的关键点:

6.1 评估数据从哪里来?

评估使用的是MNIST 测试集

test_ds=torchvision.datasets.MNIST('data/',train=False,...)

对应你目录里的:

  • t10k-images-idx3-ubyte
  • t10k-labels-idx1-ubyte

6.2 评估为什么要 no_grad()?

评估只需要 forward,不需要 backward。

torch.no_grad()的价值:

  • 更快
  • 更省显存
  • 更不容易内存爆

6.3 评估最核心的统计逻辑是什么?

  • logits:net(data)
  • loss:criterion(logits, target)
  • pred:logits.argmax(dim=1)
  • correct:(pred == target).sum()

最终:

  • avg_loss = total_loss / total_samples
  • acc = total_correct / total_samples

7. 训练日志该怎么读?你这份结果说明了什么

你的输出:

Epoch 1: train loss=0.6006, train acc=80.70% test loss=0.1068, test acc=96.75% Epoch 2: train loss=0.3271, train acc=89.95% test loss=0.0856, test acc=97.34% Epoch 3: train loss=0.2842, train acc=91.26% test loss=0.0718, test acc=97.60%

可以简单做三点结论:

  1. loss 在下降:说明优化器在持续把模型往正确方向推;
  2. acc 在上升:说明预测正确率提升;
  3. test acc 达到 97%+:说明模型已经学到 MNIST 的主特征;

你的模型在 3 个 epoch 已经接近“够用”。如果继续训练可能还会涨,但收益会变小。


8.mnist_cnn.pth到底是什么?怎么理解它的“类型”

8.1.pth不是模型结构,而是模型参数

你保存的是:

torch.save(net.state_dict(),'mnist_cnn.pth')
  • state_dict()是一个“参数字典”:键是参数名,值是张量
  • .pth只是常用后缀名,表示 PyTorch 的保存文件

一句话:

mnist_cnn.pth保存的是你训练得到的“记忆”(权重/偏置)

8.2 为什么必须同时有 Net 才能用.pth

.pth本身不包含forward的代码逻辑。

可用模型 =Net()+load_state_dict(mnist_cnn.pth)

8.3 你可以用一行看它内部是什么

importtorch sd=torch.load('mnist_cnn.pth',map_location='cpu')print(type(sd))print(list(sd.keys())[:10])

你会看到:它是 dict,里面有conv1.weightfc2.bias等键。


9. 一份“初学者最推荐”的工程组织方式

你现在的结构已经很健康:

jerry_mnist/ create_model.py # 模型结构 get_data.py # 数据准备(可选) train_model.py # 训练 + 评估 + 保存 eval_basic.py # (建议新增)单独评估 mnist_cnn.pth # 训练结果参数 data/ # MNIST 数据缓存

推荐原则:

  • 模型结构独立文件(便于复用)
  • 训练脚本只管训练
  • 评估脚本只管评估
  • 推理脚本只管推理

这样你后面做 CIFAR10、做猫狗分类、做 YOLO,都可以复用思路。


10. 代码模板:最简训练脚本(带注释)

下面是一份“看一眼就懂”的训练框架(与你现在的逻辑一致):

# train_model.py(结构示意)# 1) 准备 DataLoader(train_loader / test_loader)# 2) 创建模型 net = Net().to(device)# 3) 定义 loss + optimizer# 4) 循环 epoch:# 4.1 net.train()# 4.2 对 train_loader:zero_grad → forward → loss → backward → step# 4.3 net.eval() + no_grad()# 4.4 对 test_loader:forward → 统计 loss/acc# 5) 保存参数 torch.save(net.state_dict(), 'mnist_cnn.pth')

你要掌握的是“骨架”,以后换任何任务都能套进去。


11. 评估进阶:从两个指标走向“知道错在哪”

你现在评估了 loss/acc,这已经是基础合格。

接下来想提高理解和实战能力,建议加三种评估:

  1. 混淆矩阵:看哪些数字互相混淆
  2. 每类准确率:哪个类最弱
  3. 错误样本可视化:错的样本长什么样

这三件事会让你从“知道结果不错”升级到“知道怎么继续提升”。


12. 从 MNIST 走向更大突破:下一步练什么最有效

当你把 MNIST 这一套跑稳后,推荐你按这个顺序升级:

12.1 升级数据:CIFAR-10

  • 32×32 彩色图(3 通道)
  • 更贴近真实视觉任务
  • 你会遇到:数据增强、过拟合、模型更深

12.2 升级模型:更规范的 CNN(BatchNorm、更多层)

  • 加 BatchNorm 稳定训练
  • 更深的网络、残差结构(ResNet)

12.3 升级任务:目标检测(YOLO)

  • 输入输出不再是 10 类得分
  • 变为:框位置 + 类别 + 置信度

但注意:无论怎么升级,“主线流程图”依旧成立。


13. 关键知识点清单(建议你定期扫一遍)

数据

  • Dataset / DataLoader
  • [B,C,H,W]含义
  • transform:ToTensor / Normalize

模型

  • nn.Module/forward
  • Conv / Pool / Flatten / Linear
  • logits 与 argmax

训练

  • criterion(CrossEntropyLoss)
  • optimizer(SGD/Adam)
  • zero_grad → forward → loss → backward → step
  • train()vseval()

评估

  • no_grad()
  • avg loss / accuracy
  • 错误分析(混淆矩阵、错例)

保存

  • state_dict()/load_state_dict()
  • .pth是参数存档,不是结构

14. 你目前这份项目的一句话“专业总结”

你已经完成了一个可复用的 PyTorch 图像分类最小工程:

  • 使用torchvision.datasets.MNIST构建训练/测试数据流;
  • 使用自定义 CNN(两层卷积 + 两层全连接 + Dropout)进行分类;
  • 使用CrossEntropyLoss + SGD完成训练闭环并达到97%+测试准确率;
  • 使用state_dict()将训练得到的参数保存为mnist_cnn.pth,可在任意环境中通过load_state_dict()复现推理效果。

15. 附:你可以直接复制的“推理脚本”骨架(可选)

如果你想快速验证.pth的价值:

# predict_one_batch.pyimporttorchimporttorchvisionfromtorch.utils.dataimportDataLoaderfromtorchvisionimporttransformsfromcreate_modelimportNet device='cuda'iftorch.cuda.is_available()else'cpu'# 1) load modelnet=Net().to(device)net.load_state_dict(torch.load('mnist_cnn.pth',map_location=device))net.eval()# 2) test loadertransform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])test_ds=torchvision.datasets.MNIST('data/',train=False,download=True,transform=transform)test_loader=DataLoader(test_ds,batch_size=32,shuffle=True)data,target=next(iter(test_loader))data=data.to(device)withtorch.no_grad():logits=net(data)pred=logits.argmax(dim=1).cpu()print('pred[:10] =',pred[:10].tolist())print('target[:10]=',target[:10].tolist())

结尾:你下一步应该做什么(最实用)

如果你想把基础打得更牢,建议你立刻做两个小任务:

  1. 写一个eval_basic.py:只输出 test loss / test acc(你已经理解清楚)
  2. 写一个eval_confusion.py:输出混淆矩阵 + 错例图(知道错在哪)

做完这两步,你会对“评估”不再停留在“跑出一个数字”,而是能解释为什么这样、如何继续提升


如果你希望我继续按“最小步骤”推进:下一步我会在不增加太多代码复杂度的前提下,带你实现混淆矩阵 + 错例可视化(这一步对形成实战直觉特别有帮助)。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/2 20:29:16

别只用ChatGPT!2026年这5个开源AI工具才是程序员的真正利器

文章目录前言一、 Ollama v0.14.3 本地大模型天花板,新增图像生成直接封神实操代码(新手直接复制)核心亮点二、 Stable Diffusion XL 1.0 开源生图卷王,直逼Midjourney实操代码(本地部署极简版)核心亮点三、…

作者头像 李华
网站建设 2026/3/16 3:01:34

【开题答辩全过程】以 基于Web技术的珠宝首饰网上定制系统的设计与实现为例,包含答辩的问题和答案

个人简介 一名14年经验的资深毕设内行人,语言擅长Java、php、微信小程序、Python、Golang、安卓Android等 开发项目包括大数据、深度学习、网站、小程序、安卓、算法。平常会做一些项目定制化开发、代码讲解、答辩教学、文档编写、也懂一些降重方面的技巧。 感谢大家…

作者头像 李华
网站建设 2026/3/29 0:13:05

【开题答辩全过程】以 基于Java Web的电子商务网站的用户行为分析与个性化推荐系统为例,包含答辩的问题和答案

个人简介 一名14年经验的资深毕设内行人,语言擅长Java、php、微信小程序、Python、Golang、安卓Android等 开发项目包括大数据、深度学习、网站、小程序、安卓、算法。平常会做一些项目定制化开发、代码讲解、答辩教学、文档编写、也懂一些降重方面的技巧。 感谢大家…

作者头像 李华
网站建设 2026/4/1 1:45:11

基于51单片机的社区火灾报警辅助系统设计

基于51单片机的社区火灾报警辅助系统设计 一、设计背景与意义 社区火灾防控是民生安全的重要环节,传统火灾报警设备多依赖单一烟雾传感器,存在误报率高、报警方式单一、无定位功能等问题,且高端智能报警系统成本高、部署难度大,难…

作者头像 李华
网站建设 2026/3/29 9:37:28

方盾说说煤矿工口罩使用科学建议

煤矿作业各环节会产生大量煤尘、岩尘,其中粒径小于5微米的呼吸性粉尘危害最甚,长期吸入易引发不可逆职业病。口罩作为个人呼吸防护的最后一道屏障,选对、戴好、用对至关重要。结合井下高粉尘、高潮湿、高强度的作业特点,现提供无品…

作者头像 李华