Skip to content

Commit b083212

Browse files
committed
add context manager functionality to config.defaults
1 parent 1434540 commit b083212

2 files changed

Lines changed: 39 additions & 0 deletions

File tree

control/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ def _check_deprecation(self, key):
7373
else:
7474
return key
7575

76+
#
77+
# Context manager functionality
78+
#
79+
80+
def __call__(self, mapping):
81+
self.saved_mapping = dict()
82+
self.temp_mapping = mapping.copy()
83+
return self
84+
85+
def __enter__(self):
86+
for key, val in self.temp_mapping.items():
87+
if not key in self:
88+
raise ValueError(f"unknown parameter '{key}'")
89+
self.saved_mapping[key] = self[key]
90+
self[key] = val
91+
return self
92+
93+
def __exit__(self, exc_type, exc_val, exc_tb):
94+
for key, val in self.saved_mapping.items():
95+
self[key] = val
96+
del self.saved_mapping, self.temp_mapping
97+
return None
7698

7799
defaults = DefaultDict(_control_defaults)
78100

control/tests/config_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,20 @@ def test_legacy_repr_format(self):
332332
new = eval(repr(sys))
333333
for attr in ['A', 'B', 'C', 'D']:
334334
assert getattr(sys, attr) == getattr(sys, attr)
335+
336+
337+
def test_config_context_manager():
338+
# Make sure we can temporarily set the value of a parameter
339+
default_val = ct.config.defaults['statesp.latex_repr_type']
340+
with ct.config.defaults({'statesp.latex_repr_type': 'new value'}):
341+
assert ct.config.defaults['statesp.latex_repr_type'] != default_val
342+
assert ct.config.defaults['statesp.latex_repr_type'] == 'new value'
343+
assert ct.config.defaults['statesp.latex_repr_type'] == default_val
344+
345+
# OK to call the context manager and not do anything with it
346+
ct.config.defaults({'statesp.latex_repr_type': 'new value'})
347+
assert ct.config.defaults['statesp.latex_repr_type'] == default_val
348+
349+
with pytest.raises(ValueError, match="unknown parameter 'unknown'"):
350+
with ct.config.defaults({'unknown': 'new value'}):
351+
pass

0 commit comments

Comments
 (0)