|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +"""A ModelHandler for Anthropic Claude models using the Messages API. |
| 19 | +
|
| 20 | +This module provides an integration between Apache Beam's RunInference |
| 21 | +transform and Anthropic's Claude models, enabling batch inference in |
| 22 | +Beam pipelines. |
| 23 | +
|
| 24 | +Example usage:: |
| 25 | +
|
| 26 | + from apache_beam.ml.inference.anthropic_inference import ( |
| 27 | + AnthropicModelHandler, |
| 28 | + message_from_string, |
| 29 | + ) |
| 30 | + from apache_beam.ml.inference.base import RunInference |
| 31 | +
|
| 32 | + # Basic usage |
| 33 | + model_handler = AnthropicModelHandler( |
| 34 | + model_name='claude-haiku-4-5', |
| 35 | + api_key='your-api-key', |
| 36 | + request_fn=message_from_string, |
| 37 | + ) |
| 38 | +
|
| 39 | + # With system prompt and structured output |
| 40 | + model_handler = AnthropicModelHandler( |
| 41 | + model_name='claude-haiku-4-5', |
| 42 | + api_key='your-api-key', |
| 43 | + request_fn=message_from_string, |
| 44 | + system='You are a helpful assistant that responds concisely.', |
| 45 | + output_config={ |
| 46 | + 'format': { |
| 47 | + 'type': 'json_schema', |
| 48 | + 'schema': { |
| 49 | + 'type': 'object', |
| 50 | + 'properties': { |
| 51 | + 'answer': {'type': 'string'}, |
| 52 | + 'confidence': {'type': 'number'}, |
| 53 | + }, |
| 54 | + 'required': ['answer', 'confidence'], |
| 55 | + 'additionalProperties': False, |
| 56 | + }, |
| 57 | + }, |
| 58 | + }, |
| 59 | + ) |
| 60 | +
|
| 61 | + with beam.Pipeline() as p: |
| 62 | + results = ( |
| 63 | + p |
| 64 | + | beam.Create(['What is Apache Beam?', 'Explain MapReduce.']) |
| 65 | + | RunInference(model_handler) |
| 66 | + ) |
| 67 | +""" |
| 68 | + |
| 69 | +import logging |
| 70 | +from collections.abc import Callable |
| 71 | +from collections.abc import Iterable |
| 72 | +from collections.abc import Sequence |
| 73 | +from typing import Any |
| 74 | +from typing import Optional |
| 75 | +from typing import Union |
| 76 | + |
| 77 | +from anthropic import Anthropic |
| 78 | +from anthropic import APIStatusError |
| 79 | + |
| 80 | +from apache_beam.ml.inference import utils |
| 81 | +from apache_beam.ml.inference.base import PredictionResult |
| 82 | +from apache_beam.ml.inference.base import RemoteModelHandler |
| 83 | + |
| 84 | +LOGGER = logging.getLogger("AnthropicModelHandler") |
| 85 | + |
| 86 | + |
| 87 | +def _retry_on_appropriate_error(exception: Exception) -> bool: |
| 88 | + """Retry filter that returns True for 5xx errors or 429 (rate limiting). |
| 89 | +
|
| 90 | + Args: |
| 91 | + exception: the exception encountered during the request/response loop. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + True if the exception is retriable (429 or 5xx), False otherwise. |
| 95 | + """ |
| 96 | + if not isinstance(exception, APIStatusError): |
| 97 | + return False |
| 98 | + return exception.status_code == 429 or exception.status_code >= 500 |
| 99 | + |
| 100 | + |
| 101 | +def message_from_string( |
| 102 | + model_name: str, |
| 103 | + batch: Sequence[str], |
| 104 | + client: Anthropic, |
| 105 | + inference_args: dict[str, Any]): |
| 106 | + """Request function that sends string prompts to Claude as user messages. |
| 107 | +
|
| 108 | + Each string in the batch is sent as a separate request. The results are |
| 109 | + returned as a list of response objects. |
| 110 | +
|
| 111 | + Args: |
| 112 | + model_name: the Claude model to use (e.g. 'claude-haiku-4-5'). |
| 113 | + batch: the string inputs to send to Claude. |
| 114 | + client: the Anthropic client instance. |
| 115 | + inference_args: additional arguments passed to the messages.create call. |
| 116 | + Common args include 'max_tokens', 'system', 'temperature', 'top_p'. |
| 117 | + """ |
| 118 | + max_tokens = inference_args.pop('max_tokens', 1024) |
| 119 | + responses = [] |
| 120 | + for prompt in batch: |
| 121 | + response = client.messages.create( |
| 122 | + model=model_name, |
| 123 | + max_tokens=max_tokens, |
| 124 | + messages=[{ |
| 125 | + "role": "user", "content": prompt |
| 126 | + }], |
| 127 | + **inference_args) |
| 128 | + responses.append(response) |
| 129 | + return responses |
| 130 | + |
| 131 | + |
| 132 | +def message_from_conversation( |
| 133 | + model_name: str, |
| 134 | + batch: Sequence[list[dict[str, str]]], |
| 135 | + client: Anthropic, |
| 136 | + inference_args: dict[str, Any]): |
| 137 | + """Request function that sends multi-turn conversations to Claude. |
| 138 | +
|
| 139 | + Each element in the batch is a list of message dicts with 'role' and |
| 140 | + 'content' keys, representing a multi-turn conversation. |
| 141 | +
|
| 142 | + Args: |
| 143 | + model_name: the Claude model to use. |
| 144 | + batch: a sequence of conversations (each a list of message dicts). |
| 145 | + client: the Anthropic client instance. |
| 146 | + inference_args: additional arguments passed to the messages.create call. |
| 147 | + """ |
| 148 | + max_tokens = inference_args.pop('max_tokens', 1024) |
| 149 | + responses = [] |
| 150 | + for messages in batch: |
| 151 | + response = client.messages.create( |
| 152 | + model=model_name, |
| 153 | + max_tokens=max_tokens, |
| 154 | + messages=messages, |
| 155 | + **inference_args) |
| 156 | + responses.append(response) |
| 157 | + return responses |
| 158 | + |
| 159 | + |
| 160 | +class AnthropicModelHandler(RemoteModelHandler[Any, PredictionResult, |
| 161 | + Anthropic]): |
| 162 | + def __init__( |
| 163 | + self, |
| 164 | + model_name: str, |
| 165 | + request_fn: Callable[[str, Sequence[Any], Anthropic, dict[str, Any]], |
| 166 | + Any], |
| 167 | + api_key: Optional[str] = None, |
| 168 | + *, |
| 169 | + system: Optional[Union[str, list[dict[str, str]]]] = None, |
| 170 | + output_config: Optional[dict[str, Any]] = None, |
| 171 | + min_batch_size: Optional[int] = None, |
| 172 | + max_batch_size: Optional[int] = None, |
| 173 | + max_batch_duration_secs: Optional[int] = None, |
| 174 | + max_batch_weight: Optional[int] = None, |
| 175 | + element_size_fn: Optional[Callable[[Any], int]] = None, |
| 176 | + **kwargs): |
| 177 | + """Implementation of the ModelHandler interface for Anthropic Claude. |
| 178 | +
|
| 179 | + **NOTE:** This API and its implementation are under development and |
| 180 | + do not provide backward compatibility guarantees. |
| 181 | +
|
| 182 | + This handler connects to the Anthropic Messages API to run inference |
| 183 | + using Claude models. It supports text generation from string prompts |
| 184 | + or multi-turn conversations, with optional system prompts and |
| 185 | + structured output schemas. |
| 186 | +
|
| 187 | + Args: |
| 188 | + model_name: the Claude model to send requests to (e.g. |
| 189 | + 'claude-sonnet-4-6', 'claude-haiku-4-5'). |
| 190 | + request_fn: the function to use to send requests. Should take the |
| 191 | + model name, batch, client, and inference_args and return the |
| 192 | + responses from Claude. Built-in options are message_from_string |
| 193 | + and message_from_conversation. |
| 194 | + api_key: the Anthropic API key. If not provided, the client will |
| 195 | + look for the ANTHROPIC_API_KEY environment variable. |
| 196 | + system: optional system prompt to set the model's behavior for all |
| 197 | + requests. Can be a string or a list of content blocks (dicts |
| 198 | + with 'type' and 'text' keys). This is applied to every request |
| 199 | + in the pipeline. Per-request overrides can be passed via |
| 200 | + inference_args. |
| 201 | + output_config: optional output configuration to constrain |
| 202 | + responses to a structured schema. The value is passed directly |
| 203 | + to the Anthropic API as the 'output_config' parameter. This |
| 204 | + uses the GA API shape with a nested 'format' key. Example:: |
| 205 | +
|
| 206 | + output_config={ |
| 207 | + 'format': { |
| 208 | + 'type': 'json_schema', |
| 209 | + 'schema': { |
| 210 | + 'type': 'object', |
| 211 | + 'properties': { |
| 212 | + 'answer': {'type': 'string'}, |
| 213 | + }, |
| 214 | + 'required': ['answer'], |
| 215 | + 'additionalProperties': False, |
| 216 | + }, |
| 217 | + }, |
| 218 | + } |
| 219 | +
|
| 220 | + min_batch_size: optional. the minimum batch size to use when |
| 221 | + batching inputs. |
| 222 | + max_batch_size: optional. the maximum batch size to use when |
| 223 | + batching inputs. |
| 224 | + max_batch_duration_secs: optional. the maximum amount of time to |
| 225 | + buffer a batch before emitting; used in streaming contexts. |
| 226 | + max_batch_weight: optional. the maximum total weight of a batch. |
| 227 | + element_size_fn: optional. a function that returns the size |
| 228 | + (weight) of an element. |
| 229 | + """ |
| 230 | + self._batching_kwargs = {} |
| 231 | + self._env_vars = kwargs.get('env_vars', {}) |
| 232 | + if min_batch_size is not None: |
| 233 | + self._batching_kwargs["min_batch_size"] = min_batch_size |
| 234 | + if max_batch_size is not None: |
| 235 | + self._batching_kwargs["max_batch_size"] = max_batch_size |
| 236 | + if max_batch_duration_secs is not None: |
| 237 | + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs |
| 238 | + if max_batch_weight is not None: |
| 239 | + self._batching_kwargs["max_batch_weight"] = max_batch_weight |
| 240 | + if element_size_fn is not None: |
| 241 | + self._batching_kwargs['element_size_fn'] = element_size_fn |
| 242 | + |
| 243 | + self.model_name = model_name |
| 244 | + self.request_fn = request_fn |
| 245 | + self.api_key = api_key |
| 246 | + self.system = system |
| 247 | + self.output_config = output_config |
| 248 | + |
| 249 | + super().__init__( |
| 250 | + namespace='AnthropicModelHandler', |
| 251 | + retry_filter=_retry_on_appropriate_error, |
| 252 | + **kwargs) |
| 253 | + |
| 254 | + def batch_elements_kwargs(self): |
| 255 | + return self._batching_kwargs |
| 256 | + |
| 257 | + def create_client(self) -> Anthropic: |
| 258 | + """Creates the Anthropic client used to send requests. |
| 259 | +
|
| 260 | + If api_key was provided at construction time, it is used directly. |
| 261 | + Otherwise, the client will fall back to the ANTHROPIC_API_KEY |
| 262 | + environment variable. |
| 263 | + """ |
| 264 | + if self.api_key: |
| 265 | + return Anthropic(api_key=self.api_key) |
| 266 | + return Anthropic() |
| 267 | + |
| 268 | + def request( |
| 269 | + self, |
| 270 | + batch: Sequence[Any], |
| 271 | + model: Anthropic, |
| 272 | + inference_args: Optional[dict[str, Any]] = None |
| 273 | + ) -> Iterable[PredictionResult]: |
| 274 | + """Sends a prediction request to the Anthropic API. |
| 275 | +
|
| 276 | + Handler-level system and output_config are injected into |
| 277 | + inference_args before calling the request function. Per-request |
| 278 | + values in inference_args take precedence over handler-level values. |
| 279 | +
|
| 280 | + Args: |
| 281 | + batch: a sequence of inputs to be passed to the request function. |
| 282 | + model: an Anthropic client instance. |
| 283 | + inference_args: additional arguments to send as part of the |
| 284 | + prediction request (e.g. max_tokens, temperature, system). |
| 285 | +
|
| 286 | + Returns: |
| 287 | + An iterable of PredictionResults. |
| 288 | + """ |
| 289 | + if inference_args is None: |
| 290 | + inference_args = {} |
| 291 | + if self.system is not None and 'system' not in inference_args: |
| 292 | + inference_args['system'] = self.system |
| 293 | + if self.output_config is not None and 'output_config' not in inference_args: |
| 294 | + inference_args['output_config'] = self.output_config |
| 295 | + responses = self.request_fn(self.model_name, batch, model, inference_args) |
| 296 | + return utils._convert_to_result(batch, responses, self.model_name) |
0 commit comments