# 【Code With SOLO】用 SOLO 优化大模型推理性能仿真框架的 Batch Size 搜索策略
## 1. 摘要
在华为 Ascend msmodeling 推理性能仿真框架中,batch size 搜索采用"指数扩张 + 二分搜索"策略,默认右边界 512 导致大量无效探测。我使用 TRAE SOLO 完成了"两点估算截断右边界"优化方案的全流程——从框架逻辑梳理、优化点识别、方案设计、代码实现到调试验证,最终通过 BS=1 和 BS=512 两个数据点的内存/SLO 约束估算,将二分搜索的右边界从 512 缩小到估算值,减少了无效搜索次数,同时通过 `–slo-safety-factor` 入参保证结果可控。
## 2. 背景
我是一名大模型推理性能优化工程师,日常使用 msmodeling 框架预测和优化模型在 Ascend 芯片上的推理性能。框架的核心功能之一是吞吐量优化器(throughput_optimizer),它遍历不同并行策略(TP×EP×DP),对每种策略搜索最优 batch size。
原始搜索流程:先指数扩张(512→1024→2048…)找到上界,再二分搜索找最大可行 batch size。问题在于:当最优 batch size 远小于 512 时(如 TP=1 时仅 5),二分搜索在 [1, 512] 范围内仍需约 9 次探测,其中大部分是无效的。我希望能利用仿真框架已有的内存和 SLO 数据,提前估算右边界,缩小搜索范围。
## 3. 实践过程
### 3.1 任务拆解
整个优化分为 5 个阶段:
1. **框架逻辑梳理**:阅读 `docs/en/throughput_optimizer.md`,形成 project.md,理清搜索策略的完整调用链
2. **优化点识别**:基于 project.md 分析 9 个潜在优化方向(P1-P9),选定 P4(右边界智能初始化)
3. **方案设计**:经过多轮迭代,最终确定"两点估算"方案,形成 plan.md
4. **代码实现**:修改 4 个核心文件,实现估算逻辑
5. **调试验证**:发现结果偏差,定位根因并修复
### 3.2 使用的 SOLO 能力
- **brainstorming**:识别 9 个优化方向,评估可行性
- **writing-plans**:将 P4 方案细化为可执行的代码设计
- **代码搜索与阅读**:快速定位 `ModelRunnerMetrics`、`OptimizerData`、`parallel_runner.py` 等关键代码
- **多文件协同修改**:同时修改 `optimizer_summary.py`、`agg_throughput_optimizer.py`、`disagg_throughput_optimizer.py`、`base_throughput_optimizer.py`、`utils.py`、`throughput_optimizer.py`、`parallel_runner.py` 共 7 个文件
- **测试验证**:运行 CLI 命令和单元测试,对比修改前后结果
### 3.3 关键设计决策
**估算公式推导**:
- **Memory 约束**(精确,线性模型):
```
peak_memory(N) = model_weight + per_request_memory × N
max_batch_by_memory = N + device_memory_available(N) / per_request_memory
```
- **SLO 约束**(两点线性外推,保守估计):
```
tpot_slope = (tpot_512 - tpot_1) / (512 - 1)
max_batch_by_tpot = (1 + (tpot_limit - tpot_1) / tpot_slope) × safety_factor
```
其中 `safety_factor` 补偿线性外推的保守性,通过 `–slo-safety-factor` 入参控制。
- 最终 `estimated_right = min(max_batch_by_memory, max_batch_by_tpot, max_batch_by_ttft)`
### 3.4 踩过的坑
**坑1:仅估算 Memory 约束不够**
最初只估算内存约束,TP=8 时得到 max_batch=809,但实际 BS=512 就因 tpot 超限 early_stop。原因:tpot 约束比内存约束更紧。→ 加入 SLO 约束估算。
**坑2:单点 SLO 估算严重低估**
用 BS=1 单点估算 tpot_slope,TP=8 时 max_batch_by_tpot=17(实际约 208)。原因:BS=1 的 tpot 几乎为 0,单点外推斜率极大。→ 改为两点法(BS=1 + BS=512)。
**坑3:估算截断 right 导致结果偏差**
两点法线性外推偏保守,TP=8 时 estimated_right=203,截断 right 后 binary search 搜到 200,原始代码搜到 208。原因:decode_latency 对 batch_size 是亚线性关系,线性外推斜率偏大。→ 加入 `safety_factor` 宽松系数,通过 `–slo-safety-factor` 入参控制,默认 1.5。
**坑4:BS=1 探测与原始 early_stop check 重复**
原始代码在 expansion 后单独做 BS=1 early_stop check,新增的 BS=1 probe 与之重复。→ 将 BS=1 probe 提前到 expansion 之前,同时作为估算参考点,合并 early_stop check。
## 4. 成果展示
### 4.1 修改文件汇总
| 文件 | 修改类型 | 核心内容 |
|------|---------|---------|
| `serving_cast/service/optimizer_summary.py` | 新增 | `_memory_info` 字段 + `set/get_memory_info` 方法 |
| `serving_cast/service/agg_throughput_optimizer.py` | 新增 | `get_inference_info` 末尾填充 `memory_info` 字典 |
| `serving_cast/service/disagg_throughput_optimizer.py` | 新增 | `get_inference_info` 末尾填充 `memory_info` 字典 |
| `serving_cast/service/base_throughput_optimizer.py` | 重构 | `run()` 方法改造 + `_estimate_right_boundary` 静态方法 |
| `serving_cast/service/utils.py` | 新增 | `OptimizerData` 新增 `slo_safety_factor` 字段 |
| `cli/inference/throughput_optimizer.py` | 新增 | `–slo-safety-factor` CLI 入参 |
| `serving_cast/parallel_runner.py` | 修改 | 传入 `slo_safety_factor` 到 `OptimizerData` |
### 4.2 运行结果
测试命令:
```bash
python -m cli.inference.throughput_optimizer Qwen/Qwen3-32B \
–device TEST_DEVICE --num-devices 8 \
–input-length 3500 --output-length 1500 \
–compile --quantize-linear-action W8A8_DYNAMIC \
–quantize-attention-action DISABLED --tpot-limits 50
```
输出结果:
| TP | 原始 batch_size | 优化后 batch_size | 原始耗时 | 优化后耗时 |
|----|----------------|------------------|---------|-----------|
| 8 | 208 | 210 | ~110s | 115s |
| 4 | 84 | 84 | | |
| 2 | 25 | 25 | | |
| 1 | 5 | 5 | | |
> 注:TP=8 的 210 vs 208 偏差源于 binary search 在 tpot≈50ms 边界附近的离散性,两者都是合法的"最大满足约束的 batch_size",throughput 差异仅 0.2%。通过调整 `–slo-safety-factor` 可控制搜索范围宽度。
### 4.3 核心代码片段
`_estimate_right_boundary` 静态方法:
```python
@staticmethod
def _estimate_right_boundary(probe_info, expansion_info, optimizer_data, current_right):
estimated_right = float("inf")
safety = optimizer_data.slo_safety_factor
\# Memory 约束(精确)
per_req = expansion_info.get("per_request_memory_gb", 0)
available = expansion_info.get("device_memory_available_gb", 0)
if per_req > 0:
max_batch_by_memory = max(1, int(
expansion_info\["batch_size"\] + available / per_req
))
estimated_right = min(estimated_right, max_batch_by_memory)
\# TPOT 约束(两点线性外推 + 宽松系数)
bs_1, bs_2 = probe_info\["batch_size"\], expansion_info\["batch_size"\]
tpot_1, tpot_2 = probe_info.get("tpot"), expansion_info.get("tpot")
if tpot_1 and tpot_2 and optimizer_data.tpot_limits and tpot_2 > tpot_1:
tpot_slope = (tpot_2 - tpot_1) / (bs_2 - bs_1)
if tpot_slope > 0:
max_batch_by_tpot = max(1, int(
(bs_1 + (optimizer_data.tpot_limits - tpot_1) / tpot_slope) \* safety
))
estimated_right = min(estimated_right, max_batch_by_tpot)
\# TTFT 约束(同上)
\# ... 类似逻辑 ...
return estimated_right
```
`run()` 方法改造后的搜索流程:
```python
# Step 1: BS=1 探测(合并原 early_stop check)
probe_summary = self.get_inference_info(optimizer_data)
if probe_summary.check_early_stop_flag():
return None
probe_info = {“batch_size”: 1, **probe_summary.get_memory_info()}
# Step 2: 指数扩张 + 即时估算
for _ in range(MAX_ITER_NUMS):
summary = self.get_inference_info(optimizer_data) # BS=512
if summary.check_early_stop_flag():
estimated_right = self.\_estimate_right_boundary(probe_info, ...)
right = min(estimated_right, right) # 截断右边界
break
\# 不 early_stop 时,估算判断是否继续翻倍
estimated_right = self.\_estimate_right_boundary(probe_info, ...)
if estimated_right <= right:
break
left, right = right, right \* 2
# Step 3: 二分搜索(范围已缩小)
while left <= right:
...
```
## 5. 效果与总结
### 提效效果
- **搜索范围缩小**:以 TP=8 为例,右边界从 512 缩小到约 305(1.5 宽松系数下),二分搜索次数从 9 次减少到约 8 次
- **可配置性**:`–slo-safety-factor` 入参让用户在"搜索速度"和"结果精度"之间灵活权衡
- **框架扩展性**:`memory_info` 机制为后续优化(如更精确的非线性估算)提供了数据基础
### SOLO 在流程中的价值
- **快速理解代码库**:SOLO 帮我在短时间内理清了从 CLI → parallel_runner → optimizer → model_runner 的完整调用链
- **方案迭代加速**:从 P4-memory-only → P4+SLO → 两点法 → expansion数据复用 → 最终方案,SOLO 陪我经历了 5 轮方案迭代,每次都能快速定位问题并调整
- **多文件协同修改**:7 个文件的修改一次性完成,SOLO 保证了接口一致性
### 可复用方法
1. **两点估算截断右边界**:适用于任何"指数扩张+二分搜索"的优化场景,只需两个数据点即可估算上界
2. **宽松系数入参化**:将算法的保守性参数暴露为 CLI 入参,让用户根据场景调整
3. **memory_info 传递机制**:在 summary 对象中携带计算过程的中间数据,供上层逻辑使用,避免重复计算