11import argparse
22import logging
33import os
4+ from math import ceil
45from typing import Optional , Dict , Any , Union , List
56
67import requests
3031from pydantic import BaseModel , Field
3132
3233from shyhurricane .doc_type_model_map import ModelConfig
34+ from shyhurricane .utils import process_cpu_count
3335
3436logger = logging .getLogger (__name__ )
3537
@@ -235,6 +237,14 @@ def from_env():
235237 )
236238 return generator_config
237239
240+ def ollama_url (self ) -> str :
241+ return "http://" + (self .ollama_host or OLLAMA_HOST_DEFAULT )
242+
243+ def ollama_pull (self , model_id : str ):
244+ model , tag = model_id .rsplit (":" , maxsplit = 1 )
245+ r = requests .post (f"{ self .ollama_url ()} /api/pull" , json = {"model" : model , "tag" : tag , "force" : False })
246+ r .raise_for_status ()
247+
238248 def apply_reasoning_default (self ):
239249 self .ollama_host = self .ollama_host or OLLAMA_HOST_DEFAULT
240250 if self .ollama_model or self .gemini_model or self .openai_model or self .bedrock_model :
@@ -325,8 +335,9 @@ def create_chat_generator(self,
325335 # https://huggingface.co/docs/inference-providers/guides/gpt-oss
326336 _generation_kwargs ["effort" ] = "high"
327337 logger .info ("Using Ollama chat with model %s at %s" , self .ollama_model , self .ollama_host )
338+ self .ollama_pull (self .ollama_model )
328339 return OllamaChatGenerator (
329- url = "http://" + ( self .ollama_host or OLLAMA_HOST_DEFAULT ),
340+ url = self .ollama_url ( ),
330341 model = self .ollama_model ,
331342 timeout = ollama_timeout ,
332343 generation_kwargs = _generation_kwargs | (generation_kwargs or {}),
@@ -374,8 +385,9 @@ def create_generator(self,
374385 "temperature" : temperature or self .temperature ,
375386 }
376387 logger .info ("Using Ollama generator with model %s at %s" , self .ollama_model , self .ollama_host )
388+ self .ollama_pull (self .ollama_model )
377389 return OllamaGenerator (
378- url = "http://" + ( self .ollama_host or OLLAMA_HOST_DEFAULT ),
390+ url = self .ollama_url ( ),
379391 model = self .ollama_model ,
380392 generation_kwargs = _generation_kwargs | (generation_kwargs or {}),
381393 )
@@ -397,7 +409,7 @@ def _embedder_enable_ollama(self) -> bool:
397409 # v0.12.11, v0.13.0 - macos has use after free failures
398410 # v0.14.0 - macos embedding is working
399411 try :
400- resp_version = requests .get ("http://" + ( self .ollama_host or OLLAMA_HOST_DEFAULT ) + "/api/version" )
412+ resp_version = requests .get (self .ollama_url ( ) + "/api/version" )
401413 resp_version .raise_for_status ()
402414 version = float ("." .join (resp_version .json ()["version" ].split ("." )[0 :2 ]))
403415 return version >= 0.14
@@ -466,9 +478,10 @@ def create_document_embedder(self, model_config: ModelConfig):
466478 )
467479 elif self .ollama_model and self ._embedder_enable_ollama ():
468480 logger .info ("Using Ollama document embedder with model %s at %s" , model_path , self .ollama_host )
481+ self .ollama_pull (model_path )
469482 return OllamaDocumentEmbedder (
470483 model = model_path ,
471- url = "http://" + ( self .ollama_host or OLLAMA_HOST_DEFAULT ),
484+ url = self .ollama_url ( ),
472485 progress_bar = False ,
473486 )
474487
@@ -501,9 +514,10 @@ def create_text_embedder(self, model_config: ModelConfig):
501514 )
502515 elif self .ollama_model and self ._embedder_enable_ollama ():
503516 logger .info ("Using Ollama text embedder with model %s at %s" , model_path , self .ollama_host )
517+ self .ollama_pull (model_path )
504518 return OllamaTextEmbedder (
505519 model = model_path ,
506- url = "http://" + ( self .ollama_host or OLLAMA_HOST_DEFAULT ),
520+ url = self .ollama_url ( ),
507521 )
508522
509523 logger .info ("Using local text embedder with model %s" , model_path )
@@ -528,14 +542,18 @@ def create_sparse_document_embedder(self, model_config: ModelConfig):
528542 return FastembedSparseDocumentEmbedder (
529543 model = model_config .model_name ,
530544 cache_dir = self ._fastembed_cache_dir (),
531- batch_size = 1 ,
545+ threads = max (1 , ceil (process_cpu_count () / 2 )),
546+ batch_size = 32 ,
547+ parallel = 0 ,
532548 progress_bar = False ,
533549 )
534550
535551 def create_sparse_text_embedder (self , model_config : ModelConfig ):
536552 return FastembedSparseTextEmbedder (
537553 model = model_config .model_name ,
538554 cache_dir = self ._fastembed_cache_dir (),
555+ threads = max (1 , ceil (process_cpu_count () / 2 )),
556+ parallel = 0 ,
539557 progress_bar = False ,
540558 )
541559
0 commit comments