Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions torch/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ def _get_cache_or_reload(github, force_reload):
url = _git_archive_link(repo_owner, repo_name, branch)
_download_archive_zip(url, cached_file)

cached_zipfile = zipfile.ZipFile(cached_file)
extraced_repo_name = cached_zipfile.infolist()[0].filename
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
_remove_if_exists(extracted_repo)
# Unzip the code and rename the base folder
cached_zipfile.extractall(hub_dir)
with zipfile.ZipFile(cached_file) as cached_zipfile:
extraced_repo_name = cached_zipfile.infolist()[0].filename
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
_remove_if_exists(extracted_repo)
# Unzip the code and rename the base folder
cached_zipfile.extractall(hub_dir)

_remove_if_exists(cached_file)
_remove_if_exists(repo_dir)
Expand All @@ -182,14 +182,35 @@ def _check_module_exists(name):
import importlib.find_loader
return importlib.find_loader(name) is not None
else:
# NB: imp doesn't handle hierarchical module names (names contains dots).
# NB: Python2.7 imp.find_module() doesn't respect PEP 302,
# it cannot find a package installed as .egg(zip) file.
# Here we use workaround from:
# https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1
# Also imp doesn't handle hierarchical module names (names contains dots).
try:
# 1. Try imp.find_module(), which searches sys.path, but does
# not respect PEP 302 import hooks.
import imp
imp.find_module(name)
except Exception:
return False
return True

result = imp.find_module(name)
if result:
return True
except ImportError:
pass
path = sys.path
for item in path:
# 2. Scan path for import hooks. sys.path_importer_cache maps
# path items to optional "importer" objects, that implement
# find_module() etc. Note that path must be a subset of
# sys.path for this to work.
importer = sys.path_importer_cache.get(item)
if importer:
try:
result = importer.find_module(name, [item])
if result:
return True
except ImportError:
pass
return False

def _check_dependencies(m):
dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
Expand Down