文章目录
- 摘要
- Abstract
- U-Net架构
- Encoder模块
- Decoder模块
- U-Net代码
- 总结
摘要
本周主要学习了 U-Net 的整体网络架构与实现思路,重点理解编码器-解码器的 U 形结构、下采样与上采样流程,以及跳跃连接在融合语义信息与定位细节中的作用。同时复习了卷积、池化以及反卷积的计算原理,并结合相关代码进行简单梳理与验证,加深了对模型前向传播与输出预测过程的理解。
Abstract
This week, I mainly studied the overall network architecture and implementation approach of U-Net, focusing on understanding the U-shaped encoder-decoder structure, the downsampling and upsampling processes, as well as the role of skip connections in integrating semantic information and localization details. At the same time, I reviewed the computational principles of convolution, pooling, and transposed convolution, and conducted a simple review and verification with related code, which deepened my understanding of the model’s forward propagation and output prediction process.
U-Net架构
Encoder模块
Encoder模块也称为收缩路径,可以理解为特征提取网络。对于分割问题来说,我们首先需要理解图像的语义信息,Encoder模块的核心就是逐层地抽象去理解图像的语义信息。在Encoder模块中我们使用卷积层不断提取图像特征,用最大池化层不断缩小图片的尺寸,实质就是用越来越少的编码描述原信息,即上述提到的信息抽象。该模块可以概括为,不断进行卷积与池化操作获得面向任务的高层信息。当图片经过了Encoder模块的多次池化后,得到的每张feature map都是低分辨率的,这些低分辨率的feature maps反映了像素级的语义信息,可以理解为每个像素打上了类别标签。
卷积操作:
不管输入图像有多少个通道,一个卷积核在一次卷积中只会生成一个输出特征图,可以理解为该卷积核专门响应某一类图像模式/特征。如果希望同时捕捉多种不同特征,就需要配置多个卷积核,每个卷积核对应一种特征提取。以图1为例,Encoder 的第一层包含两次卷积运算,且每次使用 32 个卷积核;第二层同样进行两次卷积,但卷积核数量提升到 64,后续层依次递增。
最大池化:
从信号处理角度看,max pooling 更倾向于保留局部区域内最显著的响应,因此有助于突出图像中更关键、更有辨识度的特征。在分割任务中通常较少采用 mean pooling,因为平均池化会把激活“抹平”,容易削弱强特征,使重要信息不够突出。因此实践里常见的搭配是“卷积 + 最大池化”。多次最大池化的过程本质上是一种逐步的特征筛选与信息压缩,类似于不断从显著特征中再提炼更显著、更抽象的高层信息,以服务后续任务需求。
Decoder模块
Decoder也称之为拓展模块,可以理解为特征融合网络,用于将高级抽象特征恢复为高分辨率,简单来说就是将像素级的语义信息在原图中呈现。如图右侧Decoder模块所示,feature map在进行高分辨恢复的过程中,实质就是将原来的浅层位置信息(每个像素对应回图片的位置)和深层的语义信息进行融合,得到最终带有像素级语义信息的图片。解码器模块的核心操作是上采样和跳层连接。
反卷积上采样:利用反卷积核对原图像进行上采样
U-Net代码
U-Net 直接把原图作为输入,先用两次 3×3 卷积加 ReLU 把基础特征提出来。左边是编码器,一共下采样 4 次:每次先做 2×2 最大池化把尺寸缩小,再把通道数加倍,并继续用双卷积细化特征。中间的瓶颈层特征图最小,但包含的语义信息最多。右边是解码器,对应上采样 4 次:先用 2×2 转置卷积把尺寸放大,再把左边同层的特征图拼过来,接着用双卷积做融合。最后用 1×1 卷积把特征映射成每个像素的类别。代码如下:
# 双卷积块classDoubleConv(nn.Module):def__init__(self,in_channels,out_channels):super(DoubleConv,self).__init__()# 初始化self.conv=nn.Sequential(# 第一个3x3卷积(无填充,输出尺寸会缩小)nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=0),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),# 直接修改输入张量,如果需在反向传播中使用原始输入则不能使用# 第二个3x3卷积(无填充)nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=0),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))defforward(self,x):returnself.conv(x)整个网络包含23个卷积层,全部使用无填充卷积,输出尺寸比输入小一定的边框宽度。整体网络结构代码如下:
classUNetOriginal(nn.Module):def__init__(self,in_channels=1,out_channels=2):super(UNetOriginal,self).__init__()# 编码器路径(下采样)# 第一层self.enc1=DoubleConv(in_channels,64)self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)# 第二层self.enc2=DoubleConv(64,128)self.pool2=nn.MaxPool2d(kernel_size=2,stride=2)# 第三层self.enc3=DoubleConv(128,256)self.pool3=nn.MaxPool2d(kernel_size=2,stride=2)# 第四层self.enc4=DoubleConv(256,512)self.pool4=nn.MaxPool2d(kernel_size=2,stride=2)# 瓶颈层self.bottleneck=DoubleConv(512,1024)# 解码器(上采样)# 上采样4层self.upconv4=nn.ConvTranspose2d(1024,512,kernel_size=2,stride=2)self.dec4=DoubleConv(1024,512)# 输入通道 512(上采样) + 512(跳跃连接)# 上采样3层self.upconv3=nn.ConvTranspose2d(512,256,kernel_size=2,stride=2)self.dec3=DoubleConv(512,256)# 上采样2层self.upconv2=nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)self.dec2=DoubleConv(256,128)# 上采样1层self.upconv1=nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)self.dec1=DoubleConv(128,64)# 输出层self.out_conv=nn.Conv2d(64,out_channels,kernel_size=1)defforward(self,x):# 编码器路径# 第一层e1=self.enc1(x)# [B, 64, H-, W-] -> [B, 64, H-4, W-4]p1=self.pool1(e1)# [B, 64, H-4, W-4] -> [B, 64, (H-4)/2, (W-4)/2]# 第二层e2=self.enc2(p1)# [B, 64, (H-4)/2, (W-4)/2] -> [B, 128, (H-4)/2-4, (W-4)/2-4]p2=self.pool2(e2)# [B, 128, (H-4)/2-4, (W-4)/2-4] -> [B, 128, (H-4)/4-2, (W-4)/4-2]# 第三层e3=self.enc3(p2)# [B, 128, (H-4)/4-2, (W-4)/4-2] -> [B, 256, (H-4)/4-6, (W-4)/4-6]p3=self.pool3(e3)# [B, 256, (H-4)/4-6, (W-4)/4-6] -> [B, 256, (H-4)/8-3, (W-4)/8-3]# 第四层e4=self.enc4(p3)# [B, 256, (H-4)/8-3, (W-4)/8-3] -> [B, 512, (H-4)/8-7, (W-4)/8-7]p4=self.pool4(e4)# [B, 512, (H-4)/8-7, (W-4)/8-7] -> [B, 512, (H-60)/16, (W-60)/16]# 瓶颈层bottleneck=self.bottleneck(p4)# [B, 512, (H-60)/16, (W-60)/16] -> [B, 1024, (H-124)/16, (W-124)/16]# 解码器(含跳跃连接)# 上采样4层up4=self.upconv4(bottleneck)# [B, 1024, (H-124)/16, (W-124)/16] -> [B, 512, (H-124)/8, (W-124)/8]e4_cropped=self._crop_and_concat(e4,up4)# 裁剪编码器特征图 e4 以匹配上采样特征图 up4 尺寸merge4=torch.cat([e4_cropped,up4],dim=1)# 拼接,[B, 512, (H-124)/8, (W-124)/8] -> [B, 1024, (H-124)/8, (W-124)/8]d4=self.dec4(merge4)# [B, 1024, (H-124)/8, (W-124)/8] -> [B, 512, (H-156)/8, (W-156)/8]# 上采样3层up3=self.upconv3(d4)# [B, 512, (H-156)/8, (W-156)/8] -> [B, 256, (H-156)/4, (W-156)/4]e3_cropped=self._crop_and_concat(e3,up3)merge3=torch.cat([e3_cropped,up3],dim=1)# [B, 256, (H-156)/4, (W-156)/4] ->[B, 512, (H-156)/4, (W-156)/4]d3=self.dec3(merge3)# [B, 512, (H-156)/4, (W-156)/4] -> [B, 256, (H-172)/4, (W-172)/4]# 上采样2层up2=self.upconv2(d3)# [B, 256, (H-172)/4, (W-172)/4]-> [B, 128, (H-172)/2, (W-172)/2]e2_cropped=self._crop_and_concat(e2,up2)merge2=torch.cat([e2_cropped,up2],dim=1)# [B, 128, (H-172)/2, (W-172)/2]-> [B, 256, (H-172)/2, (W-172)/2]d2=self.dec2(merge2)# [B, 256, (H-172)/2, (W-172)/2]-> [B, 128, (H-180)/2, (W-180)/2]# 上采样1层up1=self.upconv1(d2)# [B, 128, (H-180)/2, (W-180)/2]-> [B, 64, H-180, W-180]e1_cropped=self._crop_and_concat(e1,up1)merge1=torch.cat([e1_cropped,up1],dim=1)# [B, 64, H-180, W-180]-> [B, 128, H-180, W-180]d1=self.dec1(merge1)# [B, 128, H-180, W-180]-> [B, 64, H-184, W-184]# 输出层output=self.out_conv(d1)# [B, 64, H-184, W-184] -> [B, out_channels, H-184, W-184]returnoutput在U-Net中,由于编码器使用无填充卷积,特征图的尺寸会逐渐缩小,在解码器中,又通过上采样恢复特征图尺寸。此时,编码器与解码器对应层的特征图尺寸可能不一致,进而导致二者无法通过跳跃连接拼接起来。故需要对编码器的特征图进行裁剪,使其与上采样后的特征图尺寸一致。
# 裁剪函数def_crop_and_concat(self,encoder_feature,decoder_feature):# 获取并计算输入特征图尺寸差异delta_h=encoder_feature.size()[2]-decoder_feature.size()[2]# 高delta_w=encoder_feature.size()[3]-decoder_feature.size()[3]# 宽# 计算裁剪边界(顶部、底部、左侧、右侧)top=delta_h//2bottom=delta_h-top left=delta_w//2right=delta_w-left# 应用裁剪cropped=encoder_feature[:,:,top:encoder_feature.size()[2]-bottom,left:encoder_feature.size()[3]-right]returncropped弹性形变是 U-Net数据增强的核心:通过生成并平滑随机位移场,模拟组织的自然拉伸与扭曲,在标注少时有效扩充数据、提升泛化。它主要由两个参数控制:一个决定形变幅度(位移大小),另一个决定平滑程度。位移场通常按高斯分布随机生成,以贴近真实组织的连续、平滑变形。
classElasticDeformation:# 初始化def__init__(self,alpha=10,sigma=5):self.alpha=alpha# 位移场的强度self.sigma=sigma# 位移场的高斯滤波标准差# 应用def__call__(self,image,mask=None):H,W=image.shape[:2]# 获取图像的高与宽# 在粗网格上生成随机位移场grid_size=3# 网格尺寸 3x3# 创建坐标,meshgrid 主要创建坐标网格,返回每个位置的x坐标与y坐标grid_x,grid_y=np.meshgrid(np.linspace(0,H,grid_size),# 生成 grid_size 个点,均匀分布在[0, H]区间np.linspace(0,W,grid_size))# 生成随机位移 (np.random.randn,从标准正态分布中生成随机数)displacement_x=np.random.randn(grid_size,grid_size)*self.alpha displacement_y=np.random.randn(grid_size,grid_size)*self.alpha# 使用双三次插值将位移场扩展到图像尺寸points=(np.linspace(0,H,grid_size),np.linspace(0,W,grid_size))# 创建插值器interp_x=RegularGridInterpolator(points,# 网格点坐标displacement_x,# 网格点上的值method='cubic',# 双三次插值bounds_error=False,# 允许超出边界的查询fill_value=0# 超出边界时填充0)interp_y=RegularGridInterpolator(points,displacement_y,method='cubic',bounds_error=False,fill_value=0)# 生成图像坐标网格coords_x,coords_y=np.meshgrid(np.arange(H),np.arange(W),indexing='ij')coords=np.stack([coords_x.ravel(),coords_y.ravel()],axis=1)# 计算每个像素的位移dx=interp_x(coords).reshape(H,W)dy=interp_y(coords).reshape(H,W)# 应用位移map_x=coords_x+dx map_y=coords_y+dy# 确保坐标在边界内map_x=np.clip(map_x,0,H-1)map_y=np.clip(map_y,0,W-1)# 重映射图像iflen(image.shape)==2:# 灰度图像# map_coordinates 实现逆向映射,即对于输出图像的每个位置,找到输入图像中对应的位置deformed_image=map_coordinates(image,# 原始图像[map_x,map_y],# 坐标映射order=3,# 双三插值mode='reflect'# 边界处理方式为反射填充)else:# 彩色图像deformed_image=np.stack([map_coordinates(image[:,:,c],[map_x,map_y],order=3,mode='reflect')forcinrange(image.shape[2])],axis=2)# 掩码处理,采用最近邻插值,保持标签完整ifmaskisnotNone:deformed_mask=map_coordinates(mask,[map_x,map_y],order=0,mode='reflect')# 最近邻插值returndeformed_image,deformed_maskreturndeformed_image插值是根据已知离散点的数据来估计未知点的值。双三次插值是一种二维插值方法,它通过一个多项式来逼近函数,这个多项式是三次的,并且在两个维度上都进行三次插值;最近邻插值则是最简单、最快的插值方法,主要是对于目标图像中的每个像素,找到源图像中距离最近的像素,然后直接使用该像素的值。
总结
通过本周学习,我对 U-Net 的核心设计思路有了更系统的理解,能够从编码器-解码器结构、下采样与上采样机制以及跳跃连接的作用出发,解释其在分割任务中兼顾语义表达与细节定位的原因。同时复习了卷积、池化和反卷积的计算原理,并结合代码进行简单验证与整理,进一步巩固了对分割网络实现流程的认识,为后续模型训练与改进打下基础。