Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,18 @@ def average_stats(responses: List[OllamaResponse]):
inference_stats(res)


def get_benchmark_models(skip_models: List[str] = []) -> List[str]:
models = ollama.list().get("models", [])
model_names = [model["name"] for model in models]
if len(skip_models) > 0:
model_names = [
model for model in model_names if model not in skip_models
]
def get_benchmark_models(models_to_run: List[str] = []) -> List[str]:
all_models = ollama.list().get("models", [])
all_model_names = [model["name"] for model in all_models]

if not models_to_run:
model_names = all_model_names
else:
model_names = [model for model in models_to_run if model in all_model_names]
if len(model_names) != len(models_to_run):
missing_models = set(models_to_run) - set(model_names)
print(f"Warning: The following models were not found: {', '.join(missing_models)}")

print(f"Evaluating models: {model_names}\n")
return model_names

Expand All @@ -167,11 +172,11 @@ def main():
default=False,
)
parser.add_argument(
"-s",
"--skip-models",
"-m",
"--models",
nargs="*",
default=[],
help="List of model names to skip. Separate multiple models with spaces.",
help="List of model names to benchmark. If not specified, all available models will be used.",
)
parser.add_argument(
"-p",
Expand All @@ -187,13 +192,13 @@ def main():
args = parser.parse_args()

verbose = args.verbose
skip_models = args.skip_models
models = args.models
prompts = args.prompts
print(
f"\nVerbose: {verbose}\nSkip models: {skip_models}\nPrompts: {prompts}"
f"\nVerbose: {verbose}\nModels: {models}\nPrompts: {prompts}"
)

model_names = get_benchmark_models(skip_models)
model_names = get_benchmark_models(args.models)
benchmarks = {}

for model_name in model_names:
Expand Down