Skip to content

Commit f1d4624

Browse files
authored
Merge pull request #603 from better629/feat_ollama
feat: add ollama api support
2 parents 25eeb6c + 40d3cc5 commit f1d4624

File tree

10 files changed

+286
-38
lines changed

10 files changed

+286
-38
lines changed

config/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ RPM: 10
4848
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
4949
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
5050

51+
#### if use self-host open llm model by ollama
52+
# OLLAMA_API_BASE: http://127.0.0.1:11434/api
53+
# OLLAMA_API_MODEL: llama2
54+
5155
#### for Search
5256

5357
## Supported values: serpapi/google/serper/ddg

metagpt/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class LLMProviderEnum(Enum):
4242
FIREWORKS = "fireworks"
4343
OPEN_LLM = "open_llm"
4444
GEMINI = "gemini"
45+
OLLAMA = "ollama"
4546

4647

4748
class Config(metaclass=Singleton):
@@ -78,7 +79,8 @@ def get_default_llm_provider_enum(self) -> LLMProviderEnum:
7879
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
7980
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
8081
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
81-
(self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key
82+
(self.gemini_api_key, LLMProviderEnum.GEMINI),
83+
(self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key
8284
]:
8385
if self._is_valid_llm_key(k):
8486
# logger.debug(f"Use LLMProvider: {v.value}")
@@ -103,6 +105,8 @@ def _update(self):
103105
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
104106
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
105107
self.gemini_api_key = self._get("GEMINI_API_KEY")
108+
self.ollama_api_base = self._get("OLLAMA_API_BASE")
109+
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
106110
_ = self.get_default_llm_provider_enum()
107111

108112
self.openai_base_url = self._get("OPENAI_BASE_URL")

metagpt/const.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,5 @@ def get_metagpt_root():
102102
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
103103

104104
YAPI_URL = "http://yapi.deepwisdomai.com/"
105+
106+
LLM_API_TIMEOUT = 300

metagpt/provider/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
from metagpt.provider.fireworks_api import FireWorksGPTAPI
1010
from metagpt.provider.google_gemini_api import GeminiGPTAPI
11+
from metagpt.provider.ollama_api import OllamaGPTAPI
1112
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
1213
from metagpt.provider.openai_api import OpenAIGPTAPI
1314
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
1415

15-
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"]
16+
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"]

metagpt/provider/general_api_base.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Desc : refs to openai 0.x sdk
4+
15
import asyncio
26
import json
37
import os
@@ -43,8 +47,8 @@
4347
# Has one attribute per thread, 'session'.
4448
_thread_context = threading.local()
4549

46-
OPENAI_LOG = os.environ.get("OPENAI_LOG")
47-
OPENAI_LOG = "debug"
50+
LLM_LOG = os.environ.get("LLM_LOG")
51+
LLM_LOG = "debug"
4852

4953

5054
class ApiType(Enum):
@@ -74,8 +78,8 @@ def from_str(label):
7478

7579

7680
def _console_log_level():
77-
if OPENAI_LOG in ["debug", "info"]:
78-
return OPENAI_LOG
81+
if LLM_LOG in ["debug", "info"]:
82+
return LLM_LOG
7983
else:
8084
return None
8185

@@ -140,7 +144,7 @@ def operation_location(self) -> Optional[str]:
140144

141145
@property
142146
def organization(self) -> Optional[str]:
143-
return self._headers.get("OpenAI-Organization")
147+
return self._headers.get("LLM-Organization")
144148

145149
@property
146150
def response_ms(self) -> Optional[int]:
@@ -478,7 +482,7 @@ def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False
478482
error_data["message"] += "\n\n" + error_data["internal_message"]
479483

480484
log_info(
481-
"OpenAI API error received",
485+
"LLM API error received",
482486
error_code=error_data.get("code"),
483487
error_type=error_data.get("type"),
484488
error_message=error_data.get("message"),
@@ -516,7 +520,7 @@ def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False
516520
)
517521

518522
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
519-
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
523+
user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,)
520524

521525
uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node")
522526
ua = {
@@ -530,17 +534,17 @@ def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict
530534
}
531535

532536
headers = {
533-
"X-OpenAI-Client-User-Agent": json.dumps(ua),
537+
"X-LLM-Client-User-Agent": json.dumps(ua),
534538
"User-Agent": user_agent,
535539
}
536540

537541
headers.update(api_key_to_header(self.api_type, self.api_key))
538542

539543
if self.organization:
540-
headers["OpenAI-Organization"] = self.organization
544+
headers["LLM-Organization"] = self.organization
541545

