-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathtest_indexing_routines.py
More file actions
82 lines (66 loc) · 2.11 KB
/
test_indexing_routines.py
File metadata and controls
82 lines (66 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import hypothesis.extra.numpy as hnp
import hypothesis.strategies as st
import numpy as np
import pytest
from hypothesis import assume, given
import mygrad as mg
from mygrad import where
from tests.wrappers.uber import backprop_test_factory, fwdprop_test_factory
def mygrad_where(x, y, condition, constant=None):
return where(condition, x, y, constant=constant)
def numpy_where(x, y, condition, constant=None):
return np.where(condition, x, y)
def condition_strat(*arrs):
shape = np.broadcast(*arrs).shape
return hnp.arrays(shape=hnp.broadcastable_shapes(shape=shape), dtype=bool)
@fwdprop_test_factory(
mygrad_func=mygrad_where,
true_func=numpy_where,
kwargs=dict(condition=condition_strat),
num_arrays=2,
)
def test_where_fwd():
pass
@backprop_test_factory(
mygrad_func=numpy_where, # exercises __array_function__ override
true_func=numpy_where,
kwargs=dict(condition=condition_strat),
num_arrays=2,
)
def test_where_bkwd():
pass
@given(condition=st.from_type(type) | hnp.arrays(shape=hnp.array_shapes(), dtype=int))
@pytest.mark.filterwarnings("ignore: Calling nonzero on 0d arrays is deprecated")
def test_where_condition_only_fwd(condition):
"""mygrad.where should merely mirror numpy.where when only `where(condition)`
is specified."""
tensor_condition = (
mg.Tensor(condition) if isinstance(condition, np.ndarray) else condition
)
try:
c = np.where(condition)
except ValueError:
assume(False)
tc = np.where(tensor_condition)
assert all(np.all(x == y) for x, y in zip(tc, c))
@given(
condition=hnp.arrays(shape=hnp.array_shapes(min_dims=1), dtype=bool),
x=st.none()
| hnp.arrays(
shape=hnp.array_shapes(min_dims=1),
dtype=int,
),
y=st.none()
| hnp.arrays(
shape=hnp.array_shapes(min_dims=1),
dtype=int,
),
)
def test_where_input_validation(condition, x, y):
args = [i for i in (x, y) if i is not None]
try:
np.where(condition, *args)
except Exception as e:
with pytest.raises(type(e)):
where(condition, *args)
return