Skip to content

Commit c0bfa45

Browse files
malfetfacebook-github-bot
authored andcommitted
Enable typechecking for torch.futures (#41675)
Summary: Add typing declarations for torch._C.Future and torch._C._collect_all Pull Request resolved: #41675 Reviewed By: izdeby Differential Revision: D22627539 Pulled By: malfet fbshipit-source-id: 29b87685d65dd24ee2094bae8a84a0fe3787e7f8
1 parent 750d9de commit c0bfa45

File tree

4 files changed

+53
-60
lines changed

4 files changed

+53
-60
lines changed

mypy.ini

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ files =
1919
caffe2,
2020
aten/src/ATen/function_wrapper.py,
2121
test/test_complex.py,
22+
test/test_futures.py,
2223
test/test_torch.py,
2324
test/test_type_hints.py,
2425
test/test_type_info.py
@@ -53,9 +54,6 @@ ignore_errors = True
5354
[mypy-torch.functional.*]
5455
ignore_errors = True
5556

56-
[mypy-torch.futures.*]
57-
ignore_errors = True
58-
5957
[mypy-torch.testing._internal.*]
6058
ignore_errors = True
6159

test/test_futures.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,28 @@ def add_one(fut):
1111

1212

1313
class TestFuture(TestCase):
14-
def test_wait(self):
15-
f = Future()
14+
def test_wait(self) -> None:
15+
f = Future[torch.Tensor]()
1616
f.set_result(torch.ones(2, 2))
1717

1818
self.assertEqual(f.wait(), torch.ones(2, 2))
1919

20-
def test_wait_multi_thread(self):
20+
def test_wait_multi_thread(self) -> None:
2121

2222
def slow_set_future(fut, value):
2323
time.sleep(0.5)
2424
fut.set_result(value)
2525

26-
f = Future()
26+
f = Future[torch.Tensor]()
2727

2828
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
2929
t.start()
3030

3131
self.assertEqual(f.wait(), torch.ones(2, 2))
3232
t.join()
3333

34-
def test_mark_future_twice(self):
35-
fut = Future()
34+
def test_mark_future_twice(self) -> None:
35+
fut = Future[int]()
3636
fut.set_result(1)
3737
with self.assertRaisesRegex(
3838
RuntimeError,
@@ -41,22 +41,22 @@ def test_mark_future_twice(self):
4141
fut.set_result(1)
4242

4343
def test_pickle_future(self):
44-
fut = Future()
44+
fut = Future[int]()
4545
errMsg = "Can not pickle torch.futures.Future"
4646
with TemporaryFileName() as fname:
4747
with self.assertRaisesRegex(RuntimeError, errMsg):
4848
torch.save(fut, fname)
4949

5050
def test_then(self):
51-
fut = Future()
51+
fut = Future[torch.Tensor]()
5252
then_fut = fut.then(lambda x: x.wait() + 1)
5353

5454
fut.set_result(torch.ones(2, 2))
5555
self.assertEqual(fut.wait(), torch.ones(2, 2))
5656
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
5757

5858
def test_chained_then(self):
59-
fut = Future()
59+
fut = Future[torch.Tensor]()
6060
futs = []
6161
last_fut = fut
6262
for _ in range(20):
@@ -69,7 +69,7 @@ def test_chained_then(self):
6969
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
7070

7171
def _test_error(self, cb, errMsg):
72-
fut = Future()
72+
fut = Future[int]()
7373
then_fut = fut.then(cb)
7474

7575
fut.set_result(5)
@@ -99,8 +99,8 @@ def raise_value_error(fut):
9999
self._test_error(raise_value_error, "Expected error")
100100

101101
def test_collect_all(self):
102-
fut1 = Future()
103-
fut2 = Future()
102+
fut1 = Future[int]()
103+
fut2 = Future[int]()
104104
fut_all = torch.futures.collect_all([fut1, fut2])
105105

106106
def slow_in_thread(fut, value):
@@ -118,8 +118,8 @@ def slow_in_thread(fut, value):
118118

119119
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
120120
def test_wait_all(self):
121-
fut1 = Future()
122-
fut2 = Future()
121+
fut1 = Future[int]()
122+
fut2 = Future[int]()
123123

124124
# No error version
125125
fut1.set_result(1)

torch/_C/__init__.pyi.in

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55
from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
6-
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
6+
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
77
from torch._six import inf
88

99
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
@@ -118,6 +118,14 @@ class _LegacyVariableBase(object):
118118
) -> None: ...
119119

120120
# Defined in torch/csrc/jit/python/init.cpp
121+
class Future(object):
122+
def __init__(self) -> None: ...
123+
def wait(self) -> Any: ...
124+
def then(self, callback: Callable) -> Future: ...
125+
def set_result(self, result: Any) -> None: ...
126+
127+
128+
def _collect_all(futures: List[Future]) -> Future: ...
121129
def _jit_get_operation(op_name: str) -> Callable: ...
122130
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
123131

torch/futures/__init__.py

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
1-
from typing import Generic, TypeVar
1+
from typing import cast, Callable, Generic, List, Type, TypeVar
22

33
import torch
4+
from torch._six import PY37
45

6+
T = TypeVar("T")
7+
S = TypeVar("S")
8+
9+
if not PY37:
10+
# Workaround for https://github.com/python/typing/issues/449 in Python 3.6
11+
from typing import GenericMeta
12+
13+
class _PyFutureMeta(type(torch._C.Future), GenericMeta): # type: ignore[misc]
14+
pass
15+
else:
16+
class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef]
17+
pass
518

