[AUR-362] Add In-memory vector store (#22)
* [AUR-362] Add In-memory vector store * [AUR-362] fix delete fun input format * [AUR-362] revise persist and from persist path to save and load * [AUR-362] revise simple.py to in_memory.py
This commit is contained in:
parent
b794051653
commit
c329c4c03f
|
@ -1,4 +1,5 @@
|
||||||
from .base import BaseVectorStore
|
from .base import BaseVectorStore
|
||||||
from .chroma import ChromaVectorStore
|
from .chroma import ChromaVectorStore
|
||||||
|
from .in_memory import InMemoryVectorStore
|
||||||
|
|
||||||
__all__ = ["BaseVectorStore", "ChromaVectorStore"]
|
__all__ = ["BaseVectorStore", "ChromaVectorStore", "InMemoryVectorStore"]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Tuple, Type, Union
|
from typing import Any, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from llama_index.schema import NodeRelationship, RelatedNodeInfo
|
||||||
from llama_index.vector_stores.types import BasePydanticVectorStore
|
from llama_index.vector_stores.types import BasePydanticVectorStore
|
||||||
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
||||||
from llama_index.vector_stores.types import VectorStoreQuery
|
from llama_index.vector_stores.types import VectorStoreQuery
|
||||||
|
@ -118,6 +119,9 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
for node, id in zip(nodes, ids):
|
for node, id in zip(nodes, ids):
|
||||||
node.id_ = id
|
node.id_ = id
|
||||||
|
node.relationships = {
|
||||||
|
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
|
||||||
|
}
|
||||||
|
|
||||||
return self._client.add(nodes=nodes) # type: ignore
|
return self._client.add(nodes=nodes) # type: ignore
|
||||||
|
|
||||||
|
|
57
knowledgehub/vectorstores/in_memory.py
Normal file
57
knowledgehub/vectorstores/in_memory.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
"""Simple vector store index."""
|
||||||
|
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||||
|
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
||||||
|
|
||||||
|
from kotaemon.vectorstores.base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryVectorStore(LlamaIndexVectorStore):
|
||||||
|
_li_class: Type[LISimpleVectorStore] = LISimpleVectorStore
|
||||||
|
store_text: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: Optional[SimpleVectorStoreData] = None,
|
||||||
|
fs: Optional[fsspec.AbstractFileSystem] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize params."""
|
||||||
|
self._data = data or SimpleVectorStoreData()
|
||||||
|
self._fs = fs or fsspec.filesystem("file")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
data=data,
|
||||||
|
fs=fs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
save_path: str,
|
||||||
|
fs: Optional[fsspec.AbstractFileSystem] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
"""save a simpleVectorStore to a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path: Path of saving vector to disk.
|
||||||
|
fs: An abstract super-class for pythonic file-systems
|
||||||
|
"""
|
||||||
|
self._client.persist(persist_path=save_path, fs=fs)
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
|
||||||
|
) -> "InMemoryVectorStore":
|
||||||
|
|
||||||
|
"""Create a SimpleKVStore from a load directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Path of loading vector.
|
||||||
|
fs: An abstract super-class for pythonic file-systems
|
||||||
|
"""
|
||||||
|
return self._client.from_persist_path(persist_path=load_path, fs=fs)
|
|
@ -1,5 +1,7 @@
|
||||||
|
import json
|
||||||
|
|
||||||
from kotaemon.documents.base import Document
|
from kotaemon.documents.base import Document
|
||||||
from kotaemon.vectorstores import ChromaVectorStore
|
from kotaemon.vectorstores import ChromaVectorStore, InMemoryVectorStore
|
||||||
|
|
||||||
|
|
||||||
class TestChromaVectorStore:
|
class TestChromaVectorStore:
|
||||||
|
@ -59,3 +61,41 @@ class TestChromaVectorStore:
|
||||||
|
|
||||||
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
|
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
|
||||||
assert out_ids == ["b"]
|
assert out_ids == ["b"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimpleVectorStore:
|
||||||
|
def test_add(self):
|
||||||
|
"""Test that add func adds correctly."""
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
ids = ["1", "2"]
|
||||||
|
db = InMemoryVectorStore()
|
||||||
|
|
||||||
|
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert output == ids, "Excepted output to be the same as ids"
|
||||||
|
|
||||||
|
def test_save_load_delete(self, tmp_path):
|
||||||
|
"""Test that delete func deletes correctly."""
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = ["1", "2", "3"]
|
||||||
|
db = InMemoryVectorStore()
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
db.delete(["3"])
|
||||||
|
db.save(save_path=tmp_path / "test_save_load_delete.json")
|
||||||
|
f = open(tmp_path / "test_save_load_delete.json")
|
||||||
|
data = json.load(f)
|
||||||
|
assert (
|
||||||
|
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||||
|
), "persist function does not save data completely"
|
||||||
|
assert (
|
||||||
|
"3" not in data["text_id_to_ref_doc_id"]
|
||||||
|
), "delete function does not delete data completely"
|
||||||
|
db2 = InMemoryVectorStore()
|
||||||
|
output = db2.load(load_path=tmp_path / "test_save_load_delete.json")
|
||||||
|
assert output.get("2") == [
|
||||||
|
0.4,
|
||||||
|
0.5,
|
||||||
|
0.6,
|
||||||
|
], "persist function does not load data completely"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user