@@ -173,6 +173,15 @@ def whoo():
173173 # The "gen" attribute is an implementation detail.
174174 self .assertFalse (ctx .gen .gi_suspended )
175175
176+ def test_contextmanager_trap_no_yield (self ):
177+ @contextmanager
178+ def whoo ():
179+ if False :
180+ yield
181+ ctx = whoo ()
182+ with self .assertRaises (RuntimeError ):
183+ ctx .__enter__ ()
184+
176185 def test_contextmanager_trap_second_yield (self ):
177186 @contextmanager
178187 def whoo ():
@@ -186,6 +195,19 @@ def whoo():
186195 # The "gen" attribute is an implementation detail.
187196 self .assertFalse (ctx .gen .gi_suspended )
188197
198+ def test_contextmanager_non_normalised (self ):
199+ @contextmanager
200+ def whoo ():
201+ try :
202+ yield
203+ except RuntimeError :
204+ raise SyntaxError
205+
206+ ctx = whoo ()
207+ ctx .__enter__ ()
208+ with self .assertRaises (SyntaxError ):
209+ ctx .__exit__ (RuntimeError , None , None )
210+
189211 def test_contextmanager_except (self ):
190212 state = []
191213 @contextmanager
@@ -265,6 +287,25 @@ def test_issue29692():
265287 self .assertEqual (ex .args [0 ], 'issue29692:Unchained' )
266288 self .assertIsNone (ex .__cause__ )
267289
290+ def test_contextmanager_wrap_runtimeerror (self ):
291+ @contextmanager
292+ def woohoo ():
293+ try :
294+ yield
295+ except Exception as exc :
296+ raise RuntimeError (f'caught { exc } ' ) from exc
297+
298+ with self .assertRaises (RuntimeError ):
299+ with woohoo ():
300+ 1 / 0
301+
302+ # If the context manager wrapped StopIteration in a RuntimeError,
303+ # we also unwrap it, because we can't tell whether the wrapping was
304+ # done by the generator machinery or by the generator itself.
305+ with self .assertRaises (StopIteration ):
306+ with woohoo ():
307+ raise StopIteration
308+
268309 def _create_contextmanager_attribs (self ):
269310 def attribs (** kw ):
270311 def decorate (func ):
@@ -276,6 +317,7 @@ def decorate(func):
276317 @attribs (foo = 'bar' )
277318 def baz (spam ):
278319 """Whee!"""
320+ yield
279321 return baz
280322
281323 def test_contextmanager_attribs (self ):
@@ -332,8 +374,11 @@ def woohoo(a, *, b):
332374
333375 def test_recursive (self ):
334376 depth = 0
377+ ncols = 0
335378 @contextmanager
336379 def woohoo ():
380+ nonlocal ncols
381+ ncols += 1
337382 nonlocal depth
338383 before = depth
339384 depth += 1
@@ -347,6 +392,7 @@ def recursive():
347392 recursive ()
348393
349394 recursive ()
395+ self .assertEqual (ncols , 10 )
350396 self .assertEqual (depth , 0 )
351397
352398
0 commit comments