news 2026/4/3 3:33:25

【Agent】生成式隐式记忆 MemGen 源码解读

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【Agent】生成式隐式记忆 MemGen 源码解读

x01 背景

MemGen 提出动态生成式记忆框架,由记忆触发器与记忆编织器两个轻量模块协同构成,旨在突破现有智能体记忆范式的局限。

当前主流的记忆实现路径为:

参数化记忆通过微调将经验编码进模型参数,虽能深度内化知识却易引发灾难性遗忘;

基于检索的记忆将经验外化存储,虽规避了遗忘问题,但静态的一次性检索机制无法体现记忆与推理动态交互的认知特性。

这一现状引出两大核心问题:如何实现记忆与推理在每一步思考中的无缝耦合,以及如何让记忆从提取式升级为满足当前需求的生成式重构,而动态生成式隐式记忆正是应对这些挑战的第三种探索路径。

0x02 源码解析

MemGen项目旨在创建一个动态且自生成的记忆框架,该框架由两个协同工作的轻量级模块组成:一个基于强化学习训练的记忆触发器和一个记忆编织器。这一框架的核心思想是解决大型语言模型(LLM)智能体能力涌现时对“自进化”机制的探索需求,其中记忆扮演关键角色。

2.1 模型

LatentMemoryModel 是 MemGen 框架的核心实现,旨在构建动态生成式隐式记忆系统,解决传统记忆范式的局限性。通过整合推理器(Reasoner)、记忆编织器(Weaver)和记忆触发器(Trigger),实现记忆与推理过程的无缝耦合,让智能体在任务执行中动态生成、使用记忆,而非依赖静态检索或参数化存储。

2.1.1 核心特色

模型的核心特色如下:

模块化协同设计:由推理器(核心推理)、编织器(生成潜在记忆)、触发器(控制记忆触发)三大模块构成,模块间通过投影层实现嵌入空间映射,结构清晰且解耦。

动态记忆增强:在推理过程中自动识别分隔符位置作为记忆增强点,动态插入编织器生成的潜在记忆,突破静态记忆注入的局限,贴合人类认知中记忆与推理的动态交互特性。

精度与效率优化:默认使用 bfloat16 精度,推理器采用 Flash Attention 2 提升计算效率;冻结推理器参数,仅训练编织器和触发器,实现参数高效学习。

灵活配置与兼容性:支持自定义触发器模型、PEFT 微调配置、记忆增强次数等参数;自动处理 Tokenizer 缺失 pad token 的问题,标准化对话模板,提升跨场景兼容性。

损失计算精准过滤:通过潜在记忆掩码排除记忆嵌入对应的位置,仅对原始输入位置计算损失,确保训练目标聚焦于核心任务性能,避免记忆生成过程干扰主任务学习。

2.1.2 网络结构

关键说明(核心设计亮点)

三大模块协同逻辑:

推理器(Reasoner):核心推理组件,权重冻结以保留基础能力,仅通过潜在记忆调整解码路径。

触发器(MemGenTrigger):动态判断记忆插入时机,输出二分类触发概率,决定是否调用编织器。

编织器(MemGenWeaver):生成针对性潜在记忆,分提示词 / 推理两阶段设计,支持 PEFT 高效微调。

核心流程闭环:输入 → 推理器生成原始嵌入 → 触发器 + 增强点选择模块确定插入位置 → 编织器生成潜在记忆 → 投影层适配维度 → 重组增强序列 → 推理器完成最终推理 → 过滤无效位置输出。

关键技术细节:

跨模块投影:通过 reasoner_to_weaver 和 weaver_to_reasoner 解决推理器与编织器嵌入维度不匹配问题。

动态记忆增强:按分隔符拆分序列,逐段插入记忆,避免长序列冗余,贴合人类 “思考 - 记忆” 交互模式。

精度与效率:全流程采用 bfloat16 精度,推理器 / 编织器启用 Flash Attention 2,平衡性能与速度。

训练与推理适配:

训练时:通过 labels 和 valid_logits 计算损失,仅优化编织器、触发器及投影层参数。

推理时:无需 labels,自动完成 “触发判断 - 记忆生成 - 推理增强” 全流程,实现动态自进化。

