Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions torch/distributed/fsdp/_exec_order_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,35 +333,34 @@ def _get_names_from_handle_indices(
handle_indices: Tuple[int, ...],
) -> List[List[str]]:
"""
Returns a list of prefixed parameter names for each handle in
``handle_indices``. If a handle index is invalid, then its prefixed
parameter names are omitted from the returned list.
Returns a list of FQNs for each handle in ``handle_indices``. If a
handle index is invalid, then its FQNs are omitted from the returned
list.
"""
prefixed_param_names: List[List[str]] = []
fqns: List[List[str]] = []
for index in handle_indices:
if index is None or index < 0 or index >= len(self.all_handles):
continue
handle = self.all_handles[index]
flat_param = handle.flat_param
prefixed_param_names.append(self.param_to_fqn[flat_param])
return prefixed_param_names
fqns.append(self.param_to_fqn[flat_param])
return fqns

def _get_names_from_handles(
self,
handles_key: _HandlesKey,
) -> List[List[str]]:
"""
Returns a list of prefixed parameter names for each handle in
``handles_key``. If a handle is invalid, then its prefixed parameter
names are omitted from the returned list.
Returns a list of FQNs for each handle in ``handles_key``. If a handle
is invalid, then its FQNs are omitted from the returned list.
"""
prefixed_param_names: List[List[str]] = []
fqns: List[List[str]] = []
for handle in handles_key:
flat_param = handle.flat_param
if flat_param not in self.param_to_fqn:
continue
prefixed_param_names.append(self.param_to_fqn[flat_param])
return prefixed_param_names
fqns.append(self.param_to_fqn[flat_param])
return fqns

def next_iter(self):
"""
Expand Down