|
| 1 | +import base64 |
| 2 | +import json |
| 3 | +import mimetypes |
| 4 | +import time |
| 5 | + |
| 6 | +import aiohttp |
| 7 | + |
| 8 | +# mistral_models = [ |
| 9 | +# "voxtral-mini-latest", |
| 10 | +# "voxtral-small-latest", |
| 11 | +# "codestral-latest", |
| 12 | +# "devstral-latest", |
| 13 | +# "devstral-medium-latest", |
| 14 | +# "devstral-small-latest", |
| 15 | +# "mistral-tiny-latest", |
| 16 | +# "mistral-small-latest", |
| 17 | +# "mistral-medium-latest", |
| 18 | +# "mistral-large-latest", |
| 19 | +# "magistral-small-latest", |
| 20 | +# "magistral-medium-latest", |
| 21 | +# "ministral-3b-latest", |
| 22 | +# "ministral-8b-latest", |
| 23 | +# "ministral-14b-latest", |
| 24 | +# "mistral-moderation-latest", |
| 25 | +# "mistral-ocr-latest", |
| 26 | +# "pixtral-large-latest", |
| 27 | +# "mistral-vibe-cli-latest", |
| 28 | +# ] |
| 29 | + |
| 30 | + |
| 31 | +def install_mistral(ctx): |
| 32 | + from llms.main import GeneratorBase, OpenAiCompatible |
| 33 | + |
| 34 | + async def get_models(request): |
| 35 | + mistral = ctx.get_provider("mistral") |
| 36 | + url = mistral.api + "/models" |
| 37 | + async with aiohttp.ClientSession() as session, session.get( |
| 38 | + url, headers=mistral.headers, timeout=ctx.get_client_timeout() |
| 39 | + ) as response: |
| 40 | + return aiohttp.web.json_response(await response.json()) |
| 41 | + |
| 42 | + ctx.add_get("mistral/models", get_models) |
| 43 | + |
| 44 | + # https://docs.mistral.ai/api/endpoint/audio/transcriptions |
| 45 | + class MistralTranscriptionGenerator(GeneratorBase): |
| 46 | + sdk = "mistral/transcriptions" |
| 47 | + api_url = "https://api.mistral.ai/v1/audio/transcriptions" |
| 48 | + |
| 49 | + def __init__(self, **kwargs): |
| 50 | + super().__init__(**kwargs) |
| 51 | + |
| 52 | + async def chat(self, chat, provider=None, context=None): |
| 53 | + headers = self.get_headers(provider, chat) |
| 54 | + # Remove Content-Type to allow aiohttp to set it for FormData |
| 55 | + if "Content-Type" in headers: |
| 56 | + del headers["Content-Type"] |
| 57 | + |
| 58 | + # Ensure x-api-key is present if Authorization is available |
| 59 | + if "Authorization" in headers and "x-api-key" not in headers: |
| 60 | + token = headers["Authorization"].replace("Bearer ", "") |
| 61 | + headers["x-api-key"] = token |
| 62 | + |
| 63 | + model = provider.provider_model(chat["model"]) or chat["model"] or "voxtral-mini-latest" |
| 64 | + # Replace internal alias with actual model name |
| 65 | + if model == "voxtral-mini-transcription": |
| 66 | + model = "voxtral-mini-latest" |
| 67 | + |
| 68 | + # Process chat to handle inputs (downloads URLs, reads files, converts to base64) |
| 69 | + chat = await self.process_chat(chat, provider_id=self.id) |
| 70 | + |
| 71 | + # Find audio data |
| 72 | + audio_data = None |
| 73 | + filename = "audio.mp3" |
| 74 | + |
| 75 | + # Search for input_audio or file in the messages |
| 76 | + for message in reversed(chat["messages"]): |
| 77 | + content = message.get("content") |
| 78 | + if isinstance(content, list): |
| 79 | + for item in content: |
| 80 | + if item.get("type") == "input_audio": |
| 81 | + audio_data = item["input_audio"]["data"] |
| 82 | + fmt = item["input_audio"].get("format", "mp3") |
| 83 | + filename = f"audio.{fmt}" |
| 84 | + break |
| 85 | + # Support 'file' type if it appears to be audio |
| 86 | + elif item.get("type") == "file": |
| 87 | + file_data = item["file"].get("file_data") |
| 88 | + fn = item["file"].get("filename", "") |
| 89 | + if fn: |
| 90 | + ext = fn.split(".")[-1] |
| 91 | + if ext.lower() in ["mp3", "wav", "ogg", "flac", "m4a"]: |
| 92 | + audio_data = file_data |
| 93 | + filename = fn |
| 94 | + break |
| 95 | + if audio_data: |
| 96 | + break |
| 97 | + |
| 98 | + if not audio_data: |
| 99 | + raise Exception( |
| 100 | + "No audio file found in the request. Please provide an audio file via --audio or as an attachment." |
| 101 | + ) |
| 102 | + |
| 103 | + # Decode base64 data |
| 104 | + if audio_data.startswith("data:"): |
| 105 | + # Handle data URI scheme: data:audio/mp3;base64,... |
| 106 | + audio_data = audio_data.split(";base64,")[1] |
| 107 | + |
| 108 | + try: |
| 109 | + file_bytes = base64.b64decode(audio_data) |
| 110 | + except Exception as e: |
| 111 | + raise Exception(f"Failed to decode audio data: {e}") from e |
| 112 | + |
| 113 | + # Prepare FormData |
| 114 | + data = aiohttp.FormData() |
| 115 | + data.add_field("model", model) |
| 116 | + data.add_field( |
| 117 | + "file", file_bytes, filename=filename, content_type=mimetypes.guess_type(filename)[0] or "audio/mpeg" |
| 118 | + ) |
| 119 | + |
| 120 | + ctx.log(f"POST {self.api_url} model={model} file={filename} ({len(file_bytes)} bytes)") |
| 121 | + |
| 122 | + async with aiohttp.ClientSession() as session, session.post( |
| 123 | + self.api_url, headers=headers, data=data |
| 124 | + ) as response: |
| 125 | + text = await response.text() |
| 126 | + if response.status != 200: |
| 127 | + raise Exception(f"Mistral API Error {response.status}: {text}") |
| 128 | + |
| 129 | + context["providerResponse"] = text |
| 130 | + |
| 131 | + try: |
| 132 | + result = json.loads(text) |
| 133 | + except Exception: |
| 134 | + result = {"text": text} # Fallback |
| 135 | + |
| 136 | + transcription = result.get("text", "") |
| 137 | + |
| 138 | + ret = { |
| 139 | + "choices": [ |
| 140 | + { |
| 141 | + "message": { |
| 142 | + "role": "assistant", |
| 143 | + "content": transcription, |
| 144 | + } |
| 145 | + } |
| 146 | + ], |
| 147 | + "created": result.get("created", int(time.time())), |
| 148 | + } |
| 149 | + |
| 150 | + if "model" in result: |
| 151 | + ret["model"] = result["model"] |
| 152 | + |
| 153 | + if "usage" in result: |
| 154 | + ret["usage"] = result["usage"] |
| 155 | + |
| 156 | + return ret |
| 157 | + |
| 158 | + class MistralProvider(OpenAiCompatible): |
| 159 | + sdk = "@ai-sdk/mistral" |
| 160 | + |
| 161 | + def __init__(self, **kwargs): |
| 162 | + if "api" not in kwargs: |
| 163 | + kwargs["api"] = "https://api.mistral.ai/v1" |
| 164 | + super().__init__(**kwargs) |
| 165 | + self.transcription = MistralTranscriptionGenerator(**kwargs) |
| 166 | + |
| 167 | + async def process_chat(self, chat, provider_id=None): |
| 168 | + ret = await super().process_chat(chat, provider_id) |
| 169 | + messages = chat.get("messages", []).copy() |
| 170 | + for message in messages: |
| 171 | + message.pop("timestamp", None) # mistral doesn't support extra fields |
| 172 | + ret["messages"] = messages |
| 173 | + return ret |
| 174 | + |
| 175 | + async def chat(self, chat, context=None): |
| 176 | + model = self.provider_model(chat["model"]) or chat["model"] |
| 177 | + model_info = self.model_info(model) |
| 178 | + model_modalities = model_info.get("modalities", {}) |
| 179 | + input_modalities = model_modalities.get("input", []) |
| 180 | + # if only audio modality, use transcription |
| 181 | + if "audio" in input_modalities and len(input_modalities) == 1: |
| 182 | + return await self.transcription.chat(chat, provider=self, context=context) |
| 183 | + return await super().chat(chat, context=context) |
| 184 | + |
| 185 | + ctx.add_provider(MistralTranscriptionGenerator) |
| 186 | + ctx.add_provider(MistralProvider) |
0 commit comments