具体网络结构如下

MemGen-1

2.1.3 代码

LatentMemoryModel 的代码如下:

@registry.register_model("latmem")

class LatentMemoryModel(BaseModel): # 定义了一个名为 LatentMemoryModel 的类,继承自 BaseModel

def __init__(

self,

reasoner_model_name: str, # 推理模型名称

weaver_model_name: str, # 记忆编织器模型名称

prompt_latents_len: int, # 提示长度

inference_latents_len: int, # 推理长度

weaver_peft_config: Optional[PeftConfig] = None, # 记忆编织器配置,可选

trigger_model_name: str = None, # 触发模型名称,可选

trigger_peft_config: Optional[PeftConfig] = None, # 触发器配置,可选

max_prompt_aug_num: int = 1, # 最大提示增强数量

max_inference_aug_num: int = 5, # 最大推理增强数量

):

super().__init__() # 调用父类构造函数

# 构建推理模型

self.model = AutoModelForCausalLM.from_pretrained( # 从预训练模型加载推理模型

reasoner_model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")

self.tokenizer = AutoTokenizer.from_pretrained(reasoner_model_name) # 加载入分词器

self.config = self.model.config # 获取模型配置

# 构建记忆编织器

self.weaver = MemGenWeaver( # 初始化记忆编织器

weaver_model_name, prompt_latents_len, inference_latents_len, weaver_peft_config

)

# 构建触发器

self.trigger = NanoTrigger() # 默认触发器,始终返回 true

if trigger_model_name is not None:

self.trigger = MemGenTrigger( # 如果指定了触发模型,则加载相应的触发器

trigger_model_name, trigger_peft_config

)

logging.info(f"Use Trigger: {trigger_model_name}") # 记录日志

# 投影层,用于在推理模型和记忆编织器之间映射嵌入

# 将推理模型输入嵌入映射到记忆编织器输入嵌入

self.reasoner_to_weaver = nn.Linear( # 线性层,从推理模型隐藏层到记忆编织器隐藏层

self.model.config.hidden_size, self.weaver.config.hidden_size, dtype=torch.bfloat16

)

# 将记忆编织器隐藏状态映射回推理模型输入嵌入

self.weaver_to_reasoner = nn.Linear( # 线性层,从记忆编织器隐藏层到推理模型隐藏层

self.weaver.config.hidden_size, self.model.config.hidden_size, dtype=torch.bfloat16

)

self.delimiters: List[str] = [",", ".", "\n"] # 用于检测增强点的分隔符

self.max_prompt_aug_num = max_prompt_aug_num # 提示后提示中插入潜在数量

self.max_inference_aug_num = max_inference_aug_num # 指定分隔符后插入潜在数量

# 后处理

self._postprocess_models() # 后处理模型

self.warnings_issued = {} # 存储发出的警告

self.model_tags = None # 存储模型标签

log_trainable_params(self) # 记录可训练参数

def add_model_tags(self, tags: Union[list[str], str]) -> None: # 添加模型标签

r"""

向模型添加自定义标签,这些标签将被推送到 Hugging Face Hub。不会覆盖模型中现有的标签。

参数:

tags (`Union[list[str], str]`):

要添加到模型的标签

例子:

```python

from transformers import AutoModel

model = AutoModel.from_pretrained("google-bert/bert-base-cased")

model.add_model_tags(["custom", "custom-bert"])

# 将模型推送到您的命名空间,名称为 "my-custom-bert"。

model.push_to_hub("my-custom-bert")

"""

if isinstance(tags, str):

tags = [tags]

if self.model_tags is None:

self.model_tags = []

for tag in tags:

if tag not in self.model_tags:

self.model_tags.append(tag)

def _postprocess_models(self):

"""

后处理记忆模型的组件:推理模型、记忆编织器、触发器和分词器。

步骤:

1. 冻结推理模型的所有参数(不更新梯度)。

2. 将所有模型转换为 bfloat16 以提高内存和计算效率。

3. 确保分词器有一个有效的填充符:

- 如果缺少填充符,使用 EOS 符作为填充符。

- 设置 `padding_side` 为 "left" 以兼容生成任务。

4. 标准化分词器的模板为 `CONVERSATION_TEMPLATE`。

"""

# 默认冻结推理模型的所有参数

fix_model_parameters(self.model)

# 将所有子模型转换为 bfloat16

self.model = self.model.bfloat16()

self.weaver = self.weaver.bfloat16()

self.trigger = self.trigger.bfloat16()

# 确保分词器有一个填充符

if self.tokenizer.pad_token is None:

self.tokenizer.pad_token = self.tokenizer.eos_token

self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

self.tokenizer.padding_side = "left"

logging.info(

f"Tokenizer has no pad token. Using EOS token ({self.tokenizer.eos_token}) as pad token."

)

# 标准化分词器的模板

self.tokenizer.chat_template = CONVERSATION_TEMPLATE

2.1.4 插入阶段

LatentMemoryModel 的两个关键函数 forward 和 generate 区别如下:

forward 函数

训练时候计算损失,由训练循环自动调用。

generate 函数

推理时候生成文本,由代码显式调用。

forward

forward 函数的主体如下:

def _forward(

self,

input_ids: torch.Tensor,

attention_mask: torch.Tensor,

labels: torch.Tensor,

**kwargs

) -> torch.Tensor:

# 预处理输入

assert input_ids.shape == attention_mask.shape == labels.shape

tokenizer = self.tokenizer

reasoner = self.model

weaver = self.weaver

delimiters = self.delimiters

max_augment_num = self.max_inference_aug_num # 限制推理增强点的数量以避免过度增强

device = self.device

embeds_dtype = reasoner.get_input_embeddings().weight.dtype

B, _ = input_ids.shape

hidden_size = reasoner.config.hidden_size

# 选择增强索引

augmentation_indices = self._select_augment_points_after_delimiter(

input_ids, labels, delimiters, tokenizer, max_augment_num

)

# 输入嵌入

inputs_embeds = reasoner.get_input_embeddings()(input_ids)

# 初始化开始索引和空张量以累积处理的段

current_start_idx = 0

current_inputs_embeds = torch.empty(B, 0, hidden_size).to(device, dtype=embeds_dtype)

current_attention_mask = torch.empty(B, 0).to(device, dtype=attention_mask.dtype)

current_latents_mask = torch.empty(B, 0).to(device, dtype=torch.bool)

# 遍历所选增强点

for aug_idx in augmentation_indices:

# 切片原始嵌入和注意力掩码

segment_inputs_embeds = inputs_embeds[:, current_start:aug_idx]

segment_attention_mask = attention_mask[:, current_start:aug_idx]

segment_latents_mask = torch.zeros(B, segment_inputs_embeds.size(1).to(device, dtype=torch.bool)

# 连接当前段到累积嵌入和掩码

current_inputs_embeds = torch.cat([current_inputs_embeds, segment_inputs_embeds], dim=1)

current_mask = torch.cat([current_mask, segment_attention_mask], dim=1)

current_position_ids = generate_position_ids(current_mask)

current_latents = torch.cat([current_latents, segment_latents], dim=1)

# 将推理模型嵌入映射到记忆编织器嵌入

weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds)

# 确定此点是否为提示(增强)的结束

is_prompt_end_aug = (labels[:, aug_idx] != -100).all() and (labels[:, aug_idx-1] == -100).all().item()

# 根据类型,使用记忆编织器增强提示或推理

if is_prompt_end_aug:

weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt(

weaver_inputs, current_attention_mask, current_position_ids

)

else:

weaver_hidden_states, attn_mask, pos_ids = weaver.augment_inference(

weaver_inputs, current_attention_mask, current_position_ids

)

# 将记忆编织器隐藏状态映射回推理模型嵌入

latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)

# 更新累积嵌入和掩码与新增强段

current_inputs_embeds = torch.cat

generate

核心作用

该 generate 方法是 MemGen 模型的推理核心,实现了动态记忆增强与序列生成的无缝融合。通过迭代生成新 token,每步自适应判断是否插入编织器生成的潜在记忆,让推理器在生成过程中实时利用动态记忆调整解码路径,最终输出增强后的序列(可选返回记忆增强位置掩码)。

核心特色

双阶段记忆增强:先执行提示词阶段记忆增强(初始化全局记忆),再在迭代生成中动态触发推理阶段增强(补充实时记忆),适配不同生成阶段的记忆需求。

自适应触发机制:通过 _should_augment 结合触发器决策,仅对需要记忆支持的序列执行增强,避免无意义的计算开销。

维度对齐优化:非增强序列采用左填充(_left_pad)方式对齐增强序列维度,确保批次内所有序列格式统一,不影响批量生成效率。

高效推理设计:

禁用梯度计算(@torch.no_grad()),节省内存并加速推理;

启用推理器缓存(use_cache=True),减少重复计算;

仅在必要时输出隐藏状态,降低计算成本。

灵活配置与可解释性:支持控制最大生成 token 数、采样策略等参数;可选返回 augmentation_pos 掩码,标记记忆插入位置,提升模型可解释性。

