Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Override dunder methods of placeholder modules to provide more inform…
…ative errors

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Jan 9, 2025
commit b5ddc8a1a9a4e29ac1530071e014b90c91d4670f
189 changes: 176 additions & 13 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import zmq.asyncio
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
from typing_extensions import Never, ParamSpec, TypeIs, assert_never

import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
Expand Down Expand Up @@ -1594,24 +1594,183 @@ def get_vllm_optional_dependencies():
}


@dataclass(frozen=True)
class PlaceholderModule:
class _PlaceholderMixin:
"""
Disallows downstream usage of placeholder modules.

We need to explicitly override each dunder method because
:meth:`__getattr__` is not called when they are accessed.

See also:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
"""

if TYPE_CHECKING:

def __getattr__(self, key: str) -> Never:
# Implemented by the main class to throw an error for
# any attribute access
...

# [Basic customization]

def __lt__(self, other: object):
return self.__getattr__("__lt__")

def __le__(self, other: object):
return self.__getattr__("__le__")

def __eq__(self, other: object):
return self.__getattr__("__eq__")

def __ne__(self, other: object):
return self.__getattr__("__ne__")

def __gt__(self, other: object):
return self.__getattr__("__gt__")

def __ge__(self, other: object):
return self.__getattr__("__ge__")

def __hash__(self):
return self.__getattr__("__hash__")

def __bool__(self):
return self.__getattr__("__bool__")

# [Callable objects]

def __call__(self, *args: object, **kwargs: object):
return self.__getattr__("__call__")

# [Container types]

def __len__(self):
return self.__getattr__("__len__")

def __getitem__(self, key: object):
return self.__getattr__("__getitem__")

def __setitem__(self, key: object, value: object):
return self.__getattr__("__setitem__")

def __delitem__(self, key: object):
return self.__getattr__("__delitem__")

# __missing__ is optional according to __getitem__ specification,
# so it is skipped

# __iter__ and __reversed__ have a default implementation
# based on __len__ and __getitem__, so they are skipped.

# [Numeric Types]

def __add__(self, other: object):
return self.__getattr__("__add__")

def __sub__(self, other: object):
return self.__getattr__("__sub__")

def __mul__(self, other: object):
return self.__getattr__("__mul__")

def __matmul__(self, other: object):
return self.__getattr__("__matmul__")

def __truediv__(self, other: object):
return self.__getattr__("__truediv__")

def __floordiv__(self, other: object):
return self.__getattr__("__floordiv__")

def __mod__(self, other: object):
return self.__getattr__("__mod__")

def __divmod__(self, other: object):
return self.__getattr__("__divmod__")

def __pow__(self, other: object, modulo: object = ...):
return self.__getattr__("__pow__")

def __lshift__(self, other: object):
return self.__getattr__("__lshift__")

def __rshift__(self, other: object):
return self.__getattr__("__rshift__")

def __and__(self, other: object):
return self.__getattr__("__and__")

def __xor__(self, other: object):
return self.__getattr__("__xor__")

def __or__(self, other: object):
return self.__getattr__("__or__")

# r* and i* methods have lower priority than
# the methods for left operand so they are skipped

def __neg__(self):
return self.__getattr__("__neg__")

def __pos__(self):
return self.__getattr__("__pos__")

def __abs__(self):
return self.__getattr__("__abs__")

def __invert__(self):
return self.__getattr__("__invert__")

# __complex__, __int__ and __float__ have a default implementation
# based on __index__, so they are skipped.

def __index__(self):
return self.__getattr__("__index__")

def __round__(self, ndigits: object = ...):
return self.__getattr__("__round__")

def __trunc__(self):
return self.__getattr__("__trunc__")

def __floor__(self):
return self.__getattr__("__floor__")

def __ceil__(self):
return self.__getattr__("__ceil__")

# [Context managers]

def __enter__(self):
return self.__getattr__("__enter__")

def __exit__(self, *args: object, **kwargs: object):
return self.__getattr__("__exit__")


class PlaceholderModule(_PlaceholderMixin):
"""
A placeholder object to use when a module does not exist.

This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str

def __init__(self, name: str) -> None:
super().__init__()

# Apply name mangling to avoid conflicting with module attributes
self.__name = name

def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)

def __getattr__(self, key: str):
name = self.name
name = self.__name

try:
importlib.import_module(self.name)
importlib.import_module(name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
Expand All @@ -1624,17 +1783,21 @@ def __getattr__(self, key: str):
"when the original module can be imported")


@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
class _PlaceholderModuleAttr(_PlaceholderMixin):

def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
super().__init__()

# Apply name mangling to avoid conflicting with module attributes
self.__module = module
self.__attr_path = attr_path

def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
return _PlaceholderModuleAttr(self.__module,
f"{self.__attr_path}.{attr_path}")

def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
getattr(self.__module, f"{self.__attr_path}.{key}")

raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
Expand Down