@@ -24,10 +24,10 @@ def get_opt_einsum() -> Any:
2424
2525def _set_enabled (_enabled : bool ) -> None :
2626 if not is_available () and _enabled :
27- warnings . warn ( 'opt_einsum is not available, so setting `enabled` to True will not reap '
28- 'the benefits of calculating an optimal path for einsum. torch.einsum will '
29- 'fall back to contracting from left to right. To enable this optimal path '
30- 'calculation, please install opt-einsum.' )
27+ raise ValueError ( f 'opt_einsum is not available, so setting `enabled` to { _enabled } will not reap '
28+ 'the benefits of calculating an optimal path for einsum. torch.einsum will '
29+ 'fall back to contracting from left to right. To enable this optimal path '
30+ 'calculation, please install opt-einsum.' )
3131 global enabled
3232 enabled = _enabled
3333
@@ -38,12 +38,13 @@ def _get_enabled() -> bool:
3838
3939def _set_strategy (_strategy : str ) -> None :
4040 if not is_available ():
41- raise ValueError ('opt_einsum is not available, so `strategy` cannot be set. Please install opt-einsum or '
42- 'unset `strategy`.' )
41+ raise ValueError (f'opt_einsum is not available, so setting `strategy` to { _strategy } will not be meaningful. '
42+ 'torch.einsum will bypass path calculation and simply contract from left to right. '
43+ 'Please install opt_einsum or unset `strategy`.' )
4344 if not enabled :
44- warnings . warn ( 'opt_einsum is not enabled, so setting a `strategy` will not make a meaningful change . '
45- 'torch.einsum will bypass path calculation and simply contract from left to right. '
46- 'Please set `enabled` to `True` as well or unset `strategy`.' )
45+ raise ValueError ( f 'opt_einsum is not enabled, so setting a `strategy` to { _strategy } will not be meaningful. '
46+ 'torch.einsum will bypass path calculation and simply contract from left to right. '
47+ 'Please set `enabled` to `True` as well or unset `strategy`.' )
4748 if _strategy not in ['auto' , 'greedy' , 'optimal' ]:
4849 raise ValueError (f'`strategy` must be one of the following: [auto, greedy, optimal] but is { _strategy } ' )
4950 global strategy
@@ -64,7 +65,7 @@ def set_flags(_enabled=None, _strategy=None):
6465
6566
6667@contextmanager
67- def flags (enabled = False , strategy = 'auto' ):
68+ def flags (enabled = None , strategy = None ):
6869 with __allow_nonbracketed_mutation ():
6970 orig_flags = set_flags (enabled , strategy )
7071 try :
@@ -94,5 +95,5 @@ def __init__(self, m, name):
9495# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
9596sys .modules [__name__ ] = OptEinsumModule (sys .modules [__name__ ], __name__ )
9697
97- enabled = True
98- strategy = 'auto'
98+ enabled = True if is_available () else False
99+ strategy = 'auto' if is_available () else None
0 commit comments