Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion lazy_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ class _StubVisitor(ast.NodeVisitor):
def __init__(self):
self._submodules = set()
self._submod_attrs = {}
self._all = None

def visit_ImportFrom(self, node: ast.ImportFrom):
if node.level != 1:
Expand All @@ -300,6 +301,38 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
else:
self._submodules.update(alias.name for alias in node.names)

def visit_Assign(self, node: ast.Assign):
assigned_list = None
for name in node.targets:
if name.id == "__all__":
assigned_list = node.value

if assigned_list is None:
return # early
elif not isinstance(assigned_list, ast.List):
msg = (
f"expected a list assigned to `__all__`, found {type(assigned_list)!r}"
)
raise ValueError(msg)

if self._all is not None:
msg = "expected only one definition of `__all__` in stub"
raise ValueError(msg)
self._all = set()

for constant in assigned_list.elts:
if (
not isinstance(constant, ast.Constant)
or not isinstance(constant.value, str)
or assigned_list == ""
):
msg = (
"expected `__all__` to contain only non-empty strings, "
f"got {constant!r}"
)
raise ValueError(msg)
self._all.add(constant.value)


def attach_stub(package_name: str, filename: str):
"""Attach lazily loaded submodules, functions from a type stub.
Expand All @@ -308,6 +341,10 @@ def attach_stub(package_name: str, filename: str):
infer ``submodules`` and ``submod_attrs``. This allows static type checkers
to find imports, while still providing lazy loading at runtime.

If the stub file defines `__all__`, it must contain a simple list of
non-empty strings. In this case, the content of `__dir__()` may be
intentionally different from `__all__`.

Parameters
----------
package_name : str
Expand Down Expand Up @@ -339,4 +376,10 @@ def attach_stub(package_name: str, filename: str):

visitor = _StubVisitor()
visitor.visit(stub_node)
return attach(package_name, visitor._submodules, visitor._submod_attrs)

__getattr__, __dir__, __all__ = attach(
package_name, visitor._submodules, visitor._submod_attrs
)
if visitor._all is not None:
__all__ = visitor._all
return __getattr__, __dir__, __all__
29 changes: 29 additions & 0 deletions lazy_loader/tests/test_lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,35 @@ def test_stub_loading_parity():
assert stub_getter("some_func") == fake_pkg.some_func


FAKE_STUB_OVERRIDE_ALL = """
__all__ = [
"rank",
"gaussian",
"sobel",
"scharr",
"roberts",
# `prewitt` not included!
"__version__", # included but not imported in stub
]

from . import rank
from ._gaussian import gaussian
from .edges import sobel, scharr, prewitt, roberts
"""


def test_stub_override_all(tmp_path):
stub = tmp_path / "stub.pyi"
stub.write_text(FAKE_STUB_OVERRIDE_ALL)
_get, _dir, _all = lazy.attach_stub("my_module", str(stub))

expect_dir = {"gaussian", "sobel", "scharr", "prewitt", "roberts", "rank"}
assert set(_dir()) == expect_dir

expect_all = {"rank", "gaussian", "sobel", "scharr", "roberts", "__version__"}
assert set(_all) == expect_all


def test_stub_loading_errors(tmp_path):
stub = tmp_path / "stub.pyi"
stub.write_text("from ..mod import func\n")
Expand Down