* add test case for Chroma save/load * minor name change * add delete_collection support for chroma * move save load to chroma --------- Co-authored-by: Nguyen Trung Duc (john) <john@cinnamon.is>
125 lines
4.8 KiB
Python
125 lines
4.8 KiB
Python
import json
|
|
|
|
from kotaemon.documents.base import Document
|
|
from kotaemon.vectorstores import ChromaVectorStore, InMemoryVectorStore
|
|
|
|
|
|
class TestChromaVectorStore:
|
|
def test_add(self, tmp_path):
|
|
"""Test that the DB add correctly"""
|
|
db = ChromaVectorStore(path=str(tmp_path))
|
|
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
|
ids = ["1", "2"]
|
|
|
|
assert db._collection.count() == 0, "Expected empty collection"
|
|
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
|
assert output == ids, "Expected output to be the same as ids"
|
|
assert db._collection.count() == 2, "Expected 2 added entries"
|
|
|
|
def test_add_from_docs(self, tmp_path):
|
|
db = ChromaVectorStore(path=str(tmp_path))
|
|
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
|
documents = [
|
|
Document(embedding=embedding, metadata=metadata)
|
|
for embedding, metadata in zip(embeddings, metadatas)
|
|
]
|
|
assert db._collection.count() == 0, "Expected empty collection"
|
|
output = db.add_from_docs(documents)
|
|
assert len(output) == 2, "Expected outputing 2 ids"
|
|
assert db._collection.count() == 2, "Expected 2 added entries"
|
|
|
|
def test_delete(self, tmp_path):
|
|
db = ChromaVectorStore(path=str(tmp_path))
|
|
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
|
ids = ["a", "b", "c"]
|
|
|
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
|
assert db._collection.count() == 3, "Expected 3 added entries"
|
|
db.delete(ids=["a", "b"])
|
|
assert db._collection.count() == 1, "Expected 1 remaining entry"
|
|
db.delete(ids=["c"])
|
|
assert db._collection.count() == 0, "Expected 0 remaining entry"
|
|
|
|
def test_query(self, tmp_path):
|
|
db = ChromaVectorStore(path=str(tmp_path))
|
|
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
|
ids = ["a", "b", "c"]
|
|
|
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
|
|
|
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
|
|
assert sim == [1.0]
|
|
assert out_ids == ["a"]
|
|
|
|
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
|
|
assert out_ids == ["b"]
|
|
|
|
def test_save_load_delete(self, tmp_path):
|
|
"""Test that save/load func behave correctly."""
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
|
ids = ["1", "2", "3"]
|
|
db = ChromaVectorStore(path=str(tmp_path))
|
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
|
db.save()
|
|
|
|
db2 = ChromaVectorStore(path=str(tmp_path))
|
|
db2.load()
|
|
assert (
|
|
db2._collection.count() == 3
|
|
), "load function does not load data completely"
|
|
|
|
# test delete collection function
|
|
db2.delete_collection()
|
|
# reinit the chroma with the same collection name
|
|
db2 = ChromaVectorStore(path=str(tmp_path))
|
|
assert (
|
|
db2._collection.count() == 0
|
|
), "delete collection function does not work correctly"
|
|
|
|
|
|
class TestInMemoryVectorStore:
|
|
def test_add(self):
|
|
"""Test that add func adds correctly."""
|
|
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
|
ids = ["1", "2"]
|
|
db = InMemoryVectorStore()
|
|
|
|
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
|
assert output == ids, "Excepted output to be the same as ids"
|
|
|
|
def test_save_load_delete(self, tmp_path):
|
|
"""Test that delete func deletes correctly."""
|
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
|
ids = ["1", "2", "3"]
|
|
db = InMemoryVectorStore()
|
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
|
db.delete(["3"])
|
|
db.save(save_path=tmp_path / "test_save_load_delete.json")
|
|
f = open(tmp_path / "test_save_load_delete.json")
|
|
data = json.load(f)
|
|
assert (
|
|
"1" and "2" in data["text_id_to_ref_doc_id"]
|
|
), "save function does not save data completely"
|
|
assert (
|
|
"3" not in data["text_id_to_ref_doc_id"]
|
|
), "delete function does not delete data completely"
|
|
db2 = InMemoryVectorStore()
|
|
output = db2.load(load_path=tmp_path / "test_save_load_delete.json")
|
|
assert output.get("2") == [
|
|
0.4,
|
|
0.5,
|
|
0.6,
|
|
], "load function does not load data completely"
|