news 2026/4/3 1:33:12

DAY46训练和测试的规范写法

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DAY46训练和测试的规范写法

目录

1. 训练和测试的规范写法:函数封装

2. 展平操作 (Flatten):除 Batch Size 外全部展平

3. Dropout 操作:训练“随机”,测试“全开”


1. 训练和测试的规范写法:函数封装

为了保持代码整洁、可复用,并将“逻辑”与“数据”解耦,标准做法是将训练和测试过程分别封装成独立的函数。

核心思想:无论输入是灰度图(MNIST)还是彩色图(CIFAR-10),训练和测试的函数代码逻辑是完全一样的。区别仅在于传入的dataloader(数据不同)和model(输入层维度不同)。

  • 训练函数train_loop

    • 职责:负责前向传播、计算损失、反向传播更新梯度。

    • 关键代码结构

      def train(dataloader, model, loss_fn, optimizer): model.train() # 1. 开启训练模式 (启用 Dropout/BatchNorm) for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # 2. 搬运数据 pred = model(X) # 3. 前向传播 loss = loss_fn(pred, y) # 4. 计算损失 optimizer.zero_grad() # 5. 梯度清零 loss.backward() # 6. 反向传播 optimizer.step() # 7. 更新参数
  • 测试函数test_loop

    • 职责:负责评估模型性能(Loss 和 Accuracy),不更新参数

    • 关键代码结构

      def test(dataloader, model, loss_fn): model.eval() # 1. 开启评估模式 (关闭 Dropout/锁定 BatchNorm) with torch.no_grad(): # 2. 停止梯度计算 (节省显存) for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) # ... 统计准确率和 Loss

2. 展平操作 (Flatten):除 Batch Size 外全部展平

全连接层(Linear Layer)无法接收 3D 或 4D 的图像张量,只能接收 1D 的特征向量。

  • 操作逻辑:保留第一个维度(Batch Size),将其余所有维度(通道 $C$、高 $H$、宽 $W$)“压扁”成一长条。

  • 维度变化

    • 输入形状:[Batch_Size, Channel, Height, Width]

    • 输出形状:[Batch_Size, Channel * Height * Width]

  • 代码实现:

    通常在模型的 __init__ 中定义,在 forward 中调用。

    self.flatten = nn.Flatten() # 默认 start_dim=1,即从第2个维度开始展平 # 或者在 forward 中直接写: x = x.view(x.size(0), -1)

3. Dropout 操作:训练“随机”,测试“全开”

Dropout 是防止过拟合的重要手段,其行为受到模型模式(Mode)的严格控制。

  • 训练阶段 (model.train())

    • 行为:按照设定的概率 p(如 0.2),随机丢弃部分神经元(输出置 0)。

    • 目的:增加学习难度,迫使神经元独立提取特征,不依赖特定路径,增强鲁棒性。

    • 注意:剩余的活跃神经元数值会被放大(),以保持总能量守恒。

  • 测试阶段 (model.eval())

    • 行为:Dropout完全关闭。所有神经元都参与计算,权重全开。

    • 目的:利用训练好的完整模型能力,输出最稳定、最准确的预测结果。


一句话总结三者关系:

我们在构建模型时加入 Flatten 以适配全连接层,加入 Dropout 以防止过拟合;在代码编写时,通过封装 train (开启 Dropout) 和 test (关闭 Dropout) 函数来规范化整个流程。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/31 0:08:39

用AI实现高效网络诊断:QUICKPING自动化工具开发指南

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 开发一个基于AI的网络诊断工具QUICKPING,功能包括:1. 输入IP或域名自动进行ping测试 2. 使用机器学习分析历史延迟数据预测网络状况 3. 可视化展示网络质量…

作者头像 李华
网站建设 2026/3/27 4:12:59

从2小时到10分钟:DRAW.IO高效绘图技巧大全

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 开发一个DRAW.IO效率工具包,包含:1. 快捷键提示插件,实时显示可用快捷键;2. 批量操作工具,支持同时修改多个图形属性&am…

作者头像 李华
网站建设 2026/4/1 21:54:14

COMFYUI安装图解指南:零基础也能轻松搞定

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 制作一个交互式Jupyter Notebook教程,包含:1. 分步可执行的代码单元格;2. 每个步骤的示意图和说明;3. 常见错误解决方案查询功能&am…

作者头像 李华
网站建设 2026/4/1 23:49:51

OPENARK:AI如何革新传统软件开发流程

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 使用OPENARK平台创建一个智能代码生成器,能够根据用户输入的自然语言描述自动生成Python代码。要求支持常见功能如数据处理、API调用和简单算法实现,并提供…

作者头像 李华
网站建设 2026/4/1 7:28:33

FastAPI实战:构建高性能股票数据API服务

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 开发一个股票数据查询API服务,包含以下功能:1) 实时股票价格查询接口 2) 历史K线数据获取 3) 股票搜索功能 4) 使用Redis缓存热点数据。要求使用FastAPI的异…

作者头像 李华
网站建设 2026/3/14 17:40:49

一文说清树莓派5安装ROS2的核心要点

树莓派5装ROS2,避坑指南:从零开始打造机器人开发平台你是不是也正打算在树莓派5上跑ROS2?想做个小车导航、视觉识别或者多机通信项目,却发现环境配到一半卡住了?别急。我最近刚把一台全新的树莓派5从“裸板”折腾成能跑…

作者头像 李华