Inferix
论文链接在这里 ,这是一篇旨在构建现代世界模型的推理引擎的文章,对标vllm,sglang等对于语言模型的优化。 inferix主要是对于长时间,高质量,可交互的video generation的推理优化
引入了序列并行,低比特量化,稀疏Attention,分布式并行等加速手段
我通过claude code辅助进行了这个codebase的理解
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 | Inferix/
├── inferix/ # 核心框架代码
│ ├── core/ # 核心模块
│ ├── distributed/ # 分布式推理支持
│ ├── kvcache_manager/ # KV缓存管理
│ ├── models/ # 模型实现
│ ├── pipeline/ # 推理管线
│ └── profiling/ # 性能分析
├── example/ # 示例和配置
│ ├── self_forcing/ # Self-Forcing模型示例
│ ├── causvid/ # CausVid模型示例
│ ├── magi/ # MAGI模型示例
│ ├── quantization/ # 量化示例
│ ├── profiling/ # 性能分析示例
│ └── rtmp_streaming/ # 视频流示例
├── LV-Bench/ # 长视频基准测试
└── tests/ # 测试代码
|
0. 理解推理框架
在阅读推理框架之前我们要了解一个问题,模型参数和推理框架的代码是如何互动的
我们有两个部分,一个是模型的权重,例如 model.safetensors 这类,当我们使用vllm这类的框架时,会进行如下的操作
| from vllm import LLM
model = LLM(model="meta-llama/Llama-2-7b-hf")
|
而在其内部,则会将这些权重进行封装,人们在发布模型的时候会标记好参数的权重的含义
我们简化vllm内部的实现,大概是这个样子的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 | # 简化示例:vLLM 内部的权重加载逻辑
def load_weights(model, hf_model_path):
# 1. 从 HuggingFace 加载权重
state_dict = torch.load(f"{hf_model_path}/pytorch_model.bin")
# 2. 权重名称映射
# HF: model.layers.0.self_attn.q_proj.weight
# vLLM: model.layers.0.attention.qkv_proj.weight (合并的QKV)
for name, param in model.named_parameters():
if "qkv_proj" in name:
# vLLM 可能会合并 Q、K、V 权重以优化性能
layer_idx = extract_layer_idx(name)
q_weight = state_dict[f"layers.{layer_idx}.self_attn.q_proj.weight"]
k_weight = state_dict[f"layers.{layer_idx}.self_attn.k_proj.weight"]
v_weight = state_dict[f"layers.{layer_idx}.self_attn.v_proj.weight"]
# 合并权重
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
param.data.copy_(qkv_weight)
|
我们会从权重中获取到 q k v 的权重,随后我们要进行flash attention还是分布式操作,就随着推理框架的代码走
1. pipeline
该项目的核心接口是一个pipeline,其支持以下三种实现
Self-Forcing 其流程是单一提示词生成,支持T2V和I2V
- 基础模型: Wan2.1-T2V-1.3B
- 核心技术: DMD (Distribution Matching Distillation)
- 架构特点:
- 使用因果扩散模型 (CausalDiffusionInferencePipeline)
- 支持 EMA (Exponential Moving Average) 权重
- 基于 Wan Base 扩展
CausVid 其流程是连续提示词生成,以及交互式的提示输入
- 基础模型: Wan2.1-T2V-1.3B
- 核心技术: 因果视频生成(Causal Video)
- 架构特点:
- 使用因果推理管线 (CausalInferencePipeline)
- 支持 rollout 机制生成更长视频
- 基于 Wan Base 扩展
MAGI 支持V2V
- 基础模型: 独立的多模态架构
- 核心技术: 完全独立的 DiT (Diffusion Transformer) 架构
- 架构特点:
- 完全独立的模型实现(DiT + VAE + T5)
- 不依赖 Wan Base
- 支持多种规模(4.5B 和 24B)
我们来细看一下CausVid的代码,也就是基于wan模型的交互性生成的pipeline,这个部分有两个代码
CausalInferencePipeline.py 这个是底层的推理,输入的噪声和提示,给出生成的video
pipeline.py 这个是上层的接口,其编排整个生成流程,并保存结果
CausalInferencePipeline
这个代码的核心部分如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 | class CausalInferencePipeline(torch.nn.Module):
def __init__(
self,
args,
wan_base_model_path,
device,
enable_kv_offload,
parallel_config: Optional[ParallelConfig] = None
):
super().__init__()
self.parallel_config = parallel_config
# Step 1: Initialize all models
# Get model path from configuration (following README path convention)
model_path = wan_base_model_path if wan_base_model_path is not None else 'weights/Wan2.1-T2V-1.3B'
self.generator = WanDiffusionWrapper(
model_path=model_path,
**getattr(args, "model_kwargs", {}), enable_kv_offload=enable_kv_offload, parallel_config=parallel_config).to(device)
self.text_encoder = WanTextEncoder(model_path=model_path)
self.vae = WanVAEWrapper(model_path=model_path)
if dist.is_initialized():
dist.barrier()
# Step 2: Initialize all causal hyperparmeters
...
# Other Method
...
def inference(self, noise: torch.Tensor, text_prompts: List[str], start_latents: Optional[torch.Tensor], return_latents: bool = True, kv_cache_manager: Optional[KVCacheManager] = None, kv_cache_requests: Optional[List] = None) -> torch.Tensor:
"""
Perform inference on the given noise and text prompts.
Inputs:
noise (torch.Tensor): The input noise tensor of shape
(batch_size, num_frames, num_channels, height, width).
text_prompts (List[str]): The list of text prompts.
Outputs:
video (torch.Tensor): The generated video tensor of shape
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
"""
batch_size, num_frames, num_channels, height, width = noise.shape
...
if return_latents:
return video, output
else:
return video
|
主体部分是初始化和推理,初始化部分主要是这三个模型,一个wan模型的基座,一个文本编码器,一个vae
| self.generator = WanDiffusionWrapper(...)
self.text_encoder = WanTextEncoder(model_path=model_path)
self.vae = WanVAEWrapper(model_path=model_path)
|
这些部分的实现都放在 model/ 文件夹下
主要来看推理部分,推理的主干部分是两个循环,因为其设计是block diffusion,每个block块内根据时间步进行diffusion,然后外侧的循环是AR的过程
| for block_index in range(num_blocks):
noisy_input = noise[:, block_index *
self.num_frame_per_block:(block_index + 1) * self.num_frame_per_block]
... # 判定是否要填充start_latents
for index, current_timestep in enumerate(self.denoising_step_list):
# 迭代去噪
...
# 更新kv-cache
...
|
这里还有一个start_latents的参数,这个参数是用于image2video和video2video的任务,这个参数表示的生成开头的前几帧的latents,比如我们在image2video的任务中,我们会先计算好这个image的latents,然后直接将其纳入到生成的pipeline中,也就是生成的前几帧并不需要去噪,而是直接将这些images的latents填充进入pipeline中,并且更新kv-cache
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 | if start_latents is not None and block_index < num_input_blocks:
# 设置 timestep 为0,标记这为去噪完成的帧
timestep = torch.ones(
[batch_size, self.num_frame_per_block], device=noise.device, dtype=torch.int64) * 0
# 提取当前块的参考帧
current_ref_latents = start_latents[:, block_index * self.num_frame_per_block:(
block_index + 1) * self.num_frame_per_block]
# 直接复制到输出
output[:, block_index * self.num_frame_per_block:(
block_index + 1) * self.num_frame_per_block] = current_ref_latents
# 填充kv-cache
self.generator(
noisy_image_or_video=current_ref_latents,
conditional_dict=conditional_dict,
timestep=timestep * 0,
current_start=block_index * self.num_frame_per_block * self.frame_seq_length,
current_end=(block_index + 1) * self.num_frame_per_block * self.frame_seq_length,
kv_start=block_index * self.num_frame_per_block * self.per_rank_frame_seq_length,
kv_end=(block_index + 1) * self.num_frame_per_block * self.per_rank_frame_seq_length,
kv_cache_manager=kv_cache_manager,
kv_cache_requests=kv_cache_requests,
)
continue # 填充完成之后进入下一步循环
|
过了这个start_latents的判断之后,就是进入了去噪的过程,去噪的过程非常直接,就是传统的ddpm的过程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 | for index, current_timestep in enumerate(self.denoising_step_list):
# set current timestep
timestep = torch.ones(
[batch_size, self.num_frame_per_block], device=noise.device, dtype=torch.int64) * current_timestep
if index < len(self.denoising_step_list) - 1:
# 去噪
denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
...
)
... # 中间处理
# 加噪
noisy_input_flat = self.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
timestep_tensor
)
# Ensure result is not None before view operation
if noisy_input_flat is not None:
noisy_input = noisy_input_flat.view(denoised_pred.shape)
else:
raise RuntimeError("scheduler.add_noise returned None")
else:
# 最后一步,获得最终输出
denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
...
)
|
在完成一个block的diffusion之后,在最后一步再调用一遍generator,更新kv-cache
pipeline
我们看完了用于实际生成的pipeline,我们来看上层的业务接口
| 用户调用
↓
CausVidPipeline (业务层)
↓
CausalInferencePipeline (底层去噪)
↓
Generator (DiT/Transformer 模型)
|
pipeline中主要是处理如何分配每一次生成和上一次生成的start_latents的关系
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 | 假设 num_overlap_frames = 3
第一轮生成:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Frame: 0 1 2 3 4 5 ... 17 18 19 20
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
└────┬────┘
准备作为下一轮的
start_latents
(Frame 18, 19, 20)
保存到 all_video: Frame 0-17 (去掉重叠的 18-20)
第二轮生成:
start_latents = [Frame 18, 19, 20] (来自上一轮)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Frame: 18 19 20 21 22 23 ... 35 36 37 38
└──固定──┘ └────新生成────┘ └─重叠─┘
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
保存到 all_video: Frame 21-35 (去掉 18-20 和 36-38)
|
2. models
该项目中的models被划分成了两层关系,一层是比较底层可复用的实现,比如attention;一层是高层的,和pipeline对接的部分,比如说causvid部分
3. kv-cache
KV cache的管控是项目中最重要的部分之一
4. profiler