鲁棒性保障:提前终止机制(所有序列生成 EOS 或达最大增强次数时终止),避免无效迭代;重构生成配置固定关键参数,确保生成稳定性。

推理生成流程图

潜在记忆插入的完整流程:

初始化阶段:对输入提示进行增强,插入初始潜在记忆。

生成循环:逐个生成token。

条件检查:在每个步骤检查是否满足插入条件。

决策判断:使用trigger模型决定是否插入潜在记忆。

潜在记忆生成:通过weaver模型生成潜在记忆表示。

嵌入连接:将潜在记忆嵌入连接到当前输入序列。

继续生成:使用增强后的序列继续生成下一个token。

具体流程如下图所示:

MemGen-2

代码如下:

@torch.no_grad() # 禁用梯度计算,适用于推理阶段,提升效率并节省内存

def generate(

self,

input_ids: torch.Tensor, # 输入token ID序列,形状[batch_size, prompt_len]

attention_mask: torch.Tensor, # 注意力掩码,形状与input_ids一致

generation_config: GenerationConfig = None, # 生成配置(如最大新token数、采样策略等)

return_augmentation_mask: bool = False, # 是否返回记忆增强位置掩码

**kwargs

) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

"""

执行MemGen模型的推理生成流程:动态融合潜在记忆与推理器,生成增强后的输出序列。

核心逻辑:

1. 初始化提示词阶段的记忆增强

2. 迭代生成新token,每步判断是否触发推理阶段记忆增强

3. 对需增强的序列插入编织器生成的潜在记忆,非增强序列左填充对齐维度

4. 生成完成后返回结果(可选返回增强位置掩码)

"""

tokenizer = self.tokenizer

reasoner = self.model

weaver = self.weaver

trigger = self.trigger

delimiters = self.delimiters

max_augment_num = self.max_inference_aug_num # 单序列最大推理阶段增强次数

invalid_token_id = -100 # 无效位置标记(用于增强位置掩码)

# 预处理输入:转移到模型所在设备

input_ids = input_ids.to(self.device)

attention_mask = attention_mask.to(self.device)

# 提取生成配置关键参数

max_new_tokens = generation_config.max_new_tokens # 最大生成新token数

do_sample = generation_config.do_sample # 是否启用采样生成

temperature = generation_config.temperature # 采样温度(控制随机性)

pad_token_id = tokenizer.pad_token_id # pad token ID

eos_token_id = tokenizer.eos_token_id # 结束token ID

prompt_len = input_ids.size(1) # 提示词长度

# 重构生成配置(固定必要参数,确保生成稳定性)

generation_config = GenerationConfig(

do_sample=do_sample,

temperature=temperature,

pad_token_id=pad_token_id,

eos_token_id=eos_token_id,

use_cache=True # 启用缓存加速生成

)

# 将输入token ID转换为嵌入向量

inputs_embeds = reasoner.get_input_embeddings()(input_ids)

B, _, hidden_size = inputs_embeds.shape # B=batch_size,hidden_size=推理器隐藏层维度

device = inputs_embeds.device # 模型所在设备(CPU/GPU)

# 初始化生成过程中的关键张量

current_inputs_embeds = inputs_embeds # 当前输入嵌入(含原始提示词+潜在记忆)

current_attention_mask = attention_mask # 当前注意力掩码

current_position_ids = generate_position_ids(current_attention_mask) # 当前位置ID

current_input_ids = input_ids # 当前已生成的token ID序列

# 提示词阶段记忆增强:生成并插入提示词专用潜在记忆

weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds) # 映射到编织器嵌入空间

weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt(

weaver_inputs_embeds, current_attention_mask, current_position_ids

)

latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) # 映射回推理器嵌入空间

# 拼接提示词与增强记忆

current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1)

current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1)

current_position_ids = torch.cat([current_position_ids, pos_ids], dim=1)

# 生成循环初始化

sentence_augment_count = torch.zeros(B, dtype=torch.int, device=device) # 各序列已增强次数

augmentation_pos = torch.full((B, max_new_tokens), fill_value=invalid_token_id, device=device) # 增强位置掩码

inserted_embeds: List[List[torch.Tensor]] = [[] for _ in range(B)] # 记录插入的潜在记忆(用于后处理)

for i in range(max_new_tokens):

# 若所有序列均已生成EOS token,提前终止

if (current_input_ids[:, -1] == eos_token_id).all():

break

# 若所有序列均已达到最大增强次数,一次性生成剩余token

if (sentence_augment_count >= max_augment_num).all():

# 调整剩余生成长度

generation_config.max_new_tokens = max_new_tokens - i

# 推理器生成剩余token

generated = reasoner.generate(

inputs_embeds=current_inputs_embeds,

attention_mask=current_attention_mask,

generation_config=generation_config,

)

current_input_ids = torch.cat([current_input_ids, generated], dim=1)

break

# 推理器前向传播,获取当前步输出

outputs = reasoner(

inputs_embeds=current_inputs_embeds,

attention_mask=current_attention_mask,

position_ids=current_position_ids,

output_hidden_states=False, # 推理阶段无需输出隐藏状态,提升效率

)

# 生成并追加一个新token,更新关键张量

current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids = self._append_one_step(

outputs, current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids, do_sample, temperature

)

# 若为最后一步生成,终止循环

if i == max_new_tokens - 1:

break

# 判断当前批次中哪些序列需要进行推理阶段记忆增强

augment_decision = self._should_augment(

current_input_ids, current_attention_mask, sentence_augment_count=sentence_augment_count,

do_sample=do_sample, temperature=temperature

)

augmentation_pos[:, i + 1] = augment_decision # 记录增强位置(1=增强,0=不增强,-100=无效)

augment_indices = torch.where(augment_decision == 1)[0] # 需增强的序列索引

# 对需增强的序列执行记忆增强,非增强序列左填充对齐维度

if len(augment_indices) > 0:

# 递增需增强序列的增强次数计数

sentence_augment_count[augment_indices] += 1

# 提取需增强序列的嵌入、掩码和位置ID

candidate_inputs_embeds = current_inputs_embeds[augment_indices]

candidate_attention_mask = current_attention_mask[augment_indices]

candidate_position_ids = current_position_ids[augment_indices]

# 编织器生成推理阶段潜在记忆

weaver_inputs_embeds = self.reasoner_to_weaver(candidate_inputs_embeds)

weaver_hidden_states, attn_mask, _ = weaver.augment_inference(

weaver_inputs_embeds, candidate_attention_mask, candidate_position_ids

)

latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states) # 映射回推理器空间

# 拼接原始嵌入与潜在记忆

candidate_inputs_embeds = torch.cat([candidate_inputs_embeds, latent_inputs_embeds], dim=1)

candidate_attention_mask = torch.cat([candidate_attention_mask, attn_mask], dim=1)

# 构建合并张量(适配所有序列,包括增强和非增强)

new_len = candidate_inputs_embeds.size(1) # 增强后序列长度

merged_inputs_embeds = torch.zeros((B, new_len, hidden_size), device=device, dtype=current_inputs_embeds.dtype)

merged_attention_mask = torch.zeros((B, new_len), device=device, dtype=current_attention_mask.dtype)

# 填充增强序列

