news 2026/4/3 6:20:51

04_残差网络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
04_残差网络

描述

残差网络是现代卷积神经网络的一种,有效的抑制了深层神经网络的梯度弥散和梯度爆炸现象,使得深度网络训练不那么困难。

下面以cifar-10-batches-py数据集,实现一个ResNet18的残差网络,通过继承nn.Module实现残差块(Residual Block),网络模型类。

定义Block

ResNetBlock派生至nn.Module,需要自己实现forward函数。

torch.nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法,可以从这个类派生自己的模型类。

nn.Module重要的函数:

  • forward(self,*input):forward函数为前向传播函数,需要自己重写,它用来实现模型的功能,并实现各个层的连接关系;
  • __call__(self, *input, **kwargs): __call__()的作用是使class实例能够像函数一样被调用,以“对象名()”的形式使用;
  • __repr__(self):__repr__函数为Python的一个内置函数,它能把一个对象用字符串的形式表达出来;
  • __init__(self):构造函数,自定义模型的网络层对象一般在这个函数中定义。
classResNetBlock(nn.Module):def__init__(self,input_channels,num_channels,stride=1):''' 构造函数:定义网络层 '''super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=stride)self.btn1=nn.BatchNorm2d(num_channels)self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1,stride=1)self.btn2=nn.BatchNorm2d(num_channels)ifstride!=1:self.downsample=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=stride)else:self.downsample=lambdax:xdefforward(self,X):''' 实现反向传播 '''Y=self.btn1(self.conv1(X))Y=nn.functional.relu(Y)Y=self.btn2(self.conv2(Y))Y+=self.downsample(X)returnnn.functional.relu(Y)

定义模型

ResNet同样派生于nn.Module,与ResNetBlock类似,需要实现forward。

torch.nn.Sequential是PyTorch 中一个用于构建顺序神经网络模型的容器类,它将多个神经网络层或模块按顺序组合在一起,简化模型搭建过程。‌Sequential器会严格按照添加的顺序执行内部的子模块,前向传播时自动传递数据,适用于简单神经网络的构建。

classResNet(nn.Module):def__init__(self,layer_dism,num_class=10):''' 构造函数:定义预处理model;构建block层 '''super(ResNet,self).__init__()# 预处理self.stem=nn.Sequential(nn.Conv2d(3,64,3,1),# 3x30x30nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2,2)# 64x15x15)self.layer1=self.build_resblock(64,64,layer_dism[0])self.layer2=self.build_resblock(64,128,layer_dism[1],2)self.layer3=self.build_resblock(128,256,layer_dism[2],2)self.layer4=self.build_resblock(256,512,layer_dism[3],2)self.avgpool=nn.AvgPool2d(1,1)self.btn=nn.Flatten()self.fc=nn.Linear(512,num_class)defbuild_resblock(self,input_channels,num_channels,block,stride=1):res_block=nn.Sequential()res_block.append(ResNetBlock(input_channels,num_channels,stride))for_inrange(1,block):res_block.append(ResNetBlock(num_channels,num_channels,stride))returnres_blockdefforward(self,X):out=self.stem(X)out=self.layer1(out)out=self.layer2(out)out=self.layer3(out)out=self.layer4(out)out=self.avgpool(out)returnself.fc(self.btn(out))

模型训练

加载数据

使用torchvision.datasets加载本地数据,如果本地没有数据,可以设置download=True自动下载。

# 定义数据转换transform=transforms.Compose([transforms.ToTensor(),# 将PIL图像转换为Tensortransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))# 归一化])# 加载CIFAR-10训练集trainset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=True,download=False,transform=transform)trainloader=th.utils.data.DataLoader(trainset,batch_size=16,shuffle=False,num_workers=2)# 加载CIFAR-10测试集testset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=False,download=False,transform=transform)testloader=th.utils.data.DataLoader(testset,batch_size=16,shuffle=False,num_workers=2)

模型初始化

模型初始化是确保网络能够有效学习的关键步骤,一个好的初始值,会使模型收敛速度提高,使模型准确率更精确。

