Add file-based document store and vector store (#96)
* Modify docstore and vectorstore objects to be reconstructable * Simplify the file docstore * Use the simple file docstore and vector store in MVP
This commit is contained in:
committed by
GitHub
parent
0ce3a8832f
commit
37c744b616
@@ -1,10 +1,15 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from elastic_transport import ApiResponseMeta
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.storages import ElasticsearchDocumentStore, InMemoryDocumentStore
|
||||
from kotaemon.storages import (
|
||||
ElasticsearchDocumentStore,
|
||||
InMemoryDocumentStore,
|
||||
SimpleFileDocumentStore,
|
||||
)
|
||||
|
||||
meta_success = ApiResponseMeta(
|
||||
status=200,
|
||||
@@ -207,7 +212,7 @@ _elastic_search_responses = [
|
||||
]
|
||||
|
||||
|
||||
def test_simple_document_store_base_interfaces(tmp_path):
|
||||
def test_inmemory_document_store_base_interfaces(tmp_path):
|
||||
"""Test all interfaces of a a document store"""
|
||||
|
||||
store = InMemoryDocumentStore()
|
||||
@@ -260,6 +265,64 @@ def test_simple_document_store_base_interfaces(tmp_path):
|
||||
store2.load(tmp_path / "store.json")
|
||||
assert len(store2.get_all()) == 17, "Laded document store should have 17 documents"
|
||||
|
||||
os.remove(tmp_path / "store.json")
|
||||
|
||||
|
||||
def test_simplefile_document_store_base_interfaces(tmp_path):
|
||||
"""Test all interfaces of a a document store"""
|
||||
|
||||
path = tmp_path / "store.json"
|
||||
|
||||
store = SimpleFileDocumentStore(path=path)
|
||||
docs = [
|
||||
Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"})
|
||||
for idx in range(10)
|
||||
]
|
||||
|
||||
# Test add and get all
|
||||
assert len(store.get_all()) == 0, "Document store should be empty"
|
||||
store.add(docs)
|
||||
assert len(store.get_all()) == 10, "Document store should have 10 documents"
|
||||
|
||||
# Test add with provided ids
|
||||
store.add(docs=docs, ids=[f"doc_{idx}" for idx in range(10)])
|
||||
assert len(store.get_all()) == 20, "Document store should have 20 documents"
|
||||
|
||||
# Test add without exist_ok
|
||||
with pytest.raises(ValueError):
|
||||
store.add(docs=docs, ids=[f"doc_{idx}" for idx in range(10)])
|
||||
|
||||
# Update ok with add exist_ok
|
||||
store.add(docs=docs, ids=[f"doc_{idx}" for idx in range(10)], exist_ok=True)
|
||||
assert len(store.get_all()) == 20, "Document store should have 20 documents"
|
||||
|
||||
# Test get with str id
|
||||
matched = store.get(docs[0].doc_id)
|
||||
assert len(matched) == 1, "Should return 1 document"
|
||||
assert matched[0].text == docs[0].text, "Should return the correct document"
|
||||
|
||||
# Test get with list of ids
|
||||
matched = store.get([docs[0].doc_id, docs[1].doc_id])
|
||||
assert len(matched) == 2, "Should return 2 documents"
|
||||
assert [doc.text for doc in matched] == [doc.text for doc in docs[:2]]
|
||||
|
||||
# Test delete with str id
|
||||
store.delete(docs[0].doc_id)
|
||||
assert len(store.get_all()) == 19, "Document store should have 19 documents"
|
||||
|
||||
# Test delete with list of ids
|
||||
store.delete([docs[1].doc_id, docs[2].doc_id])
|
||||
assert len(store.get_all()) == 17, "Document store should have 17 documents"
|
||||
|
||||
# Test save
|
||||
assert path.exists(), "File should exist"
|
||||
|
||||
# Test load
|
||||
store2 = SimpleFileDocumentStore(path=path)
|
||||
assert len(store2.get_all()) == 17, "Laded document store should have 17 documents"
|
||||
|
||||
os.remove(path)
|
||||
|
||||
|
||||
@patch(
|
||||
"elastic_transport.Transport.perform_request",
|
||||
|
@@ -1,7 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.storages import ChromaVectorStore, InMemoryVectorStore
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
from kotaemon.storages import (
|
||||
ChromaVectorStore,
|
||||
InMemoryVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
|
||||
class TestChromaVectorStore:
|
||||
@@ -24,11 +29,11 @@ class TestChromaVectorStore:
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||
documents = [
|
||||
Document(embedding=embedding, metadata=metadata)
|
||||
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
|
||||
for embedding, metadata in zip(embeddings, metadatas)
|
||||
]
|
||||
assert db._collection.count() == 0, "Expected empty collection"
|
||||
output = db.add_from_docs(documents)
|
||||
output = db.add(documents)
|
||||
assert len(output) == 2, "Expected outputing 2 ids"
|
||||
assert db._collection.count() == 2, "Expected 2 added entries"
|
||||
|
||||
@@ -69,10 +74,8 @@ class TestChromaVectorStore:
|
||||
ids = ["1", "2", "3"]
|
||||
db = ChromaVectorStore(path=str(tmp_path))
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
db.save()
|
||||
|
||||
db2 = ChromaVectorStore(path=str(tmp_path))
|
||||
db2.load()
|
||||
assert (
|
||||
db2._collection.count() == 3
|
||||
), "load function does not load data completely"
|
||||
@@ -122,3 +125,30 @@ class TestInMemoryVectorStore:
|
||||
0.5,
|
||||
0.6,
|
||||
], "load function does not load data completely"
|
||||
|
||||
|
||||
class TestSimpleFileVectorStore:
|
||||
def test_add_delete(self, tmp_path):
|
||||
"""Test that delete func deletes correctly."""
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||
ids = ["1", "2", "3"]
|
||||
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
db.delete(["3"])
|
||||
f = open(tmp_path / "test_save_load_delete.json")
|
||||
data = json.load(f)
|
||||
assert (
|
||||
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||
), "save function does not save data completely"
|
||||
assert (
|
||||
"3" not in data["text_id_to_ref_doc_id"]
|
||||
), "delete function does not delete data completely"
|
||||
db2 = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
|
||||
assert db2.get("2") == [
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
], "load function does not load data completely"
|
||||
|
||||
os.remove(tmp_path / "test_save_load_delete.json")
|
||||
|
Reference in New Issue
Block a user