From df97e8e7a0b45c480bc6e4f0d9e796ca4a3ac7ce Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sun, 5 Nov 2023 16:36:11 -0500 Subject: [PATCH 1/6] feat: Add support for jinja templating Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- llama_cpp/llama_jinja_format.py | 135 ++++++++++++++++++++++++++++++++ tests/test_llama_chat_format.py | 50 ++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 llama_cpp/llama_jinja_format.py create mode 100644 tests/test_llama_chat_format.py diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py new file mode 100644 index 0000000000..b8bd22a0fd --- /dev/null +++ b/llama_cpp/llama_jinja_format.py @@ -0,0 +1,135 @@ +""" +llama_cpp/llama_jinja_format.py +""" +import dataclasses +from typing import Any, Callable, Dict, List, Optional, Protocol, Union + +import jinja2 +from jinja2 import Template + +# NOTE: We sacrifice readability for usability. +# It will fail to work as expected if we attempt to format it in a readable way. +llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %} +[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" + + +class MetaSingleton(type): + """ + Metaclass for implementing the Singleton pattern. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Singleton(object, metaclass=MetaSingleton): + """ + Base class for implementing the Singleton pattern. + """ + + def __init__(self): + super(Singleton, self).__init__() + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +# Base Chat Formatter Protocol +class ChatFormatterInterface(Protocol): + def __init__(self, template: Optional[Dict[str, Any]] = None): + ... + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs, + ) -> ChatFormatterResponse: + ... + + +class AutoChatFormatter(ChatFormatterInterface): + def __init__( + self, + template: Optional[str] = None, + template_class: Optional[Template] = None, + ): + if template is not None: + self._template = template + else: + self._template = llama2_template # default template + + self._renderer = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string( + self._template, + template_class=template_class, + ) + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs: Any, + ) -> ChatFormatterResponse: + formatted_sequence = self._renderer.render(messages=messages, **kwargs) + return ChatFormatterResponse(prompt=formatted_sequence) + + @property + def template(self) -> str: + return self._template + + +class FormatterNotFoundException(Exception): + pass + + +class ChatFormatterFactory(metaclass=MetaSingleton): + _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} + + def register_formatter( + self, + name: str, + formatter_callable: Callable[[], ChatFormatterInterface], + overwrite=False, + ): + if not overwrite and name in self._chat_formatters: + raise ValueError( + f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." + ) + self._chat_formatters[name] = formatter_callable + + def unregister_formatter(self, name: str): + if name in self._chat_formatters: + del self._chat_formatters[name] + else: + raise ValueError(f"No formatter registered under the name '{name}'.") + + def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: + try: + formatter_callable = self._chat_formatters[name] + return formatter_callable() + except KeyError: + raise FormatterNotFoundException( + f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" + ) + + +# Define a chat format class +class Llama2Formatter(AutoChatFormatter): + def __init__(self): + super().__init__(llama2_template) + + +# With the Singleton pattern applied, regardless of where or how many times +# ChatFormatterFactory() is called, it will always return the same instance +# of the factory, ensuring that the factory's state is consistent throughout +# the application. +ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter) diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py new file mode 100644 index 0000000000..4eebcb6745 --- /dev/null +++ b/tests/test_llama_chat_format.py @@ -0,0 +1,50 @@ +from typing import List + +import pytest + +from llama_cpp import ChatCompletionMessage +from llama_cpp.llama_jinja_format import Llama2Formatter + + +@pytest.fixture +def sequence_of_messages() -> List[ChatCompletionMessage]: + return [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage( + role="user", content="Let's go with a recursive solution." + ), + ] + + +def test_llama2_formatter(sequence_of_messages): + expected_prompt = ( + "<> Welcome to CodeHelp Bot! <>\n" + "[INST] Hi there! I need some help with Python. [/INST]\n" + "Of course! What do you need help with in Python?\n" + "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" + "I can help with that! Would you like a recursive or iterative solution?\n" + "[INST] Let's go with a recursive solution. [/INST]\n" + ) + + llama2_formatter_instance = Llama2Formatter() + formatter_response = llama2_formatter_instance(sequence_of_messages) + assert ( + expected_prompt == formatter_response.prompt + ), "The formatted prompt does not match the expected output." + + +# Optionally, include a test for the 'stop' if it's part of the functionality. From 9c11d174fd343e7304fd4a196bf4d53ad8ef294c Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sun, 5 Nov 2023 17:53:17 -0500 Subject: [PATCH 2/6] fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. --- llama_cpp/llama_jinja_format.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py index b8bd22a0fd..d3c10008ec 100644 --- a/llama_cpp/llama_jinja_format.py +++ b/llama_cpp/llama_jinja_format.py @@ -9,8 +9,7 @@ # NOTE: We sacrifice readability for usability. # It will fail to work as expected if we attempt to format it in a readable way. -llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %} -[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" +llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" class MetaSingleton(type): @@ -43,7 +42,7 @@ class ChatFormatterResponse: # Base Chat Formatter Protocol class ChatFormatterInterface(Protocol): - def __init__(self, template: Optional[Dict[str, Any]] = None): + def __init__(self, template: Optional[object] = None): ... def __call__( @@ -53,6 +52,10 @@ def __call__( ) -> ChatFormatterResponse: ... + @property + def template(self) -> str: + ... + class AutoChatFormatter(ChatFormatterInterface): def __init__( @@ -91,7 +94,7 @@ class FormatterNotFoundException(Exception): pass -class ChatFormatterFactory(metaclass=MetaSingleton): +class ChatFormatterFactory(Singleton): _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} def register_formatter( From 72b7e1fc054a1f4bc0114b7a1adaa7fff50350da Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sun, 26 Nov 2023 18:43:28 -0500 Subject: [PATCH 3/6] Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- docs/templates.md | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 docs/templates.md diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 0000000000..5acdaa196b --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,52 @@ +# Templates + +This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model. + +## Introduction + +- Brief explanation of the `llama-cpp-python` project's need for a templating system. +- Overview of the `llama-2` model's interaction with templating. + +## Jinja2 Dependency Integration + +- Rationale for choosing Jinja2 as the templating engine. + - Compatibility with Hugging Face's `transformers`. + - Desire for advanced templating features and simplicity. +- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management. + +## Template Management Refactor + +- Summary of the refactor and the motivation behind it. +- Description of the new chat handler selection logic: + 1. Preference for a user-specified `chat_handler`. + 2. Fallback to a user-specified `chat_format`. + 3. Defaulting to a chat format from a `.gguf` file if available. + 4. Utilizing the `llama2` default chat format as the final fallback. +- Ensuring backward compatibility throughout the refactor. + +## Implementation Details + +- In-depth look at the new `AutoChatFormatter` class. +- Example code snippets showing how to utilize the Jinja2 environment and templates. +- Guidance on how to provide custom templates or use defaults. + +## Testing and Validation + +- Outline of the testing strategy to ensure seamless integration. +- Steps for validating backward compatibility with existing implementations. + +## Benefits and Impact + +- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience. +- Discussion of the potential impact on current users and contributors. + +## Future Work + +- Exploration of how templating can evolve within the project. +- Consideration of additional features or optimizations for the templating engine. +- Mechanisms for community feedback on the templating system. + +## Conclusion + +- Final thoughts on the integration of Jinja2 templating. +- Call to action for community involvement and feedback. From a42042a12b8a4294ea11067052a69150ce270704 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sun, 26 Nov 2023 18:43:50 -0500 Subject: [PATCH 4/6] Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6c10225819..3045f2aa6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1", + "jinja2>=2.11.3,<3.0.0", ] requires-python = ">=3.8" classifiers = [ @@ -71,4 +72,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/" [tool.pytest.ini_options] addopts = "--ignore=vendor" - From d03eb8458ea87da5a03a06cd369bae0a68deb1ec Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sun, 26 Nov 2023 19:22:57 -0500 Subject: [PATCH 5/6] Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3045f2aa6f..9c2dfcceed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,13 @@ license = { text = "MIT" } authors = [ { name = "Andrei Betlen", email = "abetlen@gmail.com" }, ] +# mkdocs-martiral requires "jinja2~=3.0" +# transformers requires "jinja2>=2.11.3" dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1", - "jinja2>=2.11.3,<3.0.0", + "jinja2>=2.11.3", ] requires-python = ">=3.8" classifiers = [ From e5d18ce18fa39bc2ae15f396dc7716ac86485166 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 29 Nov 2023 23:19:35 -0500 Subject: [PATCH 6/6] Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --- llama_cpp/llama_jinja_format.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py index d3c10008ec..68faaf6ca5 100644 --- a/llama_cpp/llama_jinja_format.py +++ b/llama_cpp/llama_jinja_format.py @@ -68,7 +68,7 @@ def __init__( else: self._template = llama2_template # default template - self._renderer = jinja2.Environment( + self._environment = jinja2.Environment( loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True, @@ -82,7 +82,7 @@ def __call__( messages: List[Dict[str, str]], **kwargs: Any, ) -> ChatFormatterResponse: - formatted_sequence = self._renderer.render(messages=messages, **kwargs) + formatted_sequence = self._environment.render(messages=messages, **kwargs) return ChatFormatterResponse(prompt=formatted_sequence) @property