|
16 | 16 |
|
17 | 17 | import contextlib |
18 | 18 | import datetime |
| 19 | +import multiprocessing |
19 | 20 | import os |
20 | 21 | import threading |
21 | 22 | import socket |
22 | 23 | import sys |
23 | 24 | import time |
| 25 | +import traceback |
24 | 26 | import warnings |
25 | 27 |
|
26 | 28 | sys.path[0:0] = [""] |
|
38 | 40 | InvalidName, |
39 | 41 | OperationFailure, |
40 | 42 | CursorNotFound) |
41 | | -from pymongo.server_selectors import writable_server_selector |
| 43 | +from pymongo.server_selectors import (any_server_selector, |
| 44 | + writable_server_selector) |
42 | 45 | from pymongo.server_type import SERVER_TYPE |
43 | 46 | from test import (client_context, |
44 | 47 | client_knobs, |
@@ -442,56 +445,38 @@ def test_fork(self): |
442 | 445 | if sys.platform == "win32": |
443 | 446 | raise SkipTest("Can't fork on windows") |
444 | 447 |
|
445 | | - try: |
446 | | - from multiprocessing import Process, Pipe |
447 | | - except ImportError: |
448 | | - raise SkipTest("No multiprocessing module") |
449 | | - |
450 | 448 | db = self.client.pymongo_test |
451 | 449 |
|
452 | | - # Failure occurs if the client is used before the fork |
| 450 | + # Ensure a socket is opened before the fork. |
453 | 451 | db.test.find_one() |
454 | | - db.connection.end_request() |
455 | | - |
456 | | - def loop(pipe): |
457 | | - while True: |
458 | | - try: |
459 | | - db.test.insert({"a": "b"}) |
460 | | - for _ in db.test.find(): |
461 | | - pass |
462 | | - except: |
463 | | - pipe.send(True) |
464 | | - os._exit(1) |
465 | 452 |
|
466 | | - cp1, cc1 = Pipe() |
467 | | - cp2, cc2 = Pipe() |
468 | | - |
469 | | - p1 = Process(target=loop, args=(cc1,)) |
470 | | - p2 = Process(target=loop, args=(cc2,)) |
471 | | - |
472 | | - p1.start() |
473 | | - p2.start() |
| 453 | + def f(pipe): |
| 454 | + try: |
| 455 | + servers = self.client._cluster.select_servers( |
| 456 | + any_server_selector) |
474 | 457 |
|
475 | | - p1.join(1) |
476 | | - p2.join(1) |
| 458 | + # In child, only the thread that called fork() is alive. |
| 459 | + assert not any(s._monitor._thread.is_alive() |
| 460 | + for s in servers) |
477 | 461 |
|
478 | | - p1.terminate() |
479 | | - p2.terminate() |
| 462 | + db.test.find_one() |
480 | 463 |
|
481 | | - p1.join() |
482 | | - p2.join() |
| 464 | + wait_until( |
| 465 | + lambda: all(s._monitor._thread.is_alive() for s in servers), |
| 466 | + "restart monitor threads") |
| 467 | + except: |
| 468 | + traceback.print_exc() # Aid debugging. |
| 469 | + pipe.send(True) |
483 | 470 |
|
484 | | - cc1.close() |
485 | | - cc2.close() |
| 471 | + parent_pipe, child_pipe = multiprocessing.Pipe() |
| 472 | + p = multiprocessing.Process(target=f, args=(child_pipe,)) |
| 473 | + p.start() |
| 474 | + p.join(10) |
| 475 | + child_pipe.close() |
486 | 476 |
|
487 | | - # recv will only have data if the subprocess failed |
488 | | - try: |
489 | | - cp1.recv() |
490 | | - self.fail() |
491 | | - except EOFError: |
492 | | - pass |
| 477 | + # Pipe will only have data if the child process failed. |
493 | 478 | try: |
494 | | - cp2.recv() |
| 479 | + parent_pipe.recv() |
495 | 480 | self.fail() |
496 | 481 | except EOFError: |
497 | 482 | pass |
|
0 commit comments