From cf1ba9977d85a9a45dd0fe14a1a1e4b405aa22fb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 1 Dec 2023 11:43:21 -0800 Subject: [PATCH 01/12] Add MongoDB Atlas document store --- haystack/document_stores/mongodb_atlas.py | 561 ++++++++++++++++++++ haystack/document_stores/mongodb_filters.py | 91 ++++ 2 files changed, 652 insertions(+) create mode 100644 haystack/document_stores/mongodb_atlas.py create mode 100644 haystack/document_stores/mongodb_filters.py diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py new file mode 100644 index 0000000000..e9a84ebc99 --- /dev/null +++ b/haystack/document_stores/mongodb_atlas.py @@ -0,0 +1,561 @@ +import re +from typing import Dict, Generator, List, Optional, Union +import numpy as np +from haystack.document_stores import BaseDocumentStore +from haystack.errors import DocumentStoreError +from haystack.nodes.retriever import DenseRetriever +from haystack.schema import Document, FilterType +from haystack.utils import get_batches_from_generator +from tqdm import tqdm +from .mongodb_filters import mongo_filter_converter +from ..lazy_imports import LazyImport + +with LazyImport("Run 'pip install farm-haystack[mongodb]'") as pinecone_import: + import pymongo + from pymongo import InsertOne, ReplaceOne, UpdateOne + from pymongo.collection import Collection + +METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] +DEFAULT_BATCH_SIZE = 50 + + +class MongoDBAtlasDocumentStore(BaseDocumentStore): + def __init__( + self, + mongo_connection_string: Optional[str] = None, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + embedding_dim: int = 768, + return_embedding: bool = False, + similarity: str = "cosine", + embedding_field: str = "embedding", + progress_bar: bool = True, + duplicate_documents: str = "overwrite", + recreate_index: bool = False, + ): + self.mongo_connection_string = _validate_mongo_connection_string(mongo_connection_string) + self.database_name = _validate_database_name(database_name) + self.collection_name = _validate_collection_name(collection_name) + self.connection = pymongo.MongoClient(self.mongo_connection_string) + self.database = self.connection[self.database_name] + self.similarity = _validate_similarity(similarity) + self.duplicate_documents = duplicate_documents + self.embedding_field = embedding_field + self.progress_bar = progress_bar + self.embedding_dim = embedding_dim + self.index = collection_name + self.return_embedding = return_embedding + self.recreate_index = recreate_index + + if self.recreate_index: + self.delete_index() + + # Implicitly create the collection if it doesn't exist + if collection_name not in self.database.list_collection_names(): + self.database.create_collection(self.collection_name) + self._get_collection().create_index("id", unique=True) + + def _create_document_field_map(self) -> Dict: + return {self.embedding_field: "embedding"} + + def _get_collection(self, index=None) -> Collection: + """ + Returns the collection named by index or returns the collection specified when the + driver was initialized. + """ + _validate_index_name(index) + if index is not None: + return self.database[index] + else: + return self.database[self.collection_name] + + def delete_documents( + self, + index: Optional[str] = None, + ids: Optional[List[str]] = None, + filters: Optional[FilterType] = None, + headers: Optional[Dict[str, str]] = None, + ): + """ + Delete documents from the document store. + + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param ids: Optional list of IDs to narrow down the documents to be deleted. + :param filters: optional filters (see get_all_documents for description). + If filters are provided along with a list of IDs, this method deletes the + intersection of the two query results (documents that match the filters and + have their ID in the list). + :param headers: MongoDBAtlasDocumentStore does not support headers. + :return None: + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + collection = self._get_collection(index) + + match (ids, filters): + case (None, None): + mongo_filters = {} + case (None, filters): + mongo_filters = mongo_filter_converter(filters) + case (ids, None): + mongo_filters = {"id": {"$in": ids}} + case (ids, filters): + mongo_filters = {"$and": [mongo_filter_converter(filters), {"id": {"$in": ids}}]} + + collection.delete_many(filter=mongo_filters) + + def delete_index(self, index=None): + """ + Deletes the collection named by index or the collection speicifed when the + driver was initialized. + """ + self._get_collection(index).drop() + + def delete_labels(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def get_all_documents( + self, + index: Optional[str] = None, + filters: Optional[FilterType] = None, + return_embedding: Optional[bool] = False, + batch_size: int = DEFAULT_BATCH_SIZE, + headers: Optional[Dict[str, str]] = None, + ): + """ + Retrieves all documents in the index (collection). + + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param filters: Optional filters to narrow down the documents that will be retrieved. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + __Example__: + + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` + Note that filters will be acting on the contents of the meta field of the documents in the collection. + :param return_embedding: Optional flag to return the embedding of the document. + :param batch_size: Number of documents to process at a time. When working with large number of documents, + batching can help reduce memory footprint. + :param headers: MongoDBAtlasDocumentStore does not support headers. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + result = self.get_all_documents_generator( + index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size + ) + return list(result) + + def get_all_labels(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def get_document_count( + self, + filters: Optional[FilterType] = None, + index: Optional[str] = None, + only_documents_without_embedding: bool = False, + headers: Optional[Dict[str, str]] = None, + ) -> int: + """ + Return the number of documents. + + :param filters: Optional filters (see get_all_documents for description). + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param only_documents_without_embedding: If set to `True`, only documents without embeddings are counted. + :param headers: MongoDBAtlasDocumentStore does not support headers. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + collection = self._get_collection(index) + + if only_documents_without_embedding: + mongo_filter = {"$and": [mongo_filter_converter(filters), {"embedding": {"$eq": None}}]} + else: + mongo_filter = mongo_filter_converter(filters) + + return collection.count_documents(mongo_filter) + + def get_embedding_count(self, filters: Optional[FilterType] = None, index: Optional[str] = None) -> int: + """ + Return the number of documents with embeddings. + + :param filters: Optional filters (see get_all_documents for description). + """ + collection = self._get_collection(index) + + filters = filters or {} + + mongo_filters = {"$and": [mongo_filter_converter(filters), {"embedding": {"$ne": None}}]} + + return collection.count_documents(mongo_filters) + + def get_all_documents_generator( + self, + index: Optional[str] = None, + filters: Optional[FilterType] = None, + return_embedding: Optional[bool] = False, + batch_size: int = DEFAULT_BATCH_SIZE, + headers: Optional[Dict[str, str]] = None, + ) -> Generator[Document, None, None]: + """ + Retrieves all documents in the index (collection). Under-the-hood, documents are fetched in batches from the + document store and yielded as individual documents. This method can be used to iteratively process + a large number of documents without having to load all documents in memory. + + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param filters: optional filters (see get_all_documents for description). + :param return_embedding: Optional flag to return the embedding of the document. + :param batch_size: Number of documents to process at a time. When working with large number of documents, + batching can help reduce memory footprint. + :param headers: MongoDBAtlasDocumentStore does not support headers. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + mongo_filters = mongo_filter_converter(filters) + + if return_embedding is None: + return_embedding = self.return_embedding + + projection = {"embedding": False} if not return_embedding else {} + + collection = self._get_collection(index) + documents = collection.find(mongo_filters, batch_size=batch_size, projection=projection) + + for doc in documents: + yield mongo_doc_to_haystack_doc(doc) + + def get_documents_by_id( + self, + ids: List[str], + index: Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + headers: Optional[Dict[str, str]] = None, + return_embedding: Optional[bool] = None, + ) -> List[Document]: + """ + Retrieves all documents matching ids. + + :param ids: List of IDs to retrieve. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param batch_size: Number of documents to retrieve at a time. When working with large number of documents, + batching can help reduce memory footprint. + :param headers: MongoDBAtlasDocumentStore does not support headers. + :param return_embedding: Optional flag to return the embedding of the document. + """ + mongo_filters = {"id": {"$in": ids}} + + result = self.get_all_documents_generator( + index=index, + filters=mongo_filters, + return_embedding=return_embedding, + batch_size=batch_size, + headers=headers, + ) + + return list(result) + + def get_document_by_id( + self, + id: str, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + return_embedding: Optional[bool] = None, + ) -> Document: + """ + Retrieves the document matching id. + + :param id: The ID of the document to retrieve + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param headers: MongoDBAtlasDocumentStore does not support headers. + :param return_embedding: Optional flag to return the embedding of the document. + """ + documents = self.get_documents_by_id(ids=[id], index=index, headers=headers, return_embedding=return_embedding) + return documents[0] + + def get_label_count(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def query_by_embedding( + self, + query_emb: np.ndarray, + filters: Optional[FilterType] = None, + top_k: int = 10, + index: Optional[str] = None, + return_embedding: Optional[bool] = None, + headers: Optional[Dict[str, str]] = None, + scale_score: bool = True, + ) -> List[Document]: + """ + Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric. + + :param query_emb: Embedding of the query + :param filters: optional filters (see get_all_documents for description). + :param top_k: How many documents to return. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param return_embedding: Whether to return document embedding. + :param headers: MongoDBAtlasDocumentStore does not support headers. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + if return_embedding is None: + return_embedding = self.return_embedding + + collection = self._get_collection(index) + + query_emb = query_emb.astype(np.float32) + + if self.similarity == "cosine": + self.normalize_embedding(query_emb) + + filters = filters or {} + + pipeline = [ + { + "$search": { + "index": self.collection_name, + "knnBeta": {"vector": query_emb.tolist(), "path": "embedding", "k": top_k}, + } + } + ] + if filters is not None: + pipeline.append({"$match": mongo_filter_converter(filters)}) + if not return_embedding: + pipeline.append({"$project": {"embedding": False}}) + pipeline.append({"$set": {"score": {"$meta": "searchScore"}}}) + documents = list(collection.aggregate(pipeline)) + + if scale_score: + for doc in documents: + doc["score"] = self.scale_to_unit_interval(doc["score"], self.similarity) + + documents = [mongo_doc_to_haystack_doc(doc) for doc in documents] + return documents + + def update_document_meta(self, id: str, meta: Dict[str, str], index: Optional[str] = None): + """ + Update the metadata dictionary of a document by specifying its string ID. + + :param id: ID of the Document to update. + :param meta: Dictionary of new metadata. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + """ + collection = self._get_collection(index) + collection.update_one({"id": id}, {"$set": {"meta": meta}}) + + def write_documents( + self, + documents: Union[List[dict], List[Document]], + index: Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + duplicate_documents: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): + """ + Parameters: + + documents: List of `Dicts` or `Documents` + index (str): search index name - contain letters, numbers, hyphens, or underscores + :param duplicate_documents: handle duplicate documents based on parameter options. + Parameter options: + - `"overwrite"`: Update any existing documents with the same ID when adding documents. + - `"skip"`: Ignore the duplicate documents. + - `"fail"`: An error is raised if the document ID of the document being added already exists. + + "overwrite" is the default behaviour. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + collection = self._get_collection(index) + + duplicate_documents = duplicate_documents or self.duplicate_documents + + field_map = self._create_document_field_map() + documents = [ + Document.from_dict(doc, field_map=field_map) if isinstance(doc, dict) else doc for doc in documents + ] + + mongo_documents = list(map(Document.to_dict, documents)) + + with tqdm( + total=len(mongo_documents), + disable=not self.progress_bar, + position=0, + unit=" docs", + desc="Writing Documents", + ) as progress_bar: + batches = get_batches_from_generator(mongo_documents, batch_size) + for batch in batches: + match duplicate_documents: + case "skip": + operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in batch] + case "fail": + operations = [InsertOne(doc) for doc in batch] + case _: + operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in batch] + + collection.bulk_write(operations) + progress_bar.update(len(batch)) + + def write_labels(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def update_embeddings( + self, + retriever: DenseRetriever, + index: Optional[str] = None, + update_existing_embeddings: bool = True, + filters: Optional[FilterType] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + ): + """ + Updates the embeddings in the document store using the encoding model specified in the retriever. + + This can be useful if you want to add or change the embeddings for your documents (e.g. after changing the + retriever config). + + :param retriever: Retriever to use to get embeddings for text. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param update_existing_embeddings: Whether to update existing embeddings of the documents. If set to `False`, + only documents without embeddings are processed. This mode can be used for incremental updating of + embeddings, wherein, only newly indexed documents get processed. + :param filters: optional filters (see get_all_documents for description). + :param batch_size: Number of documents to process at a time. When working with large number of documents, + batching can help reduce memory footprint. " + """ + filters = filters or {} + document_count = self.get_document_count( + index=index, filters=filters, only_documents_without_embedding=not update_existing_embeddings + ) + + if not update_existing_embeddings: + filters = {"$and": [filters, {"embedding": {"$eq": None}}]} + + documents = self.get_all_documents_generator( + index=index, filters=filters, return_embedding=False, batch_size=batch_size + ) + + collection = self._get_collection(index) + + with tqdm( + total=document_count, disable=not self.progress_bar, unit=" docs", desc="Updating Embeddings" + ) as progress_bar: + batches = get_batches_from_generator(documents, batch_size) + for batch in batches: + embeddings = retriever.embed_documents(batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(batch), embedding_dim=self.embedding_dim + ) + if self.similarity == "cosine": + self.normalize_embedding(embeddings) + + mongo_documents = [haystack_doc_to_mongo_doc(doc) for doc in batch] + + for doc, embedding in zip(mongo_documents, embeddings.tolist()): + doc["embedding"] = embedding + + updates = [ReplaceOne({"id": doc["id"]}, doc) for doc in mongo_documents] + collection.bulk_write(updates) + progress_bar.update(len(batch)) + + +class MongoDBAtlasDocumentStoreError(DocumentStoreError): + """Exception for issues that occur in a MongoDBAtlas document store""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class ValidationError(Exception): + """Exception for validation errors""" + + pass + + +def _validate_mongo_connection_string(mongo_connection_string): + if not mongo_connection_string: + raise MongoDBAtlasDocumentStoreError( + "A `mongodb_connection_string` is required. This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button." + ) + return mongo_connection_string + + +def _validate_database_name(database_name): + # There doesn't seem to be much restriction on the name here? All sorts of special character are apparently allowed... + # Just check if it's there. + if not database_name: + raise ValidationError("A `database_name` is required.") + return database_name + + +def _validate_collection_name(collection_name): + # There doesn't seem to be much restriction on the name here? All sorts of special character are apparently allowed... + # Just check if it's there. + if not collection_name: + raise ValidationError("A `collection_name` is required.") + return collection_name + + +def _validate_similarity(similarity): + if similarity not in METRIC_TYPES: + raise ValueError( + "MongoDB Atlas currently supports dotProduct, cosine and euclidean metrics. Please set similarity to one of the above." + ) + return similarity + + +def _validate_index_name(index_name): + if index_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", index_name)): + raise ValueError( + f'Invalid index name: "{index_name}". Index name can only contain letters, numbers, hyphens, or underscores.' + ) + return index_name + + +def mongo_doc_to_haystack_doc(mongo_doc) -> Document: + embedding = mongo_doc.get("embedding", None) + score = mongo_doc.get["score"] + + return Document( + id=mongo_doc["id"], + content=mongo_doc["content"], + content_type=mongo_doc["content_type"], + meta=mongo_doc["meta"], + embedding=embedding, + score=score, + ) + + +def haystack_doc_to_mongo_doc(haystack_doc) -> Dict: + return { + "id": haystack_doc.id, + "content": haystack_doc.content, + "content_type": haystack_doc.content_type, + "meta": haystack_doc.meta, + } diff --git a/haystack/document_stores/mongodb_filters.py b/haystack/document_stores/mongodb_filters.py new file mode 100644 index 0000000000..691ffcc8fb --- /dev/null +++ b/haystack/document_stores/mongodb_filters.py @@ -0,0 +1,91 @@ +from typing import Union, Any, Dict + +FILTER_OPERATORS = ["$and", "$or", "$not", "$eq", "$in", "$gt", "$gte", "$lt", "$lte"] +EXCLUDE_FROM_METADATA_PREPEND = ["id", "embedding"] + +METADATA_FIELD_NAME = "meta" + + +def mongo_filter_converter(filter) -> Dict[str, Any]: + if not filter: + return {} + else: + filter = _target_filter_to_metadata(filter, METADATA_FIELD_NAME) + filter = _and_or_to_list(filter) + return filter + + +def _target_filter_to_metadata(filter, metadata_field_name) -> Union[Dict[str, Any], list]: + """ + Returns a new filter with any non-operator, non-excluded keys renamed so that the metadata + field name is prepended. Does not mutate input filter. + + Example: + + { + "$and": { + "url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", + "_split_id": 0 + } + } + + will be replaced with: + + { + "$and": { + "meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", + "meta._split_id": 0 + } + } + + """ + if isinstance(filter, dict): + updated_dict = {} + for key, value in filter.items(): + if key not in FILTER_OPERATORS + EXCLUDE_FROM_METADATA_PREPEND: + key = f"{metadata_field_name}.{key}" + + if isinstance(value, (dict, list)): + updated_dict[key] = _target_filter_to_metadata(value, metadata_field_name) + else: + updated_dict[key] = value + return updated_dict + elif isinstance(filter, list): + return [_target_filter_to_metadata(item, metadata_field_name) for item in filter] + return filter + + +def _and_or_to_list(filter) -> Union[Dict[str, Any], list]: + """ + Returns a new filter replacing any dict values associated with "$and" or "$or" keys + with a list. Does not mutate input filter. + + Example: + + { + "$and": { + "url": {"$eq": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + "_split_id": {"$eq": 0}, + }, + } + + will be replaced with: + + { + "$and": [ + {"url": {"$eq": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}}, + {"_split_id": {"$eq": 0}}, + ] + } + """ + if isinstance(filter, dict): + updated_dict = filter.copy() + if "$and" in updated_dict and isinstance(filter["$and"], dict): + updated_dict["$and"] = [{key: value} for key, value in filter["$and"].items()] + if "$or" in updated_dict and isinstance(filter["$or"], dict): + updated_dict["$or"] = [{key: value} for key, value in filter["$or"].items()] + return {key: _and_or_to_list(value) for key, value in updated_dict.items()} + elif isinstance(filter, list): + return [_and_or_to_list(item) for item in filter] + else: + return filter From 70c9def164dcc213efe9b5a3cbf577b0e1b33ca6 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 8 Dec 2023 12:33:54 -0800 Subject: [PATCH 02/12] Add tests --- haystack/document_stores/mongodb_atlas.py | 20 +- pyproject.toml | 7 +- test/document_stores/test_mongodb.py | 558 ++++++++++++++++++++++ 3 files changed, 573 insertions(+), 12 deletions(-) create mode 100644 test/document_stores/test_mongodb.py diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index e9a84ebc99..4118ab18d0 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -10,7 +10,7 @@ from .mongodb_filters import mongo_filter_converter from ..lazy_imports import LazyImport -with LazyImport("Run 'pip install farm-haystack[mongodb]'") as pinecone_import: +with LazyImport("Run 'pip install farm-haystack[mongodb]'") as mongodb_import: import pymongo from pymongo import InsertOne, ReplaceOne, UpdateOne from pymongo.collection import Collection @@ -33,6 +33,7 @@ def __init__( duplicate_documents: str = "overwrite", recreate_index: bool = False, ): + mongodb_import.check() self.mongo_connection_string = _validate_mongo_connection_string(mongo_connection_string) self.database_name = _validate_database_name(database_name) self.collection_name = _validate_collection_name(collection_name) @@ -93,15 +94,14 @@ def delete_documents( collection = self._get_collection(index) - match (ids, filters): - case (None, None): - mongo_filters = {} - case (None, filters): - mongo_filters = mongo_filter_converter(filters) - case (ids, None): - mongo_filters = {"id": {"$in": ids}} - case (ids, filters): - mongo_filters = {"$and": [mongo_filter_converter(filters), {"id": {"$in": ids}}]} + if (ids, filters) == (None, None): + mongo_filters = {} + elif (ids, filters) == (None, filters): + mongo_filters = mongo_filter_converter(filters) + elif (ids, filters) == (ids, None): + mongo_filters = {"id": {"$in": ids}} + elif (ids, filters) == (ids, filters): + mongo_filters = {"$and": [mongo_filter_converter(filters), {"id": {"$in": ids}}]} collection.delete_many(filter=mongo_filters) diff --git a/pyproject.toml b/pyproject.toml index 1a1c0e94b4..960fffb473 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,11 +130,14 @@ pinecone = [ opensearch = [ "opensearch-py>=2", ] +mongodb = [ + "pymongo>=4.6", +] docstores = [ - "farm-haystack[elasticsearch,faiss,weaviate,pinecone,opensearch]", + "farm-haystack[elasticsearch,faiss,weaviate,pinecone,opensearch,mongodb]", ] docstores-gpu = [ - "farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch]", + "farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch,mongodb]", ] aws = [ # first version to support Amazon Bedrock diff --git a/test/document_stores/test_mongodb.py b/test/document_stores/test_mongodb.py new file mode 100644 index 0000000000..193c0e379e --- /dev/null +++ b/test/document_stores/test_mongodb.py @@ -0,0 +1,558 @@ +import contextlib +import os +import re +import requests +import pytest +import numpy +import roman +from haystack.document_stores.mongodb_filters import _target_filter_to_metadata, _and_or_to_list, mongo_filter_converter +from haystack.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore, pymongo +from haystack.schema import Document +from haystack.nodes import PreProcessor, EmbeddingRetriever + +pytestmark = pytest.mark.integration + +mongo_atlas_database = "database01" +mongo_atlas_collection = "test_80_days" + +mongo_atlas_username = os.getenv("MONGO_ATLAS_USERNAME") +mongo_atlas_password = os.getenv("MONGO_ATLAS_PASSWORD") +mongo_atlas_host = os.getenv("MONGO_ATLAS_HOST") +mongo_atlas_connection_params = {"retryWrites": "true", "w": "majority"} +mongo_atlas_params_string = "&".join([f"{key}={value}" for key, value in mongo_atlas_connection_params.items()]) +mongo_atlas_connection_string = ( + f"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}" +) + +document_store = MongoDBAtlasDocumentStore( + mongo_connection_string=mongo_atlas_connection_string, + database_name=mongo_atlas_database, + collection_name=mongo_atlas_collection, + embedding_dim=768, +) + +# Test data + + +# Get the book "Around the World in 80 Days" from Project Gutenberg +def get_book_online(): + response = requests.get("https://www.gutenberg.org/ebooks/103.txt.utf-8") + if response.status_code != 200: + raise requests.HTTPError(f"HTTP error {response.status_code}") + else: + return response.text + + +def get_book_local(): + with open("test-data/80_days.txt", "r", encoding="utf-8") as file: + text = file.read() + return text + + +get_book = get_book_local + + +# Divide the book into chapters +def divide_book_into_chapters(book) -> dict: + lines = book.split("\n") + current_chapter = None + chapters = {} + for line in lines: + chapter_match = re.match(r"CHAPTER\s+([IVXLCDM]+)\.*", line) + if chapter_match: + chapter_roman = chapter_match.group(1) + chapter_decimal = roman.fromRoman(chapter_roman) + current_chapter = f"CHAPTER {chapter_decimal}".title() + chapters[current_chapter] = "" + if current_chapter: + chapters[current_chapter] += line + "\n" + return chapters + + +book = get_book() +chapters = divide_book_into_chapters(book) +documents = [ + Document(content=chapters[f"Chapter {n}"], meta={"book": "Around the World in 80 Days", "Chapter": n}) + for n in range(1, len(chapters) + 1) +] + + +def test_write_documents_skip(): + document_store.delete_documents() + + processor = PreProcessor( + clean_empty_lines=True, + clean_whitespace=True, + clean_header_footer=True, + remove_substrings=None, + split_by="word", + split_length=200, + split_respect_sentence_boundary=True, + split_overlap=0, + max_chars_check=10_000, + ) + + processed_documents = processor.process([documents[0]]) + document_store.write_documents(processed_documents) + + collection = document_store._get_collection() + + filters = {"Chapter": 1, "_split_id": 0} + collection.update_one(mongo_filter_converter(filters), {"$set": {"content": "No Content"}}) + document_store.write_documents(processed_documents, duplicate_documents="skip") + assert document_store.get_all_documents(filters=filters)[0].content == "No Content" + + +def test_write_documents_overwrite(): + document_store.delete_documents() + + processor = PreProcessor( + clean_empty_lines=True, + clean_whitespace=True, + clean_header_footer=True, + remove_substrings=None, + split_by="word", + split_length=200, + split_respect_sentence_boundary=True, + split_overlap=0, + max_chars_check=10_000, + ) + + processed_documents = processor.process([documents[0]]) + document_store.write_documents(processed_documents) + + collection = document_store._get_collection() + + filters = {"Chapter": 1, "_split_id": 0} + collection.update_one(mongo_filter_converter(filters), {"$set": {"content": "No Content"}}) + document_store.write_documents(processed_documents, duplicate_documents="overwrite") + assert document_store.get_all_documents(filters=filters)[0].content != "No Content" + + +def test_write_documents_fail(): + document_store.delete_documents() + + processor = PreProcessor( + clean_empty_lines=True, + clean_whitespace=True, + clean_header_footer=True, + remove_substrings=None, + split_by="word", + split_length=200, + split_respect_sentence_boundary=True, + split_overlap=0, + max_chars_check=10_000, + ) + + processed_documents = processor.process([documents[0]]) + document_store.write_documents(processed_documents) + document_store.write_documents(processed_documents) + with pytest.raises(pymongo.BulkWriteError): + document_store.write_documents(processed_documents, duplicate_documents="fail") + + +def test_write_documents(): + document_store.delete_documents() + + processor = PreProcessor( + clean_empty_lines=True, + clean_whitespace=True, + clean_header_footer=True, + remove_substrings=None, + split_by="word", + split_length=200, + split_respect_sentence_boundary=True, + split_overlap=0, + max_chars_check=10_000, + ) + + processed_documents = processor.process(documents) + document_store.write_documents(processed_documents) + + assert document_store.get_document_count() == 373 + assert document_store.get_all_documents(return_embedding=True)[0].embedding is None + + +def test_get_document_count_without_embeddings_a(): + assert document_store.get_document_count(only_documents_without_embedding=True) == 373 + + +def test_get_embedding_count_a(): + assert document_store.get_embedding_count() == 0 + + +def test_get_document_count_without_embeddings_with_filter(): + assert document_store.get_document_count(filters={"Chapter": 1}, only_documents_without_embedding=True) == 8 + + +def test_update_embeddings_filtered(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + top_k=10, + ) + filters = {"Chapter": 1, "_split_id": 0} + document_store.update_embeddings(retriever, batch_size=30, filters=filters) + assert isinstance( + document_store.get_all_documents(return_embedding=True, filters=filters)[0].embedding, numpy.ndarray + ) + + +def test_update_embeddings(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + top_k=10, + ) + + document_store.update_embeddings(retriever, batch_size=30) + assert isinstance(document_store.get_all_documents(return_embedding=True)[0].embedding, numpy.ndarray) + + +def test_update_embeddings_not_existing(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + top_k=10, + ) + filters = {"Chapter": 1, "_split_id": 0} + filters2 = {"Chapter": 1, "_split_id": 1} + collection = document_store._get_collection() + + collection.update_one(mongo_filter_converter(filters), {"$set": {"embedding": None}}) + collection.update_one(mongo_filter_converter(filters2), {"$set": {"embedding": "not_an_embedding"}}) + + document_store.update_embeddings(retriever, batch_size=30, update_existing_embeddings=False) + assert isinstance(collection.find_one(mongo_filter_converter(filters))["embedding"], list) + assert collection.find_one(mongo_filter_converter(filters2))["embedding"] == "not_an_embedding" + + +def test_get_embedding_count_b(): + assert document_store.get_embedding_count() == 373 + + +# Getting documents + + +def test_get_all_documents_without_embedings(): + assert document_store.get_all_documents()[0].embedding is None + assert document_store.get_all_documents(return_embedding=False)[0].embedding is None + + +def test_get_all_documents_with_embedings(): + assert isinstance(document_store.get_all_documents(return_embedding=True)[0].embedding, numpy.ndarray) + + +def test_get_all_documents(): + assert len(document_store.get_all_documents()) == 373 + + +def test_get_all_documents_filtered(): + assert len(document_store.get_all_documents(filters={"Chapter": 1})) == 8 + + +def test_get_document_by_id_a(): + documents = document_store.get_all_documents(filters={"Chapter": 1, "_split_id": 0}) + assert len(documents) == 1 + document_id = documents[0].id + assert isinstance(document_store.get_document_by_id(id=document_id), Document) + + +def test_get_documents_by_id_b(): + documents = document_store.get_all_documents(filters={"Chapter": 1}) + document_ids = [document.id for document in documents] + assert len(document_ids) > 1 + assert len(document_store.get_documents_by_id(ids=document_ids)) == len(document_ids) + + +def test_get_all_documents_headers_throws(): + with pytest.raises(NotImplementedError): + document_store.get_all_documents(headers={"key": "value"}) + + +def test_get_document_count(): + assert document_store.get_document_count() == 373 + + +def test_get_document_count_without_embeddings_b(): + assert document_store.get_document_count(only_documents_without_embedding=True) == 0 + + +def test_get_document_count_filtered(): + assert document_store.get_document_count(filters={"Chapter": 1}) == 8 + assert document_store.get_document_count(filters={"Chapter": 1, "_split_id": 0}) == 1 + + +# Updating document meta + + +def test_update_document_meta(): + document = document_store.get_all_documents(filters={"Chapter": 1, "_split_id": 0})[0] + new_meta = document.meta + new_meta["new_field"] = "New metadata" + document_store.update_document_meta(id=document.id, meta=new_meta) + updated_document = document_store.get_all_documents(filters={"Chapter": 1, "_split_id": 0})[0] + assert "new_field" in updated_document.meta + assert updated_document.meta["new_field"] == "New metadata" + + +# Deleting documents + + +def test_delete_documents_filtered(): + document_store.delete_documents(filters={"Chapter": 1, "_split_id": 0}) + assert document_store.get_document_count() == 372 + + +def test_delete_documents_by_id(): + documents = document_store.get_all_documents(filters={"Chapter": 1}) + document_ids = [document.id for document in documents] + document_store.delete_documents(ids=document_ids) + assert document_store.get_document_count() == 365 + + +def test_delete_documents_by_id_filtered(): + documents = document_store.get_all_documents(filters={"Chapter": 2}) + document_ids = [document.id for document in documents] + document_store.delete_documents(ids=document_ids, filters={"_split_id": 0}) # Only deletes the intersection + assert document_store.get_document_count() == 364 + + +def test_delete_documents(): + document_store.delete_documents() + assert document_store.get_document_count() == 0 + + +def test_delete_documents_headers_throws(): + with pytest.raises(NotImplementedError): + document_store.delete_documents(headers={"key": "value"}) + + +def test_delete_index(): + document_store.delete_index() + client = pymongo.MongoClient(mongo_atlas_connection_string) + database = client[mongo_atlas_database] + assert "test_80_days" not in database.list_collection_names() + + +def test_delete_index_with_index(): + client = pymongo.MongoClient(mongo_atlas_connection_string) + database = client[mongo_atlas_database] + with contextlib.suppress(Exception): + database.create_collection("deleteme") + assert "deleteme" in database.list_collection_names() + document_store.delete_index(index="deleteme") + assert "deleteme" not in database.list_collection_names() + + +def test__create_document_field_map_a(): + assert document_store._create_document_field_map() == {"embedding": "embedding"} + + +def test__create_document_field_map_b(): + document_store = MongoDBAtlasDocumentStore( + mongo_connection_string=mongo_atlas_connection_string, + database_name=mongo_atlas_database, + collection_name=mongo_atlas_collection, + embedding_dim=768, + embedding_field="emb", + ) + assert document_store._create_document_field_map() == {"emb": "embedding"} + + +def test__get_collection_no_index(): + collection = document_store._get_collection() + assert collection.name == "test_80_days" + + +def test__get_collection_with_index(): + collection = document_store._get_collection(index="index_abcdefg") + assert collection.name == "index_abcdefg" + + +def test__get_collection_invalid_index(): + with pytest.raises(ValueError): + document_store._get_collection(index="index_a!!bcdefg") + + +def test_write_documents_index(): + document_store.delete_documents() + + processor = PreProcessor( + clean_empty_lines=True, + clean_whitespace=True, + clean_header_footer=True, + remove_substrings=None, + split_by="word", + split_length=200, + split_respect_sentence_boundary=True, + split_overlap=0, + max_chars_check=10_000, + ) + + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + top_k=10, + ) + + processed_documents = processor.process(documents) + document_store.write_documents(processed_documents) + + assert document_store.get_document_count() == 373 + assert document_store.get_all_documents(return_embedding=True)[0].embedding is None + + document_store.update_embeddings(retriever, batch_size=30) + assert isinstance(document_store.get_all_documents(return_embedding=True)[0].embedding, numpy.ndarray) + + +def test_query_by_embedding_default_topk(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + ) + embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] + results = document_store.query_by_embedding(query_emb=embedding) + assert results[0].embedding == None + assert len(results) == 10 + + +def test_query_by_embedding_default_topk_4(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + ) + embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] + results = document_store.query_by_embedding(query_emb=embedding, top_k=4) + assert results[0].embedding == None + assert len(results) == 4 + + +def test_query_by_embedding_filtered(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + top_k=10, + ) + embedding = retriever.embed_queries(["Who was Phileas Fogg?"])[0] + results = document_store.query_by_embedding(query_emb=embedding, filters={"Chapter": 1}) + assert len(results) == 3 + + +def test_query_by_embedding_include_embedding(): + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html + model_format="sentence_transformers", + top_k=10, + ) + embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] + results = document_store.query_by_embedding(query_emb=embedding, return_embedding=True) + assert isinstance(results[0].embedding, numpy.ndarray) + assert len(results) == 10 + + +def test_and_or_meta_converted(): + test_filter = { + "$and": [ + {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + {"_split_id": 0}, + {"$or": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]}, + ] + } + + target_outcome = { + "$and": [ + {"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + {"meta._split_id": 0}, + {"$or": [{"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"meta._split_id": 0}]}, + ] + } + assert _and_or_to_list(_target_filter_to_metadata(test_filter, "meta")) == target_outcome + + +def test_mongo_filter_converter_and_or_meta_converted(): + test_filter = { + "$and": [ + {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + {"_split_id": 0}, + {"$or": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]}, + ] + } + + target_outcome = { + "$and": [ + {"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + {"meta._split_id": 0}, + {"$or": [{"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"meta._split_id": 0}]}, + ] + } + + assert mongo_filter_converter(test_filter) == target_outcome + + +def test_mongo_filter_converter_falsey_empty_dict(): + assert mongo_filter_converter(None) == {} + assert mongo_filter_converter("") == {} + assert mongo_filter_converter({}) == {} + assert mongo_filter_converter([]) == {} + assert mongo_filter_converter(0) == {} + + +def test__target_filter_to_metadata_01(): + test_filter = {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "_split_id": 0} + + target_outcome = {"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "meta._split_id": 0} + + assert _target_filter_to_metadata(test_filter, "meta") == target_outcome + + +def test__target_filter_to_metadata_02(): + test_filter = {"$and": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]} + + target_outcome = {"$and": [{"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"meta._split_id": 0}]} + + assert _target_filter_to_metadata(test_filter, "meta") == target_outcome + + +def test__target_filter_to_metadata_leave_id(): + test_filter = {"id": {"$in": ["b714102aa7ac3a9622d0d00caa55fa", "b3de1a673c1eb2876585405395a10c3d"]}} + + target_outcome = {"id": {"$in": ["b714102aa7ac3a9622d0d00caa55fa", "b3de1a673c1eb2876585405395a10c3d"]}} + + assert _target_filter_to_metadata(test_filter, "meta") == target_outcome + + +def test__and_or_to_list_01(): + test_filter = {"$and": {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "_split_id": 0}} + + target_outcome = {"$and": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]} + + assert _and_or_to_list(test_filter) == target_outcome + + +def test__and_or_to_list_02(): + test_filter = { + "$and": { + "url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", + "_split_id": 0, + "$or": {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "_split_id": 0}, + } + } + + target_outcome = { + "$and": [ + {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + {"_split_id": 0}, + {"$or": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]}, + ] + } + + assert _and_or_to_list(test_filter) == target_outcome From 5031b39bdde5ea2b7b8d84a7c615c8654a438540 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 11 Dec 2023 11:20:02 -0800 Subject: [PATCH 03/12] Fix linting --- haystack/document_stores/mongodb_atlas.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 4118ab18d0..79c7a780d8 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -1,12 +1,12 @@ import re from typing import Dict, Generator, List, Optional, Union import numpy as np +from tqdm import tqdm from haystack.document_stores import BaseDocumentStore from haystack.errors import DocumentStoreError from haystack.nodes.retriever import DenseRetriever from haystack.schema import Document, FilterType from haystack.utils import get_batches_from_generator -from tqdm import tqdm from .mongodb_filters import mongo_filter_converter from ..lazy_imports import LazyImport @@ -37,7 +37,7 @@ def __init__( self.mongo_connection_string = _validate_mongo_connection_string(mongo_connection_string) self.database_name = _validate_database_name(database_name) self.collection_name = _validate_collection_name(collection_name) - self.connection = pymongo.MongoClient(self.mongo_connection_string) + self.connection: pymongo.MongoClient = pymongo.MongoClient(self.mongo_connection_string) self.database = self.connection[self.database_name] self.similarity = _validate_similarity(similarity) self.duplicate_documents = duplicate_documents @@ -268,7 +268,7 @@ def get_documents_by_id( result = self.get_all_documents_generator( index=index, - filters=mongo_filters, + filters=mongo_filters, # type: ignore [arg-type] return_embedding=return_embedding, batch_size=batch_size, headers=headers, @@ -412,13 +412,13 @@ def write_documents( ) as progress_bar: batches = get_batches_from_generator(mongo_documents, batch_size) for batch in batches: - match duplicate_documents: - case "skip": - operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in batch] - case "fail": - operations = [InsertOne(doc) for doc in batch] - case _: - operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in batch] + operations: list[Union[UpdateOne, InsertOne, ReplaceOne]] + if duplicate_documents == "skip": + operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in batch] + elif duplicate_documents == "fail": + operations = [InsertOne(doc) for doc in batch] + else: + operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in batch] collection.bulk_write(operations) progress_bar.update(len(batch)) From bb17185ea4ec6fe67c5bc019c54ceb1762a66f69 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 11 Dec 2023 11:31:11 -0800 Subject: [PATCH 04/12] Add release notes + remove extra roman import --- ...mongodb-document-store-34bd05d03717fb62.yaml | 4 ++++ test/document_stores/test_mongodb.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) create mode 100644 releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml diff --git a/releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml b/releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml new file mode 100644 index 0000000000..238aa5d23d --- /dev/null +++ b/releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add `MongoDBAtlasDocumentStore`, providing support for MongoDB Atlas as a document store. diff --git a/test/document_stores/test_mongodb.py b/test/document_stores/test_mongodb.py index 193c0e379e..2efc4ef639 100644 --- a/test/document_stores/test_mongodb.py +++ b/test/document_stores/test_mongodb.py @@ -4,7 +4,6 @@ import requests import pytest import numpy -import roman from haystack.document_stores.mongodb_filters import _target_filter_to_metadata, _and_or_to_list, mongo_filter_converter from haystack.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore, pymongo from haystack.schema import Document @@ -31,7 +30,19 @@ embedding_dim=768, ) -# Test data + +def roman_to_int(s: str) -> int: + mapping = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000} + + result = 0 + + for i in range(len(s)): + if i < len(s) - 1 and mapping[s[i]] < mapping[s[i + 1]]: + result -= mapping[s[i]] + else: + result += mapping[s[i]] + + return result # Get the book "Around the World in 80 Days" from Project Gutenberg @@ -61,7 +72,7 @@ def divide_book_into_chapters(book) -> dict: chapter_match = re.match(r"CHAPTER\s+([IVXLCDM]+)\.*", line) if chapter_match: chapter_roman = chapter_match.group(1) - chapter_decimal = roman.fromRoman(chapter_roman) + chapter_decimal = roman_to_int(chapter_roman) current_chapter = f"CHAPTER {chapter_decimal}".title() chapters[current_chapter] = "" if current_chapter: From fc90bc6788c57730b40c794ea7d479af8b9f4c84 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 11 Dec 2023 11:39:03 -0800 Subject: [PATCH 05/12] Add pymongo driver info --- haystack/document_stores/mongodb_atlas.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 79c7a780d8..61ed365075 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -7,6 +7,7 @@ from haystack.nodes.retriever import DenseRetriever from haystack.schema import Document, FilterType from haystack.utils import get_batches_from_generator +from haystack import __version__ as haystack_version from .mongodb_filters import mongo_filter_converter from ..lazy_imports import LazyImport @@ -14,6 +15,7 @@ import pymongo from pymongo import InsertOne, ReplaceOne, UpdateOne from pymongo.collection import Collection + from pymongo.driver_info import DriverInfo METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] DEFAULT_BATCH_SIZE = 50 @@ -37,7 +39,9 @@ def __init__( self.mongo_connection_string = _validate_mongo_connection_string(mongo_connection_string) self.database_name = _validate_database_name(database_name) self.collection_name = _validate_collection_name(collection_name) - self.connection: pymongo.MongoClient = pymongo.MongoClient(self.mongo_connection_string) + self.connection: pymongo.MongoClient = pymongo.MongoClient( + self.mongo_connection_string, driver=DriverInfo(name="Haystack", version=haystack_version) + ) self.database = self.connection[self.database_name] self.similarity = _validate_similarity(similarity) self.duplicate_documents = duplicate_documents From 8392881571d90f1103a276a55435d12dc3766ce6 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Tue, 12 Dec 2023 10:49:53 +0100 Subject: [PATCH 06/12] mypy and pylint fixes --- haystack/document_stores/mongodb_atlas.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 61ed365075..e6c8036868 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -36,6 +36,8 @@ def __init__( recreate_index: bool = False, ): mongodb_import.check() + super().__init__() + self.mongo_connection_string = _validate_mongo_connection_string(mongo_connection_string) self.database_name = _validate_database_name(database_name) self.collection_name = _validate_collection_name(collection_name) @@ -416,7 +418,7 @@ def write_documents( ) as progress_bar: batches = get_batches_from_generator(mongo_documents, batch_size) for batch in batches: - operations: list[Union[UpdateOne, InsertOne, ReplaceOne]] + operations: List[Union[UpdateOne, InsertOne, ReplaceOne]] if duplicate_documents == "skip": operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in batch] elif duplicate_documents == "fail": From 2cb49431984ec0e82c69b80cc148dd8714b599ac Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 12 Dec 2023 10:30:50 -0800 Subject: [PATCH 07/12] Use future import annotations --- haystack/document_stores/mongodb_atlas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index e6c8036868..06216bdf64 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from typing import Dict, Generator, List, Optional, Union import numpy as np From d7a1d6e4c40c7b28879197a9553e9634ad8ad368 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 13 Dec 2023 14:32:06 +0100 Subject: [PATCH 08/12] fix errors in tests --- haystack/document_stores/mongodb_atlas.py | 2 +- test/document_stores/test_mongodb.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 06216bdf64..801e171fce 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -548,7 +548,7 @@ def _validate_index_name(index_name): def mongo_doc_to_haystack_doc(mongo_doc) -> Document: embedding = mongo_doc.get("embedding", None) - score = mongo_doc.get["score"] + score = mongo_doc.get("score") return Document( id=mongo_doc["id"], diff --git a/test/document_stores/test_mongodb.py b/test/document_stores/test_mongodb.py index 2efc4ef639..8265b1a3fc 100644 --- a/test/document_stores/test_mongodb.py +++ b/test/document_stores/test_mongodb.py @@ -158,7 +158,7 @@ def test_write_documents_fail(): processed_documents = processor.process([documents[0]]) document_store.write_documents(processed_documents) document_store.write_documents(processed_documents) - with pytest.raises(pymongo.BulkWriteError): + with pytest.raises(pymongo.errors.BulkWriteError): document_store.write_documents(processed_documents, duplicate_documents="fail") From 9e13a23c1c561beea410aaf3556fb7e12ee38fef Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 13 Dec 2023 16:05:31 +0100 Subject: [PATCH 09/12] remove from __future__ import --- haystack/document_stores/mongodb_atlas.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 801e171fce..16772fb937 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import re from typing import Dict, Generator, List, Optional, Union import numpy as np @@ -16,7 +14,6 @@ with LazyImport("Run 'pip install farm-haystack[mongodb]'") as mongodb_import: import pymongo from pymongo import InsertOne, ReplaceOne, UpdateOne - from pymongo.collection import Collection from pymongo.driver_info import DriverInfo METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] @@ -67,7 +64,7 @@ def __init__( def _create_document_field_map(self) -> Dict: return {self.embedding_field: "embedding"} - def _get_collection(self, index=None) -> Collection: + def _get_collection(self, index=None) -> "pymongo.collection.Collection": """ Returns the collection named by index or returns the collection specified when the driver was initialized. From d7bcd79376079803ccdd3715cde899840106d0bb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 13 Dec 2023 09:41:34 -0800 Subject: [PATCH 10/12] Fix --- test/document_stores/test_mongodb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/document_stores/test_mongodb.py b/test/document_stores/test_mongodb.py index 8265b1a3fc..7e946d10fa 100644 --- a/test/document_stores/test_mongodb.py +++ b/test/document_stores/test_mongodb.py @@ -60,7 +60,7 @@ def get_book_local(): return text -get_book = get_book_local +get_book = get_book_online # Divide the book into chapters @@ -429,7 +429,7 @@ def test_query_by_embedding_default_topk(): ) embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] results = document_store.query_by_embedding(query_emb=embedding) - assert results[0].embedding == None + assert results[0].embedding is None assert len(results) == 10 From 0ca2114d0ae8567e018e66add37414ac7feb6a4a Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Thu, 14 Dec 2023 09:43:49 +0100 Subject: [PATCH 11/12] remove tests --- test/document_stores/test_mongodb.py | 569 --------------------------- 1 file changed, 569 deletions(-) delete mode 100644 test/document_stores/test_mongodb.py diff --git a/test/document_stores/test_mongodb.py b/test/document_stores/test_mongodb.py deleted file mode 100644 index 7e946d10fa..0000000000 --- a/test/document_stores/test_mongodb.py +++ /dev/null @@ -1,569 +0,0 @@ -import contextlib -import os -import re -import requests -import pytest -import numpy -from haystack.document_stores.mongodb_filters import _target_filter_to_metadata, _and_or_to_list, mongo_filter_converter -from haystack.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore, pymongo -from haystack.schema import Document -from haystack.nodes import PreProcessor, EmbeddingRetriever - -pytestmark = pytest.mark.integration - -mongo_atlas_database = "database01" -mongo_atlas_collection = "test_80_days" - -mongo_atlas_username = os.getenv("MONGO_ATLAS_USERNAME") -mongo_atlas_password = os.getenv("MONGO_ATLAS_PASSWORD") -mongo_atlas_host = os.getenv("MONGO_ATLAS_HOST") -mongo_atlas_connection_params = {"retryWrites": "true", "w": "majority"} -mongo_atlas_params_string = "&".join([f"{key}={value}" for key, value in mongo_atlas_connection_params.items()]) -mongo_atlas_connection_string = ( - f"mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}" -) - -document_store = MongoDBAtlasDocumentStore( - mongo_connection_string=mongo_atlas_connection_string, - database_name=mongo_atlas_database, - collection_name=mongo_atlas_collection, - embedding_dim=768, -) - - -def roman_to_int(s: str) -> int: - mapping = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000} - - result = 0 - - for i in range(len(s)): - if i < len(s) - 1 and mapping[s[i]] < mapping[s[i + 1]]: - result -= mapping[s[i]] - else: - result += mapping[s[i]] - - return result - - -# Get the book "Around the World in 80 Days" from Project Gutenberg -def get_book_online(): - response = requests.get("https://www.gutenberg.org/ebooks/103.txt.utf-8") - if response.status_code != 200: - raise requests.HTTPError(f"HTTP error {response.status_code}") - else: - return response.text - - -def get_book_local(): - with open("test-data/80_days.txt", "r", encoding="utf-8") as file: - text = file.read() - return text - - -get_book = get_book_online - - -# Divide the book into chapters -def divide_book_into_chapters(book) -> dict: - lines = book.split("\n") - current_chapter = None - chapters = {} - for line in lines: - chapter_match = re.match(r"CHAPTER\s+([IVXLCDM]+)\.*", line) - if chapter_match: - chapter_roman = chapter_match.group(1) - chapter_decimal = roman_to_int(chapter_roman) - current_chapter = f"CHAPTER {chapter_decimal}".title() - chapters[current_chapter] = "" - if current_chapter: - chapters[current_chapter] += line + "\n" - return chapters - - -book = get_book() -chapters = divide_book_into_chapters(book) -documents = [ - Document(content=chapters[f"Chapter {n}"], meta={"book": "Around the World in 80 Days", "Chapter": n}) - for n in range(1, len(chapters) + 1) -] - - -def test_write_documents_skip(): - document_store.delete_documents() - - processor = PreProcessor( - clean_empty_lines=True, - clean_whitespace=True, - clean_header_footer=True, - remove_substrings=None, - split_by="word", - split_length=200, - split_respect_sentence_boundary=True, - split_overlap=0, - max_chars_check=10_000, - ) - - processed_documents = processor.process([documents[0]]) - document_store.write_documents(processed_documents) - - collection = document_store._get_collection() - - filters = {"Chapter": 1, "_split_id": 0} - collection.update_one(mongo_filter_converter(filters), {"$set": {"content": "No Content"}}) - document_store.write_documents(processed_documents, duplicate_documents="skip") - assert document_store.get_all_documents(filters=filters)[0].content == "No Content" - - -def test_write_documents_overwrite(): - document_store.delete_documents() - - processor = PreProcessor( - clean_empty_lines=True, - clean_whitespace=True, - clean_header_footer=True, - remove_substrings=None, - split_by="word", - split_length=200, - split_respect_sentence_boundary=True, - split_overlap=0, - max_chars_check=10_000, - ) - - processed_documents = processor.process([documents[0]]) - document_store.write_documents(processed_documents) - - collection = document_store._get_collection() - - filters = {"Chapter": 1, "_split_id": 0} - collection.update_one(mongo_filter_converter(filters), {"$set": {"content": "No Content"}}) - document_store.write_documents(processed_documents, duplicate_documents="overwrite") - assert document_store.get_all_documents(filters=filters)[0].content != "No Content" - - -def test_write_documents_fail(): - document_store.delete_documents() - - processor = PreProcessor( - clean_empty_lines=True, - clean_whitespace=True, - clean_header_footer=True, - remove_substrings=None, - split_by="word", - split_length=200, - split_respect_sentence_boundary=True, - split_overlap=0, - max_chars_check=10_000, - ) - - processed_documents = processor.process([documents[0]]) - document_store.write_documents(processed_documents) - document_store.write_documents(processed_documents) - with pytest.raises(pymongo.errors.BulkWriteError): - document_store.write_documents(processed_documents, duplicate_documents="fail") - - -def test_write_documents(): - document_store.delete_documents() - - processor = PreProcessor( - clean_empty_lines=True, - clean_whitespace=True, - clean_header_footer=True, - remove_substrings=None, - split_by="word", - split_length=200, - split_respect_sentence_boundary=True, - split_overlap=0, - max_chars_check=10_000, - ) - - processed_documents = processor.process(documents) - document_store.write_documents(processed_documents) - - assert document_store.get_document_count() == 373 - assert document_store.get_all_documents(return_embedding=True)[0].embedding is None - - -def test_get_document_count_without_embeddings_a(): - assert document_store.get_document_count(only_documents_without_embedding=True) == 373 - - -def test_get_embedding_count_a(): - assert document_store.get_embedding_count() == 0 - - -def test_get_document_count_without_embeddings_with_filter(): - assert document_store.get_document_count(filters={"Chapter": 1}, only_documents_without_embedding=True) == 8 - - -def test_update_embeddings_filtered(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - top_k=10, - ) - filters = {"Chapter": 1, "_split_id": 0} - document_store.update_embeddings(retriever, batch_size=30, filters=filters) - assert isinstance( - document_store.get_all_documents(return_embedding=True, filters=filters)[0].embedding, numpy.ndarray - ) - - -def test_update_embeddings(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - top_k=10, - ) - - document_store.update_embeddings(retriever, batch_size=30) - assert isinstance(document_store.get_all_documents(return_embedding=True)[0].embedding, numpy.ndarray) - - -def test_update_embeddings_not_existing(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - top_k=10, - ) - filters = {"Chapter": 1, "_split_id": 0} - filters2 = {"Chapter": 1, "_split_id": 1} - collection = document_store._get_collection() - - collection.update_one(mongo_filter_converter(filters), {"$set": {"embedding": None}}) - collection.update_one(mongo_filter_converter(filters2), {"$set": {"embedding": "not_an_embedding"}}) - - document_store.update_embeddings(retriever, batch_size=30, update_existing_embeddings=False) - assert isinstance(collection.find_one(mongo_filter_converter(filters))["embedding"], list) - assert collection.find_one(mongo_filter_converter(filters2))["embedding"] == "not_an_embedding" - - -def test_get_embedding_count_b(): - assert document_store.get_embedding_count() == 373 - - -# Getting documents - - -def test_get_all_documents_without_embedings(): - assert document_store.get_all_documents()[0].embedding is None - assert document_store.get_all_documents(return_embedding=False)[0].embedding is None - - -def test_get_all_documents_with_embedings(): - assert isinstance(document_store.get_all_documents(return_embedding=True)[0].embedding, numpy.ndarray) - - -def test_get_all_documents(): - assert len(document_store.get_all_documents()) == 373 - - -def test_get_all_documents_filtered(): - assert len(document_store.get_all_documents(filters={"Chapter": 1})) == 8 - - -def test_get_document_by_id_a(): - documents = document_store.get_all_documents(filters={"Chapter": 1, "_split_id": 0}) - assert len(documents) == 1 - document_id = documents[0].id - assert isinstance(document_store.get_document_by_id(id=document_id), Document) - - -def test_get_documents_by_id_b(): - documents = document_store.get_all_documents(filters={"Chapter": 1}) - document_ids = [document.id for document in documents] - assert len(document_ids) > 1 - assert len(document_store.get_documents_by_id(ids=document_ids)) == len(document_ids) - - -def test_get_all_documents_headers_throws(): - with pytest.raises(NotImplementedError): - document_store.get_all_documents(headers={"key": "value"}) - - -def test_get_document_count(): - assert document_store.get_document_count() == 373 - - -def test_get_document_count_without_embeddings_b(): - assert document_store.get_document_count(only_documents_without_embedding=True) == 0 - - -def test_get_document_count_filtered(): - assert document_store.get_document_count(filters={"Chapter": 1}) == 8 - assert document_store.get_document_count(filters={"Chapter": 1, "_split_id": 0}) == 1 - - -# Updating document meta - - -def test_update_document_meta(): - document = document_store.get_all_documents(filters={"Chapter": 1, "_split_id": 0})[0] - new_meta = document.meta - new_meta["new_field"] = "New metadata" - document_store.update_document_meta(id=document.id, meta=new_meta) - updated_document = document_store.get_all_documents(filters={"Chapter": 1, "_split_id": 0})[0] - assert "new_field" in updated_document.meta - assert updated_document.meta["new_field"] == "New metadata" - - -# Deleting documents - - -def test_delete_documents_filtered(): - document_store.delete_documents(filters={"Chapter": 1, "_split_id": 0}) - assert document_store.get_document_count() == 372 - - -def test_delete_documents_by_id(): - documents = document_store.get_all_documents(filters={"Chapter": 1}) - document_ids = [document.id for document in documents] - document_store.delete_documents(ids=document_ids) - assert document_store.get_document_count() == 365 - - -def test_delete_documents_by_id_filtered(): - documents = document_store.get_all_documents(filters={"Chapter": 2}) - document_ids = [document.id for document in documents] - document_store.delete_documents(ids=document_ids, filters={"_split_id": 0}) # Only deletes the intersection - assert document_store.get_document_count() == 364 - - -def test_delete_documents(): - document_store.delete_documents() - assert document_store.get_document_count() == 0 - - -def test_delete_documents_headers_throws(): - with pytest.raises(NotImplementedError): - document_store.delete_documents(headers={"key": "value"}) - - -def test_delete_index(): - document_store.delete_index() - client = pymongo.MongoClient(mongo_atlas_connection_string) - database = client[mongo_atlas_database] - assert "test_80_days" not in database.list_collection_names() - - -def test_delete_index_with_index(): - client = pymongo.MongoClient(mongo_atlas_connection_string) - database = client[mongo_atlas_database] - with contextlib.suppress(Exception): - database.create_collection("deleteme") - assert "deleteme" in database.list_collection_names() - document_store.delete_index(index="deleteme") - assert "deleteme" not in database.list_collection_names() - - -def test__create_document_field_map_a(): - assert document_store._create_document_field_map() == {"embedding": "embedding"} - - -def test__create_document_field_map_b(): - document_store = MongoDBAtlasDocumentStore( - mongo_connection_string=mongo_atlas_connection_string, - database_name=mongo_atlas_database, - collection_name=mongo_atlas_collection, - embedding_dim=768, - embedding_field="emb", - ) - assert document_store._create_document_field_map() == {"emb": "embedding"} - - -def test__get_collection_no_index(): - collection = document_store._get_collection() - assert collection.name == "test_80_days" - - -def test__get_collection_with_index(): - collection = document_store._get_collection(index="index_abcdefg") - assert collection.name == "index_abcdefg" - - -def test__get_collection_invalid_index(): - with pytest.raises(ValueError): - document_store._get_collection(index="index_a!!bcdefg") - - -def test_write_documents_index(): - document_store.delete_documents() - - processor = PreProcessor( - clean_empty_lines=True, - clean_whitespace=True, - clean_header_footer=True, - remove_substrings=None, - split_by="word", - split_length=200, - split_respect_sentence_boundary=True, - split_overlap=0, - max_chars_check=10_000, - ) - - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - top_k=10, - ) - - processed_documents = processor.process(documents) - document_store.write_documents(processed_documents) - - assert document_store.get_document_count() == 373 - assert document_store.get_all_documents(return_embedding=True)[0].embedding is None - - document_store.update_embeddings(retriever, batch_size=30) - assert isinstance(document_store.get_all_documents(return_embedding=True)[0].embedding, numpy.ndarray) - - -def test_query_by_embedding_default_topk(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - ) - embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] - results = document_store.query_by_embedding(query_emb=embedding) - assert results[0].embedding is None - assert len(results) == 10 - - -def test_query_by_embedding_default_topk_4(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - ) - embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] - results = document_store.query_by_embedding(query_emb=embedding, top_k=4) - assert results[0].embedding == None - assert len(results) == 4 - - -def test_query_by_embedding_filtered(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - top_k=10, - ) - embedding = retriever.embed_queries(["Who was Phileas Fogg?"])[0] - results = document_store.query_by_embedding(query_emb=embedding, filters={"Chapter": 1}) - assert len(results) == 3 - - -def test_query_by_embedding_include_embedding(): - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-mpnet-base-v2", # Recommended here: https://www.sbert.net/docs/pretrained_models.html - model_format="sentence_transformers", - top_k=10, - ) - embedding = retriever.embed_queries(["How much money was stolen from the bank?"])[0] - results = document_store.query_by_embedding(query_emb=embedding, return_embedding=True) - assert isinstance(results[0].embedding, numpy.ndarray) - assert len(results) == 10 - - -def test_and_or_meta_converted(): - test_filter = { - "$and": [ - {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, - {"_split_id": 0}, - {"$or": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]}, - ] - } - - target_outcome = { - "$and": [ - {"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, - {"meta._split_id": 0}, - {"$or": [{"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"meta._split_id": 0}]}, - ] - } - assert _and_or_to_list(_target_filter_to_metadata(test_filter, "meta")) == target_outcome - - -def test_mongo_filter_converter_and_or_meta_converted(): - test_filter = { - "$and": [ - {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, - {"_split_id": 0}, - {"$or": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]}, - ] - } - - target_outcome = { - "$and": [ - {"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, - {"meta._split_id": 0}, - {"$or": [{"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"meta._split_id": 0}]}, - ] - } - - assert mongo_filter_converter(test_filter) == target_outcome - - -def test_mongo_filter_converter_falsey_empty_dict(): - assert mongo_filter_converter(None) == {} - assert mongo_filter_converter("") == {} - assert mongo_filter_converter({}) == {} - assert mongo_filter_converter([]) == {} - assert mongo_filter_converter(0) == {} - - -def test__target_filter_to_metadata_01(): - test_filter = {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "_split_id": 0} - - target_outcome = {"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "meta._split_id": 0} - - assert _target_filter_to_metadata(test_filter, "meta") == target_outcome - - -def test__target_filter_to_metadata_02(): - test_filter = {"$and": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]} - - target_outcome = {"$and": [{"meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"meta._split_id": 0}]} - - assert _target_filter_to_metadata(test_filter, "meta") == target_outcome - - -def test__target_filter_to_metadata_leave_id(): - test_filter = {"id": {"$in": ["b714102aa7ac3a9622d0d00caa55fa", "b3de1a673c1eb2876585405395a10c3d"]}} - - target_outcome = {"id": {"$in": ["b714102aa7ac3a9622d0d00caa55fa", "b3de1a673c1eb2876585405395a10c3d"]}} - - assert _target_filter_to_metadata(test_filter, "meta") == target_outcome - - -def test__and_or_to_list_01(): - test_filter = {"$and": {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "_split_id": 0}} - - target_outcome = {"$and": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]} - - assert _and_or_to_list(test_filter) == target_outcome - - -def test__and_or_to_list_02(): - test_filter = { - "$and": { - "url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", - "_split_id": 0, - "$or": {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", "_split_id": 0}, - } - } - - target_outcome = { - "$and": [ - {"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, - {"_split_id": 0}, - {"$or": [{"url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, {"_split_id": 0}]}, - ] - } - - assert _and_or_to_list(test_filter) == target_outcome From ca079f052de0a2e94b714da73155ea28433f4723 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Thu, 14 Dec 2023 09:59:12 +0100 Subject: [PATCH 12/12] add docstring --- haystack/document_stores/mongodb_atlas.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 16772fb937..5b4b42b65d 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -34,6 +34,21 @@ def __init__( duplicate_documents: str = "overwrite", recreate_index: bool = False, ): + """ + Document Store using MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/). + It is compatible with EmbeddingRetrievers and filters. + + :param mongo_connection_string: MongoDB Atlas connection string in the format: "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". + :param database_name: Name of the database to use. + :param collection_name: Name of the collection to use. + :param embedding_dim: Dimensionality of embeddings, 768 by default. + :param return_embedding: Whether to return document embeddings when returning documents. + :param similarity: The similarity function to use for the embeddings. One of "euclidean", "cosine" or "dotProduct". "cosine" is the default. + :param embedding_field: The name of the field in the document that contains the embedding. + :param progress_bar: Whether to show a progress bar when writing documents. + :param duplicate_documents: How to handle duplicate documents. One of "overwrite", "skip" or "fail". "overwrite" is the default. + :param recreate_index: Whether to recreate the index when initializing the document store. + """ mongodb_import.check() super().__init__()