Skip to content

Commit 086126f

Browse files
authored
Fix error using "set_start_method()" after "logger" import (#974)
Calling "multiprocessing.get_context(method=None)" had the unexpected side effect of also fixing the global start method (which can't be changed afterwards).
1 parent 14fa062 commit 086126f

File tree

4 files changed

+42
-27
lines changed

4 files changed

+42
-27
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
=============
33

44
- Add support for formatting of ``ExceptionGroup`` errors (`#805 <https://github.com/Delgan/loguru/issues/805>`_).
5+
- Fix possible ``RuntimeError`` when using ``multiprocessing.set_start_method()`` after importing the ``logger`` (`#974 <https://github.com/Delgan/loguru/issues/974>`_)
56
- Fix formatting of possible ``__notes__`` attached to an ``Exception`` (`#980 <https://github.com/Delgan/loguru/issues/980>`_).
67

78

loguru/_handler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import json
3+
import multiprocessing
34
import os
45
import threading
56
from contextlib import contextmanager
@@ -88,10 +89,15 @@ def __init__(
8889
self._decolorized_format = self._formatter.strip()
8990

9091
if self._enqueue:
91-
self._queue = self._multiprocessing_context.SimpleQueue()
92+
if self._multiprocessing_context is None:
93+
self._queue = multiprocessing.SimpleQueue()
94+
self._confirmation_event = multiprocessing.Event()
95+
self._confirmation_lock = multiprocessing.Lock()
96+
else:
97+
self._queue = self._multiprocessing_context.SimpleQueue()
98+
self._confirmation_event = self._multiprocessing_context.Event()
99+
self._confirmation_lock = self._multiprocessing_context.Lock()
92100
self._queue_lock = create_handler_lock()
93-
self._confirmation_event = self._multiprocessing_context.Event()
94-
self._confirmation_lock = self._multiprocessing_context.Lock()
95101
self._owner_process_pid = os.getpid()
96102
self._thread = Thread(
97103
target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id

loguru/_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -967,9 +967,9 @@ def add(
967967
if not isinstance(encoding, str):
968968
encoding = "ascii"
969969

970-
if context is None or isinstance(context, str):
970+
if isinstance(context, str):
971971
context = get_context(context)
972-
elif not isinstance(context, BaseContext):
972+
elif context is not None and not isinstance(context, BaseContext):
973973
raise TypeError(
974974
"Invalid context, it should be a string or a multiprocessing context, "
975975
"not: '%s'" % type(context).__name__

tests/test_add_option_context.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,63 @@
11
import multiprocessing
22
import os
3-
from unittest.mock import MagicMock
3+
from unittest.mock import patch
44

55
import pytest
66

77
from loguru import logger
88

99

10-
def get_handler_context():
11-
# No better way to test correct value than to access the private attribute.
12-
handler = next(iter(logger._core.handlers.values()))
13-
return handler._multiprocessing_context
10+
@pytest.fixture
11+
def reset_start_method():
12+
yield
13+
multiprocessing.set_start_method(None, force=True)
1414

1515

16-
def test_default_context():
17-
logger.add(lambda _: None, context=None)
18-
assert get_handler_context() == multiprocessing.get_context(None)
16+
@pytest.mark.usefixtures("reset_start_method")
17+
def test_using_multiprocessing_directly_if_context_is_none():
18+
logger.add(lambda _: None, enqueue=True, context=None)
19+
assert multiprocessing.get_start_method(allow_none=True) is not None
1920

2021

2122
@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
2223
@pytest.mark.parametrize("context_name", ["fork", "forkserver"])
2324
def test_fork_context_as_string(context_name):
24-
logger.add(lambda _: None, context=context_name)
25-
assert get_handler_context() == multiprocessing.get_context(context_name)
25+
context = multiprocessing.get_context(context_name)
26+
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
27+
logger.add(lambda _: None, context=context_name, enqueue=True)
28+
assert mock.called
29+
assert multiprocessing.get_start_method(allow_none=True) is None
2630

2731

2832
def test_spawn_context_as_string():
29-
logger.add(lambda _: None, context="spawn")
30-
assert get_handler_context() == multiprocessing.get_context("spawn")
33+
context = multiprocessing.get_context("spawn")
34+
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
35+
logger.add(lambda _: None, context="spawn", enqueue=True)
36+
assert mock.called
37+
assert multiprocessing.get_start_method(allow_none=True) is None
3138

3239

3340
@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
3441
@pytest.mark.parametrize("context_name", ["fork", "forkserver"])
3542
def test_fork_context_as_object(context_name):
3643
context = multiprocessing.get_context(context_name)
37-
logger.add(lambda _: None, context=context)
38-
assert get_handler_context() == context
44+
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
45+
logger.add(lambda _: None, context=context, enqueue=True)
46+
assert mock.called
47+
assert multiprocessing.get_start_method(allow_none=True) is None
3948

4049

4150
def test_spawn_context_as_object():
4251
context = multiprocessing.get_context("spawn")
43-
logger.add(lambda _: None, context=context)
44-
assert get_handler_context() == context
52+
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
53+
logger.add(lambda _: None, context=context, enqueue=True)
54+
assert mock.called
55+
assert multiprocessing.get_start_method(allow_none=True) is None
4556

4657

47-
def test_context_effectively_used():
48-
default_context = multiprocessing.get_context()
49-
mocked_context = MagicMock(spec=default_context, wraps=default_context)
50-
logger.add(lambda _: None, context=mocked_context, enqueue=True)
51-
logger.complete()
52-
assert mocked_context.Lock.called
58+
def test_global_start_method_is_none_if_enqueue_is_false():
59+
logger.add(lambda _: None, enqueue=False, context=None)
60+
assert multiprocessing.get_start_method(allow_none=True) is None
5361

5462

5563
def test_invalid_context_name():

0 commit comments

Comments
 (0)