torch.nn.init模块提供了一系列的权重初始化函数:

  • torch.nn.init.uniform_ :均匀分布
  • torch.nn.init.normal_ :正态分布
  • torch.nn.init.constant_:初始化为指定常数
  • torch.nn.init.kaiming_uniform_:凯明均匀分布
  • torch.nn.init.kaiming_normal_:凯明正态分布
  • torch.nn.init.xavier_uniform_:Xavier均匀分布
  • torch.nn.init.xavier_normal_:Xavier正态分布

在初始化时,最好不要将模型的参数初始化为0,因为这样会导致梯度消失,进而影响训练效果。可以将模型初始化为一个很小的值,如0.01,0.001等。

definitialize_weight(m):ifisinstance(m,nn.Conv2d)orisinstance(m,nn.Linear):nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')# mode:权重方差计算方式,可选 'fan_in' 或 'fan_out'(输入、输出神经元数量)# nonlinearity:激活函数类型,用于调整计算公式 ,一般是relu、leaky_reluifm.biasisnotNone:nn.init.constant_(m.bias,0)

[2,2,2,2] 参数分别代表四个block的中的残差块数量(可以仔细看一下build_resblock函数)

resnet_18=ResNet([2,2,2,2])resnet_18.apply(initialize_weight)# 初始化模型loss_cross=nn.CrossEntropyLoss()trainer=th.optim.SGD(resnet_18.parameters())

训练

训练过程比较漫长,这里训练只有20轮,测试精度0.51。如果有N卡加持的话,可以适当调高epoch,精度能进一步提高。

forepochinrange(0,20):running_loss=0.0forinputs,labelsintrainloader:trainer.zero_grad()outputs=resnet_18(inputs)loss=loss_cross(outputs,labels)loss.backward()trainer.step()running_loss+=loss.item()print(f'[{epoch+1}] ev loss:{running_loss/3125}')running_loss=0.0
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/1 20:42:08

如何构建高可信度的康复运动指导 Agent?:9步打造符合临床标准的系统

第一章:康复运动指导 Agent 的核心价值与临床意义在数字化医疗快速发展的背景下,康复运动指导 Agent 作为人工智能与临床康复深度融合的产物,正逐步成为患者功能恢复过程中的关键支持工具。这类智能体不仅能够根据个体化数据动态调整运动方案…

作者头像 李华
网站建设 2026/3/5 21:24:32

【保姆级图文步骤】VSCode整合Markdown制作思维导图

【图文步骤】VSCode整合Markdown制作思维导图 提示:帮帮志会陆续更新非常多的IT技术知识,希望分享的内容对您有用。本章分享的是VSCode整合Markdown。 所有文章都*不会*直接把代码放那里,让您自己去看去理解。我希望我的内容对您有用而努力…

作者头像 李华
网站建设 2026/4/1 20:02:41

基于STM32的智能车库管理系统(有完整资料)

资料查找方式:特纳斯电子(电子校园网):搜索下面编号即可编号:T5052310M设计简介:本设计是基于STM32的智能车库管理系统,主要实现以下功能:通过RFID卡读卡器记录车辆信息 通过红外传感…

作者头像 李华
网站建设 2026/3/30 13:09:05

量子-经典混合Agent系统设计(稀缺架构图首次公开)

第一章:量子-经典混合Agent系统设计(稀缺架构图首次公开)在当前人工智能与量子计算交叉演进的前沿领域,量子-经典混合Agent系统正成为突破传统算力瓶颈的关键架构。该系统融合了经典深度学习模型的语义理解能力与量子处理器在高维…

作者头像 李华
网站建设 2026/3/26 14:16:05

40、深入了解 Samba:资源、守护进程与客户端程序详解

深入了解 Samba:资源、守护进程与客户端程序详解 1. Samba 额外资源 在使用 Samba 的过程中,你可能会需要在线获取相关新闻、更新和帮助,以下是一些可利用的资源: - 文档和常见问题解答(FAQs) :Samba 附带了大量的文档文件,值得你花时间浏览。你可以在计算机的发行…

作者头像 李华