用SOLO加持大模型建模仿真

# 【Code With SOLO】用 SOLO 优化大模型推理性能仿真框架的 Batch Size 搜索策略

## 1. 摘要

在华为 Ascend msmodeling 推理性能仿真框架中,batch size 搜索采用"指数扩张 + 二分搜索"策略,默认右边界 512 导致大量无效探测。我使用 TRAE SOLO 完成了"两点估算截断右边界"优化方案的全流程——从框架逻辑梳理、优化点识别、方案设计、代码实现到调试验证,最终通过 BS=1 和 BS=512 两个数据点的内存/SLO 约束估算,将二分搜索的右边界从 512 缩小到估算值,减少了无效搜索次数,同时通过 `–slo-safety-factor` 入参保证结果可控。

## 2. 背景

我是一名大模型推理性能优化工程师,日常使用 msmodeling 框架预测和优化模型在 Ascend 芯片上的推理性能。框架的核心功能之一是吞吐量优化器(throughput_optimizer),它遍历不同并行策略(TP×EP×DP),对每种策略搜索最优 batch size。

原始搜索流程:先指数扩张(512→1024→2048…)找到上界,再二分搜索找最大可行 batch size。问题在于:当最优 batch size 远小于 512 时(如 TP=1 时仅 5),二分搜索在 [1, 512] 范围内仍需约 9 次探测,其中大部分是无效的。我希望能利用仿真框架已有的内存和 SLO 数据,提前估算右边界,缩小搜索范围。

## 3. 实践过程

### 3.1 任务拆解

整个优化分为 5 个阶段:

1. **框架逻辑梳理**:阅读 `docs/en/throughput_optimizer.md`,形成 project.md,理清搜索策略的完整调用链

2. **优化点识别**:基于 project.md 分析 9 个潜在优化方向(P1-P9),选定 P4(右边界智能初始化)

3. **方案设计**:经过多轮迭代,最终确定"两点估算"方案,形成 plan.md

4. **代码实现**:修改 4 个核心文件,实现估算逻辑

5. **调试验证**:发现结果偏差,定位根因并修复

### 3.2 使用的 SOLO 能力

- **brainstorming**:识别 9 个优化方向,评估可行性

- **writing-plans**:将 P4 方案细化为可执行的代码设计

- **代码搜索与阅读**:快速定位 `ModelRunnerMetrics`、`OptimizerData`、`parallel_runner.py` 等关键代码

- **多文件协同修改**:同时修改 `optimizer_summary.py`、`agg_throughput_optimizer.py`、`disagg_throughput_optimizer.py`、`base_throughput_optimizer.py`、`utils.py`、`throughput_optimizer.py`、`parallel_runner.py` 共 7 个文件

- **测试验证**:运行 CLI 命令和单元测试,对比修改前后结果

### 3.3 关键设计决策

**估算公式推导**

- **Memory 约束**(精确,线性模型):

```

peak_memory(N) = model_weight + per_request_memory × N

max_batch_by_memory = N + device_memory_available(N) / per_request_memory

```

- **SLO 约束**(两点线性外推,保守估计):

```

tpot_slope = (tpot_512 - tpot_1) / (512 - 1)

max_batch_by_tpot = (1 + (tpot_limit - tpot_1) / tpot_slope) × safety_factor

```

其中 `safety_factor` 补偿线性外推的保守性,通过 `–slo-safety-factor` 入参控制。

- 最终 `estimated_right = min(max_batch_by_memory, max_batch_by_tpot, max_batch_by_ttft)`

### 3.4 踩过的坑

**坑1:仅估算 Memory 约束不够**

最初只估算内存约束,TP=8 时得到 max_batch=809,但实际 BS=512 就因 tpot 超限 early_stop。原因:tpot 约束比内存约束更紧。→ 加入 SLO 约束估算。

**坑2:单点 SLO 估算严重低估**

用 BS=1 单点估算 tpot_slope,TP=8 时 max_batch_by_tpot=17(实际约 208)。原因:BS=1 的 tpot 几乎为 0,单点外推斜率极大。→ 改为两点法(BS=1 + BS=512)。

**坑3:估算截断 right 导致结果偏差**

两点法线性外推偏保守,TP=8 时 estimated_right=203,截断 right 后 binary search 搜到 200,原始代码搜到 208。原因:decode_latency 对 batch_size 是亚线性关系,线性外推斜率偏大。→ 加入 `safety_factor` 宽松系数,通过 `–slo-safety-factor` 入参控制,默认 1.5。

**坑4:BS=1 探测与原始 early_stop check 重复**

原始代码在 expansion 后单独做 BS=1 early_stop check,新增的 BS=1 probe 与之重复。→ 将 BS=1 probe 提前到 expansion 之前,同时作为估算参考点,合并 early_stop check。

## 4. 成果展示

### 4.1 修改文件汇总

| 文件 | 修改类型 | 核心内容 |

|------|---------|---------|

| `serving_cast/service/optimizer_summary.py` | 新增 | `_memory_info` 字段 + `set/get_memory_info` 方法 |

| `serving_cast/service/agg_throughput_optimizer.py` | 新增 | `get_inference_info` 末尾填充 `memory_info` 字典 |

| `serving_cast/service/disagg_throughput_optimizer.py` | 新增 | `get_inference_info` 末尾填充 `memory_info` 字典 |

