Skip to content

Commit 170fec2

Browse files
committed
-
1 parent db7742b commit 170fec2

File tree

2 files changed

+136
-47
lines changed

2 files changed

+136
-47
lines changed

source_py3/python_toolbox/logic_tools.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,59 @@
1010
from python_toolbox import cute_iter_tools
1111

1212

13-
def all_equivalent(iterable, relation=operator.eq, *, exhaustive=False):
13+
def all_equivalent(iterable, relation=operator.eq, *, assume_reflexive=True,
14+
assume_symmetric=True, assume_transitive=True):
1415
'''
1516
Return whether all elements in the iterable are equivalent to each other.
1617
17-
By default "equivalent" means they're equal to each other in Python. You
18-
can set a different relation to the `relation` argument, as a function that
19-
accepts two arguments and returns whether they're equivalent or not. You
20-
can use this, for example, to test if all items are NOT equal by passing in
21-
`relation=operator.ne`.
22-
23-
If `exhaustive` is set to `False`, it's assumed that the equality relation
24-
is transitive, therefore not every member is tested against every other
25-
member. So in a list of size `n`, `n-1` equality checks will be made.
18+
By default "equivalent" means they're all equal to each other in Python.
19+
You can set a different relation to the `relation` argument, as a function
20+
that accepts two arguments and returns whether they're equivalent or not.
21+
You can use this, for example, to test if all items are NOT equal by
22+
passing in `relation=operator.ne`. You can also define any custom relation
23+
you want: `relation=(lambda x, y: x % 7 == y % 7)`.
24+
25+
By default, we assume that the relation we're using is an equivalence
26+
relation (see http://en.wikipedia.org/wiki/Equivalence_relation for
27+
definition.) This means that we assume the relation is reflexive, symmetric
28+
and transitive, so we can do less checks on the elements to save time. You
29+
can use `assume_reflexive=False`, `assume_symmetric=False` and
30+
`assume_transitive=False` to break any of these assumptions and make this
31+
function do more checks that the equivalence holds between any pair of
32+
items from the iterable.
33+
34+
If `exhaustive` is set to `False`, it's assumed that the
35+
equality relation is transitive, therefore not every member is tested
36+
against every other member. So in a list of size `n`, `n-1` equality checks
37+
will be made.
2638
2739
If `exhaustive` is set to `True`, every member will be checked against
2840
every other member. So in a list of size `n`, `(n*(n-1))/2` equality checks
2941
will be made.
3042
'''
31-
# todo: Maybe I should simply check if `len(set(iterable)) == 1`? Will not
32-
# work for unhashables.
43+
from python_toolbox import sequence_tools
3344

34-
if exhaustive is True:
35-
items = tuple(iterable)
36-
if len(items) <= 1:
37-
return True
45+
if not assume_transitive or not assume_reflexive:
46+
iterable = sequence_tools.ensure_iterable_is_sequence(iterable)
47+
48+
if assume_transitive:
49+
pairs = cute_iter_tools.iterate_overlapping_subsequences(iterable)
50+
else:
3851
from python_toolbox import combi
3952
pairs = tuple(
40-
items * comb for comb in combi.CombSpace(len(items), 2)
53+
iterable * comb for comb in combi.CombSpace(len(iterable), 2)
4154
)
4255
# Can't feed the items directly to `CombSpace` because they might not
4356
# be hashable.
44-
else: # exhaustive is False
45-
pairs = cute_iter_tools.iterate_overlapping_subsequences(iterable)
57+
58+
if not assume_symmetric:
59+
pairs = itertools.chain(
60+
*itertools.starmap(lambda x, y: ((x, y), (y, x)), pairs)
61+
)
62+
63+
if not assume_reflexive:
64+
pairs = itertools.chain(pairs,
65+
zip(iterable, iterable))
4666

4767
return all(itertools.starmap(relation, pairs))
4868

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Copyright 2009-2015 Ram Rachum.
22
# This program is distributed under the MIT license.
33

4-
'''Testing module for `logic_tools.all_equal`.'''
5-
64
import operator
75
import itertools
86

