Skip to content

Conversation

@BBuf
Copy link
Collaborator

@BBuf BBuf commented Dec 20, 2025

Motivation

I have identified the reason for the slow speed of wan2.2. A full profiler would reveal that the first step and the 19th step are significantly slower and have similar speeds, while all other intermediate steps are fast and fully meet the expected 2x speed improvement from cp4 to cp8. The first and 19th steps are about 7 times slower than a normal step, equivalent to 7 normal steps.

The reason is that wan2.2 uses dual transformers, and we have dit_cpu_offload enabled. Therefore, the weights are on cpu after loaded, and during the first step, the weights of both transformer and transformer_2 are copied from the CPU to the CUDA device, making the first step very slow. Then, at the 19th step, a dual-stream switch occurs, requiring both transformer and transformer_2 to be offloaded back to the CPU, their weights swapped, and then copied back to the CUDA device.

main branch:

sglang generate   --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers   --text-encoder-cpu-offload   --pin-cpu-memory   --num-gpus 8   --ulysses-degree 8 --attention-backend sage_attn  --enable-torch-compile --prompt "A cat walks on the grass, realistic" --num-frames 81 --height 720 --width 1280 --num-inference-steps 27 --guidance-scale 3.5 --guidance-scale-2 4.0 --perf-dump-path /home/lmsys/bbuf/dump/wan_step_profile_cp8_main.json

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [02:28<00:00,  5.52s/it]
[12-19 08:37:22] [DenoisingStage] average time per step: 5.5156 seconds
[12-19 08:37:23] [DenoisingStage] finished in 149.6943 seconds



"denoise_steps_ms": [
    35999.06893167645,
    3261.483933776617,
    3270.5406425520778,
    3267.8588768467307,
    3260.3964526206255,
    3263.016454875469,
    3268.026988953352,
    3264.5184732973576,
    3264.636719599366,
    3267.1875776723027,
    3268.562350422144,
    3268.1023878976703,
    3266.7769035324454,
    3268.044295720756,
    3264.268895611167,
    3271.0087513551116,
    3267.674465663731,
    3266.1060262471437,
    31282.590138725936,
    3263.5639663785696,
    3262.301029637456,
    3262.3210102319717,
    3261.833382770419,
    3264.719443395734,
    3265.314467251301,
    3267.530156299472,
    3261.9533529505134
  ],

pr

sglang generate   --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers   --text-encoder-cpu-offload   --pin-cpu-memory   --num-gpus 8   --ulysses-degree 8 --attention-backend sage_attn  --enable-torch-compile --prompt "A cat walks on the grass, realistic" --num-frames 81 --height 720 --width 1280 --num-inference-steps 27 --guidance-scale 3.5 --guidance-scale-2 4.0 --dit-layerwise-offload true --perf-dump-path /home/lmsys/bbuf/dump/wan_step_profile_cp8_async_offload.json


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [01:33<00:00,  3.46s/it]
[12-20 02:59:10] [DenoisingStage] average time per step: 3.4553 seconds
[12-20 02:59:10] [DenoisingStage] finished in 94.2283 seconds

"denoise_steps_ms": [
    7717.6975486800075,
    3275.042257271707,
    3280.2467988803983,
    3282.276245765388,
    3292.9044039919972,
    3286.5121429786086,
    3273.5616639256477,
    3271.6003246605396,
    3275.5934856832027,
    3291.9061705470085,
    3293.934356421232,
    3289.3909830600023,
    3298.0582248419523,
    3300.408118404448,
    3305.0247132778168,
    3299.3013756349683,
    3302.0150866359472,
    3299.040620215237,
    3291.0401169210672,
    3296.6199973598123,
    3290.26335850358,
    3302.190547809005,
    3295.942653901875,
    3297.3329443484545,
    3297.713255509734,
    3295.0284238904715,
    3290.1963284239173
  ],

149.69->94.22s, speed up 58%

图片

memcpy and compute are completely overlapped, with no additional overhead.

Warmup

Whether torch compile is enabled or not, I found that the performance of the first step is about 7 times slower than the subsequent steps, which is unacceptable. I profiled the full stage for both scenarios and discovered that the reason for the long initialization time is the initialization overhead of the NCCL All2All. This overhead should not be in the denoise stage but should be handled in advance. Therefore, I implemented a pre-warmup logic specifically for the All2All operation to eliminate this overhead. Now, without torch compile, the time for the first step is almost the same as that of the subsequent steps. Even with compile enabled, the time for the first step is only about twice that of the subsequent steps.

5f69d545-37be-4156-b92a-c8397fb7359a

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the diffusion SGLang Diffusion label Dec 20, 2025

warmup_start = time.time()

device = torch.device(f"cuda:{get_world_group().local_rank}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use current_platform.get_device_name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to obtain each rank and then perform the all2all operation. Getting the device_name only gives us one rank?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it, but the warmup time didn't decrease.

offload_mgr = getattr(self, "_layerwise_offload_manager", None)
if offload_mgr is not None and getattr(offload_mgr, "enabled", False):
for i, block in enumerate(self.blocks):
offload_mgr.prefetch_layer(i + 1, non_blocking=True)
Copy link
Contributor

@Edenzzzz Edenzzzz Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm probably missing something, but does FSDP also prefetch and offload layerwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the same way. FSDP CPU offload handles parameter movement/sharding at the FSDP boundary (and may offload/reshard according to its policies), but it does not provide an explicit per-layer async H2D prefetch pipeline aligned with the transformer block loop. dit_layerwise_offload here is a separate mechanism that explicitly offloads per-block weights to CPU and overlaps H2D prefetch of the next block with compute of the current block for supported DiT models. Also note it is mutually exclusive with dit_cpu_offload / use_fsdp_inference in our server args checks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants