@@ -48,34 +48,6 @@ def nested(*managers):
4848 raise exc [1 ].with_traceback (exc [2 ])
4949
5050
51- def _wrap_context_manager_or_class (thing , wrapper_factory ):
52- if hasattr (type (thing ), '__enter__' ):
53- # It's a context manager.
54- return wrapper_factory (thing )
55- else :
56- assert issubclass (thing , ContextManager )
57- # It's a context manager class.
58- property_name = '__%s_context_manager_%s' % (
59- thing .__name__ ,
60- ' ' .join (random .choice (string .ascii_letters ) for _ in range (20 ))
61- )
62- return type (
63- thing .__name__ ,
64- (thing ,),
65- {
66- property_name : caching .CachedProperty (wrapper_factory ),
67- '__enter__' :
68- lambda self : getattr (self , property_name ).__enter__ (),
69- '__exit__' : lambda self , exc_type , exc_value , exc_traceback :
70- getattr (self , property_name ).
71- __exit__ (exc_type , exc_value , exc_traceback ),
72-
73- }
74- )
75-
76-
77-
78-
7951def as_idempotent (context_manager ):
8052 '''
8153 Wrap a context manager so repeated calls to enter and exit will be ignored.
@@ -95,9 +67,11 @@ def as_idempotent(context_manager):
9567
9668 Note: The first value returned by `__enter__` will be returned by all the
9769 subsequent no-op `__enter__` calls.
70+
71+ blocktodo: add to docs about different ways of calling, ensure tests
9872 '''
99- return _wrap_context_manager_or_class (
100- context_manager , _IdempotentContextManager
73+ return _IdempotentContextManager . _wrap_context_manager_or_class (
74+ context_manager ,
10175 )
10276
10377
@@ -111,19 +85,66 @@ def as_reentrant(context_manager):
11185
11286 Note: The first value returned by `__enter__` will be returned by all the
11387 subsequent no-op `__enter__` calls.
88+
89+ blocktodo: add to docs about different ways of calling, ensure tests
11490 '''
115- return _wrap_context_manager_or_class (
116- context_manager , _ReentrantContextManager
91+ return _ReentrantContextManager . _wrap_context_manager_or_class (
92+ context_manager ,
11793 )
118-
11994
120- class _IdempotentContextManager ( ContextManager ):
121- _entered = False
95+
96+ class _ContextManagerWrapper ( ContextManager ):
12297 _enter_value = None
123-
98+ __wrapped__ = None
12499 def __init__ (self , wrapped_context_manager ):
125- self .__wrapped__ = wrapped_context_manager
100+ if hasattr (wrapped_context_manager , '__enter__' ):
101+ self .__wrapped__ = wrapped_context_manager
102+ self ._wrapped_enter = wrapped_context_manager .__enter__
103+ self ._wrapped_exit = wrapped_context_manager .__exit__
104+ else :
105+ self ._wrapped_enter , self ._wrapped_exit = wrapped_context_manager
106+
107+ @classmethod
108+ def _wrap_context_manager_or_class (cls , thing ):
109+ if hasattr (type (thing ), '__enter__' ):
110+ # It's a context manager.
111+ return cls (thing )
112+ else :
113+ assert issubclass (thing , ContextManager )
114+ # It's a context manager class.
115+ property_name = '__%s_context_manager_%s' % (
116+ thing .__name__ ,
117+ '' .join (random .choice (string .ascii_letters ) for _ in range (30 ))
118+ )
119+ return type (
120+ thing .__name__ ,
121+ (thing ,),
122+ {
123+ property_name : caching .CachedProperty (
124+ lambda self : cls ((
125+ lambda : thing .__enter__ (self ),
126+ lambda exc_type , exc_value , exc_traceback :
127+ thing .__exit__ (
128+ self , exc_type , exc_value , exc_traceback
129+ )
130+ ))
131+ ),
132+ '__enter__' :
133+ lambda self : getattr (self , property_name ).__enter__ (),
134+ '__exit__' : lambda self , exc_type , exc_value , exc_traceback :
135+ getattr (self , property_name ).
136+ __exit__ (exc_type , exc_value , exc_traceback ),
137+
138+ }
139+ )
140+
141+
142+
143+
126144
145+
146+ class _IdempotentContextManager (_ContextManagerWrapper ):
147+ _entered = False
127148
128149 def __enter__ (self ):
129150 if not self ._entered :
@@ -140,11 +161,8 @@ def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None):
140161 self ._enter_value = None
141162 return exit_value
142163
143- class _ReentrantContextManager (ContextManager ):
164+ class _ReentrantContextManager (_ContextManagerWrapper ):
144165
145- def __init__ (self , wrapped_context_manager ):
146- self .__wrapped__ = wrapped_context_manager
147-
148166 depth = caching .CachedProperty (
149167 0 ,
150168 doc = '''
@@ -156,19 +174,18 @@ def __init__(self, wrapped_context_manager):
156174 '''
157175 )
158176
159- _enter_value = None
160177
161178 def __enter__ (self ):
162179 if self .depth == 0 :
163- self ._enter_value = self .__wrapped__ . __enter__ ()
180+ self ._enter_value = self ._wrapped_enter ()
164181 self .depth += 1
165182 return self ._enter_value
166183
167184
168185 def __exit__ (self , exc_type = None , exc_value = None , exc_traceback = None ):
169186 assert self .depth >= 1
170187 if self .depth == 1 :
171- exit_value = self .__wrapped__ . __exit__ (
188+ exit_value = self ._wrapped_exit (
172189 exc_type , exc_value , exc_traceback
173190 )
174191 self ._enter_value = None
0 commit comments