From f91f783e6dcf1e6d9719e54ab4ff6922cc993645 Mon Sep 17 00:00:00 2001 From: BABTUNA Date: Thu, 21 May 2026 19:10:31 -0400 Subject: [PATCH 1/2] refactor(reflex-base): consolidate BaseContext into context.base (#6514) Removes the duplicate BaseContext defined in plugins/compiler.py and makes PageContext/CompileContext inherit from reflex_base.context.base. Converts BaseContext from a frozen dataclass to a plain class with __slots__ = () so frozen (RegistrationContext, EventContext) and non-frozen (PageContext, CompileContext) subclasses can both inherit. Ports the async __aenter__/__aexit__ helpers over from the removed copy, and pins identity equality on the base via object.__eq__ / object.__hash__. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/reflex_base/context/base.py | 71 +++++++++--- .../src/reflex_base/plugins/compiler.py | 101 +----------------- tests/units/compiler/test_plugins.py | 16 ++- tests/units/reflex_base/context/test_base.py | 34 +++++- 4 files changed, 101 insertions(+), 121 deletions(-) diff --git a/packages/reflex-base/src/reflex_base/context/base.py b/packages/reflex-base/src/reflex_base/context/base.py index 7bb28d4864c..94b62781de2 100644 --- a/packages/reflex-base/src/reflex_base/context/base.py +++ b/packages/reflex-base/src/reflex_base/context/base.py @@ -2,39 +2,58 @@ from __future__ import annotations -import dataclasses from contextvars import ContextVar, Token +from types import TracebackType from typing import ClassVar from typing_extensions import Self -@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class BaseContext: - """Base context class that acts as an async context manager to set the context var.""" + """Base context class that acts as a sync/async context manager for a per-subclass ContextVar. + + Each subclass gets its own :class:`ContextVar` and a class-level mapping from + attached instances to their reset tokens, so any number of subclasses can be + entered concurrently without interfering with each other. + + Instances use identity equality (and identity-based hashing) so that two + distinct contexts with the same field values are still considered different. + """ + + __slots__ = () _context_var: ClassVar[ContextVar[Self]] _attached_context_token: ClassVar[dict[Self, Token[Self]]] + __eq__ = object.__eq__ + __hash__ = object.__hash__ + @classmethod def __init_subclass__(cls, **kwargs): - """Initialize the context variable for the subclass.""" - super(BaseContext, cls).__init_subclass__(**kwargs) + """Initialize the context variable and token registry for the subclass. + + Args: + **kwargs: Forwarded to ``super().__init_subclass__``. + """ + super().__init_subclass__(**kwargs) cls._context_var = ContextVar(cls.__name__) cls._attached_context_token = {} @classmethod def get(cls) -> Self: - """Get the context from the context variable. + """Get the active context from the context variable. Returns: - The context instance. + The active context instance. + + Raises: + LookupError: If no context has been set for this class. """ return cls._context_var.get() @classmethod def set(cls, context: Self) -> Token[Self]: - """Set the context in the context variable. + """Set the active context in the context variable. Args: context: The context instance to set. @@ -54,10 +73,13 @@ def reset(cls, token: Token[Self]) -> None: cls._context_var.reset(token) def __enter__(self) -> Self: - """Enter the context. + """Attach this context to the current task. Returns: This context instance. + + Raises: + RuntimeError: If this instance is already attached. """ if self._attached_context_token.get(self) is not None: msg = "Context is already attached, cannot enter context manager." @@ -65,12 +87,35 @@ def __enter__(self) -> Self: self._attached_context_token[self] = self._context_var.set(self) return self - def __exit__(self, *exc_info): - """Exit the context.""" - if (token := self._attached_context_token.pop(self)) is not None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task.""" + del exc_type, exc_val, exc_tb + if (token := self._attached_context_token.pop(self, None)) is not None: self._context_var.reset(token) - def ensure_context_attached(self): + async def __aenter__(self) -> Self: + """Attach this context to the current task asynchronously. + + Returns: + This context instance. + """ + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task asynchronously.""" + self.__exit__(exc_type, exc_val, exc_tb) + + def ensure_context_attached(self) -> None: """Ensure that the context is attached to the current context variable. Raises: diff --git a/packages/reflex-base/src/reflex_base/plugins/compiler.py b/packages/reflex-base/src/reflex_base/plugins/compiler.py index ecb55a03d92..5de4bbac78a 100644 --- a/packages/reflex-base/src/reflex_base/plugins/compiler.py +++ b/packages/reflex-base/src/reflex_base/plugins/compiler.py @@ -6,13 +6,10 @@ import dataclasses import inspect from collections.abc import Callable, Sequence -from contextvars import ContextVar, Token -from types import TracebackType -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, cast - -from typing_extensions import Self +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, cast from reflex_base.components.component import BaseComponent, Component +from reflex_base.context.base import BaseContext from reflex_base.utils.imports import ParsedImportDict, collapse_imports, merge_imports from reflex_base.vars import VarData @@ -581,97 +578,7 @@ def _apply_replacement( return replacement, children -@dataclasses.dataclass(kw_only=True) -class BaseContext: - """Context manager that exposes itself through a class-local context var.""" - - __context_var__: ClassVar[ContextVar[Self | None]] - - _attached_context_token: Token[Self | None] | None = dataclasses.field( - default=None, - init=False, - repr=False, - ) - - @classmethod - def __init_subclass__(cls, **kwargs: Any) -> None: - """Initialize a dedicated context variable for each subclass.""" - super().__init_subclass__(**kwargs) - cls.__context_var__ = ContextVar(cls.__name__, default=None) - - @classmethod - def get(cls) -> Self: - """Return the active context instance for the current task. - - Returns: - The active context instance for the current task. - """ - context = cls.__context_var__.get() - if context is None: - msg = f"No active {cls.__name__} is attached to the current context." - raise RuntimeError(msg) - return context - - def __enter__(self) -> Self: - """Attach this context to the current task. - - Returns: - The attached context instance. - """ - if self._attached_context_token is not None: - msg = "Context is already attached and cannot be entered twice." - raise RuntimeError(msg) - self._attached_context_token = type(self).__context_var__.set(self) - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Detach this context from the current task.""" - del exc_type, exc_val, exc_tb - if self._attached_context_token is None: - return - try: - type(self).__context_var__.reset(self._attached_context_token) - finally: - self._attached_context_token = None - - async def __aenter__(self) -> Self: - """Attach this context to the current task asynchronously. - - Returns: - The attached context instance. - """ - return self.__enter__() - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Detach this context from the current task asynchronously.""" - self.__exit__(exc_type, exc_val, exc_tb) - - def ensure_context_attached(self) -> None: - """Ensure this instance is the active context for the current task.""" - try: - current = type(self).get() - except RuntimeError as err: - msg = ( - f"{type(self).__name__} must be entered with 'with' or 'async with' " - "before calling this method." - ) - raise RuntimeError(msg) from err - if current is not self: - msg = f"{type(self).__name__} is not attached to the current task context." - raise RuntimeError(msg) - - -@dataclasses.dataclass(slots=True, kw_only=True) +@dataclasses.dataclass(slots=True, kw_only=True, eq=False) class PageContext(BaseContext): """Mutable compilation state for a single page.""" @@ -749,7 +656,7 @@ def custom_code_dict(self) -> dict[str, None]: return dict(self.module_code) -@dataclasses.dataclass(slots=True, kw_only=True) +@dataclasses.dataclass(slots=True, kw_only=True, eq=False) class CompileContext(BaseContext): """Mutable compilation state for an entire compile run.""" diff --git a/tests/units/compiler/test_plugins.py b/tests/units/compiler/test_plugins.py index 26eb1f39c99..4e53d984275 100644 --- a/tests/units/compiler/test_plugins.py +++ b/tests/units/compiler/test_plugins.py @@ -620,33 +620,31 @@ def test_context_lifecycle_and_cleanup() -> None: root_component=Fragment.create(), ) - with pytest.raises(RuntimeError, match="No active CompileContext"): + with pytest.raises(LookupError): CompileContext.get() - with pytest.raises( - RuntimeError, match="must be entered with 'with' or 'async with'" - ): + with pytest.raises(RuntimeError, match="must be entered"): compile_ctx.ensure_context_attached() with compile_ctx: assert CompileContext.get() is compile_ctx - with pytest.raises(RuntimeError, match="No active PageContext"): + with pytest.raises(LookupError): PageContext.get() with page_ctx: assert CompileContext.get() is compile_ctx assert PageContext.get() is page_ctx page_ctx.ensure_context_attached() - with pytest.raises(RuntimeError, match="No active PageContext"): + with pytest.raises(LookupError): PageContext.get() assert CompileContext.get() is compile_ctx - with pytest.raises(RuntimeError, match="No active CompileContext"): + with pytest.raises(LookupError): CompileContext.get() with pytest.raises(ValueError, match="boom"), compile_ctx: msg = "boom" raise ValueError(msg) - with pytest.raises(RuntimeError, match="No active CompileContext"): + with pytest.raises(LookupError): CompileContext.get() @@ -707,7 +705,7 @@ class DynamicContext(BaseContext): class AnotherDynamicContext(BaseContext): pass - assert DynamicContext.__context_var__ is not AnotherDynamicContext.__context_var__ + assert DynamicContext._context_var is not AnotherDynamicContext._context_var def test_apply_style_plugin_matches_legacy_style_behavior() -> None: diff --git a/tests/units/reflex_base/context/test_base.py b/tests/units/reflex_base/context/test_base.py index 11db7963159..fc622b6e7d3 100644 --- a/tests/units/reflex_base/context/test_base.py +++ b/tests/units/reflex_base/context/test_base.py @@ -6,7 +6,7 @@ from reflex_base.context.base import BaseContext -@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) class _TestContext(BaseContext): """Minimal BaseContext subclass for unit testing.""" @@ -83,7 +83,7 @@ def test_ensure_context_attached(): def test_subclasses_have_independent_context_vars(): """Two BaseContext subclasses do not share their ContextVar.""" - @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) + @dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) class _OtherContext(BaseContext): value: int = 0 @@ -92,3 +92,33 @@ class _OtherContext(BaseContext): with ctx_a, ctx_b: assert _TestContext.get().label == "a" assert _OtherContext.get().value == 42 + + +def test_identity_equality_for_subclasses_with_eq_false(): + """Two BaseContext subclass instances with the same fields are not equal.""" + ctx_a = _TestContext(label="same") + ctx_b = _TestContext(label="same") + assert ctx_a is not ctx_b + assert ctx_a != ctx_b + assert hash(ctx_a) != hash(ctx_b) + + +def test_identity_equality_isolates_entered_state(): + """Two equal-by-field instances can be entered independently.""" + ctx_a = _TestContext(label="same") + ctx_b = _TestContext(label="same") + with ctx_a: + # Entering ctx_b must not see ctx_a's attachment as its own. + with ctx_b: + assert _TestContext.get() is ctx_b + assert _TestContext.get() is ctx_a + + +async def test_async_context_manager(): + """Async __aenter__/__aexit__ attaches and detaches the context.""" + ctx = _TestContext(label="async") + async with ctx as entered: + assert entered is ctx + assert _TestContext.get() is ctx + with pytest.raises(LookupError): + _TestContext.get() From 94c05d26c403759ae3408192d7f7bc0d00996bd1 Mon Sep 17 00:00:00 2001 From: BABTUNA Date: Thu, 21 May 2026 19:48:49 -0400 Subject: [PATCH 2/2] test(compiler): match consolidated BaseContext error message The remaining "must be entered with 'with' or 'async with'" regex in test_compile_context_requires_attached_context didn't match the consolidated BaseContext's "must be entered before calling this method" message, so all 10 unit-tests jobs failed on CI. Loosen the regex to match the prefix shared by the new wording. --- tests/units/compiler/test_plugins.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/units/compiler/test_plugins.py b/tests/units/compiler/test_plugins.py index 4e53d984275..f5bb2ecf106 100644 --- a/tests/units/compiler/test_plugins.py +++ b/tests/units/compiler/test_plugins.py @@ -1029,9 +1029,7 @@ def test_compile_context_requires_attached_context() -> None: hooks=CompilerHooks(), ) - with pytest.raises( - RuntimeError, match="must be entered with 'with' or 'async with'" - ): + with pytest.raises(RuntimeError, match="must be entered"): compile_ctx.compile()