Skip to content

Commit d211796

Browse files
rag: add openai integration
1 parent 7cf81f5 commit d211796

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""RAG-specific embedding implementations using OpenAI models."""
18+
19+
from typing import Optional
20+
21+
import apache_beam as beam
22+
from apache_beam.ml.inference.base import RunInference
23+
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
24+
from apache_beam.ml.rag.types import Chunk
25+
from apache_beam.ml.transforms.base import EmbeddingsManager
26+
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
27+
from apache_beam.ml.transforms.embeddings.open_ai import _OpenAITextEmbeddingHandler
28+
29+
__all__ = ['OpenAITextEmbeddings']
30+
31+
32+
class OpenAITextEmbeddings(EmbeddingsManager):
33+
def __init__(
34+
self,
35+
model_name: str,
36+
*,
37+
api_key: Optional[str] = None,
38+
organization: Optional[str] = None,
39+
dimensions: Optional[int] = None,
40+
user: Optional[str] = None,
41+
max_batch_size: Optional[int] = None,
42+
**kwargs):
43+
"""Utilizes OpenAI text embeddings for semantic search and RAG pipelines.
44+
45+
Args:
46+
model_name: Name of the OpenAI embedding model
47+
api_key: OpenAI API key
48+
organization: OpenAI organization ID
49+
dimensions: Specific embedding dimensions to use (if supported)
50+
user: End-user identifier for tracking and rate limit calculations
51+
max_batch_size: Maximum batch size for requests to OpenAI API
52+
**kwargs: Additional arguments passed to EmbeddingsManager including
53+
ModelHandler inference_args.
54+
"""
55+
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
56+
self.model_name = model_name
57+
self.api_key = api_key
58+
self.organization = organization
59+
self.dimensions = dimensions
60+
self.user = user
61+
self.max_batch_size = max_batch_size
62+
63+
def get_model_handler(self):
64+
"""Returns model handler configured with RAG adapter."""
65+
return _OpenAITextEmbeddingHandler(
66+
model_name=self.model_name,
67+
api_key=self.api_key,
68+
organization=self.organization,
69+
dimensions=self.dimensions,
70+
user=self.user,
71+
max_batch_size=self.max_batch_size,
72+
)
73+
74+
def get_ptransform_for_processing(
75+
self, **kwargs
76+
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
77+
"""Returns PTransform that uses the RAG adapter."""
78+
return RunInference(
79+
model_handler=_TextEmbeddingHandler(self),
80+
inference_args=self.inference_args).with_output_types(Chunk)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import os
17+
import shutil
18+
import tempfile
19+
import unittest
20+
21+
import apache_beam as beam
22+
from apache_beam.ml.rag.embeddings.open_ai import OpenAITextEmbeddings
23+
from apache_beam.ml.rag.test_utils import TestHelpers
24+
from apache_beam.ml.rag.types import Chunk
25+
from apache_beam.ml.rag.types import Content
26+
from apache_beam.ml.rag.types import Embedding
27+
from apache_beam.ml.transforms.base import MLTransform
28+
from apache_beam.testing.test_pipeline import TestPipeline
29+
from apache_beam.testing.util import assert_that
30+
from apache_beam.testing.util import equal_to
31+
32+
33+
@unittest.skipIf(
34+
not os.environ.get('OPENAI_API_KEY'),
35+
'OPENAI_API_KEY environment variable is not set')
36+
class OpenAITextEmbeddingsTest(unittest.TestCase):
37+
def setUp(self):
38+
self.artifact_location = tempfile.mkdtemp(prefix='openai_')
39+
self.test_chunks = [
40+
Chunk(
41+
content=Content(text="This is a test sentence."),
42+
id="1",
43+
metadata={
44+
"source": "test.txt", "language": "en"
45+
}),
46+
Chunk(
47+
content=Content(text="Another example."),
48+
id="2",
49+
metadata={
50+
"source": "test.txt", "language": "en"
51+
})
52+
]
53+
54+
def tearDown(self) -> None:
55+
shutil.rmtree(self.artifact_location)
56+
57+
def test_embedding_pipeline(self):
58+
expected = [
59+
Chunk(
60+
id="1",
61+
embedding=Embedding(dense_embedding=[0.0] * 1536),
62+
metadata={
63+
"source": "test.txt", "language": "en"
64+
},
65+
content=Content(text="This is a test sentence.")),
66+
Chunk(
67+
id="2",
68+
embedding=Embedding(dense_embedding=[0.0] * 1536),
69+
metadata={
70+
"source": "test.txt", "language": "en"
71+
},
72+
content=Content(text="Another example."))
73+
]
74+
75+
embedder = OpenAITextEmbeddings(
76+
model_name="text-embedding-3-small",
77+
dimensions=1536,
78+
api_key=os.environ.get("OPENAI_API_KEY"))
79+
80+
with TestPipeline() as p:
81+
embeddings = (
82+
p
83+
| beam.Create(self.test_chunks)
84+
| MLTransform(write_artifact_location=self.artifact_location).
85+
with_transform(embedder))
86+
87+
assert_that(
88+
embeddings,
89+
equal_to(expected, equals_fn=TestHelpers.chunk_approximately_equals))
90+
91+
def test_embedding_pipeline_with_dimensions(self):
92+
expected = [
93+
Chunk(
94+
id="1",
95+
embedding=Embedding(dense_embedding=[0.0] * 512),
96+
metadata={
97+
"source": "test.txt", "language": "en"
98+
},
99+
content=Content(text="This is a test sentence.")),
100+
Chunk(
101+
id="2",
102+
embedding=Embedding(dense_embedding=[0.0] * 512),
103+
metadata={
104+
"source": "test.txt", "language": "en"
105+
},
106+
content=Content(text="Another example."))
107+
]
108+
109+
embedder = OpenAITextEmbeddings(
110+
model_name="text-embedding-3-small",
111+
dimensions=512,
112+
api_key=os.environ.get("OPENAI_API_KEY"))
113+
114+
with TestPipeline() as p:
115+
embeddings = (
116+
p
117+
| beam.Create(self.test_chunks)
118+
| MLTransform(write_artifact_location=self.artifact_location).
119+
with_transform(embedder))
120+
121+
assert_that(
122+
embeddings,
123+
equal_to(expected, equals_fn=TestHelpers.chunk_approximately_equals))
124+
125+
126+
if __name__ == '__main__':
127+
unittest.main()

0 commit comments

Comments
 (0)