以resnet50 的网络结构为例,解析为什么有些算子不能融合
为什么 op1 和 op2 不能融合?
快速答案
op1 = BatchNorm + ReLU
op2 = MaxPool2D
它们不能融合的核心原因是:MaxPool2D 的复杂访问模式与 BatchNorm 的顺序写入不兼容。
详细分析
op1 的特征(BatchNorm + ReLU)
op1: SchedulerNode(ComputedBuffer) ├── 输入: buf0 [2, 64, 112, 112] ← 来自 Conv2D ├── 输出: buf1 [2, 64, 112, 112] ├── 操作: BatchNorm + ReLU │ ├── sub (减去均值) │ ├── sqrt, reciprocal (标准化) │ ├── mul (缩放) │ ├── add (偏移) │ └── relu (激活) └── 访问模式: 顺序访问,一对一映射 每个输入元素 → 计算 → 一个输出元素关键代码(第40-60行):
load=ops.load('buf0',get_index)# 读取输入# ... BatchNorm 计算 ...relu=ops.relu(add_1)# ReLUstore=ops.store('buf1',get_index,relu)# 写入输出特点:
- 简单的逐元素计算
- 顺序访问内存
- 输入输出尺寸相同
op2 的特征(MaxPool2D)
op2: SchedulerNode(ComputedBuffer) ├── 输入: buf1 [2, 64, 112, 112] ← 来自 op1 ├── 输出: buf2 [2, 64, 56, 56] ← 尺寸减半! ├── 操作: MaxPool2D (kernel=3x3, stride=2) └── 访问模式: 复杂的窗口访问 每个输出需要读取 9 个输入元素(3x3窗口)关键依赖(第68-77行):
op2.unmet_dependencies=[# 9 个不同的内存位置!MemoryDep('buf1',...,+64),# 右MemoryDep('buf1',...,+7104),# 右下MemoryDep('buf1',...,+7168),# 下MemoryDep('buf1',...,+7232),# 右下MemoryDep('buf1',...,-64),# 左MemoryDep('buf1',...,-7104),# 左上MemoryDep('buf1',...,-7168),# 上MemoryDep('buf1',...,-7232),# 左上MemoryDep('buf1',...,0)# 中心]关键代码(第118-200+行):
# 读取 9 个位置的值masked_subblock1=...# 左上masked_subblock2=...# 上masked_subblock3=...# 右上# ... 更多子块 ...# 取最大值maximum=ops.maximum(masked_subblock1,masked_subblock2)maximum_1=ops.maximum(maximum,masked_subblock3)# ...问题:
- 随机访问:每个输出需要读取 9 个不同位置的输入
- 跨行访问:stride=7168 表示跨行读取
- 条件判断:大量边界检查(ge, lt, and_)
- 尺寸不匹配:输出是输入的 1/4
不能融合的 4 个核心原因
1. 迭代空间不匹配(最关键)
# op1op1.group.iteration=(1605632,1)# 2*64*112*112 = 1,605,632 元素op1.sizes=([25088,64],[])# op2op2.group.iteration=(401408,1)# 2*64*56*56 = 401,408 元素op2.sizes=([2,56,56,64],[])问题:
- op1 产生 1,605,632 个元素
- op2 只需要 401,408 个元素
- 比例 4:1(因为 MaxPool stride=2, 尺寸减半,面积变为 1/4)
如果融合会怎样?
- 无法在一个统一的循环中同时计算
- op1 需要循环 1,605,632 次
- op2 只需要循环 401,408 次
- 无法对齐!
2. 复杂的访问模式(最关键)
op1 的输出(顺序写入): ┌─────┬─────┬─────┬─────┐ │ 0 │ 1 │ 2 │ 3 │ → 顺序写入 buf1[0], buf1[1], buf1[2], ... ├─────┼─────┼─────┼─────┤ │ 4 │ 5 │ 6 │ 7 │ └─────┴─────┴─────┴─────┘ op2 的读取(窗口访问): ┌─────┬─────┬─────┐ │ -64 │ 0 │ +64 │ ← 每次需要读取 3x3=9 个位置 ├─────┼─────┼─────┤ │-7168│ │+7168│ ├─────┼─────┼─────┤ │-7232│ │+7232│ └─────┴─────┴─────┘问题:
- op1 每次只写一个位置
- op2 每次需要读取 9 个位置
- 如果融合,op1 需要等待 9 个相邻元素都计算完成
- 破坏了并行性!
3. 数据依赖复杂
# op1 的输出 buf1 的第 0 个元素会被 op2 的多个输出使用buf1[0]被以下 op2 的输出位置使用:-buf2[0](作为中心)-buf2[相邻位置1](作为窗口的一部分)-buf2[相邻位置2](作为窗口的一部分)-...问题:
- 一对多的关系
- 需要额外的同步机制
- 增加融合的复杂度
4. 内存重用模式不同
# op1op1.users=[NodeUser(node=op2,can_inplace=False)]# ^^^^^^^^^^^^^^^^# 不能原地操作!为什么 can_inplace=False?
- MaxPool 需要读取窗口内的多个值
- 如果原地修改,会破坏后续读取的数据
- 必须先读取所有需要的输入,再写入输出
如果 can_inplace=True(如 Add + ReLU):
# 可以边读边写x=load(buf0,i)y=relu(add(x,bias))store(buf0,i,y)# 原地写回但 MaxPool 不行:
# 必须先读完再写values=[load(buf1,i-64),load(buf1,i),load(buf1,i+64),...]result=max(values)store(buf2,j,result)# 不能写回 buf1对比:能融合的例子(op9 + op10)
让我们对比一个能融合的例子:
# 假设 op9 = Add, op10 = ReLUop9:y=x+bias ├── 输入:[2,256,56,56]├── 输出:[2,256,56,56]← 尺寸相同! └── 访问:y[i]=x[i]+bias op10:z=relu(y)├── 输入:[2,256,56,56]├── 输出:[2,256,56,56]← 尺寸相同! └── 访问:z[i]=relu(y[i])← 一对一!可以融合!
# 融合后fused:z[i]=relu(x[i]+bias)为什么能融合?
- 迭代空间相同
- 访问模式简单(一对一)
- 可以原地操作
- 没有复杂依赖
总结
op1 (BatchNorm + ReLU) vs op2 (MaxPool2D) 不能融合
| 维度 | op1 -> op2 | 能否融合 |
|---|---|---|
| 迭代空间 | 1,605,632 -> 401,408 (4:1) | 不匹配 |
| 访问模式 | 顺序写 -> 窗口读(9 个位置) | 不兼容 |
| 输出尺寸 | [112, 112] -> [56, 56] | 不同 |
| 原地操作 | can_inplace=False | 不支持 |
| 数据依赖 | 一对多(每个输入被多个输出使用) | 复杂 |
能融合的典型模式
| 模式 | 特点 | 示例 |
|---|---|---|
| Pointwise -> Pointwise | 一对一映射 | Add + ReLU |
| BatchNorm -> ReLU | 顺序操作 | BN + ReLU |
| Elementwise ops | 相同形状 | Mul + Add |
不能融合的典型模式
| 模式 | 原因 | 示例 |
|---|---|---|
| Reduce -> Pointwise | 尺寸改变 | MaxPool + Conv(就是这个!) |
| Pointwise -> Reduce | 访问模式不同 | Conv + MaxPool |
| 外部 Kernel | 已优化 | Conv + BN |
如何验证?
方法 1: 查看 IR 文件
# 搜索融合节点grep"FusedSchedulerNode"ir_post_fusion.txt# 如果 op1 和 op2 融合了,你会看到:# fused_op1_op2: FusedSchedulerNode([op1, op2])# 但实际上它们是分开的:# op1: SchedulerNode(ComputedBuffer)# op2: SchedulerNode(ComputedBuffer)方法 2: 使用分析工具
python analyze_fusion_diff.py# 输出会显示:# ✓ 找到 X 个融合节点# ✓ 但 op1 和 op2 不在其中方法 3: 查看 pre-fusion vs post-fusion
# 如果融合了,post-fusion 中会少一个节点# 但这里 op1 和 op2 在两个文件中都存在diffir_pre_fusion.txt ir_post_fusion.txt|grep-A5"op1\|op2"能否强制融合?
理论上可以,但不推荐
# 如果强制融合,需要:1.在 op1 中生成所有1,605,632个元素2.每4个 op1 输出对应1个 op2 输出3.在融合的 kernel 中插入复杂的窗口读取逻辑4.处理边界条件和同步# 结果:-代码复杂度暴增-寄存器压力增加-可能反而变慢正确做法
让它们分开!
- op1 (BatchNorm + ReLU) 已经融合了,很好
- op2 (MaxPool) 单独执行,使用硬件优化的 kernel
- 中间结果 buf1 通过 L2 cache 传递,开销很小
关键要点
- op1 和 op2 不融合是正确的决策
- MaxPool 的复杂访问模式是主要原因
- 迭代空间不匹配(4:1)无法克服
- 分开执行反而更高效
- 这是 PyTorch Inductor 的智能决策
不是所有相邻操作都应该融合!