6-
class _PyFuture(torch._C.Future):
7-
def wait(self):
19+
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
20+
r"""
21+
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
22+
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
23+
also exposes a set of APIs to add callback functions and set results.
24+
"""
25+
def wait(self) -> T:
826
r"""
927
Block until the value of this ``Future`` is ready.
1028
@@ -15,7 +33,8 @@ def wait(self):
1533
"""
1634
return super().wait()
1735

18-
def then(self, callback):
36+
# Have to use string annotations because PEP-0563 is not available in 3.6
37+
def then(self, callback): # type: (Callable[[Future[T]], S]) -> Future[S]
1938
r"""
2039
Append the given callback function to this ``Future``, which will be run
2140
when the ``Future`` is completed. Multiple callbacks can be added to
@@ -52,9 +71,9 @@ def then(self, callback):
5271
>>> # RPC return value is 5.
5372
>>> # Chained cb done. None
5473
"""
55-
return super().then(callback)
74+
return cast(Future[S], super().then(callback))
5675

57-
def set_result(self, result):
76+
def set_result(self, result: T) -> None:
5877
r"""
5978
Set the result for this ``Future``, which will mark this ``Future`` as
6079
completed and trigger all attached callbacks. Note that a ``Future``
@@ -85,7 +104,7 @@ def set_result(self, result):
85104
super().set_result(result)
86105

87106

88-
def collect_all(futures):
107+
def collect_all(futures: List[Future]) -> Future[List[Future]]:
89108
r"""
90109
Collects the provided :class:`~torch.futures.Future` objects into a single
91110
combined :class:`~torch.futures.Future` that is completed when all of the
@@ -116,10 +135,10 @@ def collect_all(futures):
116135
>>> # fut0 result = 0
117136
>>> # fut1 result = 1
118137
"""
119-
return torch._C._collect_all(futures)
138+
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
120139

121140

122-
def wait_all(futures):
141+
def wait_all(futures: List[Future]) -> List:
123142
r"""
124143
Waits for all provided futures to be complete, and returns
125144
the list of completed values.
@@ -132,36 +151,4 @@ def wait_all(futures):
132151
method will throw an error if ``wait`` on any
133152
:class:`~torch.futures.Future` throws.
134153
"""
135-
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
136-
137-
138-
T = TypeVar("T")
139-
GenericWithOneTypeVar = Generic[T]
140-
141-
142-
try:
143-
144-
# Combine the implementation class and the type class.
145-
class Future(_PyFuture, GenericWithOneTypeVar):
146-
r"""
147-
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
148-
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
149-
also exposes a set of APIs to add callback functions and set results.
150-
"""
151-
pass
152-
153-
154-
except TypeError as exc:
155-
# TypeError: metaclass conflict: the metaclass of a derived class
156-
# must be a (non-strict) subclass of the metaclasses of all its bases
157-
class FutureMeta(_PyFuture.__class__, GenericWithOneTypeVar.__class__):
158-
pass
159-
160-
# Combine the implementation class and the type class.
161-
class Future(_PyFuture, GenericWithOneTypeVar, metaclass=FutureMeta):
162-
r"""
163-
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
164-
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
165-
also exposes a set of APIs to add callback functions and set results.
166-
"""
167-
pass
154+
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]

0 commit comments

Comments
 (0)