542546
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
543-
headers["OpenAI-Version"] = self.api_version
547+
headers["LLM-Version"] = self.api_version
544548
if request_id is not None:
545549
headers["X-Request-Id"] = request_id
546550
headers.update(extra)
@@ -592,15 +596,14 @@ def _prepare_request_raw(
592596
headers["Content-Type"] = "application/json"
593597
else:
594598
raise openai.APIConnectionError(
595-
"Unrecognized HTTP method %r. This may indicate a bug in the "
596-
"OpenAI bindings. Please contact us through our help center at help.openai.com for "
597-
"assistance." % (method,)
599+
message=f"Unrecognized HTTP method {method}. This may indicate a bug in the LLM bindings.",
600+
request=None,
598601
)
599602

600603
headers = self.request_headers(method, headers, request_id)
601604

602-
log_debug("Request to OpenAI API", method=method, path=abs_url)
603-
log_debug("Post details", data=data, api_version=self.api_version)
605+
# log_debug("Request to LLM API", method=method, path=abs_url)
606+
# log_debug("Post details", data=data, api_version=self.api_version)
604607

605608
return abs_url, headers, data
606609

@@ -639,14 +642,14 @@ def request_raw(
639642
except requests.exceptions.Timeout as e:
640643
raise openai.APITimeoutError("Request timed out: {}".format(e)) from e
641644
except requests.exceptions.RequestException as e:
642-
raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
643-
log_debug(
644-
"OpenAI API response",
645-
path=abs_url,
646-
response_code=result.status_code,
647-
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
648-
request_id=result.headers.get("X-Request-Id"),
649-
)
645+
raise openai.APIConnectionError(message="Error communicating with LLM: {}".format(e), request=None) from e
646+
# log_debug(
647+
# "LLM API response",
648+
# path=abs_url,
649+
# response_code=result.status_code,
650+
# processing_ms=result.headers.get("LLM-Processing-Ms"),
651+
# request_id=result.headers.get("X-Request-Id"),
652+
# )
650653
return result
651654

652655
async def arequest_raw(
@@ -685,18 +688,18 @@ async def arequest_raw(
685688
}
686689
try:
687690
result = await session.request(**request_kwargs)
688-
log_info(
689-
"OpenAI API response",
690-
path=abs_url,
691-
response_code=result.status,
692-
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
693-
request_id=result.headers.get("X-Request-Id"),
694-
)
691+
# log_info(
692+
# "LLM API response",
693+
# path=abs_url,
694+
# response_code=result.status,
695+
# processing_ms=result.headers.get("LLM-Processing-Ms"),
696+
# request_id=result.headers.get("X-Request-Id"),
697+
# )
695698
return result
696699
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
697700
raise openai.APITimeoutError("Request timed out") from e
698701
except aiohttp.ClientError as e:
699-
raise openai.APIConnectionError("Error communicating with OpenAI") from e
702+
raise openai.APIConnectionError(message="Error communicating with LLM", request=None) from e
700703

701704
def _interpret_response(
702705
self, result: requests.Response, stream: bool

metagpt/provider/general_api_requestor.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,38 @@
33
# @Desc : General Async API for http-based LLM model
44

55
import asyncio
6-
from typing import AsyncGenerator, Tuple, Union
6+
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
77

88
import aiohttp
9+
import requests
910

1011
from metagpt.logs import logger
1112
from metagpt.provider.general_api_base import APIRequestor
1213

1314

15+
def parse_stream_helper(line: bytes) -> Union[bytes, None]:
16+
if line and line.startswith(b"data:"):
17+
if line.startswith(b"data: "):
18+
# SSE event may be valid when it contain whitespace
19+
line = line[len(b"data: ") :]
20+
else:
21+
line = line[len(b"data:") :]
22+
if line.strip() == b"[DONE]":
23+
# return here will cause GeneratorExit exception in urllib3
24+
# and it will close http connection with TCP Reset
25+
return None
26+
else:
27+
return line
28+
return None
29+
30+
31+
def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
32+
for line in rbody:
33+
_line = parse_stream_helper(line)
34+
if _line is not None:
35+
yield _line
36+
37+
1438
class GeneralAPIRequestor(APIRequestor):
1539
"""
1640
usage
@@ -26,16 +50,40 @@ class GeneralAPIRequestor(APIRequestor):
2650
)
2751
"""
2852

29-
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> str:
53+
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
3054
# just do nothing to meet the APIRequestor process and return the raw data
3155
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
3256

3357
return rbody
3458

59+
def _interpret_response(
60+
self, result: requests.Response, stream: bool
61+
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
62+
"""Returns the response(s) and a bool indicating whether it is a stream."""
63+
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
64+
return (
65+
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
66+
for line in parse_stream(result.iter_lines())
67+
), True
68+
else:
69+
return (
70+
self._interpret_response_line(
71+
result.content, # let the caller to decode the msg
72+
result.status_code,
73+
result.headers,
74+
stream=False,
75+
),
76+
False,
77+
)
78+
3579
async def _interpret_async_response(
3680
self, result: aiohttp.ClientResponse, stream: bool
37-
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
38-
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
81+
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
82+
if stream and (
83+
"text/event-stream" in result.headers.get("Content-Type", "")
84+
or "application/x-ndjson" in result.headers.get("Content-Type", "")
85+
):
86+
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
3987
return (
4088
self._interpret_response_line(line, result.status, result.headers, stream=True)
4189
async for line in result.content

0 commit comments

Comments
 (0)