feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)
* fix: utf-8 txt reader * fix: revise vectorstore import and make it optional * feat: add cohere chat model with tool call support * fix: simplify citation pipeline * fix: improve citation logic * fix: improve decompose func call * fix: revise question rewrite prompt * fix: revise chat box default placeholder * fix: add key from ktem to cohere rerank * fix: conv name suggestion * fix: ignore default key cohere rerank * fix: improve test connection UI * fix: reorder requirements * feat: add first setup screen * fix: update requirements * fix: vectorstore tests * fix: update cohere version * fix: relax langchain core version * fix: add demo mode * fix: update flowsettings * fix: typo * fix: fix bool env passing
This commit is contained in:
committed by
GitHub
parent
0bdb9a32f2
commit
88d577b0cc
@@ -183,7 +183,7 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_community.embeddings import CohereEmbeddings
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Iterator, List
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -7,53 +7,14 @@ from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||
from kotaemon.llms import BaseLLM
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing a single statement.
|
||||
class CiteEvidence(BaseModel):
|
||||
"""List of evidences (maximum 5) to support the answer."""
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
evidences: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
"as a substring of the original content (max 15 words)."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -68,7 +29,7 @@ class CitationPipeline(BaseComponent):
|
||||
return self.invoke(context, question)
|
||||
|
||||
def prepare_llm(self, context: str, question: str):
|
||||
schema = QuestionAnswer.schema()
|
||||
schema = CiteEvidence.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
@@ -76,7 +37,8 @@ class CitationPipeline(BaseComponent):
|
||||
}
|
||||
llm_kwargs = {
|
||||
"tools": [{"type": "function", "function": function}],
|
||||
"tool_choice": "auto",
|
||||
"tool_choice": "required",
|
||||
"tools_pydantic": [CiteEvidence],
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
@@ -85,7 +47,12 @@ class CitationPipeline(BaseComponent):
|
||||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Answer question using the following context. "
|
||||
"Use the provided function CiteEvidence() to cite your sources."
|
||||
)
|
||||
),
|
||||
HumanMessage(content=context),
|
||||
HumanMessage(content=f"Question: {question}"),
|
||||
HumanMessage(
|
||||
@@ -103,14 +70,24 @@ class CitationPipeline(BaseComponent):
|
||||
print("CitationPipeline: invoking LLM")
|
||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish invoking LLM")
|
||||
if not llm_output.messages or not llm_output.additional_kwargs.get(
|
||||
"tool_calls"
|
||||
):
|
||||
if not llm_output.additional_kwargs.get("tool_calls"):
|
||||
return None
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
first_func = llm_output.additional_kwargs["tool_calls"][0]
|
||||
|
||||
if "function" in first_func:
|
||||
# openai and cohere format
|
||||
function_output = first_func["function"]["arguments"]
|
||||
else:
|
||||
# anthropic format
|
||||
function_output = first_func["args"]
|
||||
|
||||
print("CitationPipeline:", function_output)
|
||||
|
||||
if isinstance(function_output, str):
|
||||
output = CiteEvidence.parse_raw(function_output)
|
||||
else:
|
||||
output = CiteEvidence.parse_obj(function_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
@@ -118,18 +95,4 @@ class CitationPipeline(BaseComponent):
|
||||
return output
|
||||
|
||||
async def ainvoke(self, context: str, question: str):
|
||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||
|
||||
try:
|
||||
print("CitationPipeline: async invoking LLM")
|
||||
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish async invoking LLM")
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
return output
|
||||
raise NotImplementedError()
|
||||
|
@@ -10,6 +10,7 @@ from .base import BaseReranking
|
||||
class CohereReranking(BaseReranking):
|
||||
model_name: str = "rerank-multilingual-v2.0"
|
||||
cohere_api_key: str = config("COHERE_API_KEY", "")
|
||||
use_key_from_ktem: bool = False
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
@@ -18,9 +19,25 @@ class CohereReranking(BaseReranking):
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
"Please install Cohere `pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
# try to get COHERE_API_KEY from embeddings
|
||||
if not self.cohere_api_key and self.use_key_from_ktem:
|
||||
try:
|
||||
from ktem.embeddings.manager import (
|
||||
embedding_models_manager as embeddings,
|
||||
)
|
||||
|
||||
cohere_model = embeddings.get("cohere")
|
||||
ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore
|
||||
"cohere_api_key"
|
||||
)
|
||||
if ktem_cohere_api_key != "your-key":
|
||||
self.cohere_api_key = ktem_cohere_api_key
|
||||
except Exception as e:
|
||||
print("Cannot get Cohere API key from `ktem`", e)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping reranking.")
|
||||
return documents
|
||||
@@ -35,7 +52,7 @@ class CohereReranking(BaseReranking):
|
||||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
print("Cohere score", [r.relevance_score for r in response.results])
|
||||
# print("Cohere score", [r.relevance_score for r in response.results])
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||
|
@@ -10,6 +10,7 @@ from .chats import (
|
||||
LCAnthropicChat,
|
||||
LCAzureChatOpenAI,
|
||||
LCChatOpenAI,
|
||||
LCCohereChat,
|
||||
LCGeminiChat,
|
||||
LlamaCppChat,
|
||||
)
|
||||
@@ -31,6 +32,7 @@ __all__ = [
|
||||
"ChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCGeminiChat",
|
||||
"LCCohereChat",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatOpenAI",
|
||||
"LlamaCppChat",
|
||||
|
@@ -5,6 +5,7 @@ from .langchain_based import (
|
||||
LCAzureChatOpenAI,
|
||||
LCChatMixin,
|
||||
LCChatOpenAI,
|
||||
LCCohereChat,
|
||||
LCGeminiChat,
|
||||
)
|
||||
from .llamacpp import LlamaCppChat
|
||||
@@ -18,6 +19,7 @@ __all__ = [
|
||||
"ChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCGeminiChat",
|
||||
"LCCohereChat",
|
||||
"LCChatOpenAI",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatMixin",
|
||||
|
@@ -18,6 +18,9 @@ class LCChatMixin:
|
||||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def _get_tool_call_kwargs(self):
|
||||
return {}
|
||||
|
||||
def __init__(self, stream: bool = False, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
@@ -56,9 +59,7 @@ class LCChatMixin:
|
||||
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
|
||||
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
|
||||
)
|
||||
pass
|
||||
|
||||
return LLMInterface(
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
@@ -83,8 +84,30 @@ class LCChatMixin:
|
||||
LLMInterface: generated response
|
||||
"""
|
||||
input_ = self.prepare_message(messages)
|
||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||
return self.prepare_response(pred)
|
||||
|
||||
if "tools_pydantic" in kwargs:
|
||||
tools = kwargs.pop(
|
||||
"tools_pydantic",
|
||||
)
|
||||
lc_tool_call = self._obj.bind_tools(tools)
|
||||
pred = lc_tool_call.invoke(
|
||||
input_,
|
||||
**self._get_tool_call_kwargs(),
|
||||
)
|
||||
if pred.tool_calls:
|
||||
tool_calls = pred.tool_calls
|
||||
else:
|
||||
tool_calls = pred.additional_kwargs.get("tool_calls", [])
|
||||
|
||||
output = LLMInterface(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls},
|
||||
)
|
||||
else:
|
||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||
output = self.prepare_response(pred)
|
||||
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
@@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
||||
required=True,
|
||||
)
|
||||
|
||||
def _get_tool_call_kwargs(self):
|
||||
return {"tool_choice": {"type": "any"}}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
@@ -291,3 +317,35 @@ class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
|
||||
raise ImportError("Please install langchain-google-genai")
|
||||
|
||||
return ChatGoogleGenerativeAI
|
||||
|
||||
|
||||
class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
|
||||
api_key: str = Param(
|
||||
help="API key (https://dashboard.cohere.com/api-keys)", required=True
|
||||
)
|
||||
model_name: str = Param(
|
||||
help=("Model name to use (https://dashboard.cohere.com/playground/chat)"),
|
||||
required=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model_name: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
cohere_api_key=api_key,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_cohere import ChatCohere
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-cohere")
|
||||
|
||||
return ChatCohere
|
||||
|
@@ -292,6 +292,9 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
if "tools_pydantic" in kwargs:
|
||||
kwargs.pop("tools_pydantic")
|
||||
|
||||
params_ = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
@@ -360,6 +363,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
if "tools_pydantic" in kwargs:
|
||||
kwargs.pop("tools_pydantic")
|
||||
|
||||
params_ = {
|
||||
"model": self.azure_deployment,
|
||||
"temperature": self.temperature,
|
||||
|
@@ -15,7 +15,7 @@ class TxtReader(BaseReader):
|
||||
def load_data(
|
||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> list[Document]:
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
metadata = extra_info or {}
|
||||
|
@@ -73,17 +73,25 @@ class BaseVectorStore(ABC):
|
||||
|
||||
|
||||
class LlamaIndexVectorStore(BaseVectorStore):
|
||||
_li_class: type[LIVectorStore | BasePydanticVectorStore]
|
||||
"""Mixin for LlamaIndex based vectorstores"""
|
||||
|
||||
_li_class: type[LIVectorStore | BasePydanticVectorStore] | None
|
||||
|
||||
def _get_li_class(self):
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant LlamaIndex class in in _get_li_class"
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self._li_class is None:
|
||||
raise AttributeError(
|
||||
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
||||
)
|
||||
# get li_class from the method if not set
|
||||
if not self._li_class:
|
||||
LIClass = self._get_li_class()
|
||||
else:
|
||||
LIClass = self._li_class
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
self._client = self._li_class(*args, **kwargs)
|
||||
self._client = LIClass(*args, **kwargs)
|
||||
|
||||
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
||||
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
||||
@@ -97,6 +105,9 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||
return setattr(self._client, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "_li_class":
|
||||
return super().__getattribute__(name)
|
||||
|
||||
return getattr(self._client, name)
|
||||
|
||||
def add(
|
||||
|
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
|
||||
@@ -9,7 +7,20 @@ from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class MilvusVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
|
||||
_li_class = None
|
||||
|
||||
def _get_li_class(self):
|
||||
try:
|
||||
from llama_index.vector_stores.milvus import (
|
||||
MilvusVectorStore as LIMilvusVectorStore,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install missing package: "
|
||||
"'pip install llama-index-vector-stores-milvus'"
|
||||
)
|
||||
|
||||
return LIMilvusVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,6 +57,10 @@ class MilvusVectorStore(LlamaIndexVectorStore):
|
||||
dim=dim,
|
||||
**self._kwargs,
|
||||
)
|
||||
from llama_index.vector_stores.milvus import (
|
||||
MilvusVectorStore as LIMilvusVectorStore,
|
||||
)
|
||||
|
||||
self._client = cast(LIMilvusVectorStore, self._client)
|
||||
self._inited = True
|
||||
|
||||
|
@@ -1,12 +1,23 @@
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class QdrantVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore
|
||||
_li_class = None
|
||||
|
||||
def _get_li_class(self):
|
||||
try:
|
||||
from llama_index.vector_stores.qdrant import (
|
||||
QdrantVectorStore as LIQdrantVectorStore,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install missing package: "
|
||||
"'pip install llama-index-vector-stores-qdrant'"
|
||||
)
|
||||
|
||||
return LIQdrantVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,6 +40,10 @@ class QdrantVectorStore(LlamaIndexVectorStore):
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
from llama_index.vector_stores.qdrant import (
|
||||
QdrantVectorStore as LIQdrantVectorStore,
|
||||
)
|
||||
|
||||
self._client = cast(LIQdrantVectorStore, self._client)
|
||||
|
||||
def delete(self, ids: List[str], **kwargs):
|
||||
|
@@ -30,16 +30,15 @@ dependencies = [
|
||||
"fastapi<=0.112.1",
|
||||
"gradio>=4.31.0,<4.40",
|
||||
"html2text==2024.2.26",
|
||||
"langchain>=0.1.16,<0.2.0",
|
||||
"langchain-anthropic",
|
||||
"langchain-community>=0.0.34,<0.1.0",
|
||||
"langchain>=0.1.16,<0.2.16",
|
||||
"langchain-community>=0.0.34,<=0.2.11",
|
||||
"langchain-openai>=0.1.4,<0.2.0",
|
||||
"langchain-anthropic",
|
||||
"langchain-cohere>=0.2.4,<0.3.0",
|
||||
"llama-hub>=0.0.79,<0.1.0",
|
||||
"llama-index>=0.10.40,<0.11.0",
|
||||
"llama-index-vector-stores-chroma>=0.1.9",
|
||||
"llama-index-vector-stores-lancedb",
|
||||
"llama-index-vector-stores-milvus",
|
||||
"llama-index-vector-stores-qdrant",
|
||||
"openai>=1.23.6,<2",
|
||||
"openpyxl>=3.1.2,<3.2",
|
||||
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571
|
||||
@@ -75,6 +74,9 @@ adv = [
|
||||
"llama-cpp-python<0.2.8",
|
||||
"sentence-transformers",
|
||||
"wikipedia>=1.4.0,<1.5",
|
||||
"llama-index>=0.10.40,<0.11.0",
|
||||
"llama-index-vector-stores-milvus",
|
||||
"llama-index-vector-stores-qdrant",
|
||||
]
|
||||
dev = [
|
||||
"black",
|
||||
|
@@ -135,7 +135,7 @@ def test_lchuggingface_embeddings(
|
||||
|
||||
@skip_when_cohere_not_installed
|
||||
@patch(
|
||||
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
||||
"langchain_cohere.CohereEmbeddings.embed_documents",
|
||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||
)
|
||||
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
||||
|
Reference in New Issue
Block a user