Skip to content

[3.1/4] Diffusion Quantized ckpt export - WAN 2.2 14B#855

Open
jingyu-ml wants to merge 55 commits intomainfrom
jingyux/3.1-4-diffusion
Open

[3.1/4] Diffusion Quantized ckpt export - WAN 2.2 14B#855
jingyu-ml wants to merge 55 commits intomainfrom
jingyux/3.1-4-diffusion

Conversation

@jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Feb 5, 2026

What does this PR do?

Type of change: documentation

Overview:

  1. Added multi‑backbone support for quantization: --backbone now accepts space- or comma-separated lists and resolves to a list of backbone modules.
  2. Introduced PipelineManager.iter_backbones() to iterate named backbone modules and updated get_backbone() to return a single module or a ModuleList for multi‑backbone.
  3. Updated ExportManager to save/restore per‑backbone checkpoints when a directory is provided, with {backbone_name}.pt files, and to create target directories when missing.
  4. Simplified save_checkpoint() calls to rely on the registered pipeline_manager by default.

**Usage: **

python quantize.py --model wan2.2-t2v-14b --format fp4 --batch-size 1 --calib-size 32 \
    --n-steps 30 --backbone transformer transformer_2 --model-dtype BFloat16 \
    --quantized-torch-ckpt-save-path ./wan22_mo_ckpts \
    --hf-ckpt-dir ./wan2.2-t2v-14b 

Plans

  • [1/4] Add the basic functionalities to support limited image models with NVFP4 + FP8, with some refactoring on the previous LLM code and the diffusers example. PIC: @jingyu-ml
  • [2/4] Add support to more video gen models. PIC: @jingyu-ml
  • [3/4] Add test cases, refactor on the doc, and all related README. PIC: @jingyu-ml
  • [4/4] Add the final support to ComfyUI. PIC @jingyu-ml

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: No
  • Did you write any new necessary tests?:No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Unified Hugging Face export support for diffusers pipelines and components
    • LTX-2 and Wan2.2 (T2V) support in diffusers quantization workflow
    • Comprehensive ONNX export and TensorRT engine build documentation for diffusion models
  • Documentation

    • Updated to clarify support for both transformers and diffusers models in unified export API
    • Expanded diffusers examples with LoRA fusion guidance and additional model options (Flux, SD3, SDXL variants)

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
jingyu-ml and others added 20 commits January 24, 2026 00:25
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
…lopt

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners February 5, 2026 00:50
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

This pull request extends Hugging Face export support to diffusers models, introduces comprehensive ONNX export and TensorRT engine build documentation, and refactors the quantization pipeline to support per-backbone operations instead of single-backbone workflows.

Changes

