Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ dist
.vscode
.pytest_cache
python
.env.test
.env.test
test.py
test2.py
main.py
lamoom_venv/
24 changes: 12 additions & 12 deletions docs/getting_started_notebook.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions lamoom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
from lamoom.responses import AIResponse
from lamoom.ai_models.openai.responses import OpenAIResponse
from lamoom.ai_models.behaviour import AIModelsBehaviour, PromptAttempts

from lamoom.validators import JSONValidator, XMLValidator, YAMLValidator
3 changes: 3 additions & 0 deletions lamoom/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ class NotParsedResponseException(LamoomError):

class APITokenNotProvided(LamoomError):
pass

class ValidatorException(LamoomError):
pass
104 changes: 99 additions & 5 deletions lamoom/prompt/lamoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from decimal import Decimal
import requests
import time
from lamoom.settings import LAMOOM_API_URI
import json
from lamoom.settings import LAMOOM_API_URI, PROMPT_VALIDATORS
from lamoom import Secrets, settings
from lamoom.ai_models.ai_model import AI_MODELS_PROVIDER
from lamoom.ai_models.attempt_to_call import AttemptToCall
Expand All @@ -16,7 +17,8 @@

from lamoom.exceptions import (
LamoomPromptIsnotFoundError,
RetryableCustomError
RetryableCustomError,
ValidatorException
)
from lamoom.services.SaveWorker import SaveWorker
from lamoom.prompt.prompt import Prompt
Expand All @@ -25,8 +27,7 @@
from lamoom.responses import AIResponse
from lamoom.services.lamoom import LamoomService
from lamoom.utils import current_timestamp_ms
import json

from lamoom.validators import Validator
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -220,7 +221,7 @@ def init_behavior(self, model: str, provider_url: str = None) -> AIModelsBehavio
fallback_attempts=fallback_attempts
)

