-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathtest_view_semantics.py
More file actions
283 lines (227 loc) · 8.49 KB
/
test_view_semantics.py
File metadata and controls
283 lines (227 loc) · 8.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Here we build a computational graph out of mygrad view operations, and perform
# corresponding views on numpy arrays.
#
# The views are specifically chosen to be size-preserving and element-preserving operations
# so that we can exploit the property that::
#
# view.backprop(view.data)
#
# from any view node should produce::
#
# tensor.grad == tensor.data
#
# throughout the whole graph.
#
# E.g.::
#
# x1 = Tensor([[0., 1., 2.],
# [3., 4., 5.],
# [6., 7., 8.]])
#
# x2 = x1.T
# = Tensor([[0., 3., 6.],
# [1., 4., 7.],
# [2., 5., 8.]])
#
# x3 = x2[::-1]
# = Tensor([[2., 5., 8.],
# [1., 4., 7.],
# [0., 3., 6.]])
#
# MyGrad designs its views such that the correspondence of a view's data to the base' data
# should indicate the identical relationship between their gradients. Thus setting
#
# x3.backward(x3.data)
#
# forces x3's gradient to match its data; it follows from the correspondence stated above
# that the same should be true for all other views associated with the base (x1)
from typing import Callable, List, NamedTuple, Optional, Tuple, TypeVar
import hypothesis.strategies as st
import numpy as np
import pytest
from hypothesis import settings
from hypothesis.stateful import (
Bundle,
RuleBasedStateMachine,
initialize,
invariant,
precondition,
rule,
)
from numpy import ndarray
from numpy.testing import assert_equal
import mygrad as mg
from mygrad import Tensor
from tests.utils.wrappers import clears_mem_state
class Pair(NamedTuple):
tensor: Tensor
array: ndarray
parent_pair: Optional["Pair"]
T = TypeVar("T", ndarray, Tensor)
def check_base(*, child: T, parent: T):
assert child.base is parent or child.base is parent.base, (
f"child:\n{child}\nchild.base:\n{child.base}\n\n"
f"parent:\n{parent}\nparent.base:\n{parent.base}"
)
def check_pair(pair: Pair):
"""
Checks:
- equality
- view-base relationships
- memory sharing
"""
if pair.parent_pair is not None:
parent = pair.parent_pair
assert np.shares_memory(pair.tensor, parent.tensor)
assert np.shares_memory(pair.array, parent.array)
check_base(child=pair.tensor, parent=parent.tensor)
check_base(child=pair.array, parent=parent.array)
else:
assert pair.tensor.base is None
assert pair.array.base is None
assert_equal(
actual=pair.tensor,
desired=pair.array,
err_msg="MyGrad view produced different result than NumPy view",
)
def einsum(t: T) -> Callable:
return mg.einsum if isinstance(t, Tensor) else np.einsum
def add(*args, **kwargs):
_add = mg.add if any(isinstance(x, mg.Tensor) for x in args) else np.add
return _add(*args, **kwargs)
def diagonal(t: T) -> T:
return einsum(t)("ii->i", t)
view_ops = {
"identity": lambda x: x[...],
"horizontal flip": lambda x: x[:, ::-1],
"vertical flip": lambda x: x[::-1],
"transpose": lambda x: x.T,
"einsum view": lambda x: einsum(x)("... -> ...", x),
"add and remove leading newaxis": lambda x: x[np.newaxis][0],
"add and remove middle newaxis": lambda x: x[:, np.newaxis, :][:, 0, :],
"add and remove trailing newaxis": lambda x: x[..., np.newaxis][..., 0],
}
unary_mutation_ops = {
"x += 2": lambda x: x.__iadd__(2.0),
"x -= 2": lambda x: x.__isub__(2.0),
"x /= 3": lambda x: x.__itruediv__(3.0),
"x *= 3": lambda x: x.__imul__(3.0),
"x += x": lambda x: x.__iadd__(x),
}
binary_mutation_ops = {
"x[...] = y": lambda x, y: x.__setitem__(Ellipsis, y),
"x[0] = y[0]": lambda x, y: x.__setitem__(0, y[0]),
"x += y": lambda x, y: x.__iadd__(y),
"x[...] = (x + y)": lambda x, y: add(x, y, out=x),
"diag(x)[...] = diag(y)": lambda x, y: diagonal(x).__setitem__(
Ellipsis, diagonal(y)
),
}
@settings(deadline=None)
class ViewGraphCompare(RuleBasedStateMachine):
"""
This state machine creates tensor-array pairs - in correspondence
with each other - from view operations. It also manipulates
tensors/arrays with inplace operations.
Everywhere the elements/shape of a tensor should match that of its
corresponding array.
The cases generated by this state machine exercises MyGrad's view
semantics, its inplace operation semantics, and the features that
emerge from their combination. More specifically it assures that:
- Base-view relationships and memory sharing are consistent with NumPy
- Mutations affects a tensor, its base, and the base's views
consistently with NumPy
- The correspondence between a base's data and a view's data dictates
the same correspondence between the gradients of the base and view.
"""
# Stores the tensor from which the base view is created.
# This enables us to at least check that backprop always
# reaches this tensor
static_upstream_tensor: Tensor
def __init__(self):
super().__init__()
# stores the corresponding node/tensor v1, v2, ... as they are
# created via the unit test (through `create_node` or `fuse_nodes`)
# `Node` is the naive implementation of `Tensor` that we are checking
# against
self.pair_list: List[Pair] = []
# Stores the tensor from which we will trigger backprop
self.terminal_tensor: Optional[Tensor] = None
nodes = Bundle("nodes")
def track_pair(self, pair: Pair):
self.pair_list.append(pair)
@initialize(target=nodes, shape=st.sampled_from([(3, 3), (4, 4)]))
def create_base(self, shape: Tuple[int, int]) -> Pair:
"""
Creates an equivalent tensor, array pair.
These are square, 2D arrays from which we will
begin to form views and/or perform mutations.
"""
size = float(np.prod(shape))
t = mg.arange(size).reshape(shape).copy()
arr = np.arange(size).reshape(shape).copy()
self.static_upstream_tensor = t
pair = Pair(tensor=+t, array=arr, parent_pair=None)
assert not np.shares_memory(t, arr)
self.track_pair(pair)
return pair
@rule(target=nodes, parent=nodes, op=st.sampled_from(list(view_ops)))
def create_view(self, parent: Pair, op: str) -> Pair:
fn = view_ops[op]
view_pair = Pair(
tensor=fn(parent.tensor), array=fn(parent.array), parent_pair=parent
)
self.track_pair(view_pair)
return view_pair
@rule(pair=nodes, op=st.sampled_from(list(unary_mutation_ops)))
def unary_mutate_pair(self, pair: Pair, op: str):
fn = unary_mutation_ops[op]
fn(pair.tensor)
fn(pair.array)
@rule(pair1=nodes, pair2=nodes, op=st.sampled_from(list(binary_mutation_ops)))
def binary_mutate_pair(self, pair1: Pair, pair2: Pair, op: str):
fn = binary_mutation_ops[op]
fn(pair1.tensor, pair2.tensor)
fn(pair1.array, pair2.array)
@precondition(lambda self: self.terminal_tensor is None)
@rule(pair=nodes)
def pick_terminal_tensor(self, pair: Pair):
self.terminal_tensor = pair.tensor
@invariant()
def check_all_nodes(self):
for pair in self.pair_list:
check_pair(pair)
@clears_mem_state
def teardown(self):
if self.terminal_tensor is None:
return
t = self.terminal_tensor
# see comment at top of script for explanation
# of why we set `t.grad = t.data`
t.backward(t.data)
for tensor in (p.tensor for p in self.pair_list):
assert_equal(actual=tensor.grad, desired=tensor.data)
if tensor.base is not None:
assert tensor.grad.base is tensor.base.grad
# any backprop had to involve the static upstream
# tensor from which the original base-view was created
assert self.static_upstream_tensor.grad is not None
# make sure backprop didn't break any relationships or
# change any value
self.check_all_nodes()
# touching any of these tensors should null its gradients
for tensor in (p.tensor for p in self.pair_list):
_ = +tensor
assert tensor.grad is None
TestGraphComparison = ViewGraphCompare.TestCase
# We also want to check that all of this logic holds up in
# no-autodiff mode.
class Tmp(ViewGraphCompare):
def pick_terminal_tensor(self, pair: Pair):
pass
@clears_mem_state
def teardown(self):
pass
@pytest.mark.usefixtures("no_autodiff")
class NoAutoDiffView(Tmp.TestCase):
pass