Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions src/usethis/_integrations/pre_commit/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from usethis._integrations.pre_commit.language import get_system_language
from usethis._integrations.pre_commit.yaml import PreCommitConfigYAMLManager
from usethis._pipeweld.containers import series
from usethis._pipeweld.func import Adder

if TYPE_CHECKING:
from collections.abc import Collection
Expand Down Expand Up @@ -67,39 +69,40 @@ def add_repo(repo: schema.LocalRepo | schema.UriRepo) -> None:
mgr.commit_model(model)
else:
# There are existing hooks so we need to know where to insert the new hook.

# Get the precendents, i.e. hooks occurring before the new hook
# Also the successors, i.e. hooks occurring after the new hook
# Use pipeweld to determine the correct insertion position based on the
# canonical hook ordering.
try:
hook_idx = _HOOK_ORDER.index(hook_config.id)
except ValueError:
msg = f"Hook '{hook_config.id}' not recognized."
raise NotImplementedError(msg) from None
precedents = _HOOK_ORDER[:hook_idx]
successors = _HOOK_ORDER[hook_idx + 1 :]

existing_precedents = [hook for hook in existing_hooks if hook in precedents]
existing_successors = [hook for hook in existing_hooks if hook in successors]

# Add immediately after the last precedecessor.
# If there isn't one, we want to add as late as possible without violating
# order, i.e. before the first successor, if there is one.
if existing_precedents:
last_precedent = existing_precedents[-1]
elif not existing_successors:
last_precedent = existing_hooks[-1]
else:
first_successor = existing_successors[0]
first_successor_idx = existing_hooks.index(first_successor)
if first_successor_idx == 0:
last_precedent = None
else:
last_precedent = existing_hooks[first_successor_idx - 1]

prerequisites = set(_HOOK_ORDER[:hook_idx])
postrequisites = set(_HOOK_ORDER[hook_idx + 1 :])

pipeline = series(*existing_hooks)
adder = Adder(
pipeline=pipeline,
step=hook_config.id,
prerequisites=prerequisites,
postrequisites=postrequisites,
force_linear=True,
)
result = adder.add()

# With force_linear=True, solution is a flat Series of strings.
flat = result.solution.root
idx = flat.index(hook_config.id)
predecessor: str | None = None
if idx > 0:
prev_item = flat[idx - 1]
assert isinstance(prev_item, str)
predecessor = prev_item

model.repos = insert_repo(
repo_to_insert=repo,
existing_repos=model.repos,
predecessor=last_precedent,
predecessor=predecessor,
)

mgr.commit_model(model)
Expand Down
63 changes: 63 additions & 0 deletions src/usethis/_pipeweld/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Adder(BaseModel):
prerequisites: set[str] = set()
postrequisites: set[str] = set()
compatible_config_groups: set[str] = set()
force_linear: bool = False

def add(self) -> WeldResult:
if len(self.pipeline) == 0:
Expand All @@ -60,6 +61,11 @@ def add(self) -> WeldResult:

instructions += new_instructions

if self.force_linear:
original_order = _extract_ordered_steps(self.pipeline)
flat = _linearize_component(rearranged_pipeline, self.step, original_order)
rearranged_pipeline = series(*flat)

return WeldResult(
solution=rearranged_pipeline,
instructions=instructions,
Expand Down Expand Up @@ -605,3 +611,60 @@ def get_endpoint(component: str | Series | DepGroup | Parallel) -> str:
return get_endpoint(component.series)
else:
assert_never(component)


def _extract_ordered_steps(
component: str | Series | Parallel | DepGroup,
) -> list[str]:
"""Extract all step names from a component in depth-first order."""
if isinstance(component, str):
return [component]
elif isinstance(component, Series | Parallel):
return [s for sub in component.root for s in _extract_ordered_steps(sub)]
elif isinstance(component, DepGroup):
return _extract_ordered_steps(component.series)
else:
assert_never(component)


def _linearize_component(
component: str | Series | Parallel | DepGroup,
new_step: str,
original_order: list[str],
) -> list[str]:
"""Flatten a pipeline component to a linear list of step names.

Within parallel groups, existing steps maintain their relative order from
``original_order`` and the new step is placed after all existing steps.
"""
if isinstance(component, str):
return [component]
elif isinstance(component, Series):
result: list[str] = []
for sub in component.root:
result.extend(_linearize_component(sub, new_step, original_order))
return result
elif isinstance(component, Parallel):
sublists = [
_linearize_component(sub, new_step, original_order)
for sub in component.root
]
all_items = [item for sublist in sublists for item in sublist]

def sort_key(item: str) -> tuple[int, int]:
# Existing steps sort by their original position (priority 0);
# the new step sorts last (priority 1).
# Steps not found in original_order are placed after all known steps.
if item == new_step:
return (1, 0)
try:
return (0, original_order.index(item))
except ValueError:
return (0, len(original_order))

all_items.sort(key=sort_key)
return all_items
elif isinstance(component, DepGroup):
return _linearize_component(component.series, new_step, original_order)
else:
assert_never(component)
119 changes: 119 additions & 0 deletions tests/usethis/_integrations/pre_commit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,122 @@ def test_aliases(self):
schema.HookDefinition(id="ruff-check"),
schema.HookDefinition(id="ruff"),
)


class TestAddRepoPipeweld:
"""Integration tests for pipeweld-based hook insertion."""

def test_insert_between_nondependent_and_postrequisite(self, tmp_path: Path):
"""Insert a recognized hook between an unrecognized hook and a postrequisite."""
with change_cwd(tmp_path), files_manager():
# Set up: foo (unrecognized) then codespell (recognized, late in order)
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="foo",
name="foo",
entry="foo .",
language=schema.Language("system"),
)
],
),
)
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="codespell",
name="codespell",
entry="codespell .",
language=schema.Language("system"),
)
],
)
)

# Act: add ruff-format (comes before codespell, after foo)
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-format",
name="ruff-format",
entry="ruff format .",
language=schema.Language("system"),
)
],
)
)

# Assert: ruff-format should be between foo and codespell
assert get_hook_ids() == ["foo", "ruff-format", "codespell"]

def test_insert_with_prerequisite_present(self, tmp_path: Path):
"""Insert a hook after an existing prerequisite."""
with change_cwd(tmp_path), files_manager():
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-check",
name="ruff-check",
entry="ruff check .",
language=schema.Language("system"),
)
],
)
)

add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-format",
name="ruff-format",
entry="ruff format .",
language=schema.Language("system"),
)
],
)
)

assert get_hook_ids() == ["ruff-check", "ruff-format"]

def test_insert_before_postrequisite_only(self, tmp_path: Path):
"""Insert a hook before an existing postrequisite when no predecessor exists."""
with change_cwd(tmp_path), files_manager():
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="codespell",
name="codespell",
entry="codespell .",
language=schema.Language("system"),
)
],
)
)

add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-check",
name="ruff-check",
entry="ruff check .",
language=schema.Language("system"),
)
],
)
)

assert get_hook_ids() == ["ruff-check", "codespell"]
Loading
Loading