Skip to content

Decentralized disaggregated deployment architecture#947

Open
fuheaven wants to merge 463 commits intoModelTC:mainfrom
fuheaven:disagg
Open

Decentralized disaggregated deployment architecture#947
fuheaven wants to merge 463 commits intoModelTC:mainfrom
fuheaven:disagg

Conversation

@fuheaven
Copy link
Copy Markdown
Contributor

@fuheaven fuheaven commented Mar 16, 2026

Summary

Integrated Mooncake's disaggregated deployment mode into the runner to provide LightX2V with full three-stage disaggregated inference capability. The inference pipeline can be split into Encoder, Transformer, and Decoder nodes, where the VAE Decoder is deployed independently on the Decoder node. This support includes both Wan and Qwen model families.

On top of the three-stage foundation, this PR further introduces decentralized queue scheduling: a Controller process hosts RDMA metadata ring buffers, Transformer and Decoder run as pull-based workers, and the client only needs a single HTTP POST to the Encoder—no more three-way sequential requests. Multiple Transformer workers can be deployed across GPUs for parallel DiT execution.

Feature Highlights

  1. Disaggregated deployment is integrated with the Mooncake engine, enabling efficient RDMA-based data transfer. Inference I/O can reach the theoretical maximum bandwidth of the GPUs.
  2. The text encoder component is integrated with LightLLM optimizations. It supports kernel-level optimizations and service-level optimizations, delivering an additional ~30% performance improvement.
  3. Compared with Mooncake's standalone disagg submission, this integration is implemented within the local runner. Currently, it supports both Wan runner and Qwen runner.
  4. In Mooncake's original disagg approach, each stage runs as different threads within a unified process, which creates tight producer/consumer coupling and does not match high-concurrency scenarios. We decouple them into independent processes, allowing the three stages (encoder + transformer + decoder) to be deployed on different machines and different GPUs. Under high concurrency, this improves throughput.
  5. Decentralized queue scheduling with RDMA ring buffers (RDMABuffer): a Controller hosts request / phase1 / phase2 metadata rings; Encoder publishes dispatch metadata after inference; Transformer and Decoder workers pull tasks from the rings automatically. The client sends one HTTP request to the Encoder instead of three sequential POSTs.
  6. Multi-Transformer worker parallelism: multiple Transformer workers (each with a unique receiver_engine_rank) can run on different GPUs. Requests specify disagg_phase1_receiver_engine_rank to target a specific worker, enabling round-robin or explicit routing.
  7. True RDMA atomics: rdma_faa upgraded from read-modify-write shim to real IBV_WR_ATOMIC_FETCH_AND_ADD; new rdma_cas (IBV_WR_ATOMIC_CMP_AND_SWP) added. Both RDMAServer and RDMAClient register REMOTE_ATOMIC access flags.
  8. Queue metrics & monitoring: each service (Encoder / Transformer / Decoder) reports queue depth (queue_sizes, queue_total_pending, all_queues_empty) via the Reporter's set_extra_metrics_provider() hook, providing real-time pipeline backlog visibility.

Disaggregated Architecture (Three-Stage Pipeline)

Based on the disagg_mode configuration, the inference pipeline is physically split into three independent services. Data flows through Phase1 (Encoder → Transformer) and Phase2 (Transformer → Decoder), requiring two Mooncake transfers.

Encoder Role (disagg_mode="encoder")

  • Loads only:
    • Text Encoder
    • Image Encoder (for I2V / I2I)
    • VAE Encoder
  • Skips:
    • DiT
    • VAE Decoder (handled by the Decoder node in the three-stage setup)

After startup, it performs feature extraction and sends tensors through Mooncake Phase1 to the Transformer node, including:

  • context
  • clip_encoder_out
  • vae_encoder_out
  • latent_shape
  • (other required intermediate tensors)

Transformer Role (disagg_mode="transformer")

  • Loads only:
    • DiT
  • Skips:
    • Encoder
    • VAE Decoder
    • (VAE decoding is handled by the Decoder node)

After startup, it waits for Phase1 data. Upon receiving it, it performs:

  • Hash verification
  • Input assembly
  • Denoising

If decoder_engine_rank is configured, it sends the denoised latent space to the Decoder node via Mooncake Phase2, and does not perform local VAE decoding.

