diff --git a/knowledgehub/storages/__init__.py b/knowledgehub/storages/__init__.py index d700d60..86225a9 100644 --- a/knowledgehub/storages/__init__.py +++ b/knowledgehub/storages/__init__.py @@ -1,10 +1,15 @@ -from .docstores import BaseDocumentStore, InMemoryDocumentStore +from .docstores import ( + BaseDocumentStore, + ElasticsearchDocumentStore, + InMemoryDocumentStore, +) from .vectorstores import BaseVectorStore, ChromaVectorStore, InMemoryVectorStore __all__ = [ # Document stores "BaseDocumentStore", "InMemoryDocumentStore", + "ElasticsearchDocumentStore", # Vector stores "BaseVectorStore", "ChromaVectorStore", diff --git a/knowledgehub/storages/docstores/__init__.py b/knowledgehub/storages/docstores/__init__.py index bee4fc5..a592c08 100644 --- a/knowledgehub/storages/docstores/__init__.py +++ b/knowledgehub/storages/docstores/__init__.py @@ -1,4 +1,5 @@ from .base import BaseDocumentStore +from .elasticsearch import ElasticsearchDocumentStore from .in_memory import InMemoryDocumentStore -__all__ = ["BaseDocumentStore", "InMemoryDocumentStore"] +__all__ = ["BaseDocumentStore", "InMemoryDocumentStore", "ElasticsearchDocumentStore"] diff --git a/knowledgehub/storages/docstores/base.py b/knowledgehub/storages/docstores/base.py index e057f0b..62c9314 100644 --- a/knowledgehub/storages/docstores/base.py +++ b/knowledgehub/storages/docstores/base.py @@ -17,14 +17,13 @@ class BaseDocumentStore(ABC): self, docs: Union[Document, List[Document]], ids: Optional[Union[List[str], str]] = None, - exist_ok: bool = False, + **kwargs, ): """Add document into document store Args: docs: Document or list of documents ids: List of ids of the documents. Optional, if not set will use doc.doc_id - exist_ok: If True, will not raise error if document already exist """ ... @@ -34,10 +33,15 @@ class BaseDocumentStore(ABC): ... @abstractmethod - def get_all(self) -> dict: + def get_all(self) -> List[Document]: """Get all documents""" ... + @abstractmethod + def count(self) -> int: + """Count number of documents""" + ... + @abstractmethod def delete(self, ids: Union[List[str], str]): """Delete document by id""" diff --git a/knowledgehub/storages/docstores/elasticsearch.py b/knowledgehub/storages/docstores/elasticsearch.py new file mode 100644 index 0000000..723e960 --- /dev/null +++ b/knowledgehub/storages/docstores/elasticsearch.py @@ -0,0 +1,164 @@ +from pathlib import Path +from typing import List, Optional, Union + +from ...base import Document +from .base import BaseDocumentStore + +MAX_DOCS_TO_GET = 10**4 + + +class ElasticsearchDocumentStore(BaseDocumentStore): + """Simple memory document store that store document in a dictionary""" + + def __init__( + self, + index_name: str = "docstore", + elasticsearch_url: str = "http://localhost:9200", + k1: float = 2.0, + b: float = 0.75, + ): + try: + from elasticsearch import Elasticsearch + from elasticsearch.helpers import bulk + except ImportError: + raise ImportError( + "To use ElaticsearchDocstore please install `pip install elasticsearch`" + ) + + self.elasticsearch_url = elasticsearch_url + self.index_name = index_name + + # Create an Elasticsearch client instance + self.client = Elasticsearch(elasticsearch_url) + self.es_bulk = bulk + # Define the index settings and mappings + settings = { + "analysis": {"analyzer": {"default": {"type": "standard"}}}, + "similarity": { + "custom_bm25": { + "type": "BM25", + "k1": k1, + "b": b, + } + }, + } + mappings = { + "properties": { + "content": { + "type": "text", + "similarity": "custom_bm25", # Use the custom BM25 similarity + } + } + } + + # Create the index with the specified settings and mappings + if not self.client.indices.exists(index=index_name): + self.client.indices.create( + index=index_name, mappings=mappings, settings=settings + ) + + def add( + self, + docs: Union[Document, List[Document]], + ids: Optional[Union[List[str], str]] = None, + **kwargs + ): + """Add document into document store + + Args: + docs: list of documents to add + ids: specify the ids of documents to add or + use existing doc.doc_id + refresh_indices: request Elasticsearch to update + its index (default to True) + """ + refresh_indices = kwargs.pop("refresh_indices", True) + + if ids and not isinstance(ids, list): + ids = [ids] + if not isinstance(docs, list): + docs = [docs] + doc_ids = ids if ids else [doc.doc_id for doc in docs] + + requests = [] + for doc_id, doc in zip(doc_ids, docs): + text = doc.text + metadata = doc.metadata + request = { + "_op_type": "index", + "_index": self.index_name, + "content": text, + "metadata": metadata, + "_id": doc_id, + } + requests.append(request) + self.es_bulk(self.client, requests) + + if refresh_indices: + self.client.indices.refresh(index=self.index_name) + + def query_raw(self, query: dict) -> List[Document]: + """Query Elasticsearch store using query format of ES client + + Args: + query (dict): Elasticsearch query format + + Returns: + List[Document]: List of result documents + """ + res = self.client.search(index=self.index_name, body=query) + docs = [] + for r in res["hits"]["hits"]: + docs.append( + Document( + id_=r["_id"], + text=r["_source"]["content"], + metadata=r["_source"]["metadata"], + ) + ) + return docs + + def query(self, query: str, top_k: int = 10) -> List[Document]: + """Search Elasticsearch docstore using search query (BM25) + + Args: + query (str): query text + top_k (int, optional): number of + top documents to return. Defaults to 10. + + Returns: + List[Document]: List of result documents + """ + query_dict = {"query": {"match": {"content": query}}, "size": top_k} + return self.query_raw(query_dict) + + def get(self, ids: Union[List[str], str]) -> List[Document]: + """Get document by id""" + if not isinstance(ids, list): + ids = [ids] + query_dict = {"query": {"terms": {"_id": ids}}} + return self.query_raw(query_dict) + + def count(self) -> int: + """Count number of documents""" + count = int( + self.client.cat.count(index=self.index_name, format="json")[0]["count"] + ) + return count + + def get_all(self) -> List[Document]: + """Get all documents""" + query_dict = {"query": {"match_all": {}}, "size": MAX_DOCS_TO_GET} + return self.query_raw(query_dict) + + def delete(self, ids: Union[List[str], str]): + """Delete document by id""" + raise NotImplementedError("Delete by-id is a Work-in-Progress.") + + def save(self, path: Union[str, Path]): + """Save document to path""" + # not required for ElasticDocstore + + def load(self, path: Union[str, Path]): + """Load document store from path""" + # not required for ElasticDocstore diff --git a/knowledgehub/storages/docstores/in_memory.py b/knowledgehub/storages/docstores/in_memory.py index 23f3e22..645890e 100644 --- a/knowledgehub/storages/docstores/in_memory.py +++ b/knowledgehub/storages/docstores/in_memory.py @@ -16,20 +16,24 @@ class InMemoryDocumentStore(BaseDocumentStore): self, docs: Union[Document, List[Document]], ids: Optional[Union[List[str], str]] = None, - exist_ok: bool = False, + **kwargs, ): """Add document into document store Args: - docs: Union[Document, List[Document]], - ids: Optional[Union[List[str], str]] = None, + docs: list of documents to add + ids: specify the ids of documents to add or + use existing doc.doc_id + exist_ok: raise error when duplicate doc-id + found in the docstore (default to False) """ - doc_ids = ids if ids else [doc.doc_id for doc in docs] - if not isinstance(doc_ids, list): - doc_ids = [doc_ids] + exist_ok: bool = kwargs.pop("exist_ok", False) + if ids and not isinstance(ids, list): + ids = [ids] if not isinstance(docs, list): docs = [docs] + doc_ids = ids if ids else [doc.doc_id for doc in docs] for doc_id, doc in zip(doc_ids, docs): if doc_id in self._store and not exist_ok: @@ -43,9 +47,13 @@ class InMemoryDocumentStore(BaseDocumentStore): return [self._store[doc_id] for doc_id in ids] - def get_all(self) -> dict: + def get_all(self) -> List[Document]: """Get all documents""" - return self._store + return list(self._store.values()) + + def count(self) -> int: + """Count number of documents""" + return len(self._store) def delete(self, ids: Union[List[str], str]): """Delete document by id""" diff --git a/setup.py b/setup.py index a60a1ff..4e9e7de 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ setuptools.setup( "farm-haystack==1.19.0", "sentence_transformers", "cohere", + "elasticsearch", "pypdf", ], }, diff --git a/tests/test_docstores.py b/tests/test_docstores.py index a6cb9a0..9a14e84 100644 --- a/tests/test_docstores.py +++ b/tests/test_docstores.py @@ -1,7 +1,182 @@ +from unittest.mock import patch + import pytest +from elastic_transport import ApiResponseMeta from kotaemon.base import Document -from kotaemon.storages import InMemoryDocumentStore +from kotaemon.storages import ElasticsearchDocumentStore, InMemoryDocumentStore + +meta_success = ApiResponseMeta( + status=200, + http_version="1.1", + headers={"x-elastic-product": "Elasticsearch"}, + duration=1.0, + node=None, +) +meta_fail = ApiResponseMeta( + status=404, + http_version="1.1", + headers={"x-elastic-product": "Elasticsearch"}, + duration=1.0, + node=None, +) +_elastic_search_responses = [ + # check exist + (meta_fail, None), + # create index + ( + meta_success, + {"acknowledged": True, "shards_acknowledged": True, "index": "test"}, + ), + # count API + ( + meta_success, + [{"epoch": "1700474422", "timestamp": "10:00:22", "count": "0"}], + ), + # add documents + ( + meta_success, + { + "took": 50, + "errors": False, + "items": [ + { + "index": { + "_index": "test", + "_id": "a3774dab-b8f1-43ba-adb8-842cb7a76eeb", + "_version": 1, + "result": "created", + "_shards": {"total": 2, "successful": 1, "failed": 0}, + "_seq_no": 0, + "_primary_term": 1, + "status": 201, + } + }, + { + "index": { + "_index": "test", + "_id": "b44f5593-7587-4f91-afd0-5736e5bd5bfe", + "_version": 1, + "result": "created", + "_shards": {"total": 2, "successful": 1, "failed": 0}, + "_seq_no": 1, + "_primary_term": 1, + "status": 201, + } + }, + { + "index": { + "_index": "test", + "_id": "13ae7825-eef9-4214-a164-983c2e6bbeaa", + "_version": 1, + "result": "created", + "_shards": {"total": 2, "successful": 1, "failed": 0}, + "_seq_no": 2, + "_primary_term": 1, + "status": 201, + } + }, + ], + }, + ), + # check exist + ( + meta_success, + {"_shards": {"total": 2, "successful": 1, "failed": 0}}, + ), + # count + ( + meta_success, + [{"epoch": "1700474422", "timestamp": "10:00:22", "count": "3"}], + ), + # get_all + ( + meta_success, + { + "took": 1, + "timed_out": False, + "_shards": {"total": 1, "successful": 1, "skipped": 0, "failed": 0}, + "hits": { + "total": {"value": 3, "relation": "eq"}, + "max_score": 1.0, + "hits": [ + { + "_index": "test", + "_id": "a3774dab-b8f1-43ba-adb8-842cb7a76eeb", + "_score": 1.0, + "_source": {"content": "Sample text 0", "metadata": {}}, + }, + { + "_index": "test", + "_id": "b44f5593-7587-4f91-afd0-5736e5bd5bfe", + "_score": 1.0, + "_source": {"content": "Sample text 1", "metadata": {}}, + }, + { + "_index": "test", + "_id": "13ae7825-eef9-4214-a164-983c2e6bbeaa", + "_score": 1.0, + "_source": {"content": "Sample text 2", "metadata": {}}, + }, + ], + }, + }, + ), + # get by-id + ( + meta_success, + { + "took": 1, + "timed_out": False, + "_shards": {"total": 1, "successful": 1, "skipped": 0, "failed": 0}, + "hits": { + "total": {"value": 1, "relation": "eq"}, + "max_score": 1.0, + "hits": [ + { + "_index": "test", + "_id": "a3774dab-b8f1-43ba-adb8-842cb7a76eeb", + "_score": 1.0, + "_source": {"content": "Sample text 0", "metadata": {}}, + } + ], + }, + }, + ), + # query + ( + meta_success, + { + "took": 2, + "timed_out": False, + "_shards": {"total": 1, "successful": 1, "skipped": 0, "failed": 0}, + "hits": { + "total": {"value": 3, "relation": "eq"}, + "max_score": 0.13353139, + "hits": [ + { + "_index": "test", + "_id": "a3774dab-b8f1-43ba-adb8-842cb7a76eeb", + "_score": 0.13353139, + "_source": {"content": "Sample text 0", "metadata": {}}, + }, + { + "_index": "test", + "_id": "b44f5593-7587-4f91-afd0-5736e5bd5bfe", + "_score": 0.13353139, + "_source": {"content": "Sample text 1", "metadata": {}}, + }, + { + "_index": "test", + "_id": "13ae7825-eef9-4214-a164-983c2e6bbeaa", + "_score": 0.13353139, + "_source": {"content": "Sample text 2", "metadata": {}}, + }, + ], + }, + }, + ), +] def test_simple_document_store_base_interfaces(tmp_path): @@ -56,3 +231,33 @@ def test_simple_document_store_base_interfaces(tmp_path): store2 = InMemoryDocumentStore() store2.load(tmp_path / "store.json") assert len(store2.get_all()) == 17, "Laded document store should have 17 documents" + + +@patch( + "elastic_transport.Transport.perform_request", + side_effect=_elastic_search_responses, +) +def test_elastic_document_store(elastic_api): + store = ElasticsearchDocumentStore(index_name="test") + + docs = [ + Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"}) + for idx in range(3) + ] + + # Test add and get all + assert store.count() == 0, "Document store should be empty" + store.add(docs) + assert store.count() == 3, "Document store count should changed after adding docs" + + docs = store.get_all() + first_doc = docs[0] + assert len(docs) == 3, "Document store get_all() failed" + + doc_by_ids = store.get(first_doc.doc_id) + assert doc_by_ids[0].doc_id == first_doc.doc_id, "Document store get() failed" + + docs = store.query("text") + assert len(docs) == 3, "Document store query() failed" + + elastic_api.assert_called()