Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# OS
.DS_Store
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11.8
54 changes: 38 additions & 16 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from typing import List
from typing import List, Optional

import ollama
from pydantic import (
Expand All @@ -14,19 +14,21 @@
class Message(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
tool_calls: Optional[List[dict]] = None


class OllamaResponse(BaseModel):
model: str
created_at: datetime
created_at: datetime = None
message: Message
done: bool
total_duration: int
done: bool = True
total_duration: int = None
load_duration: int = 0
prompt_eval_count: int = Field(-1, validate_default=True)
prompt_eval_duration: int
eval_count: int
eval_duration: int
prompt_eval_count: int = Field(0, validate_default=True)
prompt_eval_duration: int = 0
eval_count: int = 0
eval_duration: int = 0

@field_validator("prompt_eval_count")
@classmethod
Expand All @@ -35,14 +37,32 @@ def validate_prompt_eval_count(cls, value: int) -> int:
print(
"\nWarning: prompt token count was not provided, potentially due to prompt caching. For more info, see https://github.com/ollama/ollama/issues/2068\n"
)
return 0 # Set default value
return 0
return value

@classmethod
def from_chat_response(cls, response):
return cls(
model=response.model,
message=Message(
role=response.message.role,
content=response.message.content,
images=getattr(response.message, 'images', None),
tool_calls=getattr(response.message, 'tool_calls', None)
),
done=True,
total_duration=getattr(response, 'total_duration', 0),
load_duration=getattr(response, 'load_duration', 0),
prompt_eval_count=getattr(response, 'prompt_eval_count', 0),
prompt_eval_duration=getattr(response, 'prompt_eval_duration', 0),
eval_count=getattr(response, 'eval_count', 0),
eval_duration=getattr(response, 'eval_duration', 0)
)


def run_benchmark(
model_name: str, prompt: str, verbose: bool
) -> OllamaResponse:

last_element = None

if verbose:
Expand All @@ -57,7 +77,7 @@ def run_benchmark(
stream=True,
)
for chunk in stream:
print(chunk["message"]["content"], end="", flush=True)
print(chunk.message.content, end="", flush=True)
last_element = chunk
else:
last_element = ollama.chat(
Expand All @@ -74,10 +94,7 @@ def run_benchmark(
print("System Error: No response received from ollama")
return None

# with open("data/ollama/ollama_res.json", "w") as outfile:
# outfile.write(json.dumps(last_element, indent=4))

return OllamaResponse.model_validate(last_element)
return OllamaResponse.from_chat_response(last_element)


def nanosec_to_sec(nanosec):
Expand Down Expand Up @@ -145,8 +162,11 @@ def average_stats(responses: List[OllamaResponse]):


def get_benchmark_models(skip_models: List[str] = []) -> List[str]:
print(ollama.list())
x = ollama.list()
print(x.get("models", []))
models = ollama.list().get("models", [])
model_names = [model["name"] for model in models]
model_names = [model["model"] for model in models]
if len(skip_models) > 0:
model_names = [
model for model in model_names if model not in skip_models
Expand Down Expand Up @@ -202,9 +222,11 @@ def main():
if verbose:
print(f"\n\nBenchmarking: {model_name}\nPrompt: {prompt}")
response = run_benchmark(model_name, prompt, verbose=verbose)
#print(f"----------{response}----------")
responses.append(response)

if verbose:

print(f"Response: {response.message.content}")
inference_stats(response)
benchmarks[model_name] = responses
Expand Down
14 changes: 12 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
ollama
pydantic
annotated-types==0.7.0
anyio==4.7.0
certifi==2024.12.14
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
idna==3.10
ollama==0.4.4
pydantic==2.10.3
pydantic-core==2.27.1
sniffio==1.3.1
typing-extensions==4.12.2