Decoder Role (disagg_mode="decode")

  • Loads only:
    • VAE Decoder
  • Skips:
    • Text/Image Encoder
    • DiT

After startup, it enters a Phase2 receive-and-wait state. When it receives the latent space from the Transformer, it performs:

  • VAE decoding
  • Saving output videos/images

Both task completion status and result files are stored on the Decoder node.

Decentralized Queue Scheduling

Architecture

┌──────────┐  HTTP POST   ┌──────────┐ Phase1 RDMA ┌─────────────┐ Phase2 RDMA ┌──────────┐
│  Client  │ ──────────→  │ Encoder  │ ──────────→ │ Transformer │ ──────────→ │ Decoder  │
└──────────┘              │ (GPU 0)  │             │ (GPU 1/2/3) │             │ (GPU 0)  │
                          └──────────┘             └─────────────┘             └──────────┘
                                ↑                        ↑                          ↑
                          lightx2v.server          pull worker ×N              pull worker
                          HTTP port 8002          (qwen_t2i_queue_workers)   (qwen_t2i_queue_workers)
                                │
                          ┌──────────┐
                          │Controller│  ← RDMA metadata ring buffers (always-on)
                          └──────────┘

How it differs from standard three-stage

Aspect Standard three-stage Decentralized queue
Client calls Must POST to Decoder → Transformer → Encoder separately Single POST to Encoder HTTP
Transformer HTTP server, one request at a time Pull worker, multiple instances consume in parallel
Decoder HTTP server Pull worker, auto-consumes Phase2
Request routing Client explicitly specifies Encoder writes RDMA ring, workers pull by rank
Result retrieval Poll Decoder HTTP Poll Encoder HTTP
Scaling Fixed 1:1:1 ratio N Transformer workers on N GPUs

Data flow

  1. Client POSTs to Encoder HTTP (/v1/tasks/image/) with prompt, data_bootstrap_room (unique room ID), and disagg_phase1_receiver_engine_rank (target Transformer rank).
  2. Encoder runs Text Encoder inference, creates a per-request Mooncake session, sends feature tensors via Phase1, and publishes dispatch metadata to the Phase1 RDMA ring.
  3. Transformer (pull worker) consumes the Phase1 ring slot matching its rank, initializes Mooncake Phase1 receiver + Phase2 sender, runs DiT denoising, sends latents via Phase2, and publishes dispatch metadata to the Phase2 RDMA ring.
  4. Decoder (pull worker) consumes the Phase2 ring, initializes Mooncake Phase2 receiver, runs VAE decode, and saves the output image.
  5. Client polls Encoder's /v1/tasks/{task_id}/status until completed.

Key components

  • Controller (ControllerService.serve_rdma_dispatch_only()): hosts three RDMA ring buffers (request / phase1 / phase2), no model loading, always-on background process.
  • RDMABuffer (rdma_buffer.py): shared ring buffer over RDMAServer/RDMAClient with slot-level atomic coordination for multi-producer/multi-consumer JSON dispatch.
  • Pull workers (qwen_t2i_queue_workers.py): Transformer and Decoder worker loops that consume from RDMA rings via disagg_try_consume_phase1() / disagg_try_consume_phase2(), then call disagg_transformer_prepare_dispatch() / disagg_decoder_prepare_dispatch() to set up per-request Mooncake sessions.

gushiqiao and others added 30 commits December 5, 2025 11:13
Tidy VAReader & OmniVAReader
Tidy VARecorder & X264VARecorder
VARecorder with stream, use buffer stream
Tidy env WORKER_RANK, READER_RANK, RECORDER_RANK
Support voice type choose
Co-authored-by: root <root@pt-de4c35727a1b4d1b9f27f422f06026ec-worker-0.pt-de4c35727a1b4d1b9f27f422f06026ec.ns-devsft-3460edd0.svc.cluster.local>
Co-authored-by: root <root@pt-9b2035a55fe647eeb007584b238e5077-worker-0.pt-9b2035a55fe647eeb007584b238e5077.ns-devsft-3460edd0.svc.cluster.local>
Co-authored-by: yihuiwen <yihuiwen@sensetime.com>
Co-authored-by: wangshankun <wangshankun@sensetime.com>
WanModel 继承自 CompiledMethodsMixin,它肯定有 compile
1. rename dcu to hygon_dcu
2. fix flash attention bug
Co-authored-by: yihuiwen <yihuiwen@sensetime.com>
Co-authored-by: qinxinyi <qxy118045534@163.com>
Co-authored-by: yihuiwen <yihuiwen@sensetime.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

