Add file-based document store and vector store (#96)

* Modify docstore and vectorstore objects to be reconstructable
* Simplify the file docstore
* Use the simple file docstore and vector store in MVP
This commit is contained in:
Duc Nguyen (john) 2023-12-04 17:46:00 +07:00 committed by GitHub
parent 0ce3a8832f
commit 37c744b616
18 changed files with 324 additions and 149 deletions

View File

@ -73,10 +73,13 @@ class LCEmbeddingMixin:
return self._kwargs[name] return self._kwargs[name]
return getattr(self._obj, name) return getattr(self._obj, name)
def dump(self): def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return { return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}", "__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs, **params,
} }
def specs(self, path: str): def specs(self, path: str):

View File

@ -82,10 +82,13 @@ class LlamaIndexDocTransformerMixin:
return self._kwargs[name] return self._kwargs[name]
return getattr(self._obj, name) return getattr(self._obj, name)
def dump(self): def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return { return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}", "__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs, **params,
} }
def run( def run(

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from pathlib import Path
from typing import Optional, Sequence, cast from typing import Optional, Sequence, cast
from kotaemon.base import BaseComponent, Document, RetrievedDocument from kotaemon.base import BaseComponent, Document, RetrievedDocument
@ -68,37 +67,6 @@ class VectorIndexing(BaseIndexing):
if self.doc_store: if self.doc_store:
self.doc_store.add(input_) self.doc_store.add(input_)
def save(
self,
path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Save the whole state of the indexing pipeline vector store and all
necessary information to disk
Args:
path (str): path to save the state
"""
if isinstance(path, str):
path = Path(path)
self.vector_store.save(path / vectorstore_fname)
if self.doc_store:
self.doc_store.save(path / docstore_fname)
def load(
self,
path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Load all information from disk to an object"""
if isinstance(path, str):
path = Path(path)
self.vector_store.load(path / vectorstore_fname)
if self.doc_store:
self.doc_store.load(path / docstore_fname)
class VectorRetrieval(BaseRetrieval): class VectorRetrieval(BaseRetrieval):
"""Retrieve list of documents from vector store""" """Retrieve list of documents from vector store"""
@ -144,37 +112,6 @@ class VectorRetrieval(BaseRetrieval):
return result return result
def save(
self,
path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Save the whole state of the indexing pipeline vector store and all
necessary information to disk
Args:
path (str): path to save the state
"""
if isinstance(path, str):
path = Path(path)
self.vector_store.save(path / vectorstore_fname)
if self.doc_store:
self.doc_store.save(path / docstore_fname)
def load(
self,
path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Load all information from disk to an object"""
if isinstance(path, str):
path = Path(path)
self.vector_store.load(path / vectorstore_fname)
if self.doc_store:
self.doc_store.load(path / docstore_fname)
class TextVectorQA(BaseComponent): class TextVectorQA(BaseComponent):
retrieving_pipeline: BaseRetrieval retrieving_pipeline: BaseRetrieval

View File

@ -101,10 +101,13 @@ class LCChatMixin:
return self._kwargs[name] return self._kwargs[name]
return getattr(self._obj, name) return getattr(self._obj, name)
def dump(self): def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return { return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}", "__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs, **params,
} }
def specs(self, path: str): def specs(self, path: str):

View File

@ -78,10 +78,13 @@ class LCCompletionMixin:
return self._kwargs[name] return self._kwargs[name]
return getattr(self._obj, name) return getattr(self._obj, name)
def dump(self): def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return { return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}", "__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs, **params,
} }
def specs(self, path: str): def specs(self, path: str):

View File

@ -2,16 +2,24 @@ from .docstores import (
BaseDocumentStore, BaseDocumentStore,
ElasticsearchDocumentStore, ElasticsearchDocumentStore,
InMemoryDocumentStore, InMemoryDocumentStore,
SimpleFileDocumentStore,
)
from .vectorstores import (
BaseVectorStore,
ChromaVectorStore,
InMemoryVectorStore,
SimpleFileVectorStore,
) )
from .vectorstores import BaseVectorStore, ChromaVectorStore, InMemoryVectorStore
__all__ = [ __all__ = [
# Document stores # Document stores
"BaseDocumentStore", "BaseDocumentStore",
"InMemoryDocumentStore", "InMemoryDocumentStore",
"ElasticsearchDocumentStore", "ElasticsearchDocumentStore",
"SimpleFileDocumentStore",
# Vector stores # Vector stores
"BaseVectorStore", "BaseVectorStore",
"ChromaVectorStore", "ChromaVectorStore",
"InMemoryVectorStore", "InMemoryVectorStore",
"SimpleFileVectorStore",
] ]

