Skip to content

Commit ed8e5e9

Browse files
committed
working on OpenSearch Indexing Pipeline
1 parent 57d7845 commit ed8e5e9

File tree

4 files changed

+1146
-0
lines changed

4 files changed

+1146
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
ARG REGION=us-east-1
2+
3+
# SageMaker PyTorch image for INFERENCE
4+
# For more information, see https://github.com/aws/deep-learning-containers/blob/master/available_images.md
5+
FROM 763104351884.dkr.ecr.$REGION.amazonaws.com/pytorch-inference:2.0.1-cpu-py310-ubuntu20.04-sagemaker
6+
7+
# pip leaves the install caches populated which uses a
8+
# significant amount of space. These optimizations save a fair
9+
# amount of space in the image, which reduces start up time.
10+
RUN pip install --no-cache-dir -U pip && pip --no-cache-dir install opensearch-py==2.2.0 \
11+
transformers==4.32.1 \
12+
farm-haystack[aws] \
13+
farm-haystack[preprocessing] \
14+
farm-haystack[opensearch] \
15+
farm-haystack[inference]
16+
17+
RUN pip freeze
18+
19+
ENV PATH="/opt/ml/code:${PATH}"
20+
21+
# this environment variable is used by the SageMaker PyTorch container to determine our user code directory.
22+
ENV SAGEMAKER_SUBMIT_DIRECTORY /opt/ml/code
23+
24+
# Set some environment variables. PYTHONUNBUFFERED keeps Python from buffering our standard
25+
# output stream, which means that logs can be delivered to the user quickly. PYTHONDONTWRITEBYTECODE
26+
# keeps Python from writing the .pyc files which are unnecessary in this case. We also update
27+
# PATH so that the train and serve programs are found when the container is invoked.
28+
ENV PYTHONUNBUFFERED=TRUE
29+
ENV PYTHONDONTWRITEBYTECODE=TRUE
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import os
2+
import argparse
3+
import glob
4+
import json
5+
import logging
6+
import sys
7+
import time
8+
9+
import warnings
10+
warnings.filterwarnings("ignore") # avoid printing out absolute paths
11+
12+
import boto3
13+
14+
from haystack.document_stores import OpenSearchDocumentStore
15+
from haystack.nodes import (
16+
JsonConverter,
17+
PreProcessor,
18+
EmbeddingRetriever
19+
)
20+
21+
from haystack import Pipeline
22+
23+
# from opensearchpy.client import OpenSearch
24+
from opensearchpy import (
25+
OpenSearch,
26+
RequestsHttpConnection
27+
)
28+
29+
logger = logging.getLogger()
30+
logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
31+
32+
33+
def get_credentials(secret_id: str, region_name: str='us-east-1') -> str:
34+
"""
35+
Retrieve credentials password for given username from AWS SecretsManager
36+
"""
37+
38+
client = boto3.client('secretsmanager', region_name=region_name)
39+
response = client.get_secret_value(SecretId=secret_id)
40+
secrets_value = json.loads(response['SecretString'])
41+
42+
return secrets_value
43+
44+
45+
def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch:
46+
#update the region if you're working other than us-east-1
47+
48+
aos_client = OpenSearch(
49+
hosts=[{'host': host.replace("https://", ""), 'port': 443}],
50+
http_auth=http_auth,
51+
use_ssl=True,
52+
verify_certs=True,
53+
connection_class = RequestsHttpConnection
54+
)
55+
exists = aos_client.indices.exists(index_name)
56+
logger.info(f"index_name={index_name}, exists={exists}")
57+
return exists
58+
59+
60+
def main():
61+
parser = argparse.ArgumentParser()
62+
63+
parser.add_argument("--opensearch-endpoint", type=str, default=None)
64+
parser.add_argument("--opensearch-port", type=int, default=443)
65+
parser.add_argument("--opensearch-secret-id", type=str, default=None)
66+
parser.add_argument("--aws-region", type=str, default="us-east-1")
67+
parser.add_argument("--input-data-dir", type=str, default="/opt/ml/processing/input")
68+
69+
args, _ = parser.parse_known_args()
70+
logger.info("Received arguments {}".format(args))
71+
72+
# list all the files
73+
files = glob.glob(os.path.join(args.input_data_dir, "*.json"))
74+
logger.info(f"there are {len(files)} files to process in the {args.input_data_dir} folder")
75+
76+
# retrieve secret to talk to opensearch
77+
secret = get_credentials(args.opensearch_secret_id, args.aws_region)
78+
79+
OPENSEARCH_HOST = args.opensearch_endpoint
80+
OPENSEARCH_PORT = 443
81+
OPENSEARCH_USERNAME = secret['username']
82+
OPENSEARCH_PASSWORD = secret['password']
83+
84+
# Stage one: read all the docs, split them into chunks.
85+
st = time.time()
86+
logger.info('Loading documents ...')
87+
88+
# first check if index exists, if it does then call the add_documents function
89+
# otherwise call the from_documents function which would first create the index
90+
# and then do a bulk add. Both add_documents and from_documents do a bulk add
91+
# but it is important to call from_documents first so that the index is created
92+
# correctly for K-NN
93+
http_auth = (secret['username'], secret['password'])
94+
OPENSEARCH_INDEX_NAME = "document"
95+
index_exists = check_if_index_exists(OPENSEARCH_INDEX_NAME,
96+
args.aws_region,
97+
args.opensearch_cluster_domain,
98+
http_auth)
99+
100+
if index_exists:
101+
et = time.time()
102+
logger.info(f"index={args.opensearch_index_name} does exists, not going to call add_documents")
103+
logger.info(f'run time in seconds: {et-st:.2f}')
104+
logger.info("all done")
105+
sys.exit(0)
106+
107+
doc_store = OpenSearchDocumentStore(host=OPENSEARCH_HOST,
108+
port=OPENSEARCH_PORT,
109+
username=OPENSEARCH_USERNAME,
110+
password=OPENSEARCH_PASSWORD,
111+
embedding_dim=384)
112+
113+
converter = JsonConverter()
114+
115+
preprocessor = PreProcessor(
116+
clean_empty_lines=True,
117+
split_by='word',
118+
split_respect_sentence_boundary=True,
119+
split_length=80,
120+
split_overlap=20
121+
)
122+
123+
retriever = EmbeddingRetriever(
124+
document_store=doc_store,
125+
embedding_model="sentence-transformers/all-MiniLM-L12-v2",
126+
devices=["cpu"],
127+
top_k=5
128+
)
129+
130+
indexing_pipeline = Pipeline()
131+
indexing_pipeline.add_node(component=converter, name="Converter", inputs=["File"])
132+
indexing_pipeline.add_node(component=preprocessor, name="PreProcessor", inputs=["Converter"])
133+
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["PreProcessor"])
134+
indexing_pipeline.add_node(component=doc_store, name="DocumentStore", inputs=["Retriever"])
135+
136+
indexing_pipeline.run(file_paths=files)
137+
138+
et = time.time()
139+
logger.info(f'run time in seconds: {et-st:.2f}')
140+
logger.info("all done")
141+
142+
143+
if __name__ == "__main__":
144+
main()

0 commit comments

Comments
 (0)