feat: Qdrant vectorstore support (#260)
* feat: Qdrant vectorstore support * chore: review changes * docs: Updated README.md
This commit is contained in:
parent
cbe45a4395
commit
e2bd78e9c4
|
@ -189,7 +189,7 @@ starting point.
|
||||||
KH_DOCSTORE=(Elasticsearch | LanceDB | SimpleFileDocumentStore)
|
KH_DOCSTORE=(Elasticsearch | LanceDB | SimpleFileDocumentStore)
|
||||||
|
|
||||||
# setup your preferred vectorstore (for vector-based search)
|
# setup your preferred vectorstore (for vector-based search)
|
||||||
KH_VECTORSTORE=(ChromaDB | LanceDB | InMemory)
|
KH_VECTORSTORE=(ChromaDB | LanceDB | InMemory | Qdrant)
|
||||||
|
|
||||||
# Enable / disable multimodal QA
|
# Enable / disable multimodal QA
|
||||||
KH_REASONINGS_USE_MULTIMODAL=True
|
KH_REASONINGS_USE_MULTIMODAL=True
|
||||||
|
|
|
@ -81,6 +81,7 @@ KH_VECTORSTORE = {
|
||||||
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
||||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||||
# "__type__": "kotaemon.storages.MilvusVectorStore",
|
# "__type__": "kotaemon.storages.MilvusVectorStore",
|
||||||
|
# "__type__": "kotaemon.storages.QdrantVectorStore",
|
||||||
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
||||||
}
|
}
|
||||||
KH_LLMS = {}
|
KH_LLMS = {}
|
||||||
|
|
|
@ -11,6 +11,7 @@ from .vectorstores import (
|
||||||
InMemoryVectorStore,
|
InMemoryVectorStore,
|
||||||
LanceDBVectorStore,
|
LanceDBVectorStore,
|
||||||
MilvusVectorStore,
|
MilvusVectorStore,
|
||||||
|
QdrantVectorStore,
|
||||||
SimpleFileVectorStore,
|
SimpleFileVectorStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,4 +29,5 @@ __all__ = [
|
||||||
"SimpleFileVectorStore",
|
"SimpleFileVectorStore",
|
||||||
"LanceDBVectorStore",
|
"LanceDBVectorStore",
|
||||||
"MilvusVectorStore",
|
"MilvusVectorStore",
|
||||||
|
"QdrantVectorStore",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,6 +3,7 @@ from .chroma import ChromaVectorStore
|
||||||
from .in_memory import InMemoryVectorStore
|
from .in_memory import InMemoryVectorStore
|
||||||
from .lancedb import LanceDBVectorStore
|
from .lancedb import LanceDBVectorStore
|
||||||
from .milvus import MilvusVectorStore
|
from .milvus import MilvusVectorStore
|
||||||
|
from .qdrant import QdrantVectorStore
|
||||||
from .simple_file import SimpleFileVectorStore
|
from .simple_file import SimpleFileVectorStore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -12,4 +13,5 @@ __all__ = [
|
||||||
"SimpleFileVectorStore",
|
"SimpleFileVectorStore",
|
||||||
"LanceDBVectorStore",
|
"LanceDBVectorStore",
|
||||||
"MilvusVectorStore",
|
"MilvusVectorStore",
|
||||||
|
"QdrantVectorStore",
|
||||||
]
|
]
|
||||||
|
|
67
libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py
Normal file
67
libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from typing import Any, List, Optional, Type, cast
|
||||||
|
|
||||||
|
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
|
||||||
|
|
||||||
|
from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorStore(LlamaIndexVectorStore):
|
||||||
|
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
client_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
self._collection_name = collection_name
|
||||||
|
self._url = url
|
||||||
|
self._api_key = api_key
|
||||||
|
self._client_kwargs = client_kwargs
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
collection_name=collection_name,
|
||||||
|
url=url,
|
||||||
|
api_key=api_key,
|
||||||
|
client_kwargs=client_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self._client = cast(LIQdrantVectorStore, self._client)
|
||||||
|
|
||||||
|
def delete(self, ids: List[str], **kwargs):
|
||||||
|
"""Delete vector embeddings from vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids of the embeddings to be deleted
|
||||||
|
kwargs: meant for vectorstore-specific parameters
|
||||||
|
"""
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
self._client.client.delete(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
points_selector=models.PointIdsList(
|
||||||
|
points=ids,
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def drop(self):
|
||||||
|
"""Delete entire collection from vector stores"""
|
||||||
|
self._client.client.delete_collection(self._collection_name)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
return self._client.client.count(
|
||||||
|
collection_name=self._collection_name, exact=True
|
||||||
|
).count
|
||||||
|
|
||||||
|
def __persist_flow__(self):
|
||||||
|
return {
|
||||||
|
"collection_name": self._collection_name,
|
||||||
|
"url": self._url,
|
||||||
|
"api_key": self._api_key,
|
||||||
|
"client_kwargs": self._client_kwargs,
|
||||||
|
**self._kwargs,
|
||||||
|
}
|
|
@ -22,7 +22,7 @@ requires-python = ">= 3.10"
|
||||||
description = "Kotaemon core library for AI development."
|
description = "Kotaemon core library for AI development."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"click>=8.1.7,<9",
|
"click>=8.1.7,<9",
|
||||||
"cohere>=5.3.2,<5.4",
|
"cohere>=5.3.2,<6",
|
||||||
"cookiecutter>=2.6.0,<2.7",
|
"cookiecutter>=2.6.0,<2.7",
|
||||||
"fast_langdetect",
|
"fast_langdetect",
|
||||||
"gradio>=4.31.0,<4.40",
|
"gradio>=4.31.0,<4.40",
|
||||||
|
@ -73,6 +73,7 @@ adv = [
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
"llama-cpp-python<0.2.8",
|
"llama-cpp-python<0.2.8",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
|
"llama-index-vector-stores-qdrant",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"black",
|
"black",
|
||||||
|
|
|
@ -51,6 +51,15 @@ def if_unstructured_not_installed():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def if_cohere_not_installed():
|
||||||
|
try:
|
||||||
|
import cohere # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def if_llama_cpp_not_installed():
|
def if_llama_cpp_not_installed():
|
||||||
try:
|
try:
|
||||||
import llama_cpp # noqa: F401
|
import llama_cpp # noqa: F401
|
||||||
|
@ -76,6 +85,10 @@ skip_when_unstructured_not_installed = pytest.mark.skipif(
|
||||||
if_unstructured_not_installed(), reason="unstructured is not installed"
|
if_unstructured_not_installed(), reason="unstructured is not installed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
skip_when_cohere_not_installed = pytest.mark.skipif(
|
||||||
|
if_cohere_not_installed(), reason="cohere is not installed"
|
||||||
|
)
|
||||||
|
|
||||||
skip_openai_lc_wrapper_test = pytest.mark.skipif(
|
skip_openai_lc_wrapper_test = pytest.mark.skipif(
|
||||||
True, reason="OpenAI LC wrapper test is skipped"
|
True, reason="OpenAI LC wrapper test is skipped"
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from kotaemon.embeddings import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
skip_when_cohere_not_installed,
|
||||||
skip_when_fastembed_not_installed,
|
skip_when_fastembed_not_installed,
|
||||||
skip_when_sentence_bert_not_installed,
|
skip_when_sentence_bert_not_installed,
|
||||||
)
|
)
|
||||||
|
@ -132,6 +133,7 @@ def test_lchuggingface_embeddings(
|
||||||
langchain_huggingface_embedding_call.assert_called()
|
langchain_huggingface_embedding_call.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@skip_when_cohere_not_installed
|
||||||
@patch(
|
@patch(
|
||||||
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
||||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from kotaemon.base import DocumentWithEmbedding
|
from kotaemon.base import DocumentWithEmbedding
|
||||||
from kotaemon.storages import (
|
from kotaemon.storages import (
|
||||||
ChromaVectorStore,
|
ChromaVectorStore,
|
||||||
InMemoryVectorStore,
|
InMemoryVectorStore,
|
||||||
MilvusVectorStore,
|
MilvusVectorStore,
|
||||||
|
QdrantVectorStore,
|
||||||
SimpleFileVectorStore,
|
SimpleFileVectorStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -248,3 +251,118 @@ class TestMilvusVectorStore:
|
||||||
# reinit the milvus with the same collection name
|
# reinit the milvus with the same collection name
|
||||||
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
|
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
|
||||||
assert db2.count() == 0, "delete collection function does not work correctly"
|
assert db2.count() == 0, "delete collection function does not work correctly"
|
||||||
|
|
||||||
|
|
||||||
|
class TestQdrantVectorStore:
|
||||||
|
def test_add(self):
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
ids = [
|
||||||
|
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||||
|
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||||
|
]
|
||||||
|
|
||||||
|
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert output == ids, "Expected output to be the same as ids"
|
||||||
|
assert db.count() == 2, "Expected 2 added entries"
|
||||||
|
|
||||||
|
def test_add_from_docs(self, tmp_path):
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
|
||||||
|
documents = [
|
||||||
|
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
|
||||||
|
for embedding, metadata in zip(embeddings, metadatas)
|
||||||
|
]
|
||||||
|
|
||||||
|
output = db.add(documents)
|
||||||
|
assert len(output) == 2, "Expected outputting 2 ids"
|
||||||
|
assert db.count() == 2, "Expected 2 added entries"
|
||||||
|
|
||||||
|
def test_delete(self, tmp_path):
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = [
|
||||||
|
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||||
|
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||||
|
"6bed07c3-d284-47a3-a711-c3f9186755b8",
|
||||||
|
]
|
||||||
|
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
assert db.count() == 3, "Expected 3 added entries"
|
||||||
|
db.delete(
|
||||||
|
ids=[
|
||||||
|
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||||
|
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert db.count() == 1, "Expected 1 remaining entry"
|
||||||
|
db.delete(ids=["6bed07c3-d284-47a3-a711-c3f9186755b8"])
|
||||||
|
assert db.count() == 0, "Expected 0 remaining entry"
|
||||||
|
|
||||||
|
def test_query(self, tmp_path):
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
|
||||||
|
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = [
|
||||||
|
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||||
|
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||||
|
"6bed07c3-d284-47a3-a711-c3f9186755b8",
|
||||||
|
]
|
||||||
|
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
|
||||||
|
assert sim[0] - 1.0 < 1e-6
|
||||||
|
assert out_ids == ["0f0611b3-2d9c-4818-ab69-1f1c4cf66693"]
|
||||||
|
|
||||||
|
_, _, out_ids = db.query(embedding=[0.4, 0.5, 0.6], top_k=1)
|
||||||
|
assert out_ids == ["90aba5d3-f4f8-47c6-bad9-5ea457442e07"]
|
||||||
|
|
||||||
|
def test_save_load_delete(self, tmp_path):
|
||||||
|
"""Test that save/load func behave correctly."""
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
||||||
|
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
|
||||||
|
ids = [
|
||||||
|
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
|
||||||
|
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
|
||||||
|
"6bed07c3-d284-47a3-a711-c3f9186755b8",
|
||||||
|
]
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
db = QdrantVectorStore(
|
||||||
|
collection_name="test", client=QdrantClient(path=tmp_path)
|
||||||
|
)
|
||||||
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
del db
|
||||||
|
|
||||||
|
db2 = QdrantVectorStore(
|
||||||
|
collection_name="test", client=QdrantClient(path=tmp_path)
|
||||||
|
)
|
||||||
|
assert db2.count() == 3
|
||||||
|
|
||||||
|
db2.drop()
|
||||||
|
del db2
|
||||||
|
|
||||||
|
db2 = QdrantVectorStore(
|
||||||
|
collection_name="test", client=QdrantClient(path=tmp_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
# Since no docs were added, the collection should not exist yet
|
||||||
|
# and thus the count function should raise an exception
|
||||||
|
db2.count()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user