|
1 | 1 | import multiprocessing |
2 | 2 | import os |
3 | | -from unittest.mock import MagicMock |
| 3 | +from unittest.mock import patch |
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | from loguru import logger |
8 | 8 |
|
9 | 9 |
|
10 | | -def get_handler_context(): |
11 | | - # No better way to test correct value than to access the private attribute. |
12 | | - handler = next(iter(logger._core.handlers.values())) |
13 | | - return handler._multiprocessing_context |
| 10 | +@pytest.fixture |
| 11 | +def reset_start_method(): |
| 12 | + yield |
| 13 | + multiprocessing.set_start_method(None, force=True) |
14 | 14 |
|
15 | 15 |
|
16 | | -def test_default_context(): |
17 | | - logger.add(lambda _: None, context=None) |
18 | | - assert get_handler_context() == multiprocessing.get_context(None) |
| 16 | +@pytest.mark.usefixtures("reset_start_method") |
| 17 | +def test_using_multiprocessing_directly_if_context_is_none(): |
| 18 | + logger.add(lambda _: None, enqueue=True, context=None) |
| 19 | + assert multiprocessing.get_start_method(allow_none=True) is not None |
19 | 20 |
|
20 | 21 |
|
21 | 22 | @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") |
22 | 23 | @pytest.mark.parametrize("context_name", ["fork", "forkserver"]) |
23 | 24 | def test_fork_context_as_string(context_name): |
24 | | - logger.add(lambda _: None, context=context_name) |
25 | | - assert get_handler_context() == multiprocessing.get_context(context_name) |
| 25 | + context = multiprocessing.get_context(context_name) |
| 26 | + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: |
| 27 | + logger.add(lambda _: None, context=context_name, enqueue=True) |
| 28 | + assert mock.called |
| 29 | + assert multiprocessing.get_start_method(allow_none=True) is None |
26 | 30 |
|
27 | 31 |
|
28 | 32 | def test_spawn_context_as_string(): |
29 | | - logger.add(lambda _: None, context="spawn") |
30 | | - assert get_handler_context() == multiprocessing.get_context("spawn") |
| 33 | + context = multiprocessing.get_context("spawn") |
| 34 | + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: |
| 35 | + logger.add(lambda _: None, context="spawn", enqueue=True) |
| 36 | + assert mock.called |
| 37 | + assert multiprocessing.get_start_method(allow_none=True) is None |
31 | 38 |
|
32 | 39 |
|
33 | 40 | @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") |
34 | 41 | @pytest.mark.parametrize("context_name", ["fork", "forkserver"]) |
35 | 42 | def test_fork_context_as_object(context_name): |
36 | 43 | context = multiprocessing.get_context(context_name) |
37 | | - logger.add(lambda _: None, context=context) |
38 | | - assert get_handler_context() == context |
| 44 | + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: |
| 45 | + logger.add(lambda _: None, context=context, enqueue=True) |
| 46 | + assert mock.called |
| 47 | + assert multiprocessing.get_start_method(allow_none=True) is None |
39 | 48 |
|
40 | 49 |
|
41 | 50 | def test_spawn_context_as_object(): |
42 | 51 | context = multiprocessing.get_context("spawn") |
43 | | - logger.add(lambda _: None, context=context) |
44 | | - assert get_handler_context() == context |
| 52 | + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: |
| 53 | + logger.add(lambda _: None, context=context, enqueue=True) |
| 54 | + assert mock.called |
| 55 | + assert multiprocessing.get_start_method(allow_none=True) is None |
45 | 56 |
|
46 | 57 |
|
47 | | -def test_context_effectively_used(): |
48 | | - default_context = multiprocessing.get_context() |
49 | | - mocked_context = MagicMock(spec=default_context, wraps=default_context) |
50 | | - logger.add(lambda _: None, context=mocked_context, enqueue=True) |
51 | | - logger.complete() |
52 | | - assert mocked_context.Lock.called |
| 58 | +def test_global_start_method_is_none_if_enqueue_is_false(): |
| 59 | + logger.add(lambda _: None, enqueue=False, context=None) |
| 60 | + assert multiprocessing.get_start_method(allow_none=True) is None |
53 | 61 |
|
54 | 62 |
|
55 | 63 | def test_invalid_context_name(): |
|
0 commit comments