Skip to content

Commit 87acdef

Browse files
feat: tpm-rpm limit in prometheus metrics (BerriAI#19725)
Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>
1 parent 79603b9 commit 87acdef

File tree

4 files changed

+102
-21
lines changed

4 files changed

+102
-21
lines changed

litellm/integrations/prometheus.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,18 @@ def __init__( # noqa: PLR0915
316316
labelnames=self.get_labels_for_metric("litellm_deployment_state"),
317317
)
318318

319+
self.litellm_deployment_tpm_limit = self._gauge_factory(
320+
"litellm_deployment_tpm_limit",
321+
"Deployment TPM limit found in config",
322+
labelnames=self.get_labels_for_metric("litellm_deployment_tpm_limit"),
323+
)
324+
325+
self.litellm_deployment_rpm_limit = self._gauge_factory(
326+
"litellm_deployment_rpm_limit",
327+
"Deployment RPM limit found in config",
328+
labelnames=self.get_labels_for_metric("litellm_deployment_rpm_limit"),
329+
)
330+
319331
self.litellm_deployment_cooled_down = self._counter_factory(
320332
"litellm_deployment_cooled_down",
321333
"LLM Deployment Analytics - Number of times a deployment has been cooled down by LiteLLM load balancing logic. exception_status is the status of the exception that caused the deployment to be cooled down",
@@ -1778,6 +1790,49 @@ def set_llm_deployment_failure_metrics(self, request_kwargs: dict):
17781790
)
17791791
)
17801792

1793+
def _set_deployment_tpm_rpm_limit_metrics(
1794+
self,
1795+
model_info: dict,
1796+
litellm_params: dict,
1797+
litellm_model_name: Optional[str],
1798+
model_id: Optional[str],
1799+
api_base: Optional[str],
1800+
llm_provider: Optional[str],
1801+
):
1802+
"""
1803+
Set the deployment TPM and RPM limits metrics
1804+
"""
1805+
tpm = model_info.get("tpm") or litellm_params.get("tpm")
1806+
rpm = model_info.get("rpm") or litellm_params.get("rpm")
1807+
1808+
if tpm is not None:
1809+
_labels = prometheus_label_factory(
1810+
supported_enum_labels=self.get_labels_for_metric(
1811+
metric_name="litellm_deployment_tpm_limit"
1812+
),
1813+
enum_values=UserAPIKeyLabelValues(
1814+
litellm_model_name=litellm_model_name,
1815+
model_id=model_id,
1816+
api_base=api_base,
1817+
api_provider=llm_provider,
1818+
),
1819+
)
1820+
self.litellm_deployment_tpm_limit.labels(**_labels).set(tpm)
1821+
1822+
if rpm is not None:
1823+
_labels = prometheus_label_factory(
1824+
supported_enum_labels=self.get_labels_for_metric(
1825+
metric_name="litellm_deployment_rpm_limit"
1826+
),
1827+
enum_values=UserAPIKeyLabelValues(
1828+
litellm_model_name=litellm_model_name,
1829+
model_id=model_id,
1830+
api_base=api_base,
1831+
api_provider=llm_provider,
1832+
),
1833+
)
1834+
self.litellm_deployment_rpm_limit.labels(**_labels).set(rpm)
1835+
17811836
def set_llm_deployment_success_metrics(
17821837
self,
17831838
request_kwargs: dict,
@@ -1811,6 +1866,16 @@ def set_llm_deployment_success_metrics(
18111866
_model_info = _metadata.get("model_info") or {}
18121867
model_id = _model_info.get("id", None)
18131868

1869+
if _model_info or _litellm_params:
1870+
self._set_deployment_tpm_rpm_limit_metrics(
1871+
model_info=_model_info,
1872+
litellm_params=_litellm_params,
1873+
litellm_model_name=litellm_model_name,
1874+
model_id=model_id,
1875+
api_base=api_base,
1876+
llm_provider=llm_provider,
1877+
)
1878+
18141879
remaining_requests: Optional[int] = None
18151880
remaining_tokens: Optional[int] = None
18161881
if additional_headers := standard_logging_payload["hidden_params"][

litellm/litellm_core_utils/get_litellm_params.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,11 @@ def get_litellm_params(
9393
"text_completion": text_completion,
9494
"azure_ad_token_provider": azure_ad_token_provider,
9595
"user_continue_message": user_continue_message,
96-
"base_model": base_model or (
97-
_get_base_model_from_litellm_call_metadata(metadata=metadata) if metadata else None
96+
"base_model": base_model
97+
or (
98+
_get_base_model_from_litellm_call_metadata(metadata=metadata)
99+
if metadata
100+
else None
98101
),
99102
"litellm_trace_id": litellm_trace_id,
100103
"litellm_session_id": litellm_session_id,
@@ -139,5 +142,7 @@ def get_litellm_params(
139142
"aws_sts_endpoint": kwargs.get("aws_sts_endpoint"),
140143
"aws_external_id": kwargs.get("aws_external_id"),
141144
"aws_bedrock_runtime_endpoint": kwargs.get("aws_bedrock_runtime_endpoint"),
145+
"tpm": kwargs.get("tpm"),
146+
"rpm": kwargs.get("rpm"),
142147
}
143148
return litellm_params

litellm/main.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@
148148
validate_and_fix_openai_messages,
149149
validate_and_fix_openai_tools,
150150
validate_chat_completion_tool_choice,
151-
validate_openai_optional_params
151+
validate_openai_optional_params,
152152
)
153153

154154
from ._logging import verbose_logger
@@ -368,7 +368,7 @@ async def create(self, messages, model=None, **kwargs):
368368

369369
@tracer.wrap()
370370
@client
371-
async def acompletion( # noqa: PLR0915
371+
async def acompletion( # noqa: PLR0915
372372
model: str,
373373
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
374374
messages: List = [],
@@ -603,12 +603,11 @@ async def acompletion( # noqa: PLR0915
603603
if timeout is not None and isinstance(timeout, (int, float)):
604604
timeout_value = float(timeout)
605605
init_response = await asyncio.wait_for(
606-
loop.run_in_executor(None, func_with_context),
607-
timeout=timeout_value
606+
loop.run_in_executor(None, func_with_context), timeout=timeout_value
608607
)
609608
else:
610609
init_response = await loop.run_in_executor(None, func_with_context)
611-
610+
612611
if isinstance(init_response, dict) or isinstance(
613612
init_response, ModelResponse
614613
): ## CACHING SCENARIO
@@ -640,6 +639,7 @@ async def acompletion( # noqa: PLR0915
640639
except asyncio.TimeoutError:
641640
custom_llm_provider = custom_llm_provider or "openai"
642641
from litellm.exceptions import Timeout
642+
643643
raise Timeout(
644644
message=f"Request timed out after {timeout} seconds",
645645
model=model,
@@ -1118,7 +1118,6 @@ def completion( # type: ignore # noqa: PLR0915
11181118
# validate optional params
11191119
stop = validate_openai_optional_params(stop=stop)
11201120

1121-
11221121
######### unpacking kwargs #####################
11231122
args = locals()
11241123

@@ -1135,7 +1134,9 @@ def completion( # type: ignore # noqa: PLR0915
11351134
# Check if MCP tools are present (following responses pattern)
11361135
# Cast tools to Optional[Iterable[ToolParam]] for type checking
11371136
tools_for_mcp = cast(Optional[Iterable[ToolParam]], tools)
1138-
if LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(tools=tools_for_mcp):
1137+
if LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(
1138+
tools=tools_for_mcp
1139+
):
11391140
# Return coroutine - acompletion will await it
11401141
# completion() can return a coroutine when MCP tools are present, which acompletion() awaits
11411142
return acompletion_with_mcp( # type: ignore[return-value]
@@ -1536,6 +1537,8 @@ def completion( # type: ignore # noqa: PLR0915
15361537
max_retries=max_retries,
15371538
timeout=timeout,
15381539
litellm_request_debug=kwargs.get("litellm_request_debug", False),
1540+
tpm=kwargs.get("tpm"),
1541+
rpm=kwargs.get("rpm"),
15391542
)
15401543
cast(LiteLLMLoggingObj, logging).update_environment_variables(
15411544
model=model,
@@ -2361,11 +2364,7 @@ def completion( # type: ignore # noqa: PLR0915
23612364
input=messages, api_key=api_key, original_response=response
23622365
)
23632366
elif custom_llm_provider == "minimax":
2364-
api_key = (
2365-
api_key
2366-
or get_secret_str("MINIMAX_API_KEY")
2367-
or litellm.api_key
2368-
)
2367+
api_key = api_key or get_secret_str("MINIMAX_API_KEY") or litellm.api_key
23692368

23702369
api_base = (
23712370
api_base
@@ -2413,7 +2412,9 @@ def completion( # type: ignore # noqa: PLR0915
24132412
or custom_llm_provider == "wandb"
24142413
or custom_llm_provider == "clarifai"
24152414
or custom_llm_provider in litellm.openai_compatible_providers
2416-
or JSONProviderRegistry.exists(custom_llm_provider) # JSON-configured providers
2415+
or JSONProviderRegistry.exists(
2416+
custom_llm_provider
2417+
) # JSON-configured providers
24172418
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
24182419
): # allow user to make an openai call with a custom base
24192420
# note: if a user sets a custom base - we should ensure this works
@@ -4724,7 +4725,7 @@ def embedding( # noqa: PLR0915
47244725

47254726
if headers is not None and headers != {}:
47264727
optional_params["extra_headers"] = headers
4727-
4728+
47284729
if encoding_format is not None:
47294730
optional_params["encoding_format"] = encoding_format
47304731
else:
@@ -6759,9 +6760,7 @@ def speech( # noqa: PLR0915
67596760
if text_to_speech_provider_config is None:
67606761
text_to_speech_provider_config = MinimaxTextToSpeechConfig()
67616762

6762-
minimax_config = cast(
6763-
MinimaxTextToSpeechConfig, text_to_speech_provider_config
6764-
)
6763+
minimax_config = cast(MinimaxTextToSpeechConfig, text_to_speech_provider_config)
67656764

67666765
if api_base is not None:
67676766
litellm_params_dict["api_base"] = api_base
@@ -6901,7 +6900,7 @@ async def ahealth_check(
69016900
custom_llm_provider_from_params = model_params.get("custom_llm_provider", None)
69026901
api_base_from_params = model_params.get("api_base", None)
69036902
api_key_from_params = model_params.get("api_key", None)
6904-
6903+
69056904
model, custom_llm_provider, _, _ = get_llm_provider(
69066905
model=model,
69076906
custom_llm_provider=custom_llm_provider_from_params,
@@ -7275,8 +7274,9 @@ def __getattr__(name: str) -> Any:
72757274
_encoding = tiktoken.get_encoding("cl100k_base")
72767275
# Cache it in the module's __dict__ for subsequent accesses
72777276
import sys
7277+
72787278
sys.modules[__name__].__dict__["encoding"] = _encoding
72797279
global _encoding_cache
72807280
_encoding_cache = _encoding
72817281
return _encoding
7282-
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
7282+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

litellm/types/integrations/prometheus.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ class UserAPIKeyLabelNames(Enum):
199199
"litellm_cache_hits_metric",
200200
"litellm_cache_misses_metric",
201201
"litellm_cached_tokens_metric",
202+
"litellm_deployment_tpm_limit",
203+
"litellm_deployment_rpm_limit",
202204
"litellm_remaining_api_key_requests_for_model",
203205
"litellm_remaining_api_key_tokens_for_model",
204206
"litellm_llm_api_failed_requests_metric",
@@ -406,6 +408,15 @@ class PrometheusMetricLabels:
406408
UserAPIKeyLabelNames.API_PROVIDER.value,
407409
]
408410

411+
litellm_deployment_tpm_limit = [
412+
UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value,
413+
UserAPIKeyLabelNames.MODEL_ID.value,
414+
UserAPIKeyLabelNames.API_BASE.value,
415+
UserAPIKeyLabelNames.API_PROVIDER.value,
416+
]
417+
418+
litellm_deployment_rpm_limit = litellm_deployment_tpm_limit
419+
409420
litellm_deployment_cooled_down = [
410421
UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value,
411422
UserAPIKeyLabelNames.MODEL_ID.value,

0 commit comments

Comments
 (0)