| `serving_cast/service/base_throughput_optimizer.py` | 重构 | `run()` 方法改造 + `_estimate_right_boundary` 静态方法 |

| `serving_cast/service/utils.py` | 新增 | `OptimizerData` 新增 `slo_safety_factor` 字段 |

| `cli/inference/throughput_optimizer.py` | 新增 | `–slo-safety-factor` CLI 入参 |

| `serving_cast/parallel_runner.py` | 修改 | 传入 `slo_safety_factor` 到 `OptimizerData` |

### 4.2 运行结果

测试命令:

```bash

python -m cli.inference.throughput_optimizer Qwen/Qwen3-32B \

–device TEST_DEVICE --num-devices 8 \

–input-length 3500 --output-length 1500 \

–compile --quantize-linear-action W8A8_DYNAMIC \

–quantize-attention-action DISABLED --tpot-limits 50

```

输出结果:

| TP | 原始 batch_size | 优化后 batch_size | 原始耗时 | 优化后耗时 |

|----|----------------|------------------|---------|-----------|

| 8 | 208 | 210 | ~110s | 115s |

| 4 | 84 | 84 | | |

| 2 | 25 | 25 | | |

| 1 | 5 | 5 | | |

> 注:TP=8 的 210 vs 208 偏差源于 binary search 在 tpot≈50ms 边界附近的离散性,两者都是合法的"最大满足约束的 batch_size",throughput 差异仅 0.2%。通过调整 `–slo-safety-factor` 可控制搜索范围宽度。

### 4.3 核心代码片段

`_estimate_right_boundary` 静态方法:

```python

@staticmethod

def _estimate_right_boundary(probe_info, expansion_info, optimizer_data, current_right):

estimated_right = float("inf")

safety = optimizer_data.slo_safety_factor



\# Memory 约束(精确)

per_req = expansion_info.get("per_request_memory_gb", 0)

available = expansion_info.get("device_memory_available_gb", 0)

if per_req > 0:

    max_batch_by_memory = max(1, int(

        expansion_info\["batch_size"\] + available / per_req

    ))

    estimated_right = min(estimated_right, max_batch_by_memory)



\# TPOT 约束(两点线性外推 + 宽松系数)

bs_1, bs_2 = probe_info\["batch_size"\], expansion_info\["batch_size"\]

tpot_1, tpot_2 = probe_info.get("tpot"), expansion_info.get("tpot")

if tpot_1 and tpot_2 and optimizer_data.tpot_limits and tpot_2 > tpot_1:

    tpot_slope = (tpot_2 - tpot_1) / (bs_2 - bs_1)

    if tpot_slope > 0:

        max_batch_by_tpot = max(1, int(

            (bs_1 + (optimizer_data.tpot_limits - tpot_1) / tpot_slope) \* safety

        ))

        estimated_right = min(estimated_right, max_batch_by_tpot)



\# TTFT 约束(同上)

\# ... 类似逻辑 ...



return estimated_right

```

`run()` 方法改造后的搜索流程:

```python

# Step 1: BS=1 探测(合并原 early_stop check)

probe_summary = self.get_inference_info(optimizer_data)

if probe_summary.check_early_stop_flag():

return None

probe_info = {“batch_size”: 1, **probe_summary.get_memory_info()}

# Step 2: 指数扩张 + 即时估算

for _ in range(MAX_ITER_NUMS):

summary = self.get_inference_info(optimizer_data)  # BS=512

if summary.check_early_stop_flag():

    estimated_right = self.\_estimate_right_boundary(probe_info, ...)

    right = min(estimated_right, right)  # 截断右边界

    break

\# 不 early_stop 时,估算判断是否继续翻倍

estimated_right = self.\_estimate_right_boundary(probe_info, ...)

if estimated_right <= right:

    break

left, right = right, right \* 2

# Step 3: 二分搜索(范围已缩小)

while left <= right:

...

```

## 5. 效果与总结

### 提效效果

- **搜索范围缩小**:以 TP=8 为例,右边界从 512 缩小到约 305(1.5 宽松系数下),二分搜索次数从 9 次减少到约 8 次

- **可配置性**:`–slo-safety-factor` 入参让用户在"搜索速度"和"结果精度"之间灵活权衡

- **框架扩展性**:`memory_info` 机制为后续优化(如更精确的非线性估算)提供了数据基础

### SOLO 在流程中的价值

- **快速理解代码库**:SOLO 帮我在短时间内理清了从 CLI → parallel_runner → optimizer → model_runner 的完整调用链

- **方案迭代加速**:从 P4-memory-only → P4+SLO → 两点法 → expansion数据复用 → 最终方案,SOLO 陪我经历了 5 轮方案迭代,每次都能快速定位问题并调整

- **多文件协同修改**:7 个文件的修改一次性完成,SOLO 保证了接口一致性

### 可复用方法

1. **两点估算截断右边界**:适用于任何"指数扩张+二分搜索"的优化场景,只需两个数据点即可估算上界

2. **宽松系数入参化**:将算法的保守性参数暴露为 CLI 入参,让用户根据场景调整

3. **memory_info 传递机制**:在 summary 对象中携带计算过程的中间数据,供上层逻辑使用,避免重复计算