此拉取请求引入了LightX2V框架中大型生成模型(如Wan和Qwen Image)的三段式分离部署模式。这一改进旨在通过将推理流水线拆分为独立的Encoder、Transformer和Decoder服务,显著优化显存使用、提高系统吞吐量,并支持跨设备或跨机器的灵活部署。通过集成Mooncake传输引擎和LightLLM优化,确保了数据传输的高效性和编码阶段的性能提升,从而为高分辨率、长时生成场景提供了更稳定和可扩展的解决方案。

Highlights

  • 三段式分离部署: 为LightX2V推理流水线引入了完整的Encoder、Transformer、Decoder三段式分离部署能力,支持将推理过程拆分为独立的服务,部署在不同的显卡或节点上。
  • Mooncake引擎集成: 集成了高性能Mooncake传输引擎,支持RDMA/TCP通信,实现Encoder与Transformer之间(Phase1)以及Transformer与Decoder之间(Phase2)的高效数据传输。
  • 模型支持扩展: Wan和Qwen系列模型均已支持完整的三段式分离部署,包括VAE Decoder的独立部署。
  • 性能优化: Text Encoder部分集成了LightLLM优化(kernel或service),可提升性能高达30%。通过解耦各阶段为独立进程,提高了高并发场景下的吞吐量。
  • 显存优化: 分离部署模式显著降低了各节点的显存占用,因为每个节点只加载自身所需的模型部分,特别是在Decoder节点独立承载VAE解码时。
  • 详细部署指南: 新增了详细的中文部署指南文档,涵盖了配置方法、启动服务与请求流程、以及RDMA与TCP协议选择等内容。

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • configs/qwen_image/qwen_image_i2i_disagg_decode.json
    • Added新配置,用于Qwen Image I2I任务的Decoder分离部署模式。
  • configs/qwen_image/qwen_image_i2i_disagg_encoder.json
    • Added新配置,用于Qwen Image I2I任务的Encoder分离部署模式。
  • configs/qwen_image/qwen_image_i2i_disagg_transformer.json
    • Added新配置,用于Qwen Image I2I任务的Transformer分离部署模式。
  • configs/qwen_image/qwen_image_t2i_disagg_decode.json
    • Added新配置,用于Qwen Image T2I任务的Decoder分离部署模式。
  • configs/qwen_image/qwen_image_t2i_disagg_encoder.json
    • Added新配置,用于Qwen Image T2I任务的Encoder分离部署模式。
  • configs/qwen_image/qwen_image_t2i_disagg_transformer.json
    • Added新配置,用于Qwen Image T2I任务的Transformer分离部署模式。
  • configs/wan/wan_i2v_disagg_decode.json
    • Added新配置,用于Wan I2V任务的Decoder分离部署模式。
  • configs/wan/wan_i2v_disagg_encoder.json
    • Added新配置,用于Wan I2V任务的Encoder分离部署模式。
  • configs/wan/wan_i2v_disagg_transformer.json
    • Added新配置,用于Wan I2V任务的Transformer分离部署模式。
  • configs/wan/wan_t2v_disagg_decode.json
    • Added新配置,用于Wan T2V任务的Decoder分离部署模式。
  • configs/wan/wan_t2v_disagg_encoder.json
    • Added新配置,用于Wan T2V任务的Encoder分离部署模式。
  • configs/wan/wan_t2v_disagg_transformer.json
    • Added新配置,用于Wan T2V任务的Transformer分离部署模式。
  • examples/BeginnerGuide/ZH_CN/DisaggSplitDeploy.md
    • Added一篇新的中文指南文档,详细介绍了Diffusion模型的分离部署。
  • lightx2v/disagg/disagg_mixin.py
    • Added一个新的Mixin类,用于实现基于Mooncake的分布式通信功能。
  • lightx2v/models/runners/base_runner.py
    • Updatedinit_scheduler方法,在解码模式下使用NullScheduler。
  • lightx2v/models/runners/default_runner.py
    • Updatedinit_modules和end_run方法,以正确处理模型在分离部署模式下可能为None的情况。
  • lightx2v/models/runners/qwen_image/qwen_image_runner.py
    • IntegratedDisaggMixin,Updatedload_model以根据分离模式加载特定组件,Modifiedinit_modules和run_pipeline以支持分离部署逻辑,Adjustedset_target_shape以处理Transformer模式下的形状恢复。
  • lightx2v/models/runners/wan/wan_runner.py
    • IntegratedDisaggMixin,Updatedload_model以根据分离模式加载特定组件,Added_run_transformer_role方法,Modifiedrun_pipeline以支持分离部署逻辑。
  • lightx2v/models/schedulers/scheduler.py
    • AddedNullScheduler类,用于在不需要调度器的分离部署角色中作为占位符。
  • scripts/base/base.sh
    • AddedMOONCAKE_CONFIG_PATH环境变量的设置。
  • scripts/server/disagg/qwen/post_qwen_i2i.py
    • Added一个Python脚本,用于向Qwen Image I2I三段式分离部署服务发送请求。
  • scripts/server/disagg/qwen/post_qwen_t2i.py
    • Added一个Python脚本,用于向Qwen Image T2I三段式分离部署服务发送请求。
  • scripts/server/disagg/qwen/start_qwen_i2i_disagg.sh
    • Added一个Shell脚本,用于启动Qwen Image I2I三段式分离部署服务。
  • scripts/server/disagg/qwen/start_qwen_t2i_disagg.sh
    • Added一个Shell脚本,用于启动Qwen Image T2I三段式分离部署服务。
  • scripts/server/disagg/wan/post_wan_i2v.py
    • Added一个Python脚本,用于向Wan I2V三段式分离部署服务发送请求。
  • scripts/server/disagg/wan/post_wan_t2v.py
    • Added一个Python脚本,用于向Wan T2V三段式分离部署服务发送请求。
  • scripts/server/disagg/wan/start_wan_i2v_disagg.sh
    • Added一个Shell脚本,用于启动Wan I2V三段式分离部署服务。
  • scripts/server/disagg/wan/start_wan_t2v_disagg.sh
    • Added一个Shell脚本,用于启动Wan T2V三段式分离部署服务。
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