Cohort / File(s) Summary
Version & Documentation Updates
CHANGELOG.rst, README.md, docs/source/deployment/3_unified_hf.rst, docs/source/getting_started/1_overview.rst
Added changelog entries and updated existing documentation to reflect unified Hugging Face export API now supporting both transformers and diffusers models/pipelines.
Diffusers Examples Documentation
examples/diffusers/README.md
Restructured and reorganized documentation by moving ONNX export and TensorRT engine build instructions to a dedicated quantization/ONNX.md file. Updated quantization script examples with new model variants (LTX-2, WAN2.2, flux-schnell, sd3-medium, sd3.5-medium) and added new LoRA fusion guidance section.
ONNX Workflow Documentation
examples/diffusers/quantization/ONNX.md
New comprehensive documentation page detailing end-to-end ONNX export and TensorRT engine build workflows, including quantization steps (8-bit, FP8), memory requirements, trtexec commands, and inference execution examples for multiple diffusion models.
Pipeline Backbone Support
examples/diffusers/quantization/pipeline_manager.py
Added iter_backbones() method to iterate over pipeline backbone modules with special LTX2 handling, and print_quant_summary() method to log per-backbone quantization information. Refactored get_backbone() to leverage new iterator.
Per-Backbone Quantization Export
examples/diffusers/quantization/quantize.py
Modified ExportManager to accept optional pipeline_manager and refactored save_checkpoint() and restore_checkpoint() to perform per-backbone operations. Changed CLI --backbone argument to accept multiple values (nargs="+"). Updated orchestration in main workflow to iterate over backbones via pipeline manager.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 69.23% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title mentions 'WAN 2.2 14B' but is vague about the main change; it doesn't clearly convey that this is about multi-backbone checkpoint export support, which is the core technical change. Clarify the title to reflect the primary change: multi-backbone per-checkpoint export support for diffusion quantization, not just WAN 2.2.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/3.1-4-diffusion

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@examples/diffusers/quantization/ONNX.md`:
- Line 26: Standardize the checkpoint placeholder used in the docs: replace the
inconsistent {MODEL_NAME} and {MODEL} occurrences with a single chosen
placeholder (e.g., {MODEL_NAME}) for the --quantized-torch-ckpt-save-path and
all related examples; update every instance including the command example shown
and the other occurrence around line 118 so that
--quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt and any references
(README/usage examples) consistently use the same placeholder name.
- Around line 45-48: Update the TensorRT version guidance and SVDQuant claim and
standardize placeholders: clarify that the "INT8 requires >= 9.2.0" statement is
specific to LLM inference on select GPUs (A100, A10G, L4, L40, L40S, H100,
GH200) and note NVIDIA's general production recommendation of TensorRT 8.6.1;
keep the FP8 guidance (TensorRT >= 10.2.0) but scope it similarly; replace the
incorrect blanket "SVDQuant deployment is currently not supported" with a
corrected note that SVDQuant is supported via NVIDIA ModelOpt and can be
integrated with TensorRT (with additional complexity and runtime
considerations); and standardize all placeholders to a single token (choose
{MODEL_NAME} and replace all occurrences of {MODEL} accordingly).

In `@examples/diffusers/quantization/pipeline_manager.py`:
- Around line 184-199: The code currently does list(self.config.backbone) which
splits strings into characters; instead normalize self.config.backbone into a
list of backbone names by checking if it's a str and splitting on commas
(str.split(",") with strip on each token) or otherwise converting the iterable
to a list, then assign to names; keep the existing LTX2 branch using
ModelType.LTX2 and _ensure_ltx2_transformer_cached (yielding name and
self._transformer), preserve the RuntimeError when names is empty, and continue
using getattr(self.pipe, name, None) for each normalized name to raise the same
missing-backbone error if a module is absent.

--format int8 --batch-size 2 \
--calib-size 32 --alpha 0.8 --n-steps 20 \
--model-dtype {Half/BFloat16} --trt-high-precision-dtype {Half|BFloat16} \
--quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --onnx-dir {ONNX_DIR}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use consistent placeholder naming for model checkpoints.

{MODEL_NAME} vs {MODEL} is inconsistent and may confuse users. Consider standardizing to one placeholder.

Proposed doc tweak
-    --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --onnx-dir {ONNX_DIR}
+    --quantized-torch-ckpt-save-path ./{MODEL}.pt --onnx-dir {ONNX_DIR}

Also applies to: 118-118

🤖 Prompt for AI Agents
In `@examples/diffusers/quantization/ONNX.md` at line 26, Standardize the
checkpoint placeholder used in the docs: replace the inconsistent {MODEL_NAME}
and {MODEL} occurrences with a single chosen placeholder (e.g., {MODEL_NAME})
for the --quantized-torch-ckpt-save-path and all related examples; update every
instance including the command example shown and the other occurrence around
line 118 so that --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt and any
references (README/usage examples) consistently use the same placeholder name.

Comment on lines +45 to +48
> [!IMPORTANT]
> TensorRT environment must be setup prior -- Please see [Pre-Requisites](../README.md#pre-requisites)
> INT8 requires **TensorRT version >= 9.2.0**. If you prefer to use the FP8 TensorRT, ensure you have **TensorRT version 10.2.0 or higher**. You can download the latest version of TensorRT at [here](https://developer.nvidia.com/tensorrt/download). Deployment of SVDQuant is currently not supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

TensorRT 9.2.0 INT8 quantization engine requirements

💡 Result:

TensorRT 9.2.0 INT8 engine — practical requirements

1) You must be on a supported TensorRT 9.2.0 target (this release is limited).
TensorRT 9.2.0 GA is explicitly scoped to LLM use and only on A100, A10G, L4, L40, L40S, H100, and GH200 (Linux x86 and Linux SBSA/Arm). For other apps / other GPUs, NVIDIA directs you to use TensorRT 8.6.1 for production. [1]

2) Hardware must support INT8 execution.
Your GPU must support INT8 for inference (TensorRT support matrices list which compute capabilities/devices support INT8). For example, T4 (7.5), A100 (8.0), A10 (8.6), H100 (9.0) support INT8. [2]

3) You need quantization information to build an INT8 engine.
To build an INT8 engine, TensorRT needs either:

  • a model that already contains quantization (for example, explicit Quantize/Dequantize), or
  • post-training calibration / dynamic ranges.

For calibration, you must implement IInt8Calibrator, provide batches via getBatch() (matching getBatchSize()), and attach it with config->setInt8Calibrator(...). [3]

4) Calibration input format requirement (common gotcha).
Even if you configure INT8 I/O, TensorRT still expects calibration data in FP32, and NVIDIA notes keeping the FP32 calibration values within [-128.0, 127.0] so they can be converted to INT8 without precision loss. [4]

Sources: NVIDIA TensorRT 9.2.0 release notes (archived) [1]; NVIDIA TensorRT support matrix (archived example) [2]; NVIDIA “Working with Quantized Types” (calibration API expectations) [3]; NVIDIA TensorRT developer guide note on INT8 calibration data type/range [4].


🌐 Web query:

TensorRT 10.2.0 FP8 engine build requirements

💡 Result:

To build an FP8 engine with TensorRT 10.2.0, you need:

  • GPU that supports FP8

    • TensorRT 10.2.0 lists FP8 support for SM 8.9 (e.g., L40S / Ada) and SM 9.0 (e.g., H100 / Hopper). [1]
    • FP8 convolution tactics specifically were added for Hopper GPUs in 10.2.0. [2]
  • Use explicit quantization (Q/DQ) for FP8

    • FP8 is supported only via explicit quantization, meaning your model must contain Quantize/Dequantize (Q/DQ) nodes and you must provide/encode the quantization scales. [3]
    • Do not mix FP8 and INT8 in the same network. [4]
  • Supported CUDA + minimum driver

    • TensorRT 10.2.0 supports CUDA 11.0 (update 3) through CUDA 12.6 (exact supported point releases listed in the 10.2.0 support matrix). [1]
    • Minimum driver requirement is R450 (Linux) / R452 (Windows) (minimum needed for CUDA 11.0). [2]

Sources: NVIDIA TensorRT 10.2.0 Support Matrix [1], TensorRT 10.2.0 Release Notes [2], “Working with Quantized Types” (explicit quantization + FP8 constraints) [3][4].


🌐 Web query:

SVDQuant TensorRT deployment support status

💡 Result:

SVDQuant is supported as a quantization (PTQ) algorithm in NVIDIA TensorRT Model Optimizer (ModelOpt), and ModelOpt’s outputs are intended for deployment via TensorRT / TensorRT‑LLM. In ModelOpt docs, SVDQuant is explicitly listed alongside other algorithms (SmoothQuant, AWQ) under mtq.quantize() support. [2]

For TensorRT deployment, NVIDIA’s guidance indicates SVDQuant is usable in the workflow, but it is not the “zero-extra-runtime-overhead” path: NVIDIA notes (in the context of TensorRT FP4 diffusion deployment) that SVDQuant is supported in ModelOpt (v0.27) but “increases deployment complexity” and “introduces some runtime overhead,” compared with QAT. [1]

Practical status summary:

  • Yes: SVDQuant is supported in ModelOpt and is part of the NVIDIA quantization toolchain feeding TensorRT. [1][2]
  • But: deploying SVDQuant-quantized models in TensorRT is currently described by NVIDIA as more complex than QAT and may involve additional runtime processing/overhead (that is, it’s not just “export and run with no extra considerations”). [1]
  • TensorRT‑LLM side: TensorRT‑LLM can run pre-quantized checkpoints produced by ModelOpt (general support statement), but it doesn’t specifically call out SVDQuant as a “standard recipe” the way it does for GPTQ/AWQ/FP4/FP8 families—so treat SVDQuant as ModelOpt-supported, deployment-path-dependent rather than a universally “one-click” TRT‑LLM quantization mode. [3]

Sources: [1] [2] [3]


🏁 Script executed:

# Check the file for placeholder inconsistencies mentioned in the scratchpad
cat -n examples/diffusers/quantization/ONNX.md | grep -E "\{MODEL_NAME\}|\{MODEL\}"

Repository: NVIDIA/Model-Optimizer

Length of output: 475


Correct TensorRT version scoping and SVDQuant support claims.

The stated TensorRT version requirements need clarification:

  • TensorRT 9.2.0 is scoped specifically to LLM inference on select GPUs (A100, A10G, L4, L40, L40S, H100, GH200). NVIDIA directs users to TensorRT 8.6.1 for general production use. The blanket "INT8 requires >= 9.2.0" is misleading.
  • SVDQuant is supported in NVIDIA's ModelOpt quantization toolchain and integrates with TensorRT deployment, though with added complexity and runtime overhead compared to QAT. The claim that "SVDQuant deployment is currently not supported" is incorrect.
  • Placeholder naming is inconsistent: {MODEL_NAME} appears at line 26 but {MODEL} at lines 118, 129, and 140. Standardize to one format.
🤖 Prompt for AI Agents
In `@examples/diffusers/quantization/ONNX.md` around lines 45 - 48, Update the
TensorRT version guidance and SVDQuant claim and standardize placeholders:
clarify that the "INT8 requires >= 9.2.0" statement is specific to LLM inference
on select GPUs (A100, A10G, L4, L40, L40S, H100, GH200) and note NVIDIA's
general production recommendation of TensorRT 8.6.1; keep the FP8 guidance
(TensorRT >= 10.2.0) but scope it similarly; replace the incorrect blanket
"SVDQuant deployment is currently not supported" with a corrected note that
SVDQuant is supported via NVIDIA ModelOpt and can be integrated with TensorRT
(with additional complexity and runtime considerations); and standardize all
placeholders to a single token (choose {MODEL_NAME} and replace all occurrences
of {MODEL} accordingly).

Comment on lines +184 to +199
names = list(self.config.backbone)

if self.config.model_type == ModelType.LTX2:
self._ensure_ltx2_transformer_cached()
return self._transformer
return getattr(self.pipe, self.config.backbone)
name = names[0] if names else "transformer"
yield name, self._transformer
return

if not names:
raise RuntimeError("No backbone names provided.")

for name in names:
module = getattr(self.pipe, name, None)
if module is None:
raise RuntimeError(f"Pipeline missing backbone module '{name}'.")
yield name, module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Normalize backbone specs to avoid per-character splitting and honor comma-separated input.
list(self.config.backbone) breaks when the value is a string (e.g., "unet,transformer") and doesn’t split commas. This will raise missing-backbone errors even though the CLI help allows comma separation.

🛠️ Suggested fix
-import logging
+import logging
+import re
 from collections.abc import Iterator
 from typing import Any
@@
-        names = list(self.config.backbone)
+        backbone_spec = self.config.backbone
+        raw_items: list[str] = []
+        if backbone_spec:
+            if isinstance(backbone_spec, str):
+                raw_items = [backbone_spec]
+            else:
+                raw_items = list(backbone_spec)
+
+        names = [n for item in raw_items for n in re.split(r"[,\s]+", item) if n]
🤖 Prompt for AI Agents
In `@examples/diffusers/quantization/pipeline_manager.py` around lines 184 - 199,
The code currently does list(self.config.backbone) which splits strings into
characters; instead normalize self.config.backbone into a list of backbone names
by checking if it's a str and splitting on commas (str.split(",") with strip on
each token) or otherwise converting the iterable to a list, then assign to
names; keep the existing LTX2 branch using ModelType.LTX2 and
_ensure_ltx2_transformer_cached (yielding name and self._transformer), preserve
the RuntimeError when names is empty, and continue using getattr(self.pipe,
name, None) for each normalized name to raise the same missing-backbone error if
a module is absent.

Comment on lines +217 to +235
def save_checkpoint(self) -> None:
"""
Save quantized model checkpoint.

Args:
backbone: Model backbone to save
"""
if not self.config.quantized_torch_ckpt_path:
return

self.logger.info(f"Saving quantized checkpoint to {self.config.quantized_torch_ckpt_path}")
mto.save(backbone, str(self.config.quantized_torch_ckpt_path))
ckpt_path = self.config.quantized_torch_ckpt_path
if self.pipeline_manager is None:
raise RuntimeError("Pipeline manager is required for per-backbone checkpoints.")
backbone_pairs = list(self.pipeline_manager.iter_backbones())

for name, backbone in backbone_pairs:
ckpt_path.mkdir(parents=True, exist_ok=True)
target_path = ckpt_path / f"{name}.pt"
self.logger.info(f"Saving backbone '{name}' to {target_path}")
mto.save(backbone, str(target_path))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle file vs. directory paths to avoid silent restore failures and awkward “*.pt/” directories.
Right now a file path is treated as a directory on save, and restore skips files entirely. This breaks legacy usage and can silently skip restores.

🛠️ Suggested fix
     def save_checkpoint(self) -> None:
@@
-        ckpt_path = self.config.quantized_torch_ckpt_path
+        ckpt_path = self.config.quantized_torch_ckpt_path
         if self.pipeline_manager is None:
             raise RuntimeError("Pipeline manager is required for per-backbone checkpoints.")
         backbone_pairs = list(self.pipeline_manager.iter_backbones())
 
-        for name, backbone in backbone_pairs:
-            ckpt_path.mkdir(parents=True, exist_ok=True)
-            target_path = ckpt_path / f"{name}.pt"
-            self.logger.info(f"Saving backbone '{name}' to {target_path}")
-            mto.save(backbone, str(target_path))
+        if ckpt_path.suffix == ".pt":
+            if len(backbone_pairs) != 1:
+                raise ValueError("Provide a directory path when saving multiple backbones.")
+            name, backbone = backbone_pairs[0]
+            self.logger.info(f"Saving backbone '{name}' to {ckpt_path}")
+            mto.save(backbone, str(ckpt_path))
+        else:
+            ckpt_path.mkdir(parents=True, exist_ok=True)
+            for name, backbone in backbone_pairs:
+                target_path = ckpt_path / f"{name}.pt"
+                self.logger.info(f"Saving backbone '{name}' to {target_path}")
+                mto.save(backbone, str(target_path))
@@
     def restore_checkpoint(self) -> None:
@@
-        restore_path = self.config.restore_from
+        restore_path = self.config.restore_from
         if self.pipeline_manager is None:
             raise RuntimeError("Pipeline manager is required for per-backbone checkpoints.")
         backbone_pairs = list(self.pipeline_manager.iter_backbones())
-        if restore_path.exists() and restore_path.is_dir():
-            for name, backbone in backbone_pairs:
-                source_path = restore_path / f"{name}.pt"
-                if not source_path.exists():
-                    raise FileNotFoundError(f"Backbone checkpoint not found: {source_path}")
-                self.logger.info(f"Restoring backbone '{name}' from {source_path}")
-                mto.restore(backbone, str(source_path))
-        self.logger.info("Backbone checkpoints restored successfully")
+        if restore_path.is_file():
+            if len(backbone_pairs) != 1:
+                raise ValueError("Provide a directory path when restoring multiple backbones.")
+            name, backbone = backbone_pairs[0]
+            self.logger.info(f"Restoring backbone '{name}' from {restore_path}")
+            mto.restore(backbone, str(restore_path))
+            self.logger.info("Backbone checkpoint restored successfully")
+            return
+        if restore_path.is_dir():
+            for name, backbone in backbone_pairs:
+                source_path = restore_path / f"{name}.pt"
+                if not source_path.exists():
+                    raise FileNotFoundError(f"Backbone checkpoint not found: {source_path}")
+                self.logger.info(f"Restoring backbone '{name}' from {source_path}")
+                mto.restore(backbone, str(source_path))
+            self.logger.info("Backbone checkpoints restored successfully")
+            return
+        raise FileNotFoundError(f"Restore checkpoint not found: {restore_path}")

Also applies to: 278-297

@codecov
Copy link

codecov bot commented Feb 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.72%. Comparing base (452c5a0) to head (04abab5).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #855   +/-   ##
=======================================
  Coverage   73.72%   73.72%           
=======================================
  Files         196      196           
  Lines       20457    20457           
=======================================
  Hits        15082    15082           
  Misses       5375     5375           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant