Skip to content

Commit 7747fb6

Browse files
committed
-
1 parent c63e2d5 commit 7747fb6

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

source_py3/python_toolbox/nifty_collections/bagging.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,27 @@ def __repr__(self):
517517
def __reversed__(self):
518518
# Gets overridden in `_OrderedBagMixin`.
519519
raise TypeError("Can't reverse an unordered bag.")
520+
521+
522+
def get_contained_bags(self):
523+
'''
524+
Get all counters that are subsets of this bags.
525+
526+
This means all counters that have counts identical or smaller for each
527+
key.
528+
'''
529+
from python_toolbox import combi
530+
531+
keys, amounts = zip(*((key, amount) for key, amount in self.items()))
532+
533+
return combi.MapSpace(
534+
lambda amounts_tuple:
535+
type(self)(self._dict_type(zip(keys, amounts_tuple))),
536+
combi.ProductSpace(map(lambda amount: range(amount+1), amounts))
537+
)
538+
539+
540+
520541

521542

522543

@@ -791,7 +812,22 @@ def frozen_bag_bag(self):
791812
def get_mutable(self):
792813
'''Get a mutable version of this bag.'''
793814
return self._mutable_type(self)
794-
815+
816+
# Poor man's caching done here because we can't import
817+
# `python_toolbox.caching` due to import loop:
818+
_contained_bags = None
819+
def get_contained_bags(self):
820+
'''
821+
Get all counters that are subsets of this bags.
822+
823+
This means all counters that have counts identical or smaller for each
824+
key.
825+
'''
826+
if self._contained_bags is None:
827+
self._contained_bags = super().get_contained_bags()
828+
return self._contained_bags
829+
830+
795831

796832
class _BaseDictDelegator(collections.MutableMapping):
797833
'''

source_py3/test_python_toolbox/test_nifty_collections/test_bagging.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,31 @@ def test_operations(self):
365365

366366

367367

368+
def test_get_contained_bags(self):
369+
bag = self.bag_type('abracadabra')
370+
contained_bags = bag.get_contained_bags()
371+
assert len(contained_bags) == 2 ** len('abracadabra')
372+
had_full_one = False
373+
for contained_bag in contained_bags:
374+
assert contained_bag <= bag
375+
if contained_bag == bag:
376+
assert had_full_one is False
377+
had_full_one = True
378+
else:
379+
assert contained_bag < bag
380+
if isinstance(bag, nifty_collections.Ordered):
381+
assert cute_iter_tools.is_sorted(
382+
tuple(contained_bag.items()),
383+
key=tuple(bag.items()).index
384+
)
385+
386+
contained_bags_tuple = tuple(contained_bags)
387+
assert self.bag_type('arcaba') in contained_bags_tuple
388+
assert self.bag_type('db') in contained_bags_tuple
389+
assert self.bag_type() in contained_bags_tuple
390+
assert self.bag_type('x') not in contained_bags_tuple
391+
392+
368393

369394
class BaseMutableBagTestCase(BaseBagTestCase):
370395

0 commit comments

Comments
 (0)