diff --git a/.gitignore b/.gitignore index 28d4435..0e80245 100644 --- a/.gitignore +++ b/.gitignore @@ -456,3 +456,4 @@ logs/ credentials.txt S.gpg-agent* +.vscode/settings.json diff --git a/knowledgehub/vectorstores/base.py b/knowledgehub/vectorstores/base.py index 310c019..51f85b3 100644 --- a/knowledgehub/vectorstores/base.py +++ b/knowledgehub/vectorstores/base.py @@ -53,19 +53,6 @@ class BaseVectorStore(ABC): """ ... - # @abstractmethod - # def update(self, *args, **kwargs): - # ... - - # @abstractmethod - # def persist(self, *args, **kwargs): - # ... - - # @classmethod - # @abstractmethod - # def load(self, *args, **kwargs): - # ... - @abstractmethod def query( self, @@ -85,6 +72,14 @@ class BaseVectorStore(ABC): """ ... + @abstractmethod + def load(self, *args, **kwargs): + pass + + @abstractmethod + def save(self, *args, **kwargs): + pass + class LlamaIndexVectorStore(BaseVectorStore): _li_class: Type[Union[LIVectorStore, BasePydanticVectorStore]] diff --git a/knowledgehub/vectorstores/chroma.py b/knowledgehub/vectorstores/chroma.py index a990207..8c55a4f 100644 --- a/knowledgehub/vectorstores/chroma.py +++ b/knowledgehub/vectorstores/chroma.py @@ -54,3 +54,21 @@ class ChromaVectorStore(LlamaIndexVectorStore): kwargs: meant for vectorstore-specific parameters """ self._client._collection.delete(ids=ids) + + def delete_collection(self, collection_name: Optional[str] = None): + """Delete entire collection under specified name from vector stores + + Args: + collection_name: Name of the collection to delete + """ + # a rather ugly chain call but it do the job of finding + # original chromadb client and call delete_collection() method + if collection_name is None: + collection_name = self._client.client.name + self._client.client._client.delete_collection(collection_name) + + def save(self, *args, **kwargs): + pass + + def load(self, *args, **kwargs): + pass diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index 98d34d0..553acea 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -62,8 +62,31 @@ class TestChromaVectorStore: _, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1) assert out_ids == ["b"] + def test_save_load_delete(self, tmp_path): + """Test that save/load func behave correctly.""" + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ids = ["1", "2", "3"] + db = ChromaVectorStore(path=str(tmp_path)) + db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) + db.save() -class TestSimpleVectorStore: + db2 = ChromaVectorStore(path=str(tmp_path)) + db2.load() + assert ( + db2._collection.count() == 3 + ), "load function does not load data completely" + + # test delete collection function + db2.delete_collection() + # reinit the chroma with the same collection name + db2 = ChromaVectorStore(path=str(tmp_path)) + assert ( + db2._collection.count() == 0 + ), "delete collection function does not work correctly" + + +class TestInMemoryVectorStore: def test_add(self): """Test that add func adds correctly.""" @@ -88,7 +111,7 @@ class TestSimpleVectorStore: data = json.load(f) assert ( "1" and "2" in data["text_id_to_ref_doc_id"] - ), "persist function does not save data completely" + ), "save function does not save data completely" assert ( "3" not in data["text_id_to_ref_doc_id"] ), "delete function does not delete data completely" @@ -98,4 +121,4 @@ class TestSimpleVectorStore: 0.4, 0.5, 0.6, - ], "persist function does not load data completely" + ], "load function does not load data completely"