本次 PR 引入了非常重要的三段式分离部署功能(Encoder + Transformer + Decoder),这是一个很棒的工程实现,可以有效优化大规模生成模型在分布式环境下的显存占用和推理吞吐。

代码实现非常全面,涵盖了从底层通信(基于 Mooncake)、核心逻辑(DisaggMixin)、与现有 Runner 的集成,到上层的配置、文档和测试脚本。整体设计考虑周全,例如:

  • 使用 DisaggMixin 来复用分离部署逻辑,代码结构清晰。
  • 针对不同角色(encoder, transformer, decode)按需加载模型,有效降低显存。
  • 包含了数据传输的哈希校验,保证了数据一致性。
  • 提供了详尽的中文文档和开箱即用的启动、测试脚本,极大地降低了用户的使用门槛。

我发现了一些文档和脚本注释中的小问题,并已在具体的 review comments 中提出建议,希望能让这个功能更加完善。总体来说,这是一次高质量的提交。

Comment on lines +358 to +366
```bash
python -m lightx2v.server \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan/wan_t2v_disagg_decode.json \
--host 0.0.0.0 \
--port 8004
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

你好,这篇文档写得非常详细,对用户理解和使用分离部署功能非常有帮助。
在 3.1 节手动启动服务的示例代码中,使用了 $model_path${lightx2v_path} 这两个环境变量。对于直接阅读这部分内容的用户来说,可能不清楚如何设置这两个变量。
建议在这里增加一个简短的说明,提醒用户需要先设置这两个环境变量,并可以参考脚本 scripts/server/disagg/wan/start_wan_t2v_disagg.sh 中的定义方式。例如:

> **注意**:以下命令中的 `$model_path``${lightx2v_path}` 变量需要提前设置。`$lightx2v_path` 应指向项目根目录,`$model_path` 应指向模型文件所在的目录。

这样可以提升文档的易用性。

# GPU_T : Transformer (port 8005)
#
# Override GPUs via environment variables:
# GPU_ENCODER=4 GPU_TRANSFORMER=5 GPU_DECODER=6 ./start_wan_i2v_disagg_all.sh
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在脚本的注释中,示例命令引用了 start_wan_i2v_disagg_all.sh,但当前脚本的文件名是 start_wan_i2v_disagg.sh。这可能是一个小笔误。

Suggested change
# GPU_ENCODER=4 GPU_TRANSFORMER=5 GPU_DECODER=6 ./start_wan_i2v_disagg_all.sh
# GPU_ENCODER=4 GPU_TRANSFORMER=5 GPU_DECODER=6 ./start_wan_i2v_disagg.sh

# GPU_T : Transformer (port 8003)
#
# Override GPUs via environment variables:
# GPU_ENCODER=4 GPU_TRANSFORMER=5 GPU_DECODER=6 ./start_wan_t2v_disagg_all.sh
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

和另一个脚本类似,这里的注释中示例命令引用了 start_wan_t2v_disagg_all.sh,但当前脚本的文件名是 start_wan_t2v_disagg.sh。建议修正这个笔误。

Suggested change
# GPU_ENCODER=4 GPU_TRANSFORMER=5 GPU_DECODER=6 ./start_wan_t2v_disagg_all.sh
# GPU_ENCODER=4 GPU_TRANSFORMER=5 GPU_DECODER=6 ./start_wan_t2v_disagg.sh

@fuheaven fuheaven changed the title 三段式分离部署(Encoder + Transformer + Decoder) Three-stage disaggregated deployment architecture(Encoder + Transformer + Decoder) Mar 20, 2026
@fuheaven fuheaven changed the title Three-stage disaggregated deployment architecture(Encoder + Transformer + Decoder) Decentralized disaggregated deployment architecture Apr 2, 2026
if self.text_encoder_type in ["lightllm_service", "lightllm_kernel"]:
logger.info(f"Using LightLLM text encoder: {self.text_encoder_type}")

def set_config(self, config_modify):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用链:客户端 HTTP POST → server → worker.py:95 → runner.set_config(task_data) → runner.run_pipeline(input_info)
这里的逻辑是:将 HTTP 请求中扁平化的 disagg 参数,同步映射到 config["disagg_config"] 这个嵌套字典中。在去中心化模式下,客户端只需要一个 HTTP POST 就能指定"这个请求发给哪个 Transformer worker"和"使用哪个 Mooncake room"

self.model = None
self.text_encoders = None
self.vae = self.load_vae()
else:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有指定disagg_mode的话就会走默认的load model逻辑

encoder_config = self.config.copy()
lightllm_config = self.config.get("lightllm_config", {})
encoder_config.update(lightllm_config)
encoder_config = dict(self.config)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是简化了写法

assert self.config.get("cpu_offload", False)
if self.config.get("disagg_mode"):
self.init_disagg(self.config)
super().init_modules()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用基类的init来代替,因为基类中init_module有一样的lazy load逻辑

logger.info(f"Qwen Image Runner got custom shape: {width}x{height}")
return (width, height)

cfg_h = self.config.get("target_height")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段是为了允许通过 config 中的 target_height / target_width 字段直接指定输出图片的分辨率,不必依赖aspect_ratio预设或者请求参数

return None

def set_target_shape(self):
# In disagg transformer mode, use the shape transmitted from encoder
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里跟上面一样,是为了方便post请求新增的,因为我希望在配置文件中预设默认输出分辨率,而不需要每次 HTTP 请求都传 target_shape 或 aspect_ratio

def run_image_encoder(self):
pass

@ProfilingContext4DebugL2("Load models")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段看起来是跟基类中的load model重复了,故删去

def init_scheduler(self):
super().init_scheduler()
if self.config.get("disagg_mode") == "decode":
return
Copy link
Copy Markdown
Contributor Author

@fuheaven fuheaven Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是因为decoder模式要直接返回,不能初始化schedule,所以我还搞了个nullschedule

- disagg decode: receive_transformer_outputs → VAE → save.
"""
self.input_info = input_info
disagg_mode = self.config.get("disagg_mode")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里就是按照encoder, transformer, decode进行if判断,走不同的pipeline

def init_scheduler(self):
"""Initialize scheduler"""
pass
"""Initialize scheduler."""
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里就是为了方便三段式decode阶段,搞了个nullschedule

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.