Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions torch/distributed/checkpoint/dedup_tensors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

from typing import Dict, List
import dataclasses
import logging
from typing import Dict, List

from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.planner import SavePlan

__all__ = ["dedup_tensors"]


def init_logger() -> logging.Logger:
logger = logging.getLogger(__name__)
level = logging.INFO
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
console.setFormatter(formatter)
console.setLevel(level)
logger.addHandler(console)
logger.propagate = False
return logger

logger = init_logger()

# TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
Expand All @@ -18,12 +35,13 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:

replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}

# Remove deplicates by always keeping the first entry.
# Remove duplicates by always keeping the first entry.
# Compute the per-rank remove set.
plan_to_keys: Dict[int, List[MetadataIndex]] = {}
for key, plans in replicated_items.items():
for plan_idx in plans[1:]:
plan_to_keys.setdefault(plan_idx, []).append(key)
logger.info(f"Duplicate keys to remove: {plan_to_keys}")

for plan_idx, keys in plan_to_keys.items():
key_set = set(keys)
Expand Down