merged_inputs_embeds[augment_indices] = candidate_inputs_embeds

merged_attention_mask[augment_indices] = candidate_attention_mask

# 填充非增强序列(左填充对齐长度)

non_augment_indices = torch.where(augment_decision != 1)[0]

if len(non_augment_indices) > 0:

non_aug_inputs_embeds = current_inputs_embeds[non_augment_indices]

non_aug_attention_mask = current_attention_mask[non_augment_indices]

non_aug_inputs_embeds, non_aug_attention_mask, _ = self._left_pad(

non_aug_inputs_embeds, non_aug_attention_mask, None, weaver.inference_latents_num

)

merged_inputs_embeds[non_augment_indices] = non_aug_inputs_embeds

merged_attention_mask[non_augment_indices] = non_aug_attention_mask

# 更新当前关键张量

current_inputs_embeds = merged_inputs_embeds

current_attention_mask = merged_attention_mask

current_position_ids = generate_position_ids(current_attention_mask) # 重新生成位置ID

# 记录插入的潜在记忆(用于后处理或可解释性分析)

for idx, embed in zip(augment_indices, latent_inputs_embeds):

inserted_embeds[idx].append(embed.clone().detach().cpu())

# 后处理:调整增强位置掩码长度与生成结果一致

new_generated_len = current_input_ids.size(1) - prompt_len

augmentation_pos = augmentation_pos[:, :new_generated_len]

# 根据配置返回结果:仅生成序列 或 序列+增强位置掩码

if not return_augmentation_mask:

return current_input_ids

else:

return current_input_ids, augmentation_pos

2.2 Trigger

2.2.1. 核心作用

该模块定义了 MemGen 框架中记忆触发器的核心接口与两种具体实现,核心作用是动态决策记忆增强的时机—— 即在推理过程中判断何时插入编织器生成的潜在记忆,实现记忆与推理的动态耦合,突破传统静态记忆注入的局限。

2.2.2. 核心特色

抽象接口统一规范:Trigger抽象基类定义了触发器的核心接口,确保后续扩展新触发器时遵循统一标准,提升代码可扩展性。

双实现适配不同场景:

NanoTrigger:极简实现,始终触发记忆增强,无需训练,适用于快速测试、基线对比或无需动态控制的简单场景。

MemGenTrigger:基于预训练 LLM 的智能触发器,通过二分类头适配决策任务,支持 PEFT 参数高效微调,能根据输入序列动态判断是否触发,适配复杂真实场景。

高效适配与灵活扩展:

采用 bfloat16 精度和 Flash Attention 2 优化计算效率;

支持 PEFT 微调,在不冻结基础模型的前提下实现参数高效学习;

替换 LLM 原始输出头为二分类头,精准适配 "是否插入记忆" 的决策需求。

模块解耦设计:触发器决策独立于编织器模块,仅基于输入序列和数据分布做出判断,保证了模块间的低耦合和高内聚。

2.2.3 网络架构

网络架构图如下。

说明如下:

模型支持PEFT参数高效微调(如LoRA),适配于Transformer Blocks层

整体精度采用bfloat16,平衡计算效率与数值稳定性

注意力计算通过Flash Attention 2优化,提升长序列处理速度

MemGen-3

2.2.4 代码

class Trigger(torch.nn.Module, ABC):

"""

记忆触发器的抽象基类(Trigger)。

定义了触发器的核心接口,用于决定在推理过程中何时触发记忆增强(插入潜在记忆)。

所有具体触发器实现都需继承此类并实现forward方法。

"""

def __init__(self):

super().__init__() # 调用父类Module的初始化方法

@abstractmethod

def forward(self, **kwargs) -> bool:

"""

抽象前向传播方法:接收输入数据,返回是否触发记忆增强的决策。

子类必须实现此方法,定义具体的触发逻辑。

Args:

**kwargs: 可变关键字参数,包含输入序列、注意力掩码等模型所需数据

Returns:

bool: 触发决策(True表示触发记忆增强,False表示不触发)

"""

...

class NanoTrigger(torch.nn.Module):

"""

极简触发器(NanoTrigger):始终触发记忆增强的基础实现。

无需复杂逻辑,固定返回触发决策,适用于基础测试或无需动态控制的场景。

"""