View File

@ -1,5 +1,11 @@
from .base import BaseDocumentStore from .base import BaseDocumentStore
from .elasticsearch import ElasticsearchDocumentStore from .elasticsearch import ElasticsearchDocumentStore
from .in_memory import InMemoryDocumentStore from .in_memory import InMemoryDocumentStore
from .simple_file import SimpleFileDocumentStore
__all__ = ["BaseDocumentStore", "InMemoryDocumentStore", "ElasticsearchDocumentStore"] __all__ = [
"BaseDocumentStore",
"InMemoryDocumentStore",
"ElasticsearchDocumentStore",
"SimpleFileDocumentStore",
]

View File

@ -1,8 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from ...base import Document from kotaemon.base import Document
class BaseDocumentStore(ABC): class BaseDocumentStore(ABC):
@ -46,13 +45,3 @@ class BaseDocumentStore(ABC):
def delete(self, ids: Union[List[str], str]): def delete(self, ids: Union[List[str], str]):
"""Delete document by id""" """Delete document by id"""
... ...
@abstractmethod
def save(self, path: Union[str, Path]):
"""Save document to path"""
...
@abstractmethod
def load(self, path: Union[str, Path]):
"""Load document store from path"""
...

View File

@ -1,7 +1,7 @@
from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from ...base import Document from kotaemon.base import Document
from .base import BaseDocumentStore from .base import BaseDocumentStore
MAX_DOCS_TO_GET = 10**4 MAX_DOCS_TO_GET = 10**4
@ -27,6 +27,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.elasticsearch_url = elasticsearch_url self.elasticsearch_url = elasticsearch_url
self.index_name = index_name self.index_name = index_name
self.k1 = k1
self.b = b
# Create an Elasticsearch client instance # Create an Elasticsearch client instance
self.client = Elasticsearch(elasticsearch_url) self.client = Elasticsearch(elasticsearch_url)
@ -160,10 +162,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.client.delete_by_query(index=self.index_name, body=query) self.client.delete_by_query(index=self.index_name, body=query)
self.client.indices.refresh(index=self.index_name) self.client.indices.refresh(index=self.index_name)
def save(self, path: Union[str, Path]): def __persist_flow__(self):
"""Save document to path""" return {
# not required for ElasticDocstore "index_name": self.index_name,
"elasticsearch_url": self.elasticsearch_url,
def load(self, path: Union[str, Path]): "k1": self.k1,
"""Load document store from path""" "b": self.b,
# not required for ElasticDocstore }

View File

@ -2,7 +2,8 @@ import json
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from ...base import Document from kotaemon.base import Document
from .base import BaseDocumentStore from .base import BaseDocumentStore
@ -74,3 +75,6 @@ class InMemoryDocumentStore(BaseDocumentStore):
with open(path) as f: with open(path) as f:
store = json.load(f) store = json.load(f)
self._store = {key: Document.from_dict(value) for key, value in store.items()} self._store = {key: Document.from_dict(value) for key, value in store.items()}
def __persist_flow__(self):
return {}

View File

@ -0,0 +1,44 @@
from pathlib import Path
from typing import List, Optional, Union
from kotaemon.base import Document
from .in_memory import InMemoryDocumentStore
class SimpleFileDocumentStore(InMemoryDocumentStore):
"""Improve InMemoryDocumentStore by auto saving whenever the corpus is changed"""
def __init__(self, path: str | Path):
super().__init__()
self._path = path
if path is not None and Path(path).is_file():
self.load(path)
def add(
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
**kwargs,
):
"""Add document into document store
Args:
docs: list of documents to add
ids: specify the ids of documents to add or
use existing doc.doc_id
exist_ok: raise error when duplicate doc-id
found in the docstore (default to False)
"""
super().add(docs=docs, ids=ids, **kwargs)
self.save(self._path)
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""
super().delete(ids=ids)
self.save(self._path)
def __persist_flow__(self):
from theflow.utils.modules import serialize
return {"path": serialize(self._path)}

View File

@ -1,5 +1,11 @@
from .base import BaseVectorStore from .base import BaseVectorStore
from .chroma import ChromaVectorStore from .chroma import ChromaVectorStore
from .in_memory import InMemoryVectorStore from .in_memory import InMemoryVectorStore
from .simple_file import SimpleFileVectorStore
__all__ = ["BaseVectorStore", "ChromaVectorStore", "InMemoryVectorStore"] __all__ = [
"BaseVectorStore",
"ChromaVectorStore",
"InMemoryVectorStore",
"SimpleFileVectorStore",
]

View File

@ -1,12 +1,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, Type, Union from typing import Any, Optional
from llama_index.schema import NodeRelationship, RelatedNodeInfo from llama_index.schema import NodeRelationship, RelatedNodeInfo
from llama_index.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.types import VectorStore as LIVectorStore from llama_index.vector_stores.types import VectorStore as LIVectorStore
from llama_index.vector_stores.types import VectorStoreQuery from llama_index.vector_stores.types import VectorStoreQuery
from kotaemon.base import Document, DocumentWithEmbedding from kotaemon.base import DocumentWithEmbedding
class BaseVectorStore(ABC): class BaseVectorStore(ABC):
@ -17,10 +19,10 @@ class BaseVectorStore(ABC):
@abstractmethod @abstractmethod
def add( def add(
self, self,
embeddings: List[List[float]] | List[DocumentWithEmbedding], embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
) -> List[str]: ) -> list[str]:
"""Add vector embeddings to vector stores """Add vector embeddings to vector stores
Args: Args:
@ -35,16 +37,7 @@ class BaseVectorStore(ABC):
... ...
@abstractmethod @abstractmethod
def add_from_docs(self, docs: List[Document]): def delete(self, ids: list[str], **kwargs):
"""Add vector embeddings to vector stores
Args:
docs: List of Document objects
"""
...
@abstractmethod
def delete(self, ids: List[str], **kwargs):
"""Delete vector embeddings from vector stores """Delete vector embeddings from vector stores
Args: Args:
@ -56,11 +49,11 @@ class BaseVectorStore(ABC):
@abstractmethod @abstractmethod
def query( def query(
self, self,
embedding: List[float], embedding: list[float],
top_k: int = 1, top_k: int = 1,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
**kwargs, **kwargs,
) -> Tuple[List[List[float]], List[float], List[str]]: ) -> tuple[list[list[float]], list[float], list[str]]:
"""Return the top k most similar vector embeddings """Return the top k most similar vector embeddings
Args: Args:
@ -73,17 +66,9 @@ class BaseVectorStore(ABC):
""" """
... ...
@abstractmethod
def load(self, *args, **kwargs):
pass
@abstractmethod
def save(self, *args, **kwargs):
pass
class LlamaIndexVectorStore(BaseVectorStore): class LlamaIndexVectorStore(BaseVectorStore):
_li_class: Type[Union[LIVectorStore, BasePydanticVectorStore]] _li_class: type[LIVectorStore | BasePydanticVectorStore]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if self._li_class is None: if self._li_class is None:
@ -104,12 +89,12 @@ class LlamaIndexVectorStore(BaseVectorStore):
def add( def add(
self, self,
embeddings: List[List[float]] | List[DocumentWithEmbedding], embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
): ):
if isinstance(embeddings[0], list): if isinstance(embeddings[0], list):
nodes = [ nodes: list[DocumentWithEmbedding] = [
DocumentWithEmbedding(embedding=embedding) for embedding in embeddings DocumentWithEmbedding(embedding=embedding) for embedding in embeddings
] ]
else: else:
@ -126,20 +111,17 @@ class LlamaIndexVectorStore(BaseVectorStore):
return self._client.add(nodes=nodes) return self._client.add(nodes=nodes)
def add_from_docs(self, docs: List[Document]): def delete(self, ids: list[str], **kwargs):
return self._client.add(nodes=docs)
def delete(self, ids: List[str], **kwargs):
for id_ in ids: for id_ in ids:
self._client.delete(ref_doc_id=id_, **kwargs) self._client.delete(ref_doc_id=id_, **kwargs)
def query( def query(
self, self,
embedding: List[float], embedding: list[float],
top_k: int = 1, top_k: int = 1,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
**kwargs, **kwargs,
) -> Tuple[List[List[float]], List[float], List[str]]: ) -> tuple[list[list[float]], list[float], list[str]]:
output = self._client.query( output = self._client.query(
query=VectorStoreQuery( query=VectorStoreQuery(
query_embedding=embedding, query_embedding=embedding,

View File

@ -21,6 +21,17 @@ class ChromaVectorStore(LlamaIndexVectorStore):
flat_metadata: bool = True, flat_metadata: bool = True,
**kwargs: Any, **kwargs: Any,
): ):
self._path = path
self._collection_name = collection_name
self._host = host
self._port = port
self._ssl = ssl
self._headers = headers
self._collection_kwargs = collection_kwargs
self._stores_text = stores_text
self._flat_metadata = flat_metadata
self._kwargs = kwargs
try: try:
import chromadb import chromadb
except ImportError: except ImportError:
@ -70,8 +81,16 @@ class ChromaVectorStore(LlamaIndexVectorStore):
def count(self) -> int: def count(self) -> int:
return self._collection.count() return self._collection.count()
def save(self, *args, **kwargs): def __persist_flow__(self):
pass return {
"path": self._path,
def load(self, *args, **kwargs): "collection_name": self._collection_name,
pass "host": self._host,
"port": self._port,
"ssl": self._ssl,
"headers": self._headers,
"collection_kwargs": self._collection_kwargs,
"stores_text": self._stores_text,
"flat_metadata": self._flat_metadata,
**self._kwargs,
}

View File

@ -1,5 +1,4 @@
"""Simple vector store index.""" """Simple vector store index."""
from typing import Any, Optional, Type from typing import Any, Optional, Type
import fsspec import fsspec
@ -53,3 +52,11 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
fs: An abstract super-class for pythonic file-systems fs: An abstract super-class for pythonic file-systems
""" """
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs) self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)
def __persist_flow__(self):
d = self._data.to_dict()
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
return {
"data": d,
# "fs": self._fs,
}

View File

@ -0,0 +1,66 @@
"""Simple file vector store index."""
from pathlib import Path
from typing import Any, Optional, Type
import fsspec
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.vector_stores.simple import SimpleVectorStoreData
from kotaemon.base import DocumentWithEmbedding
from .base import LlamaIndexVectorStore
class SimpleFileVectorStore(LlamaIndexVectorStore):
"""Similar to InMemoryVectorStore but is backed by file by default"""
_li_class: Type[LISimpleVectorStore] = LISimpleVectorStore
store_text: bool = False
def __init__(
self,
path: str | Path,
data: Optional[SimpleVectorStoreData] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._data = data or SimpleVectorStoreData()
self._fs = fs or fsspec.filesystem("file")
self._path = path
self._save_path = Path(path)
super().__init__(
data=data,
fs=fs,
**kwargs,
)
if self._save_path.is_file():
self._client = self._li_class.from_persist_path(
persist_path=str(self._save_path), fs=self._fs
)
def add(
self,
embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
):
r = super().add(embeddings, metadatas, ids)
self._client.persist(str(self._save_path), self._fs)
return r
def delete(self, ids: list[str], **kwargs):
r = super().delete(ids, **kwargs)
self._client.persist(str(self._save_path), self._fs)
return r
def __persist_flow__(self):
d = self._data.to_dict()
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
return {
"data": d,
"path": str(self._path),
# "fs": self._fs,
}

View File

@ -1,10 +1,15 @@
import os
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from elastic_transport import ApiResponseMeta from elastic_transport import ApiResponseMeta
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.storages import ElasticsearchDocumentStore, InMemoryDocumentStore from kotaemon.storages import (
ElasticsearchDocumentStore,
InMemoryDocumentStore,
SimpleFileDocumentStore,
)
meta_success = ApiResponseMeta( meta_success = ApiResponseMeta(
status=200, status=200,
@ -207,7 +212,7 @@ _elastic_search_responses = [
] ]
def test_simple_document_store_base_interfaces(tmp_path): def test_inmemory_document_store_base_interfaces(tmp_path):
"""Test all interfaces of a a document store""" """Test all interfaces of a a document store"""
store = InMemoryDocumentStore() store = InMemoryDocumentStore()
@ -260,6 +265,64 @@ def test_simple_document_store_base_interfaces(tmp_path):
store2.load(tmp_path / "store.json") store2.load(tmp_path / "store.json")
assert len(store2.get_all()) == 17, "Laded document store should have 17 documents" assert len(store2.get_all()) == 17, "Laded document store should have 17 documents"
os.remove(tmp_path / "store.json")
def test_simplefile_document_store_base_interfaces(tmp_path):
"""Test all interfaces of a a document store"""
path = tmp_path / "store.json"
store = SimpleFileDocumentStore(path=path)
docs = [
Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"})
for idx in range(10)
]
# Test add and get all
assert len(store.get_all()) == 0, "Document store should be empty"
store.add(docs)
assert len(store.get_all()) == 10, "Document store should have 10 documents"
# Test add with provided ids
store.add(docs=docs, ids=[f"doc_{idx}" for idx in range(10)])
assert len(store.get_all()) == 20, "Document store should have 20 documents"
# Test add without exist_ok
with pytest.raises(ValueError):
store.add(docs=docs, ids=[f"doc_{idx}" for idx in range(10)])
# Update ok with add exist_ok
store.add(docs=docs, ids=[f"doc_{idx}" for idx in range(10)], exist_ok=True)
assert len(store.get_all()) == 20, "Document store should have 20 documents"
# Test get with str id
matched = store.get(docs[0].doc_id)
assert len(matched) == 1, "Should return 1 document"
assert matched[0].text == docs[0].text, "Should return the correct document"
# Test get with list of ids
matched = store.get([docs[0].doc_id, docs[1].doc_id])
assert len(matched) == 2, "Should return 2 documents"
assert [doc.text for doc in matched] == [doc.text for doc in docs[:2]]
# Test delete with str id
store.delete(docs[0].doc_id)
assert len(store.get_all()) == 19, "Document store should have 19 documents"
# Test delete with list of ids
store.delete([docs[1].doc_id, docs[2].doc_id])
assert len(store.get_all()) == 17, "Document store should have 17 documents"
# Test save
assert path.exists(), "File should exist"
# Test load
store2 = SimpleFileDocumentStore(path=path)
assert len(store2.get_all()) == 17, "Laded document store should have 17 documents"
os.remove(path)
@patch( @patch(
"elastic_transport.Transport.perform_request", "elastic_transport.Transport.perform_request",

View File

@ -1,7 +1,12 @@
import json import json
import os
from kotaemon.base import Document from kotaemon.base import DocumentWithEmbedding
from kotaemon.storages import ChromaVectorStore, InMemoryVectorStore from kotaemon.storages import (
ChromaVectorStore,
InMemoryVectorStore,
SimpleFileVectorStore,
)
class TestChromaVectorStore: class TestChromaVectorStore:
@ -24,11 +29,11 @@ class TestChromaVectorStore:
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
documents = [ documents = [
Document(embedding=embedding, metadata=metadata) DocumentWithEmbedding(embedding=embedding, metadata=metadata)
for embedding, metadata in zip(embeddings, metadatas) for embedding, metadata in zip(embeddings, metadatas)
] ]
assert db._collection.count() == 0, "Expected empty collection" assert db._collection.count() == 0, "Expected empty collection"
output = db.add_from_docs(documents) output = db.add(documents)
assert len(output) == 2, "Expected outputing 2 ids" assert len(output) == 2, "Expected outputing 2 ids"
assert db._collection.count() == 2, "Expected 2 added entries" assert db._collection.count() == 2, "Expected 2 added entries"
@ -69,10 +74,8 @@ class TestChromaVectorStore:
ids = ["1", "2", "3"] ids = ["1", "2", "3"]
db = ChromaVectorStore(path=str(tmp_path)) db = ChromaVectorStore(path=str(tmp_path))
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.save()
db2 = ChromaVectorStore(path=str(tmp_path)) db2 = ChromaVectorStore(path=str(tmp_path))
db2.load()
assert ( assert (
db2._collection.count() == 3 db2._collection.count() == 3
), "load function does not load data completely" ), "load function does not load data completely"
@ -122,3 +125,30 @@ class TestInMemoryVectorStore:
0.5, 0.5,
0.6, 0.6,
], "load function does not load data completely" ], "load function does not load data completely"
class TestSimpleFileVectorStore:
def test_add_delete(self, tmp_path):
"""Test that delete func deletes correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"]
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"])
f = open(tmp_path / "test_save_load_delete.json")
data = json.load(f)
assert (
"1" and "2" in data["text_id_to_ref_doc_id"]
), "save function does not save data completely"
assert (
"3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely"
db2 = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
assert db2.get("2") == [
0.4,
0.5,
0.6,
], "load function does not load data completely"
os.remove(tmp_path / "test_save_load_delete.json")