@@ -15,31 +13,41 @@ def test():
1513
_check(True)
1614

1715

18-
def _check(exhaustive):
19-
'''Check the basic working of `all_equal` with given `exhaustive` flag.'''
20-
assert all_equivalent([1, 1, 1, 1], exhaustive=exhaustive)
21-
assert all_equivalent([1, 1, 1.0, 1], exhaustive=exhaustive)
22-
assert all_equivalent(((1 + 0j), 1, 1.0, 1), exhaustive=exhaustive)
23-
assert all_equivalent([], exhaustive=exhaustive)
24-
assert all_equivalent(iter([1, 1, 1.0, 1]), exhaustive=exhaustive)
25-
assert all_equivalent({'meow'}, exhaustive=exhaustive)
26-
assert all_equivalent(['frr', 'frr', 'frr', 'frr'], exhaustive=exhaustive)
16+
def _check(assume_transitive):
17+
'''Check the basic working of `all_equal` with given `assume_transitive` flag.'''
18+
assert all_equivalent([1, 1, 1, 1], assume_transitive=assume_transitive)
19+
assert all_equivalent([1, 1, 1.0, 1], assume_transitive=assume_transitive)
20+
assert all_equivalent(((1 + 0j), 1, 1.0, 1),
21+
assume_transitive=assume_transitive)
22+
assert all_equivalent([], assume_transitive=assume_transitive)
23+
assert all_equivalent(iter([1, 1, 1.0, 1]),
24+
assume_transitive=assume_transitive)
25+
assert all_equivalent({'meow'}, assume_transitive=assume_transitive)
26+
assert all_equivalent(['frr', 'frr', 'frr', 'frr'],
27+
assume_transitive=assume_transitive)
2728

28-
assert not all_equivalent([1, 1, 2, 1], exhaustive=exhaustive)
29-
assert not all_equivalent([1, 1, 1.001, 1], exhaustive=exhaustive)
30-
assert not all_equivalent(((1 + 0j), 3, 1.0, 1), exhaustive=exhaustive)
31-
assert not all_equivalent(range(7), exhaustive=exhaustive)
32-
assert not all_equivalent(iter([1, 17, 1.0, 1]), exhaustive=exhaustive)
33-
assert not all_equivalent({'meow', 'grr'}, exhaustive=exhaustive)
29+
assert not all_equivalent([1, 1, 2, 1],
30+
assume_transitive=assume_transitive)
31+
assert not all_equivalent([1, 1, 1.001, 1],
32+
assume_transitive=assume_transitive)
33+
assert not all_equivalent(((1 + 0j), 3, 1.0, 1),
34+
assume_transitive=assume_transitive)
35+
assert not all_equivalent(range(7), assume_transitive=assume_transitive)
36+
assert not all_equivalent(iter([1, 17, 1.0, 1]),
37+
assume_transitive=assume_transitive)
38+
assert not all_equivalent({'meow', 'grr'},
39+
assume_transitive=assume_transitive)
3440
assert not all_equivalent(['frr', 'frr', {}, 'frr', 'frr'],
35-
exhaustive=exhaustive)
36-
assert not all_equivalent(itertools.count()) # Not using given `exhaustive`
37-
# flag here because `count()` is
38-
# infinite.
41+
assume_transitive=assume_transitive)
42+
assert not all_equivalent(itertools.count())
43+
# Not using given `assume_transitive` flag here because `count()` is
44+
# infinite.
3945

4046

41-
def test_exhaustive_true():
42-
'''Test `all_equal` in cases where `exhaustive=True` is relevant.'''
47+
def test_assume_transitive_false():
48+
'''
49+
Test `all_equivalent` in cases where `assume_transitive=False` is relevant.
50+
'''
4351

4452
class FunkyFloat(float):
4553
def __eq__(self, other):
@@ -53,7 +61,68 @@ def __eq__(self, other):
5361
]
5462

5563
assert all_equivalent(funky_floats)
56-
assert not all_equivalent(funky_floats, exhaustive=True)
64+
assert not all_equivalent(funky_floats, assume_transitive=False)
65+
66+
67+
def test_all_assumptions():
68+
class EquivalenceChecker:
69+
pairs_checked = []
70+
def __init__(self, tag):
71+
self.tag = tag
72+
def is_equivalent(self, other):
73+
EquivalenceChecker.pairs_checked.append((self, other))
74+
return True
75+
def __eq__(self, other):
76+
return (type(self), self.tag) == (type(other), other.tag)
77+
78+
def get_pairs_for_options(**kwargs):
79+
assert EquivalenceChecker.pairs_checked == []
80+
# Testing with an iterator instead of the tuple to ensure it works and that
81+
# the function doesn't try to exhaust it twice.
82+
assert all_equivalent(iter(things), EquivalenceChecker.is_equivalent,
83+
**kwargs) is True
84+
try:
85+
return tuple((a.tag, b.tag) for (a, b) in
86+
EquivalenceChecker.pairs_checked)
87+
finally:
88+
EquivalenceChecker.pairs_checked = []
89+
90+
x0 = EquivalenceChecker(0)
91+
x1 = EquivalenceChecker(1)
92+
x2 = EquivalenceChecker(2)
93+
things = (x0, x1, x2)
94+
95+
assert get_pairs_for_options(assume_reflexive=False, assume_symmetric=False,
96+
assume_transitive=False) == (
97+
(0, 1), (1, 0), (0, 2), (2, 0), (1, 2), (2, 1), (0, 0), (1, 1), (2, 2)
98+
)
99+
assert get_pairs_for_options(assume_reflexive=False, assume_symmetric=False,
100+
assume_transitive=True) == (
101+
(0, 1), (1, 0), (1, 2), (2, 1), (0, 0), (1, 1), (2, 2)
102+
)
103+
assert get_pairs_for_options(assume_reflexive=False, assume_symmetric=True,
104+
assume_transitive=False) == (
105+
(0, 1), (0, 2), (1, 2), (0, 0), (1, 1), (2, 2)
106+
)
107+
assert get_pairs_for_options(assume_reflexive=False, assume_symmetric=True,
108+
assume_transitive=True) == (
109+
(0, 1), (1, 2), (0, 0), (1, 1), (2, 2)
110+
)
111+
assert get_pairs_for_options(assume_reflexive=True, assume_symmetric=False,
112+
assume_transitive=False) == (
113+
(0, 1), (1, 0), (0, 2), (2, 0), (1, 2), (2, 1),
114+
)
115+
assert get_pairs_for_options(assume_reflexive=True, assume_symmetric=False,
116+
assume_transitive=True) == (
117+
(0, 1), (1, 0), (1, 2), (2, 1),
118+
)
119+
assert get_pairs_for_options(assume_reflexive=True, assume_symmetric=True,
120+
assume_transitive=False) == (
121+
(0, 1), (0, 2), (1, 2),
122+
)
123+
assert get_pairs_for_options(assume_reflexive=True, assume_symmetric=True,
124+
assume_transitive=True) == ((0, 1), (1, 2))
125+
57126

58127

59128

@@ -62,17 +131,17 @@ def test_custom_relations():
62131
assert all_equivalent(range(4), relation=operator.ge) is False
63132
assert all_equivalent(range(4), relation=operator.le) is True
64133
assert all_equivalent(range(4), relation=operator.le,
65-
exhaustive=True) is True
66-
# (Always comparing small to big, even on exhaustive.)
134+
assume_transitive=True) is True
135+
# (Always comparing small to big, even on `assume_transitive=False`.)
67136

68137
assert all_equivalent(range(4),
69138
relation=lambda x, y: (x // 10 == y // 10)) is True
70139
assert all_equivalent(range(4),
71140
relation=lambda x, y: (x // 10 == y // 10),
72-
exhaustive=True) is True
141+
assume_transitive=True) is True
73142
assert all_equivalent(range(8, 12),
74143
relation=lambda x, y: (x // 10 == y // 10)) is False
75144
assert all_equivalent(range(8, 12),
76145
relation=lambda x, y: (x // 10 == y // 10),
77-
exhaustive=True) is False
146+
assume_transitive=True) is False
78147

0 commit comments

Comments
 (0)