def call(
def call_llm(
self,
prompt_id: str,
context: t.Dict[str, str],
Expand Down Expand Up @@ -310,6 +311,99 @@ def call(
"Prompt call failed, no attempts worked"
)
raise Exception


def call(
self,
prompt_id: str,
context: t.Dict[str, str],
model: str,
provider_url: str = None,
params: t.Dict[str, t.Any] = {},
version: str = None,
count_of_retries: int = 5,
test_data: dict = {},
stream_function: t.Callable = None,
check_connection: t.Callable = None,
stream_params: dict = {},
) -> AIResponse:

max_attempts = 1
validators = PROMPT_VALIDATORS.get(prompt_id)
if validators is None:
validators = []
else:
validators = validators.values()
for validator in validators:
max_attempts += min(sum(map(int, validator.retry_rules.values())), validator.retry)

total_results: t.List[AIResponse] = []
total_errors: t.List[dict] = []

for iteration in range(max_attempts):
result = None
try:
result = self.call_llm(
prompt_id=prompt_id,
context=context,
model=model,
provider_url=provider_url,
params=params,
version=version,
count_of_retries=count_of_retries,
test_data=test_data,
stream_function=stream_function,
check_connection=check_connection,
stream_params=stream_params
)
except Exception as e:
logger.error(f"Attempt {iteration + 1} failed with error: {str(e)}")
if result is None:
result = AIResponse()
result.errors = [{
"iteration": iteration,
"error": str(e)
}]
break

validation_failed = False
can_retry = False

validation_errors = []
for validator in validators:
validator.validate(result)
if validator.has_errors():
validation_failed = True
for error in validator.get_errors():
validation_errors.append({
"id": validator.id,
"iteration": iteration,
"error": validator.format_error(error)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add here validator's id/name

})
if validator.can_retry():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that you made that logic isolated

can_retry = True
else:
total_errors.extend(validation_errors)
break

result.errors = validation_errors if validation_errors else None
total_results.append(result)

if validation_failed:
if can_retry and iteration < max_attempts - 1:
logger.info(f"Validation errors occurred, retrying (attempt {iteration + 1}/{max_attempts})")
continue
else:
error_messages = [e["error"] for e in validation_errors[-len(validators):]]
logger.error(f"Validation failed: {', '.join(error_messages)}")
raise ValidatorException()
else:
total_results[-1].attemps = total_results[:-1]
return total_results[-1]

logger.error("All attempts failed")
raise Exception("All attempts failed. Errors: " + ", ".join([e["error"] for e in total_errors]))


def get_prompt(self, prompt_id: str, version: str = None) -> Prompt:
"""
Expand Down
10 changes: 1 addition & 9 deletions lamoom/response_parsers/response_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,11 @@
import yaml

from lamoom.exceptions import NotParsedResponseException
from lamoom.responses import AIResponse
from lamoom.responses import AIResponse, Tag

logger = logging.getLogger(__name__)


@dataclass
class Tag:
start_tag: str
end_tag: str
include_tag: bool
is_right_find_end_ind: bool = False


@dataclass
class TaggedContent:
content: str
Expand Down
69 changes: 68 additions & 1 deletion lamoom/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import json
import logging
from dataclasses import dataclass, field
from functools import cached_property
import typing as t

from lamoom.exceptions import LamoomError
logger = logging.getLogger(__name__)


Expand All @@ -24,6 +25,13 @@ class Metrics:
ai_model_details: dict = None
latency: int = None

@dataclass
class Tag:
start_tag: str
end_tag: str
include_tag: bool
is_right_find_end_ind: bool = False


@dataclass(kw_only=True)
class AIResponse:
Expand All @@ -34,10 +42,69 @@ class AIResponse:
prompt: Prompt = field(default_factory=Prompt)
metrics: Metrics = field(default_factory=Metrics)
id: str = ""
errors: t.Optional[t.List[LamoomError]] = None
attemps: t.Optional[t.List] = None

@property
def response(self) -> str:
return self._response

def get_message_str(self) -> str:
return json.loads(self.response)

@cached_property
def json_list(self) -> t.List[t.Dict]:

tags = [Tag("```json", "\n```", 0), Tag("```json", "```", 0)]
return self._get_tagged_content_list(tags)

@cached_property
def xml_list(self) -> t.List[t.Dict]:

tags = [Tag("```xml", "\n```", 0), Tag("```xml", "```", 0)]
return self._get_tagged_content_list(tags)

@cached_property
def yaml_list(self) -> t.List[t.Dict]:

tags = [Tag("```yaml", "\n```", 0), Tag("```yaml", "```", 0)]
return self._get_tagged_content_list(tags)

def _get_tagged_content_list(self, tags: Tag) -> t.List[t.Dict]:

content_list = []
start_from = 0
while True:
response_tagged, start_from, end_ind = self._get_format_from_response(tags, start_from)
if response_tagged:
content_list.append({
"content": response_tagged,
"start_ind": start_from,
"end_ind": end_ind
})
start_from = end_ind
else:
break
return content_list

def _get_format_from_response(self, tags: list, start_from: int = 0):

start_ind, end_ind = 0, -1
content = self.response[start_from:]
for t in tags:
start_ind = content.find(t.start_tag)
if t.is_right_find_end_ind:
end_ind = content.rfind(t.end_tag, start_ind + len(t.start_tag))
else:
end_ind = content.find(t.end_tag, start_ind + len(t.start_tag))
if start_ind != -1:
try:
if t.include_tag:
end_ind += len(t.end_tag)
else:
start_ind += len(t.start_tag)
response_tagged = content[start_ind:end_ind].strip()
return response_tagged, start_from + start_ind, start_from + end_ind
except Exception as e:
logger.exception(f"Couldn't parse json:\n{content}")
return None, 0, -1
9 changes: 7 additions & 2 deletions lamoom/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
os.environ.get("LAMOOM_RECEIVE_PROMPT_FROM_SERVER", True)
)
PIPE_PROMPTS = {}
PROMPT_VALIDATORS = {}
FALLBACK_MODELS = []


Expand All @@ -48,7 +49,11 @@ class Secrets:
CUSTOM_API_KEY: str = field(default_factory=lambda: os.getenv("CUSTOM_API_KEY"))
OPENAI_ORG: str = field(default_factory=lambda: os.getenv("OPENAI_ORG"))
azure_keys: dict = field(
default_factory=lambda: json.loads(
os.getenv("azure_keys", os.getenv("AZURE_OPENAI_KEYS", os.getenv("AZURE_KEYS", "{}")))
default_factory = lambda: (
json.loads(
os.getenv("azure_keys", os.getenv("AZURE_OPENAI_KEYS", os.getenv("AZURE_KEYS", "{}")))
)
if os.getenv("azure_keys") or os.getenv("AZURE_OPENAI_KEYS") or os.getenv("AZURE_KEYS")
else {}
)
)
Loading