def __init__(self):

super().__init__()

# 注册一个缓冲区张量,用于获取模型所在设备(无实际计算意义)

self.register_buffer("_device", torch.tensor(0.0))

@property

def device(self):

"""获取模型所在设备(CPU/GPU)"""

return self._device.device

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> bool:

# 该"极简触发器"始终预测需要插入记忆

# 输出logits张量,其中插入决策(索引=1)的概率被设为1.0

# 适用于批次中的每个token位置

batch_size, seq_len = input_ids.shape

# 初始化logits张量:形状为[batch_size, seq_len, 2],2表示"不插入"(0)和"插入"(1)两类

logits = torch.zeros(batch_size, seq_len, 2, device=input_ids.device)

logits[..., 1] = 1.0 # 将所有位置的"插入"决策概率设为1.0

return logits

class MemGenTrigger(torch.nn.Module):

"""

MemGen框架的专用触发器模块(MemGenTrigger)。

- 输入:接收推理器模型当前解码序列的`inputs_embeds`(或input_ids)

- 输出:生成形状为[batch_size, seq_len, 2]的logits张量,

表示每个位置"不插入"(0)和"插入"(1)记忆的概率,用于动态决策记忆增强时机。

"""

def __init__(

self,

pretrained_model_name_or_path: str, # 预训练模型名称或路径(用于初始化触发器LLM)

peft_config: Optional[PeftConfig] = None # PEFT配置(可选,用于参数高效微调)

):

super().__init__()

# 构建基础LLM模型(作为触发器的核心推理组件)

self.model = AutoModelForCausalLM.from_pretrained(

pretrained_model_name_or_path,

torch_dtype=torch.bfloat16, # 使用bfloat16精度提升效率

attn_implementation="flash_attention_2" # 启用Flash Attention 2优化注意力计算

)

self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # 对应的Tokenizer

# 对基础模型进行后处理(设置可训练、替换输出头)

self.model = self._postprocess(self.model)

# 若提供PEFT配置,应用参数高效微调

if peft_config is not None:

self.model = get_peft_model(self.model, peft_config)

self.config = self.model.config # 保存模型配置

@property

def device(self):

"""获取模型所在设备(CPU/GPU)"""

return self.model.device

def _postprocess(self, model: PreTrainedModel):

"""

对基础模型进行后处理,适配触发器的二分类任务需求。

Args:

model: 原始预训练LLM模型

Returns:

处理后的模型(可训练、替换为二分类输出头)

"""

# 设置所有模型参数为可训练

for parameter in model.parameters():

parameter.requires_grad = True

# 将原始语言模型的输出头(lm_head)替换为二分类头

hidden_size = model.config.hidden_size # 模型隐藏层维度

classification_head = nn.Linear(hidden_size, 2) # 输出维度为2(不插入/插入)

model.lm_head = classification_head

# 确保新的二分类头参数可训练

for param in model.lm_head.parameters():

param.requires_grad = True

return model

def forward(

self,

input_ids: Optional[torch.LongTensor] = None, # 生成序列的token ID,形状[batch_size, seq_len]

attention_mask: Optional[torch.Tensor] = None, # 注意力掩码,避免关注填充token

**kwargs: Unpack[TransformersKwargs], # 传递给底层模型的额外参数

) -> torch.Tensor:

"""

序列生成的触发决策机制。

触发器基于已生成的`input_ids`做出决策,受数据分布影响,但独立于编织器模块。

Args:

input_ids (Optional[torch.LongTensor]): 生成序列的token ID张量

attention_mask (Optional[torch.Tensor]): 注意力掩码,默认None

**kwargs: 传递给底层模型的额外关键字参数

Returns:

torch.Tensor: Logits张量,形状为`(batch_size, seq_len, num_classes)`

num_classes=2,分别对应"不插入"(索引0)和"插入"(索引1)的概率

"""

# 调用基础模型前向传播,返回二分类logits

return self.model(

input_ids=input_ids,

attention_mask=attention_mask,

**kwargs

).logits

2.3 MemGenWeaver

2.3.1 核心作用

