@@ -27,6 +27,10 @@ class MyExc(Exception):
2727 pass
2828
2929
30+ def get_error_types (eg ):
31+ return {type (exc ) for exc in eg .exceptions }
32+
33+
3034class TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
3135
3236 def setUp (self ):
@@ -117,10 +121,11 @@ async def runner():
117121
118122 NUM += 10
119123
120- with self .assertRaisesRegex (taskgroups .TaskGroupError ,
121- r'1 sub errors: \(ZeroDivisionError\)' ):
124+ with self .assertRaises (taskgroups .TaskGroupError ) as cm :
122125 await self .loop .create_task (runner ())
123126
127+ self .assertEqual (get_error_types (cm .exception ), {ZeroDivisionError })
128+
124129 self .assertEqual (NUM , 0 )
125130 self .assertTrue (t2_cancel )
126131 self .assertTrue (t2 .cancelled ())
@@ -162,12 +167,10 @@ async def runner():
162167
163168 # The 3 foo1 sub tasks can be racy when the host is busy - if the
164169 # cancellation happens in the middle, we'll see partial sub errors here
165- with self .assertRaisesRegex (
166- taskgroups .TaskGroupError ,
167- r'(1|2|3) sub errors: \(ZeroDivisionError\)' ,
168- ):
170+ with self .assertRaises (taskgroups .TaskGroupError ) as cm :
169171 await self .loop .create_task (runner ())
170172
173+ self .assertEqual (get_error_types (cm .exception ), {ZeroDivisionError })
171174 self .assertEqual (NUM , 0 )
172175 self .assertTrue (t2_cancel )
173176 self .assertTrue (runner_cancel )
@@ -280,7 +283,7 @@ async def runner():
280283 try :
281284 await runner ()
282285 except taskgroups .TaskGroupError as t :
283- self .assertEqual (t . get_error_types (), {ZeroDivisionError })
286+ self .assertEqual (get_error_types (t ), {ZeroDivisionError })
284287 else :
285288 self .fail ('TaskGroupError was not raised' )
286289
@@ -309,7 +312,7 @@ async def runner():
309312 try :
310313 await runner ()
311314 except taskgroups .TaskGroupError as t :
312- self .assertEqual (t . get_error_types (), {ZeroDivisionError })
315+ self .assertEqual (get_error_types (t ), {ZeroDivisionError })
313316 else :
314317 self .fail ('TaskGroupError was not raised' )
315318
@@ -382,9 +385,11 @@ async def runner():
382385 g2 .create_task (crash_after (0.2 ))
383386
384387 r = self .loop .create_task (runner ())
385- with self .assertRaisesRegex (taskgroups .TaskGroupError , r'1 sub errors' ) :
388+ with self .assertRaises (taskgroups .TaskGroupError ) as cm :
386389 await r
387390
391+ self .assertEqual (get_error_types (cm .exception ), {ValueError })
392+
388393 async def test_taskgroup_14 (self ):
389394
390395 async def crash_after (t ):
@@ -399,9 +404,13 @@ async def runner():
399404 g2 .create_task (crash_after (0.1 ))
400405
401406 r = self .loop .create_task (runner ())
402- with self .assertRaisesRegex (taskgroups .TaskGroupError , r'1 sub errors' ) :
407+ with self .assertRaises (taskgroups .TaskGroupError ) as cm :
403408 await r
404409
410+ # TODO(guido): Check that the nested exception group is expected
411+ self .assertEqual (get_error_types (cm .exception ), {taskgroups .TaskGroupError })
412+ self .assertEqual (get_error_types (cm .exception .exceptions [0 ]), {ValueError })
413+
405414 async def test_taskgroup_15 (self ):
406415
407416 async def crash_soon ():
@@ -497,7 +506,7 @@ async def runner():
497506 try :
498507 await r
499508 except taskgroups .TaskGroupError as t :
500- self .assertEqual (t . get_error_types (), {MyExc })
509+ self .assertEqual (get_error_types (t ), {MyExc })
501510 else :
502511 self .fail ('TaskGroupError was not raised' )
503512
@@ -523,7 +532,7 @@ async def runner():
523532 try :
524533 await r
525534 except taskgroups .TaskGroupError as t :
526- self .assertEqual (t . get_error_types (), {MyExc , ZeroDivisionError })
535+ self .assertEqual (get_error_types (t ), {MyExc , ZeroDivisionError })
527536 else :
528537 self .fail ('TasgGroupError was not raised' )
529538
0 commit comments