diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index 47ce1ee987..7c61eb0f6e 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -1,5 +1,6 @@ import inspect import logging +import importlib from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload from haystack.nodes.base import BaseComponent @@ -40,7 +41,7 @@ def __init__( use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = None, devices: Optional[List[Union[str, "torch.device"]]] = None, - invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None, + invocation_layer_class: Optional[Union[Type[PromptModelInvocationLayer], str]] = None, model_kwargs: Optional[Dict] = None, ): """ @@ -73,7 +74,7 @@ def __init__( self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class) def create_invocation_layer( - self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] + self, invocation_layer_class: Optional[Union[Type[PromptModelInvocationLayer], str]] ) -> PromptModelInvocationLayer: kwargs = { "api_key": self.api_key, @@ -84,6 +85,18 @@ def create_invocation_layer( } all_kwargs = {**self.model_kwargs, **kwargs} + if isinstance(invocation_layer_class, str): + module_name, class_name = invocation_layer_class.rsplit(".", maxsplit=1) + try: + module = importlib.import_module(module_name) + except ImportError as e: + msg = f"Can't find module {module_name}" + raise ValueError(msg) from e + invocation_layer_class = getattr(module, class_name) + if invocation_layer_class is None: + msg = f"Can'f find class {class_name} in module {module_name}" + ValueError(msg) + if invocation_layer_class: return invocation_layer_class( model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs diff --git a/releasenotes/notes/prompt-model-invocation-layer-e7a69a3ac3beb5a7.yaml b/releasenotes/notes/prompt-model-invocation-layer-e7a69a3ac3beb5a7.yaml new file mode 100644 index 0000000000..2f2a6bb497 --- /dev/null +++ b/releasenotes/notes/prompt-model-invocation-layer-e7a69a3ac3beb5a7.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Change `PromptModel` constructor parameter `invocation_layer_class` to accept a `str` too. + If a `str` is used the invocation layer class will be imported and used. + This should ease serialisation to YAML when using `invocation_layer_class` with `PromptModel`. diff --git a/test/prompt/test_prompt_model.py b/test/prompt/test_prompt_model.py index 9e7c2b0f51..985872b071 100644 --- a/test/prompt/test_prompt_model.py +++ b/test/prompt/test_prompt_model.py @@ -39,6 +39,16 @@ def test_constructor_with_no_supported_model(): PromptModel("some-random-model") +@pytest.mark.unit +def test_constructor_with_invocation_layer_class_string(): + model = PromptModel( + invocation_layer_class="haystack.nodes.prompt.invocation_layer.CohereInvocationLayer", api_key="fake_api_key" + ) + from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer + + assert isinstance(model.model_invocation_layer, CohereInvocationLayer) + + @pytest.mark.asyncio async def test_ainvoke(): def async_return(result):