[AUR-430] Add test case for Chroma VectoStore save/load (#26)
* add test case for Chroma save/load * minor name change * add delete_collection support for chroma * move save load to chroma --------- Co-authored-by: Nguyen Trung Duc (john) <john@cinnamon.is>
This commit is contained in:
parent
4f189dc931
commit
6207f4332a
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -456,3 +456,4 @@ logs/
|
||||||
credentials.txt
|
credentials.txt
|
||||||
|
|
||||||
S.gpg-agent*
|
S.gpg-agent*
|
||||||
|
.vscode/settings.json
|
||||||
|
|
|
@ -53,19 +53,6 @@ class BaseVectorStore(ABC):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# def update(self, *args, **kwargs):
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# def persist(self, *args, **kwargs):
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# @classmethod
|
|
||||||
# @abstractmethod
|
|
||||||
# def load(self, *args, **kwargs):
|
|
||||||
# ...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
|
@ -85,6 +72,14 @@ class BaseVectorStore(ABC):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexVectorStore(BaseVectorStore):
|
class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
_li_class: Type[Union[LIVectorStore, BasePydanticVectorStore]]
|
_li_class: Type[Union[LIVectorStore, BasePydanticVectorStore]]
|
||||||
|
|
|
@ -54,3 +54,21 @@ class ChromaVectorStore(LlamaIndexVectorStore):
|
||||||
kwargs: meant for vectorstore-specific parameters
|
kwargs: meant for vectorstore-specific parameters
|
||||||
"""
|
"""
|
||||||
self._client._collection.delete(ids=ids)
|
self._client._collection.delete(ids=ids)
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name: Optional[str] = None):
|
||||||
|
"""Delete entire collection under specified name from vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Name of the collection to delete
|
||||||
|
"""
|
||||||
|
# a rather ugly chain call but it do the job of finding
|
||||||
|
# original chromadb client and call delete_collection() method
|
||||||
|
if collection_name is None:
|
||||||
|
collection_name = self._client.client.name
|
||||||
|
self._client.client._client.delete_collection(collection_name)
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
|
@ -62,8 +62,31 @@ class TestChromaVectorStore:
|
||||||
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
|
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
|
||||||
assert out_ids == ["b"]
|
assert out_ids == ["b"]
|
||||||
|
|
||||||
|
def test_save_load_delete(self, tmp_path):
|
||||||
|
"""Test that save/load func behave correctly."""
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["1", "2", "3"]
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
db.save()
|
||||||
|
|
||||||
class TestSimpleVectorStore:
|
db2 = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
db2.load()
|
||||||
|
assert (
|
||||||
|
db2._collection.count() == 3
|
||||||
|
), "load function does not load data completely"
|
||||||
|
|
||||||
|
# test delete collection function
|
||||||
|
db2.delete_collection()
|
||||||
|
# reinit the chroma with the same collection name
|
||||||
|
db2 = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
assert (
|
||||||
|
db2._collection.count() == 0
|
||||||
|
), "delete collection function does not work correctly"
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryVectorStore:
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
"""Test that add func adds correctly."""
|
"""Test that add func adds correctly."""
|
||||||
|
|
||||||
|
@ -88,7 +111,7 @@ class TestSimpleVectorStore:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
assert (
|
assert (
|
||||||
"1" and "2" in data["text_id_to_ref_doc_id"]
|
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||||
), "persist function does not save data completely"
|
), "save function does not save data completely"
|
||||||
assert (
|
assert (
|
||||||
"3" not in data["text_id_to_ref_doc_id"]
|
"3" not in data["text_id_to_ref_doc_id"]
|
||||||
), "delete function does not delete data completely"
|
), "delete function does not delete data completely"
|
||||||
|
@ -98,4 +121,4 @@ class TestSimpleVectorStore:
|
||||||
0.4,
|
0.4,
|
||||||
0.5,
|
0.5,
|
||||||
0.6,
|
0.6,
|
||||||
], "persist function does not load data completely"
|
], "load function does not load data completely"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user