1+ #
2+ # Licensed to the Apache Software Foundation (ASF) under one or more
3+ # contributor license agreements. See the NOTICE file distributed with
4+ # this work for additional information regarding copyright ownership.
5+ # The ASF licenses this file to You under the Apache License, Version 2.0
6+ # (the "License"); you may not use this file except in compliance with
7+ # the License. You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+ import os
17+ import shutil
18+ import tempfile
19+ import unittest
20+
21+ import apache_beam as beam
22+ from apache_beam .ml .rag .embeddings .open_ai import OpenAITextEmbeddings
23+ from apache_beam .ml .rag .test_utils import TestHelpers
24+ from apache_beam .ml .rag .types import Chunk
25+ from apache_beam .ml .rag .types import Content
26+ from apache_beam .ml .rag .types import Embedding
27+ from apache_beam .ml .transforms .base import MLTransform
28+ from apache_beam .testing .test_pipeline import TestPipeline
29+ from apache_beam .testing .util import assert_that
30+ from apache_beam .testing .util import equal_to
31+
32+
33+ @unittest .skipIf (
34+ not os .environ .get ('OPENAI_API_KEY' ),
35+ 'OPENAI_API_KEY environment variable is not set' )
36+ class OpenAITextEmbeddingsTest (unittest .TestCase ):
37+ def setUp (self ):
38+ self .artifact_location = tempfile .mkdtemp (prefix = 'openai_' )
39+ self .test_chunks = [
40+ Chunk (
41+ content = Content (text = "This is a test sentence." ),
42+ id = "1" ,
43+ metadata = {
44+ "source" : "test.txt" , "language" : "en"
45+ }),
46+ Chunk (
47+ content = Content (text = "Another example." ),
48+ id = "2" ,
49+ metadata = {
50+ "source" : "test.txt" , "language" : "en"
51+ })
52+ ]
53+
54+ def tearDown (self ) -> None :
55+ shutil .rmtree (self .artifact_location )
56+
57+ def test_embedding_pipeline (self ):
58+ expected = [
59+ Chunk (
60+ id = "1" ,
61+ embedding = Embedding (dense_embedding = [0.0 ] * 1536 ),
62+ metadata = {
63+ "source" : "test.txt" , "language" : "en"
64+ },
65+ content = Content (text = "This is a test sentence." )),
66+ Chunk (
67+ id = "2" ,
68+ embedding = Embedding (dense_embedding = [0.0 ] * 1536 ),
69+ metadata = {
70+ "source" : "test.txt" , "language" : "en"
71+ },
72+ content = Content (text = "Another example." ))
73+ ]
74+
75+ embedder = OpenAITextEmbeddings (
76+ model_name = "text-embedding-3-small" ,
77+ dimensions = 1536 ,
78+ api_key = os .environ .get ("OPENAI_API_KEY" ))
79+
80+ with TestPipeline () as p :
81+ embeddings = (
82+ p
83+ | beam .Create (self .test_chunks )
84+ | MLTransform (write_artifact_location = self .artifact_location ).
85+ with_transform (embedder ))
86+
87+ assert_that (
88+ embeddings ,
89+ equal_to (expected , equals_fn = TestHelpers .chunk_approximately_equals ))
90+
91+ def test_embedding_pipeline_with_dimensions (self ):
92+ expected = [
93+ Chunk (
94+ id = "1" ,
95+ embedding = Embedding (dense_embedding = [0.0 ] * 512 ),
96+ metadata = {
97+ "source" : "test.txt" , "language" : "en"
98+ },
99+ content = Content (text = "This is a test sentence." )),
100+ Chunk (
101+ id = "2" ,
102+ embedding = Embedding (dense_embedding = [0.0 ] * 512 ),
103+ metadata = {
104+ "source" : "test.txt" , "language" : "en"
105+ },
106+ content = Content (text = "Another example." ))
107+ ]
108+
109+ embedder = OpenAITextEmbeddings (
110+ model_name = "text-embedding-3-small" ,
111+ dimensions = 512 ,
112+ api_key = os .environ .get ("OPENAI_API_KEY" ))
113+
114+ with TestPipeline () as p :
115+ embeddings = (
116+ p
117+ | beam .Create (self .test_chunks )
118+ | MLTransform (write_artifact_location = self .artifact_location ).
119+ with_transform (embedder ))
120+
121+ assert_that (
122+ embeddings ,
123+ equal_to (expected , equals_fn = TestHelpers .chunk_approximately_equals ))
124+
125+
126+ if __name__ == '__main__' :
127+ unittest .main ()
0 commit comments