Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions torch/backends/opt_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def get_opt_einsum() -> Any:

def _set_enabled(_enabled: bool) -> None:
if not is_available() and _enabled:
warnings.warn('opt_einsum is not available, so setting `enabled` to True will not reap '
'the benefits of calculating an optimal path for einsum. torch.einsum will '
'fall back to contracting from left to right. To enable this optimal path '
'calculation, please install opt-einsum.')
raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap '
'the benefits of calculating an optimal path for einsum. torch.einsum will '
'fall back to contracting from left to right. To enable this optimal path '
'calculation, please install opt-einsum.')
global enabled
enabled = _enabled

Expand All @@ -38,12 +38,13 @@ def _get_enabled() -> bool:

def _set_strategy(_strategy: str) -> None:
if not is_available():
raise ValueError('opt_einsum is not available, so `strategy` cannot be set. Please install opt-einsum or '
'unset `strategy`.')
raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please install opt_einsum or unset `strategy`.')
if not enabled:
warnings.warn('opt_einsum is not enabled, so setting a `strategy` will not make a meaningful change. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please set `enabled` to `True` as well or unset `strategy`.')
raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please set `enabled` to `True` as well or unset `strategy`.')
if _strategy not in ['auto', 'greedy', 'optimal']:
raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}')
global strategy
Expand All @@ -64,7 +65,7 @@ def set_flags(_enabled=None, _strategy=None):


@contextmanager
def flags(enabled=False, strategy='auto'):
def flags(enabled=None, strategy=None):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, strategy)
try:
Expand Down Expand Up @@ -94,5 +95,5 @@ def __init__(self, m, name):
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)

enabled = True
strategy = 'auto'
enabled = True if is_available() else False
strategy = 'auto' if is_available() else None