-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Diffusion] Wan video model support zero-cost weight offload and overlap with compute #15511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
|
||
| warmup_start = time.time() | ||
|
|
||
| device = torch.device(f"cuda:{get_world_group().local_rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we use current_platform.get_device_name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to obtain each rank and then perform the all2all operation. Getting the device_name only gives us one rank?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it, but the warmup time didn't decrease.
python/sglang/multimodal_gen/runtime/distributed/parallel_state.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/loader/component_loader.py
Outdated
Show resolved
Hide resolved
| offload_mgr = getattr(self, "_layerwise_offload_manager", None) | ||
| if offload_mgr is not None and getattr(offload_mgr, "enabled", False): | ||
| for i, block in enumerate(self.blocks): | ||
| offload_mgr.prefetch_layer(i + 1, non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably missing something, but does FSDP also prefetch and offload layerwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in the same way. FSDP CPU offload handles parameter movement/sharding at the FSDP boundary (and may offload/reshard according to its policies), but it does not provide an explicit per-layer async H2D prefetch pipeline aligned with the transformer block loop. dit_layerwise_offload here is a separate mechanism that explicitly offloads per-block weights to CPU and overlaps H2D prefetch of the next block with compute of the current block for supported DiT models. Also note it is mutually exclusive with dit_cpu_offload / use_fsdp_inference in our server args checks.
Motivation
I have identified the reason for the slow speed of wan2.2. A full profiler would reveal that the first step and the 19th step are significantly slower and have similar speeds, while all other intermediate steps are fast and fully meet the expected 2x speed improvement from cp4 to cp8. The first and 19th steps are about 7 times slower than a normal step, equivalent to 7 normal steps.
The reason is that wan2.2 uses dual transformers, and we have
dit_cpu_offloadenabled. Therefore, the weights are on cpu after loaded, and during the first step, the weights of bothtransformerandtransformer_2are copied from the CPU to the CUDA device, making the first step very slow. Then, at the 19th step, a dual-stream switch occurs, requiring bothtransformerandtransformer_2to be offloaded back to the CPU, their weights swapped, and then copied back to the CUDA device.main branch:
pr
149.69->94.22s, speed up 58% 。
memcpy and compute are completely overlapped, with no additional overhead.
Warmup
Whether torch compile is enabled or not, I found that the performance of the first step is about 7 times slower than the subsequent steps, which is unacceptable. I profiled the full stage for both scenarios and discovered that the reason for the long initialization time is the initialization overhead of the NCCL All2All. This overhead should not be in the denoise stage but should be handled in advance. Therefore, I implemented a pre-warmup logic specifically for the All2All operation to eliminate this overhead. Now, without torch compile, the time for the first step is almost the same as that of the subsequent steps. Even with compile enabled, the time for the first step is only about twice that of the subsequent steps.
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist