Skip to content

Commit 412d534

Browse files
committed
[ML] Add AnthropicModelHandler for Claude inference in Beam pipelines
Adds a new RemoteModelHandler for Anthropic's Claude models via the Messages API. Features include: - AnthropicModelHandler with retry logic for 429/5xx errors - message_from_string and message_from_conversation request functions - System prompt support (constructor-level and per-request overrides) - Structured JSON output via output_config (GA API, requires anthropic>=0.86.0) - Comprehensive unit tests and integration tests (Fizz counting rule) - Integration tests gated behind ANTHROPIC_API_KEY env var
1 parent 2eb71e9 commit 412d534

File tree

5 files changed

+881
-0
lines changed

5 files changed

+881
-0
lines changed
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A ModelHandler for Anthropic Claude models using the Messages API.
19+
20+
This module provides an integration between Apache Beam's RunInference
21+
transform and Anthropic's Claude models, enabling batch inference in
22+
Beam pipelines.
23+
24+
Example usage::
25+
26+
from apache_beam.ml.inference.anthropic_inference import (
27+
AnthropicModelHandler,
28+
message_from_string,
29+
)
30+
from apache_beam.ml.inference.base import RunInference
31+
32+
# Basic usage
33+
model_handler = AnthropicModelHandler(
34+
model_name='claude-haiku-4-5',
35+
api_key='your-api-key',
36+
request_fn=message_from_string,
37+
)
38+
39+
# With system prompt and structured output
40+
model_handler = AnthropicModelHandler(
41+
model_name='claude-haiku-4-5',
42+
api_key='your-api-key',
43+
request_fn=message_from_string,
44+
system='You are a helpful assistant that responds concisely.',
45+
output_config={
46+
'format': {
47+
'type': 'json_schema',
48+
'schema': {
49+
'type': 'object',
50+
'properties': {
51+
'answer': {'type': 'string'},
52+
'confidence': {'type': 'number'},
53+
},
54+
'required': ['answer', 'confidence'],
55+
'additionalProperties': False,
56+
},
57+
},
58+
},
59+
)
60+
61+
with beam.Pipeline() as p:
62+
results = (
63+
p
64+
| beam.Create(['What is Apache Beam?', 'Explain MapReduce.'])
65+
| RunInference(model_handler)
66+
)
67+
"""
68+
69+
import logging
70+
from collections.abc import Callable
71+
from collections.abc import Iterable
72+
from collections.abc import Sequence
73+
from typing import Any
74+
from typing import Optional
75+
from typing import Union
76+
77+
from anthropic import Anthropic
78+
from anthropic import APIStatusError
79+
80+
from apache_beam.ml.inference import utils
81+
from apache_beam.ml.inference.base import PredictionResult
82+
from apache_beam.ml.inference.base import RemoteModelHandler
83+
84+
LOGGER = logging.getLogger("AnthropicModelHandler")
85+
86+
87+
def _retry_on_appropriate_error(exception: Exception) -> bool:
88+
"""Retry filter that returns True for 5xx errors or 429 (rate limiting).
89+
90+
Args:
91+
exception: the exception encountered during the request/response loop.
92+
93+
Returns:
94+
True if the exception is retriable (429 or 5xx), False otherwise.
95+
"""
96+
if not isinstance(exception, APIStatusError):
97+
return False
98+
return exception.status_code == 429 or exception.status_code >= 500
99+
100+
101+
def message_from_string(
102+
model_name: str,
103+
batch: Sequence[str],
104+
client: Anthropic,
105+
inference_args: dict[str, Any]):
106+
"""Request function that sends string prompts to Claude as user messages.
107+
108+
Each string in the batch is sent as a separate request. The results are
109+
returned as a list of response objects.
110+
111+
Args:
112+
model_name: the Claude model to use (e.g. 'claude-haiku-4-5').
113+
batch: the string inputs to send to Claude.
114+
client: the Anthropic client instance.
115+
inference_args: additional arguments passed to the messages.create call.
116+
Common args include 'max_tokens', 'system', 'temperature', 'top_p'.
117+
"""
118+
max_tokens = inference_args.pop('max_tokens', 1024)
119+
responses = []
120+
for prompt in batch:
121+
response = client.messages.create(
122+
model=model_name,
123+
max_tokens=max_tokens,
124+
messages=[{
125+
"role": "user", "content": prompt
126+
}],
127+
**inference_args)
128+
responses.append(response)
129+
return responses
130+
131+
132+
def message_from_conversation(
133+
model_name: str,
134+
batch: Sequence[list[dict[str, str]]],
135+
client: Anthropic,
136+
inference_args: dict[str, Any]):
137+
"""Request function that sends multi-turn conversations to Claude.
138+
139+
Each element in the batch is a list of message dicts with 'role' and
140+
'content' keys, representing a multi-turn conversation.
141+
142+
Args:
143+
model_name: the Claude model to use.
144+
batch: a sequence of conversations (each a list of message dicts).
145+
client: the Anthropic client instance.
146+
inference_args: additional arguments passed to the messages.create call.
147+
"""
148+
max_tokens = inference_args.pop('max_tokens', 1024)
149+
responses = []
150+
for messages in batch:
151+
response = client.messages.create(
152+
model=model_name,
153+
max_tokens=max_tokens,
154+
messages=messages,
155+
**inference_args)
156+
responses.append(response)
157+
return responses
158+
159+
160+
class AnthropicModelHandler(RemoteModelHandler[Any, PredictionResult,
161+
Anthropic]):
162+
def __init__(
163+
self,
164+
model_name: str,
165+
request_fn: Callable[[str, Sequence[Any], Anthropic, dict[str, Any]],
166+
Any],
167+
api_key: Optional[str] = None,
168+
*,
169+
system: Optional[Union[str, list[dict[str, str]]]] = None,
170+
output_config: Optional[dict[str, Any]] = None,
171+
min_batch_size: Optional[int] = None,
172+
max_batch_size: Optional[int] = None,
173+
max_batch_duration_secs: Optional[int] = None,
174+
max_batch_weight: Optional[int] = None,
175+
element_size_fn: Optional[Callable[[Any], int]] = None,
176+
**kwargs):
177+
"""Implementation of the ModelHandler interface for Anthropic Claude.
178+
179+
**NOTE:** This API and its implementation are under development and
180+
do not provide backward compatibility guarantees.
181+
182+
This handler connects to the Anthropic Messages API to run inference
183+
using Claude models. It supports text generation from string prompts
184+
or multi-turn conversations, with optional system prompts and
185+
structured output schemas.
186+
187+
Args:
188+
model_name: the Claude model to send requests to (e.g.
189+
'claude-sonnet-4-6', 'claude-haiku-4-5').
190+
request_fn: the function to use to send requests. Should take the
191+
model name, batch, client, and inference_args and return the
192+
responses from Claude. Built-in options are message_from_string
193+
and message_from_conversation.
194+
api_key: the Anthropic API key. If not provided, the client will
195+
look for the ANTHROPIC_API_KEY environment variable.
196+
system: optional system prompt to set the model's behavior for all
197+
requests. Can be a string or a list of content blocks (dicts
198+
with 'type' and 'text' keys). This is applied to every request
199+
in the pipeline. Per-request overrides can be passed via
200+
inference_args.
201+
output_config: optional output configuration to constrain
202+
responses to a structured schema. The value is passed directly
203+
to the Anthropic API as the 'output_config' parameter. This
204+
uses the GA API shape with a nested 'format' key. Example::
205+
206+
output_config={
207+
'format': {
208+
'type': 'json_schema',
209+
'schema': {
210+
'type': 'object',
211+
'properties': {
212+
'answer': {'type': 'string'},
213+
},
214+
'required': ['answer'],
215+
'additionalProperties': False,
216+
},
217+
},
218+
}
219+
220+
min_batch_size: optional. the minimum batch size to use when
221+
batching inputs.
222+
max_batch_size: optional. the maximum batch size to use when
223+
batching inputs.
224+
max_batch_duration_secs: optional. the maximum amount of time to
225+
buffer a batch before emitting; used in streaming contexts.
226+
max_batch_weight: optional. the maximum total weight of a batch.
227+
element_size_fn: optional. a function that returns the size
228+
(weight) of an element.
229+
"""
230+
self._batching_kwargs = {}
231+
self._env_vars = kwargs.get('env_vars', {})
232+
if min_batch_size is not None:
233+
self._batching_kwargs["min_batch_size"] = min_batch_size
234+
if max_batch_size is not None:
235+
self._batching_kwargs["max_batch_size"] = max_batch_size
236+
if max_batch_duration_secs is not None:
237+
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
238+
if max_batch_weight is not None:
239+
self._batching_kwargs["max_batch_weight"] = max_batch_weight
240+
if element_size_fn is not None:
241+
self._batching_kwargs['element_size_fn'] = element_size_fn
242+
243+
self.model_name = model_name
244+
self.request_fn = request_fn
245+
self.api_key = api_key
246+
self.system = system
247+
self.output_config = output_config
248+
249+
super().__init__(
250+
namespace='AnthropicModelHandler',
251+
retry_filter=_retry_on_appropriate_error,
252+
**kwargs)
253+
254+
def batch_elements_kwargs(self):
255+
return self._batching_kwargs
256+
257+
def create_client(self) -> Anthropic:
258+
"""Creates the Anthropic client used to send requests.
259+
260+
If api_key was provided at construction time, it is used directly.
261+
Otherwise, the client will fall back to the ANTHROPIC_API_KEY
262+
environment variable.
263+
"""
264+
if self.api_key:
265+
return Anthropic(api_key=self.api_key)
266+
return Anthropic()
267+
268+
def request(
269+
self,
270+
batch: Sequence[Any],
271+
model: Anthropic,
272+
inference_args: Optional[dict[str, Any]] = None
273+
) -> Iterable[PredictionResult]:
274+
"""Sends a prediction request to the Anthropic API.
275+
276+
Handler-level system and output_config are injected into
277+
inference_args before calling the request function. Per-request
278+
values in inference_args take precedence over handler-level values.
279+
280+
Args:
281+
batch: a sequence of inputs to be passed to the request function.
282+
model: an Anthropic client instance.
283+
inference_args: additional arguments to send as part of the
284+
prediction request (e.g. max_tokens, temperature, system).
285+
286+
Returns:
287+
An iterable of PredictionResults.
288+
"""
289+
if inference_args is None:
290+
inference_args = {}
291+
if self.system is not None and 'system' not in inference_args:
292+
inference_args['system'] = self.system
293+
if self.output_config is not None and 'output_config' not in inference_args:
294+
inference_args['output_config'] = self.output_config
295+
responses = self.request_fn(self.model_name, batch, model, inference_args)
296+
return utils._convert_to_result(batch, responses, self.model_name)

0 commit comments

Comments
 (0)