-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcore.py
More file actions
78 lines (62 loc) · 2.72 KB
/
core.py
File metadata and controls
78 lines (62 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python
from inspect import getdoc, getmembers, isfunction
from typing import Any, Callable, Mapping, Sequence, Union
import matplotlib.pyplot as plt
import torch
# Taken from
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/apply_func.py
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
Returns:
the resulting collection
"""
elem_type = type(data)
# Breaking condition
if isinstance(data, dtype):
return function(data, *args, **kwargs)
# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
if isinstance(data, tuple) and hasattr(data, "_fields"): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
# data is neither of dtype, nor a collection
return data
# Function to convert a list of arguments containing torch tensors, into
# a corresponding list of arguments containing numpy arrays
def _torch2np(*args, **kwargs):
"""
Convert a list of arguments containing torch tensors into a list of
arguments containing numpy arrays
"""
def convert(arg):
return arg.detach().cpu().numpy()
# first unnamed arguments
outargs = apply_to_collection(args, torch.Tensor, convert)
# then keyword arguments
outkwargs = apply_to_collection(kwargs, torch.Tensor, convert)
return outargs, outkwargs
# Iterate over all members of 'plt' in order to duplicate them
for name, member in getmembers(plt):
if isfunction(member):
doc = getdoc(member)
strdoc = "" if doc is None else doc
exec(
(
"def {name}(*args, **kwargs):\n"
+ '\t"""{doc}"""\n'
+ "\tnew_args, new_kwargs = _torch2np(*args, **kwargs)\n"
+ "\treturn plt.{name}(*new_args, **new_kwargs)"
).format(name=name, doc=strdoc)
)
else:
exec("{name} = plt.{name}".format(name=name))
# break