diff --git a/haystack/document_stores/pinecone.py b/haystack/document_stores/pinecone.py index 33da45f1f2..243ff19a9b 100644 --- a/haystack/document_stores/pinecone.py +++ b/haystack/document_stores/pinecone.py @@ -79,6 +79,8 @@ def __init__( environment: str = "us-west1-gcp", pinecone_index: Optional["pinecone.Index"] = None, embedding_dim: int = 768, + pods: int = 1, + pod_type: str = "p1.x1", return_embedding: bool = False, index: str = "document", similarity: str = "cosine", @@ -98,6 +100,8 @@ def __init__( regions are supported, contact Pinecone [here](https://www.pinecone.io/contact/) if required. :param pinecone_index: pinecone-client Index object, an index will be initialized or loaded if not specified. :param embedding_dim: The embedding vector size. + :param pods: The number of pods for the index to use, including replicas. Defaults to 1. + :param pod_type: The type of pod to use. Defaults to `"p1.x1"`. :param return_embedding: Whether to return document embeddings. :param index: Name of index in document store to use. :param similarity: The similarity function used to compare document vectors. `"cosine"` is the default @@ -151,6 +155,8 @@ def __init__( self.duplicate_documents = duplicate_documents # Pinecone index params + self.pods = pods + self.pod_type = pod_type self.replicas = replicas self.shards = shards self.namespace = namespace @@ -182,6 +188,8 @@ def __init__( else: self.pinecone_indexes[self.index] = self._create_index( embedding_dim=self.embedding_dim, + pods=self.pods, + pod_type=self.pod_type, index=self.index, metric_type=self.metric_type, replicas=self.replicas, @@ -199,6 +207,8 @@ def _index(self, index) -> str: def _create_index( self, embedding_dim: int, + pods: int = 1, + pod_type: str = "p1.x1", index: Optional[str] = None, metric_type: Optional[str] = "cosine", replicas: Optional[int] = 1, @@ -225,6 +235,8 @@ def _create_index( pinecone.create_index( name=index, dimension=embedding_dim, + pods=pods, + pod_type=pod_type, metric=metric_type, replicas=replicas, shards=shards, @@ -254,6 +266,8 @@ def _index_connection_exists(self, index: str, create: bool = False) -> Optional if create: return self._create_index( embedding_dim=self.embedding_dim, + pods=self.pods, + pod_type=self.pod_type, index=index, metric_type=self.metric_type, replicas=self.replicas, diff --git a/releasenotes/notes/add-arg-to-PineconeDocumentStore-984add063663e70b.yaml b/releasenotes/notes/add-arg-to-PineconeDocumentStore-984add063663e70b.yaml new file mode 100644 index 0000000000..75629e23f7 --- /dev/null +++ b/releasenotes/notes/add-arg-to-PineconeDocumentStore-984add063663e70b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Users can now define the number of pods and pod type directly when creating a PineconeDocumentStore instance. diff --git a/test/document_stores/test_pinecone.py b/test/document_stores/test_pinecone.py index 6bacb9e48b..7d92aef3e4 100644 --- a/test/document_stores/test_pinecone.py +++ b/test/document_stores/test_pinecone.py @@ -38,11 +38,16 @@ def ds(self, monkeypatch, request) -> PineconeDocumentStore: monkeypatch.setattr(f"pinecone.{fname}", function, raising=False) for cname, class_ in getmembers(pinecone_mock, isclass): monkeypatch.setattr(f"pinecone.{cname}", class_, raising=False) + params = getattr(request, "param", {}) + pods = params.get("pods", None) + pod_type = params.get("pod_type", None) return PineconeDocumentStore( api_key=os.environ.get("PINECONE_API_KEY") or "fake-pinecone-test-key", embedding_dim=768, embedding_field="embedding", + pods=pods, + pod_type=pod_type, index="haystack_tests", similarity="cosine", recreate_index=True, diff --git a/test/mocks/pinecone.py b/test/mocks/pinecone.py index e25255f2c2..b74dd459fe 100644 --- a/test/mocks/pinecone.py +++ b/test/mocks/pinecone.py @@ -45,6 +45,8 @@ def __init__( api_key: Optional[str] = None, environment: Optional[str] = None, dimension: Optional[int] = None, + pods: Optional[int] = None, + pod_type: Optional[str] = None, metric: Optional[str] = None, replicas: Optional[int] = None, shards: Optional[int] = None, @@ -55,6 +57,8 @@ def __init__( self.environment = environment self.dimension = dimension self.metric = metric + self.pods = pods + self.pod_type = pod_type self.replicas = replicas self.shards = shards self.metadata_config = metadata_config @@ -338,6 +342,8 @@ def create_index( dimension: int, metric: str = "cosine", replicas: int = 1, + pods: int = 1, + pod_type: str = "p1.x1", shards: int = 1, metadata_config: Optional[dict] = None, ): @@ -348,6 +354,8 @@ def create_index( dimension=dimension, metric=metric, replicas=replicas, + pods=pods, + pod_type=pod_type, shards=shards, metadata_config=metadata_config, )