Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/usethis/_integrations/pyproject/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

class PyProjectConfig(BaseModel):
id_keys: list[str]
main_contents: dict[str, Any]
value: Any
2 changes: 2 additions & 0 deletions src/usethis/_integrations/pyproject/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def set_config_value(
pyproject = read_pyproject_toml()

try:
# Index our way into each ID key.
# Eventually, we should land at a final dict, which si the one we are setting.
p, parent = pyproject, {}
for key in id_keys:
TypeAdapter(dict).validate_python(p)
Expand Down
12 changes: 6 additions & 6 deletions src/usethis/_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def add_pyproject_configs(self) -> None:
first_addition = True
for config in configs:
try:
set_config_value(config.id_keys, config.main_contents)
set_config_value(config.id_keys, config.value)
except PyProjectTOMLValueAlreadySetError:
pass
else:
Expand Down Expand Up @@ -223,7 +223,7 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
return [
PyProjectConfig(
id_keys=["tool", "pyproject-fmt"],
main_contents={"keep_full_version": True},
value={"keep_full_version": True},
)
]

Expand All @@ -244,7 +244,7 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
return [
PyProjectConfig(
id_keys=["tool", "pytest"],
main_contents={
value={
"ini_options": {
"testpaths": ["tests"],
"addopts": [
Expand All @@ -256,14 +256,14 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
),
PyProjectConfig(
id_keys=["tool", "coverage", "run"],
main_contents={
value={
"source": ["src"],
"omit": ["*/pytest-of-*/*"],
},
),
PyProjectConfig(
id_keys=["tool", "coverage", "report"],
main_contents={
value={
"exclude_also": [
"if TYPE_CHECKING:",
"raise AssertionError",
Expand Down Expand Up @@ -369,7 +369,7 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
return [
PyProjectConfig(
id_keys=["tool", "ruff"],
main_contents={
value={
"src": ["src"],
"line-length": 88,
"lint": {"select": []},
Expand Down
108 changes: 102 additions & 6 deletions tests/usethis/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
]

def get_pyproject_configs(self) -> list[PyProjectConfig]:
return [
PyProjectConfig(id_keys=["tool", self.name], main_contents={"key": "value"})
]
return [PyProjectConfig(id_keys=["tool", self.name], value={"key": "value"})]

def get_associated_ruff_rules(self) -> list[str]:
return ["MYRULE"]
Expand Down Expand Up @@ -117,9 +115,7 @@ def test_default(self):
def test_specific(self):
tool = MyTool()
assert tool.get_pyproject_configs() == [
PyProjectConfig(
id_keys=["tool", "my_tool"], main_contents={"key": "value"}
)
PyProjectConfig(id_keys=["tool", "my_tool"], value={"key": "value"})
]

class TestGetAssociatedRuffRules:
Expand Down Expand Up @@ -545,3 +541,103 @@ def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
# Assert
assert (tmp_path / ".pre-commit-config.yaml").exists()
assert get_hook_names() == [_PLACEHOLDER_ID]

class TestAddPyprojectConfigs:
def test_no_config(self, tmp_path: Path):
# Arrange
class NoConfigTool(Tool):
@property
def name(self) -> str:
return "no_config_tool"

def get_pyproject_configs(self) -> list[PyProjectConfig]:
return []

nc_tool = NoConfigTool()

# Act
with change_cwd(tmp_path):
nc_tool.add_pyproject_configs()

# Assert
assert not (tmp_path / "pyproject.toml").exists()

def test_empty(self, tmp_path: Path, capfd: pytest.CaptureFixture[str]):
# Arrange
class ThisTool(Tool):
@property
def name(self) -> str:
return "mytool"

def get_pyproject_configs(self) -> list[PyProjectConfig]:
return [
PyProjectConfig(
id_keys=["tool", "mytool"],
value={"key": "value"},
),
]

(tmp_path / "pyproject.toml").write_text("")

# Act
with change_cwd(tmp_path):
ThisTool().add_pyproject_configs()

# Assert
assert (
(tmp_path / "pyproject.toml").read_text()
== """\
[tool.mytool]
key = "value"
"""
)
out, err = capfd.readouterr()
assert not err
assert out == "✔ Adding mytool config to 'pyproject.toml'.\n"

def test_differing_sections(
self, tmp_path: Path, capfd: pytest.CaptureFixture[str]
):
# https://github.com/nathanjmcdougall/usethis-python/issues/184

# Arrange
class ThisTool(Tool):
@property
def name(self) -> str:
return "mytool"

def get_pyproject_configs(self) -> list[PyProjectConfig]:
return [
PyProjectConfig(
id_keys=["tool", "mytool", "name"],
value="Modular Design",
),
PyProjectConfig(
id_keys=["tool", "mytool", "root_packages"],
value=["example"],
),
]

(tmp_path / "pyproject.toml").write_text(
"""\
[tool.mytool]
name = "Modular Design"
"""
)

# Act
with change_cwd(tmp_path):
ThisTool().add_pyproject_configs()

# Assert
assert (
(tmp_path / "pyproject.toml").read_text()
== """\
[tool.mytool]
name = "Modular Design"
root_packages = ["example"]
"""
)
out, err = capfd.readouterr()
assert not err
assert out == "✔ Adding mytool config to 'pyproject.toml'.\n"