MGeo结合Airflow调度,批量任务自动化
在地址数据治理实践中,单次推理只是起点,真正考验工程能力的是高频、多源、大规模的地址对齐任务。物流订单清洗、政务地址归一化、POI库跨平台合并——这些场景往往涉及数万至百万级地址对的批量比对,手动执行或简单脚本已无法满足时效性与稳定性要求。MGeo作为专为中文地址优化的语义匹配模型,虽具备高精度向量编码能力,但其原生推理脚本(/root/推理.py)仅面向单次验证设计,缺乏任务管理、失败重试、状态追踪与资源调度能力。
本文聚焦“批量”这一核心诉求,手把手带你将MGeo服务接入Apache Airflow,构建一套可监控、可重试、可编排、可扩展的地址相似度批量处理流水线。不依赖K8s集群,不改造模型底层,仅通过轻量级调度层封装,即可实现:
千级地址对分钟级完成比对
任务失败自动告警+重试
多任务并行且GPU资源可控
批量结果自动落库/导出
全流程可视化追踪
全程基于你已部署的MGeo镜像(4090D单卡),零新增硬件投入。
1. 为什么是Airflow?而非Cron或自写脚本
1.1 批量任务的真实痛点
当你尝试用for循环调用MGeo脚本处理1000个地址对时,很快会遇到:
- GPU资源争抢:多个Python进程同时加载模型,显存溢出报错
CUDA out of memory - 单点故障无感知:第527个地址对处理失败,整个脚本中断,需人工定位重跑
- 进度不可见:运行中无法知道“已完成多少”、“卡在哪一步”、“平均耗时多少”
- 结果难管理:输出散落在终端日志,无法结构化存储或对接下游系统
而Airflow天然解决这些问题:
| 痛点 | Airflow方案 |
|---|---|
| GPU资源过载 | 通过pool机制限制并发任务数,确保单卡稳定运行 |
| 任务失败中断 | 内置重试策略(retries=3)、失败回调(邮件/钉钉通知) |
| 进度黑盒 | Web UI实时显示DAG状态、任务耗时、日志详情、输入参数 |
| 结果分散 | 自定义Operator将结果写入MySQL/CSV/对象存储,统一出口 |
注意:Airflow本身不运行模型,它只调度你的MGeo容器。我们采用外部容器模式——Airflow Worker启动一个临时Docker容器执行推理,任务结束即销毁,完全复用现有镜像环境。
1.2 架构设计:轻量集成,零侵入
Airflow Scheduler → Airflow Worker → 启动Docker容器 → 运行MGeo镜像 → 输出结果 → 清理容器关键优势:
- 不修改MGeo镜像:仍使用原
/root/推理.py逻辑,仅需增加参数解析能力 - 不占用宿主机环境:所有依赖隔离在容器内,避免Python版本冲突
- 弹性伸缩:Worker节点可横向扩展,应对突发大任务
- 复用现有部署:无需重新配置GPU驱动或CUDA环境
2. 环境准备:Airflow服务端部署
2.1 快速启动Airflow(单机开发模式)
在一台非GPU服务器(推荐4核8G)上部署Airflow,避免与MGeo容器争抢GPU资源:
# 创建独立环境 mkdir mgeo-airflow && cd mgeo-airflow python3 -m venv airflow-env source airflow-env/bin/activate # 安装Airflow(2.8+支持DockerOperator增强) pip install "apache-airflow[celery,redis,docker]==2.8.1" \ "apache-airflow-providers-docker==4.10.0" # 初始化数据库与用户 airflow db upgrade airflow users create \ --username admin \ --password admin123 \ --firstname Admin \ --lastname User \ --role Admin \ --email admin@example.com # 启动Webserver与Scheduler airflow webserver & airflow scheduler &访问http://<your-server-ip>:8080,用admin/admin123登录。
2.2 配置Docker连接与GPU支持
Airflow需能调用宿主机Docker Daemon,并透传GPU设备:
# 在Airflow服务器上执行(确保Docker已安装) sudo usermod -aG docker $USER sudo systemctl restart docker # 验证Airflow Worker可访问Docker docker run hello-world # 成功则继续编辑Airflow配置文件airflow.cfg(通常位于~/airflow/airflow.cfg):
# 启用DockerOperator [operators] default_queue = default # 配置Docker连接 [docker] host = unix:///var/run/docker.sock tls_ca_cert = tls_client_cert = tls_client_key = tls_verify = False验证:重启Airflow后,在Web UI → Admin → Connections 中确认
docker_default连接状态为healthy。
3. MGeo推理脚本增强:支持参数化调用
原始/root/推理.py硬编码测试地址,需改造为命令行可接收参数,以适配Airflow动态传入任务数据。
3.1 修改推理脚本(/root/推理.py)
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ MGeo地址相似度批量推理脚本(Airflow适配版) 支持从JSON文件读取地址对,输出CSV结果 """ import sys import json import csv import torch from sentence_transformers import SentenceTransformer from pathlib import Path def load_address_pairs(json_path): """从JSON文件加载地址对列表""" with open(json_path, 'r', encoding='utf-8') as f: return json.load(f) def save_results_to_csv(results, csv_path): """保存结果到CSV""" with open(csv_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) writer.writerow(['addr_a', 'addr_b', 'similarity', 'timestamp']) for r in results: writer.writerow([r['a'], r['b'], r['score'], r['time']]) def main(): if len(sys.argv) != 3: print("用法: python 推理.py <input_json_path> <output_csv_path>") print("示例: python 推理.py /tmp/input.json /tmp/output.csv") sys.exit(1) input_path = Path(sys.argv[1]) output_path = Path(sys.argv[2]) # 加载地址对 pairs = load_address_pairs(input_path) print(f" 加载 {len(pairs)} 个地址对") # 加载模型(单次初始化) device = "cuda" if torch.cuda.is_available() else "cpu" print(f" 加载MGeo模型到 {device}...") model = SentenceTransformer("alienvs/mgeo-base-chinese-address").to(device) # 批量计算相似度 results = [] import time from datetime import datetime for i, pair in enumerate(pairs): start_time = time.time() emb_a = model.encode([pair['a']], convert_to_tensor=True) emb_b = model.encode([pair['b']], convert_to_tensor=True) sim = torch.cosine_similarity(emb_a, emb_b).item() elapsed = time.time() - start_time result = { 'a': pair['a'], 'b': pair['b'], 'score': round(sim, 4), 'time': datetime.now().isoformat() } results.append(result) if (i + 1) % 10 == 0: print(f" 已处理 {i+1}/{len(pairs)},当前相似度: {result['score']}") # 保存结果 save_results_to_csv(results, output_path) print(f" 结果已保存至 {output_path}") if __name__ == "__main__": main()3.2 测试脚本功能
在MGeo容器内验证新脚本:
# 创建测试输入文件 cat > /tmp/test_input.json << 'EOF' [ {"a": "北京市朝阳区建国路88号", "b": "北京朝阳建外88号"}, {"a": "上海市徐汇区漕溪北路1200号", "b": "上海徐家汇华亭宾馆"} ] EOF # 执行推理 python /root/推理.py /tmp/test_input.json /tmp/test_output.csv # 查看结果 cat /tmp/test_output.csv预期输出:
addr_a,addr_b,similarity,timestamp 北京市朝阳区建国路88号,北京朝阳建外88号,0.9234,2024-06-15T10:22:33.123456 上海市徐汇区漕溪北路1200号,上海徐家汇华亭宾馆,0.7821,2024-06-15T10:22:35.6543214. 构建Airflow DAG:批量地址对齐流水线
4.1 DAG核心逻辑设计
我们将创建一个名为mgeo_batch_matching的DAG,包含以下关键环节:
- 触发任务:支持手动触发或定时调度(如每天凌晨2点清洗昨日订单)
- 数据准备:从MySQL/CSV/HTTP下载待匹配地址对,生成JSON输入文件
- GPU推理:调用MGeo容器执行批量比对(关键!)
- 结果处理:解析CSV,筛选高相似度对(>0.85),写入结果表
- 清理与通知:删除临时文件,发送企业微信通知
4.2 编写DAG文件(dags/mgeo_batch_dag.py)
from datetime import datetime, timedelta from airflow import DAG from airflow.operators.python import PythonOperator from airflow.providers.docker.operators.docker import DockerOperator from airflow.providers.postgres.operators.postgres import PostgresOperator from airflow.models import Variable import json import os # DAG默认参数 default_args = { 'owner': 'data-engineer', 'depends_on_past': False, 'start_date': datetime(2024, 6, 15), 'email_on_failure': True, 'email': ['alert@company.com'], 'retries': 2, 'retry_delay': timedelta(minutes=5), } dag = DAG( 'mgeo_batch_matching', default_args=default_args, description='MGeo地址相似度批量匹配DAG', schedule_interval='0 2 * * *', # 每天凌晨2点执行 catchup=False, tags=['geo', 'mgeo', 'batch'], ) # 任务1:准备输入数据(模拟从MySQL读取) def prepare_input_data(**context): # 实际项目中替换为SQL查询或API调用 sample_pairs = [ {"a": "广州市天河区体育东路123号", "b": "广州天河正佳广场东门"}, {"a": "杭州市西湖区文三路159号", "b": "杭州文三路电子信息大厦"}, {"a": "成都市武侯区天府大道北段1700号", "b": "成都高新区环球中心"} ] # 写入临时JSON文件(Airflow Worker可访问路径) input_path = f"/tmp/mgeo_input_{context['ts_nodash']}.json" with open(input_path, 'w', encoding='utf-8') as f: json.dump(sample_pairs, f, ensure_ascii=False, indent=2) # 将路径传递给下游任务 context['ti'].xcom_push(key='input_path', value=input_path) print(f" 输入文件已生成: {input_path}") prepare_task = PythonOperator( task_id='prepare_input_data', python_callable=prepare_input_data, dag=dag, ) # 任务2:调用MGeo容器进行GPU推理 mgeo_inference = DockerOperator( task_id='run_mgeo_inference', image='registry.cn-hangzhou.aliyuncs.com/mgeo-team/mgeo-inference:latest', api_version='auto', auto_remove=True, command=[ 'python', '/root/推理.py', '{{ ti.xcom_pull(key="input_path") }}', '/tmp/mgeo_output_{{ ts_nodash }}.csv' ], docker_url='unix://var/run/docker.sock', network_mode='bridge', mounts=[ # 挂载宿主机/tmp目录,使输入输出文件双向可见 '/tmp:/tmp' ], # 关键:限制GPU资源,防止OOM extra_hosts={'host.docker.internal': 'host-gateway'}, environment={ 'NVIDIA_VISIBLE_DEVICES': '0' # 显式指定使用GPU 0 }, # 设置Docker Pool控制并发(需提前在Airflow UI创建pool) pool='gpu_pool', pool_slots=1, dag=dag, ) # 任务3:结果入库(模拟) def save_results_to_db(**context): output_path = f"/tmp/mgeo_output_{context['ts_nodash']}.csv" # 实际项目中:用pandas读取CSV,插入MySQL/PostgreSQL print(f" 正在处理结果文件: {output_path}") # 示例:打印前3行 with open(output_path, 'r', encoding='utf-8') as f: lines = f.readlines()[:4] for line in lines: print(line.strip()) print(" 结果已入库(模拟)") save_task = PythonOperator( task_id='save_results_to_db', python_callable=save_results_to_db, dag=dag, ) # 任务4:清理临时文件 def cleanup_temp_files(**context): input_path = context['ti'].xcom_pull(key='input_path') output_path = f"/tmp/mgeo_output_{context['ts_nodash']}.csv" for path in [input_path, output_path]: if os.path.exists(path): os.remove(path) print(f"🗑 已删除临时文件: {path}") cleanup_task = PythonOperator( task_id='cleanup_temp_files', python_callable=cleanup_temp_files, dag=dag, ) # 设置任务依赖关系 prepare_task >> mgeo_inference >> save_task >> cleanup_task4.3 创建GPU资源池(关键!)
在Airflow Web UI中创建资源池,确保同一时间最多1个MGeo任务使用GPU:
- 进入Admin → Pools
- 点击+添加新Pool
- Pool Name:
gpu_pool - Slots:
1(单卡只能运行1个MGeo容器) - Description:
MGeo GPU推理专用池
- Pool Name:
此设置保证:即使同时触发10个任务,也只会串行执行,避免显存爆炸。
5. 生产就绪:监控、告警与性能调优
5.1 监控关键指标
在Airflow Web UI中重点关注:
- DAG Duration:单次完整流程耗时(目标:<15分钟/万地址对)
- Task Duration:
run_mgeo_inference任务耗时(反映GPU利用率) - Failed Tasks:失败任务数(应为0,否则检查GPU日志)
- Pool Usage:
gpu_pool使用率(长期100%说明需扩容)
5.2 告警配置(企业微信示例)
在DAG中添加失败回调函数:
import requests def send_wechat_alert(context): url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_WEBHOOK_KEY" message = { "msgtype": "text", "text": { "content": f"🚨 MGeo批量任务失败\nDAG: {context['dag'].dag_id}\n任务: {context['task_instance'].task_id}\n时间: {context['execution_date']}\n日志: {context['task_instance'].log_url}" } } requests.post(url, json=message) # 在default_args中添加 default_args.update({ 'on_failure_callback': send_wechat_alert, })5.3 性能调优实战技巧
| 场景 | 问题 | 解决方案 |
|---|---|---|
| 慢 | 单地址对耗时>5秒 | 改用model.encode(addresses, batch_size=32)批量编码,而非逐个调用 |
| OOM | 容器启动报CUDA内存不足 | 在DockerOperator中增加mem_limit='20g',或减小batch_size |
| 不稳定 | 偶发网络超时 | 在DockerOperator中添加network_mode='host',避免bridge网络延迟 |
| 结果不准 | 相似度普遍偏低 | 检查输入地址是否含乱码,或升级MGeo模型至mgeo-large-chinese-address |
6. 扩展场景:多模型协同与增量更新
6.1 多模型投票提升鲁棒性
当业务对精度要求极高时,可并行调用多个地址模型,取相似度均值:
# 在DockerOperator中并行启动3个容器 mgeo_task = DockerOperator(task_id='mgeo', image='mgeo:latest', ...) bert_task = DockerOperator(task_id='bert', image='bert-address:latest', ...) jaccard_task = DockerOperator(task_id='jaccard', image='jaccard:latest', ...) # 使用TriggerDagRunOperator汇总结果6.2 增量地址对齐(CDC模式)
对于持续流入的新地址,可设计增量DAG:
- 触发条件:监听MySQL binlog或Kafka Topic
- 数据源:只拉取
last_update_time > 上次执行时间的地址对 - 去重:在结果表中添加
UNIQUE(addr_a, addr_b)约束,避免重复计算
总结
本文完整实现了MGeo地址相似度模型与Airflow的深度集成,将原本面向单次验证的推理脚本,升级为企业级批量处理引擎。核心成果包括:
- 零改造复用镜像:仅增强
推理.py参数解析能力,100%兼容原有部署; - GPU资源精控:通过Airflow Pool机制,确保单卡稳定支撑高并发任务;
- 全流程可观测:从任务触发、GPU执行、结果入库到清理,每步状态清晰可查;
- 生产就绪保障:内置重试、告警、监控、日志追溯,满足SLA要求。
下一步行动建议:
- 将DAG中
prepare_input_data替换为真实数据源(如MySQL查询、OSS文件下载);- 在
save_results_to_db中接入业务数据库,生成“疑似重复POI”报表;- 为高频地址对添加Redis缓存层,降低GPU调用频次;
- 结合Prometheus采集GPU显存/温度指标,实现容量预警。
批量不是终点,而是智能地址治理的起点。当MGeo遇上Airflow,地址数据便拥有了自我校验、自我演进的能力。
--- > **获取更多AI镜像** > > 想探索更多AI镜像和应用场景?访问 [CSDN星图镜像广场](https://ai.csdn.net/?utm_source=mirror_blog_end),提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。