MemGenWeaver 是 MemGen 框架的核心组件之一,负责生成动态潜在记忆并将其与推理器的输入序列融合,从而实现记忆与推理过程的无缝交织。它通过可学习的潜在记忆查询向量,在提示词阶段和推理阶段分别生成针对性的记忆表示,引导推理器调整解码路径,提升智能体的动态决策能力。

2.3.2 核心特色

双阶段记忆生成:区分提示词阶段(augment_prompt)和推理阶段(augment_inference),使用各自独立的可学习潜在记忆查询向量,适配不同阶段的记忆需求,增强记忆生成的针对性。

灵活的潜在记忆融合:通过_augment方法统一实现潜在记忆与输入序列的融合,包括嵌入拼接、注意力掩码扩展和位置 ID 计算,确保记忆与原始输入在语义空间和时序上的一致性。

高效的模型设计:

基于预训练 LLM 构建,支持 PEFT 参数高效微调,在保留基础能力的同时降低训练成本;

采用 bfloat16 精度和 Flash Attention 2 优化,提升计算效率和内存利用率。

动态记忆编织机制:生成的潜在记忆并非静态检索结果,而是基于当前输入序列动态生成的隐藏状态,能够捕捉实时上下文信息,实现 “生成式记忆” 的核心特性。

模块化与可扩展性:与推理器、触发器解耦,通过标准化接口交互;潜在记忆的数量可通过参数灵活配置,适配不同任务对记忆容量的需求。

2.3.3 网络架构

网络架构图如下。

说明如下:

核心组件:

可学习潜在记忆向量:分阶段设计(P=提示词阶段数量,I=推理阶段数量),支持动态生成记忆

预训练LLM:作为记忆生成核心,默认启用bfloat16精度和Flash Attention 2优化

序列融合层:确保输入与记忆在语义、掩码、时序上的一致性

核心流程:

输入 → 选择对应阶段的潜在记忆 → 融合序列 → LLM生成隐藏状态 → 提取潜在记忆输出

支持PEFT参数高效微调(如LoRA),适配于Transformer Blocks层

输出用途:

生成的潜在记忆将通过投影层映射到推理器的嵌入空间,与原始输入融合以引导解码

MemGen-4

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/1 0:16:12

ASP.NET Core Blazor进阶1:高级组件开发

1. 渲染片段&#xff08;RenderFragment&#xff09;1.1 基本概念RenderFragment是Blazor中用于动态渲染UI内容的核心概念&#xff0c;它允许组件接收并渲染来自父组件的标记内容。1.2 基础用法<!-- ChildComponent.razor --><div class"card"><div c…

作者头像 李华
网站建设 2026/3/30 12:09:50

终极3行代码搞定智能搜索:WPF UI让你的输入效率翻倍提升

终极3行代码搞定智能搜索&#xff1a;WPF UI让你的输入效率翻倍提升 【免费下载链接】wpfui WPF UI在您熟悉和喜爱的WPF框架中提供了流畅的体验。直观的设计、主题、导航和新的沉浸式控件。所有这些都是本地化且毫不费力的。 项目地址: https://gitcode.com/GitHub_Trending/…

作者头像 李华
网站建设 2026/3/31 8:20:39

【面板数据】城市二手房房价数据(2010-2025年)

数据简介&#xff1a;城市二手房均价是指在特定行政区域和统计周期&#xff08;如月度或年度&#xff09;内&#xff0c;所有完成产权交易的二手住宅总成交金额与成交套数的算术平均值。数据可以表征供需关系动态为开发商定价策略&#xff08;如竞品对标&#xff09;和ZF调控政…

作者头像 李华
网站建设 2026/3/22 6:22:12

Oracle数据库空间深度回收:从诊断到优化实战指南

随着企业业务数据的持续快速增长&#xff0c;Oracle 数据库占用的磁盘空间常常呈膨胀趋势&#xff0c;这不仅导致备份文件庞大、恢复时间延长&#xff0c;还直接推高了存储成本。本文将系统化解析 Oracle 空间回收的完整链路&#xff0c;从空间诊断、高水位线处理到高效压缩与自…

作者头像 李华