feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)
* fix: utf-8 txt reader * fix: revise vectorstore import and make it optional * feat: add cohere chat model with tool call support * fix: simplify citation pipeline * fix: improve citation logic * fix: improve decompose func call * fix: revise question rewrite prompt * fix: revise chat box default placeholder * fix: add key from ktem to cohere rerank * fix: conv name suggestion * fix: ignore default key cohere rerank * fix: improve test connection UI * fix: reorder requirements * feat: add first setup screen * fix: update requirements * fix: vectorstore tests * fix: update cohere version * fix: relax langchain core version * fix: add demo mode * fix: update flowsettings * fix: typo * fix: fix bool env passing
This commit is contained in:
parent
0bdb9a32f2
commit
88d577b0cc
|
@ -24,9 +24,13 @@ if not KH_APP_VERSION:
|
||||||
except Exception:
|
except Exception:
|
||||||
KH_APP_VERSION = "local"
|
KH_APP_VERSION = "local"
|
||||||
|
|
||||||
|
KH_ENABLE_FIRST_SETUP = True
|
||||||
|
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
|
||||||
|
|
||||||
# App can be ran from anywhere and it's not trivial to decide where to store app data.
|
# App can be ran from anywhere and it's not trivial to decide where to store app data.
|
||||||
# So let's use the same directory as the flowsetting.py file.
|
# So let's use the same directory as the flowsetting.py file.
|
||||||
KH_APP_DATA_DIR = this_dir / "ktem_app_data"
|
KH_APP_DATA_DIR = this_dir / "ktem_app_data"
|
||||||
|
KH_APP_DATA_EXISTS = KH_APP_DATA_DIR.exists()
|
||||||
KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# User data directory
|
# User data directory
|
||||||
|
@ -59,7 +63,9 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
||||||
KH_DOC_DIR = this_dir / "docs"
|
KH_DOC_DIR = this_dir / "docs"
|
||||||
|
|
||||||
KH_MODE = "dev"
|
KH_MODE = "dev"
|
||||||
KH_FEATURE_USER_MANAGEMENT = True
|
KH_FEATURE_USER_MANAGEMENT = config(
|
||||||
|
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
|
||||||
|
)
|
||||||
KH_USER_CAN_SEE_PUBLIC = None
|
KH_USER_CAN_SEE_PUBLIC = None
|
||||||
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
|
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
|
||||||
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
|
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
|
||||||
|
@ -202,6 +208,14 @@ KH_LLMS["groq"] = {
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
|
KH_LLMS["cohere"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
||||||
|
"model_name": "command-r-plus-08-2024",
|
||||||
|
"api_key": "your-key",
|
||||||
|
},
|
||||||
|
"default": False,
|
||||||
|
}
|
||||||
|
|
||||||
# additional embeddings configurations
|
# additional embeddings configurations
|
||||||
KH_EMBEDDINGS["cohere"] = {
|
KH_EMBEDDINGS["cohere"] = {
|
||||||
|
|
|
@ -183,7 +183,7 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
|
|
||||||
def _get_lc_class(self):
|
def _get_lc_class(self):
|
||||||
try:
|
try:
|
||||||
from langchain_community.embeddings import CohereEmbeddings
|
from langchain_cohere import CohereEmbeddings
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from langchain.embeddings import CohereEmbeddings
|
from langchain.embeddings import CohereEmbeddings
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterator, List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -7,53 +7,14 @@ from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||||
from kotaemon.llms import BaseLLM
|
from kotaemon.llms import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class FactWithEvidence(BaseModel):
|
class CiteEvidence(BaseModel):
|
||||||
"""Class representing a single statement.
|
"""List of evidences (maximum 5) to support the answer."""
|
||||||
|
|
||||||
Each fact has a body and a list of sources.
|
evidences: List[str] = Field(
|
||||||
If there are multiple facts make sure to break them apart
|
|
||||||
such that each one only uses a set of sources that are relevant to it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
|
||||||
substring_quote: List[str] = Field(
|
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Each source should be a direct quote from the context, "
|
"Each source should be a direct quote from the context, "
|
||||||
"as a substring of the original content"
|
"as a substring of the original content (max 15 words)."
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
|
||||||
import regex
|
|
||||||
|
|
||||||
minor = quote
|
|
||||||
major = context
|
|
||||||
|
|
||||||
errs_ = 0
|
|
||||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
|
||||||
while s is None and errs_ <= errs:
|
|
||||||
errs_ += 1
|
|
||||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
|
||||||
|
|
||||||
if s is not None:
|
|
||||||
yield from s.spans()
|
|
||||||
|
|
||||||
def get_spans(self, context: str) -> Iterator[str]:
|
|
||||||
for quote in self.substring_quote:
|
|
||||||
yield from self._get_span(quote, context)
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionAnswer(BaseModel):
|
|
||||||
"""A question and its answer as a list of facts each one should have a source.
|
|
||||||
each sentence contains a body and a list of sources."""
|
|
||||||
|
|
||||||
question: str = Field(..., description="Question that was asked")
|
|
||||||
answer: List[FactWithEvidence] = Field(
|
|
||||||
...,
|
|
||||||
description=(
|
|
||||||
"Body of the answer, each fact should be "
|
|
||||||
"its separate object with a body and a list of sources"
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,7 +29,7 @@ class CitationPipeline(BaseComponent):
|
||||||
return self.invoke(context, question)
|
return self.invoke(context, question)
|
||||||
|
|
||||||
def prepare_llm(self, context: str, question: str):
|
def prepare_llm(self, context: str, question: str):
|
||||||
schema = QuestionAnswer.schema()
|
schema = CiteEvidence.schema()
|
||||||
function = {
|
function = {
|
||||||
"name": schema["title"],
|
"name": schema["title"],
|
||||||
"description": schema["description"],
|
"description": schema["description"],
|
||||||
|
@ -76,7 +37,8 @@ class CitationPipeline(BaseComponent):
|
||||||
}
|
}
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
"tools": [{"type": "function", "function": function}],
|
"tools": [{"type": "function", "function": function}],
|
||||||
"tool_choice": "auto",
|
"tool_choice": "required",
|
||||||
|
"tools_pydantic": [CiteEvidence],
|
||||||
}
|
}
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
|
@ -85,7 +47,12 @@ class CitationPipeline(BaseComponent):
|
||||||
"questions with correct and exact citations."
|
"questions with correct and exact citations."
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
HumanMessage(content="Answer question using the following context"),
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
"Answer question using the following context. "
|
||||||
|
"Use the provided function CiteEvidence() to cite your sources."
|
||||||
|
)
|
||||||
|
),
|
||||||
HumanMessage(content=context),
|
HumanMessage(content=context),
|
||||||
HumanMessage(content=f"Question: {question}"),
|
HumanMessage(content=f"Question: {question}"),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
|
@ -103,14 +70,24 @@ class CitationPipeline(BaseComponent):
|
||||||
print("CitationPipeline: invoking LLM")
|
print("CitationPipeline: invoking LLM")
|
||||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||||
print("CitationPipeline: finish invoking LLM")
|
print("CitationPipeline: finish invoking LLM")
|
||||||
if not llm_output.messages or not llm_output.additional_kwargs.get(
|
if not llm_output.additional_kwargs.get("tool_calls"):
|
||||||
"tool_calls"
|
|
||||||
):
|
|
||||||
return None
|
return None
|
||||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
|
||||||
"arguments"
|
first_func = llm_output.additional_kwargs["tool_calls"][0]
|
||||||
]
|
|
||||||
output = QuestionAnswer.parse_raw(function_output)
|
if "function" in first_func:
|
||||||
|
# openai and cohere format
|
||||||
|
function_output = first_func["function"]["arguments"]
|
||||||
|
else:
|
||||||
|
# anthropic format
|
||||||
|
function_output = first_func["args"]
|
||||||
|
|
||||||
|
print("CitationPipeline:", function_output)
|
||||||
|
|
||||||
|
if isinstance(function_output, str):
|
||||||
|
output = CiteEvidence.parse_raw(function_output)
|
||||||
|
else:
|
||||||
|
output = CiteEvidence.parse_obj(function_output)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
@ -118,18 +95,4 @@ class CitationPipeline(BaseComponent):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def ainvoke(self, context: str, question: str):
|
async def ainvoke(self, context: str, question: str):
|
||||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
raise NotImplementedError()
|
||||||
|
|
||||||
try:
|
|
||||||
print("CitationPipeline: async invoking LLM")
|
|
||||||
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
|
||||||
print("CitationPipeline: finish async invoking LLM")
|
|
||||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
|
||||||
"arguments"
|
|
||||||
]
|
|
||||||
output = QuestionAnswer.parse_raw(function_output)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from .base import BaseReranking
|
||||||
class CohereReranking(BaseReranking):
|
class CohereReranking(BaseReranking):
|
||||||
model_name: str = "rerank-multilingual-v2.0"
|
model_name: str = "rerank-multilingual-v2.0"
|
||||||
cohere_api_key: str = config("COHERE_API_KEY", "")
|
cohere_api_key: str = config("COHERE_API_KEY", "")
|
||||||
|
use_key_from_ktem: bool = False
|
||||||
|
|
||||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||||
"""Use Cohere Reranker model to re-order documents
|
"""Use Cohere Reranker model to re-order documents
|
||||||
|
@ -18,9 +19,25 @@ class CohereReranking(BaseReranking):
|
||||||
import cohere
|
import cohere
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
"Please install Cohere `pip install cohere` to use Cohere Reranking"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# try to get COHERE_API_KEY from embeddings
|
||||||
|
if not self.cohere_api_key and self.use_key_from_ktem:
|
||||||
|
try:
|
||||||
|
from ktem.embeddings.manager import (
|
||||||
|
embedding_models_manager as embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
cohere_model = embeddings.get("cohere")
|
||||||
|
ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore
|
||||||
|
"cohere_api_key"
|
||||||
|
)
|
||||||
|
if ktem_cohere_api_key != "your-key":
|
||||||
|
self.cohere_api_key = ktem_cohere_api_key
|
||||||
|
except Exception as e:
|
||||||
|
print("Cannot get Cohere API key from `ktem`", e)
|
||||||
|
|
||||||
if not self.cohere_api_key:
|
if not self.cohere_api_key:
|
||||||
print("Cohere API key not found. Skipping reranking.")
|
print("Cohere API key not found. Skipping reranking.")
|
||||||
return documents
|
return documents
|
||||||
|
@ -35,7 +52,7 @@ class CohereReranking(BaseReranking):
|
||||||
response = cohere_client.rerank(
|
response = cohere_client.rerank(
|
||||||
model=self.model_name, query=query, documents=_docs
|
model=self.model_name, query=query, documents=_docs
|
||||||
)
|
)
|
||||||
print("Cohere score", [r.relevance_score for r in response.results])
|
# print("Cohere score", [r.relevance_score for r in response.results])
|
||||||
for r in response.results:
|
for r in response.results:
|
||||||
doc = documents[r.index]
|
doc = documents[r.index]
|
||||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||||
|
|
|
@ -10,6 +10,7 @@ from .chats import (
|
||||||
LCAnthropicChat,
|
LCAnthropicChat,
|
||||||
LCAzureChatOpenAI,
|
LCAzureChatOpenAI,
|
||||||
LCChatOpenAI,
|
LCChatOpenAI,
|
||||||
|
LCCohereChat,
|
||||||
LCGeminiChat,
|
LCGeminiChat,
|
||||||
LlamaCppChat,
|
LlamaCppChat,
|
||||||
)
|
)
|
||||||
|
@ -31,6 +32,7 @@ __all__ = [
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"LCAnthropicChat",
|
"LCAnthropicChat",
|
||||||
"LCGeminiChat",
|
"LCGeminiChat",
|
||||||
|
"LCCohereChat",
|
||||||
"LCAzureChatOpenAI",
|
"LCAzureChatOpenAI",
|
||||||
"LCChatOpenAI",
|
"LCChatOpenAI",
|
||||||
"LlamaCppChat",
|
"LlamaCppChat",
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .langchain_based import (
|
||||||
LCAzureChatOpenAI,
|
LCAzureChatOpenAI,
|
||||||
LCChatMixin,
|
LCChatMixin,
|
||||||
LCChatOpenAI,
|
LCChatOpenAI,
|
||||||
|
LCCohereChat,
|
||||||
LCGeminiChat,
|
LCGeminiChat,
|
||||||
)
|
)
|
||||||
from .llamacpp import LlamaCppChat
|
from .llamacpp import LlamaCppChat
|
||||||
|
@ -18,6 +19,7 @@ __all__ = [
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"LCAnthropicChat",
|
"LCAnthropicChat",
|
||||||
"LCGeminiChat",
|
"LCGeminiChat",
|
||||||
|
"LCCohereChat",
|
||||||
"LCChatOpenAI",
|
"LCChatOpenAI",
|
||||||
"LCAzureChatOpenAI",
|
"LCAzureChatOpenAI",
|
||||||
"LCChatMixin",
|
"LCChatMixin",
|
||||||
|
|
|
@ -18,6 +18,9 @@ class LCChatMixin:
|
||||||
"Please return the relevant Langchain class in in _get_lc_class"
|
"Please return the relevant Langchain class in in _get_lc_class"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_tool_call_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
def __init__(self, stream: bool = False, **params):
|
def __init__(self, stream: bool = False, **params):
|
||||||
self._lc_class = self._get_lc_class()
|
self._lc_class = self._get_lc_class()
|
||||||
self._obj = self._lc_class(**params)
|
self._obj = self._lc_class(**params)
|
||||||
|
@ -56,9 +59,7 @@ class LCChatMixin:
|
||||||
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
|
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
|
||||||
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
|
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
pass
|
||||||
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMInterface(
|
return LLMInterface(
|
||||||
text=all_text[0] if len(all_text) > 0 else "",
|
text=all_text[0] if len(all_text) > 0 else "",
|
||||||
|
@ -83,8 +84,30 @@ class LCChatMixin:
|
||||||
LLMInterface: generated response
|
LLMInterface: generated response
|
||||||
"""
|
"""
|
||||||
input_ = self.prepare_message(messages)
|
input_ = self.prepare_message(messages)
|
||||||
|
|
||||||
|
if "tools_pydantic" in kwargs:
|
||||||
|
tools = kwargs.pop(
|
||||||
|
"tools_pydantic",
|
||||||
|
)
|
||||||
|
lc_tool_call = self._obj.bind_tools(tools)
|
||||||
|
pred = lc_tool_call.invoke(
|
||||||
|
input_,
|
||||||
|
**self._get_tool_call_kwargs(),
|
||||||
|
)
|
||||||
|
if pred.tool_calls:
|
||||||
|
tool_calls = pred.tool_calls
|
||||||
|
else:
|
||||||
|
tool_calls = pred.additional_kwargs.get("tool_calls", [])
|
||||||
|
|
||||||
|
output = LLMInterface(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": tool_calls},
|
||||||
|
)
|
||||||
|
else:
|
||||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||||
return self.prepare_response(pred)
|
output = self.prepare_response(pred)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_tool_call_kwargs(self):
|
||||||
|
return {"tool_choice": {"type": "any"}}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
@ -291,3 +317,35 @@ class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
raise ImportError("Please install langchain-google-genai")
|
raise ImportError("Please install langchain-google-genai")
|
||||||
|
|
||||||
return ChatGoogleGenerativeAI
|
return ChatGoogleGenerativeAI
|
||||||
|
|
||||||
|
|
||||||
|
class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
api_key: str = Param(
|
||||||
|
help="API key (https://dashboard.cohere.com/api-keys)", required=True
|
||||||
|
)
|
||||||
|
model_name: str = Param(
|
||||||
|
help=("Model name to use (https://dashboard.cohere.com/playground/chat)"),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**params,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
cohere_api_key=api_key,
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_lc_class(self):
|
||||||
|
try:
|
||||||
|
from langchain_cohere import ChatCohere
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install langchain-cohere")
|
||||||
|
|
||||||
|
return ChatCohere
|
||||||
|
|
|
@ -292,6 +292,9 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||||
|
|
||||||
def openai_response(self, client, **kwargs):
|
def openai_response(self, client, **kwargs):
|
||||||
"""Get the openai response"""
|
"""Get the openai response"""
|
||||||
|
if "tools_pydantic" in kwargs:
|
||||||
|
kwargs.pop("tools_pydantic")
|
||||||
|
|
||||||
params_ = {
|
params_ = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
|
@ -360,6 +363,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
|
|
||||||
def openai_response(self, client, **kwargs):
|
def openai_response(self, client, **kwargs):
|
||||||
"""Get the openai response"""
|
"""Get the openai response"""
|
||||||
|
if "tools_pydantic" in kwargs:
|
||||||
|
kwargs.pop("tools_pydantic")
|
||||||
|
|
||||||
params_ = {
|
params_ = {
|
||||||
"model": self.azure_deployment,
|
"model": self.azure_deployment,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
|
|
|
@ -15,7 +15,7 @@ class TxtReader(BaseReader):
|
||||||
def load_data(
|
def load_data(
|
||||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
metadata = extra_info or {}
|
metadata = extra_info or {}
|
||||||
|
|
|
@ -73,17 +73,25 @@ class BaseVectorStore(ABC):
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexVectorStore(BaseVectorStore):
|
class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
_li_class: type[LIVectorStore | BasePydanticVectorStore]
|
"""Mixin for LlamaIndex based vectorstores"""
|
||||||
|
|
||||||
|
_li_class: type[LIVectorStore | BasePydanticVectorStore] | None
|
||||||
|
|
||||||
|
def _get_li_class(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please return the relevant LlamaIndex class in in _get_li_class"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
if self._li_class is None:
|
# get li_class from the method if not set
|
||||||
raise AttributeError(
|
if not self._li_class:
|
||||||
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
LIClass = self._get_li_class()
|
||||||
)
|
else:
|
||||||
|
LIClass = self._li_class
|
||||||
|
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
|
|
||||||
self._client = self._li_class(*args, **kwargs)
|
self._client = LIClass(*args, **kwargs)
|
||||||
|
|
||||||
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
||||||
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
||||||
|
@ -97,6 +105,9 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
return setattr(self._client, name, value)
|
return setattr(self._client, name, value)
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if name == "_li_class":
|
||||||
|
return super().__getattribute__(name)
|
||||||
|
|
||||||
return getattr(self._client, name)
|
return getattr(self._client, name)
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Type, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
|
|
||||||
|
|
||||||
from kotaemon.base import DocumentWithEmbedding
|
from kotaemon.base import DocumentWithEmbedding
|
||||||
|
|
||||||
|
@ -9,7 +7,20 @@ from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorStore(LlamaIndexVectorStore):
|
class MilvusVectorStore(LlamaIndexVectorStore):
|
||||||
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
|
_li_class = None
|
||||||
|
|
||||||
|
def _get_li_class(self):
|
||||||
|
try:
|
||||||
|
from llama_index.vector_stores.milvus import (
|
||||||
|
MilvusVectorStore as LIMilvusVectorStore,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install missing package: "
|
||||||
|
"'pip install llama-index-vector-stores-milvus'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return LIMilvusVectorStore
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -46,6 +57,10 @@ class MilvusVectorStore(LlamaIndexVectorStore):
|
||||||
dim=dim,
|
dim=dim,
|
||||||
**self._kwargs,
|
**self._kwargs,
|
||||||
)
|
)
|
||||||
|
from llama_index.vector_stores.milvus import (
|
||||||
|
MilvusVectorStore as LIMilvusVectorStore,
|
||||||
|
)
|
||||||
|
|
||||||
self._client = cast(LIMilvusVectorStore, self._client)
|
self._client = cast(LIMilvusVectorStore, self._client)
|
||||||
self._inited = True
|
self._inited = True
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,23 @@
|
||||||
from typing import Any, List, Optional, Type, cast
|
from typing import Any, List, Optional, cast
|
||||||
|
|
||||||
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
|
|
||||||
|
|
||||||
from .base import LlamaIndexVectorStore
|
from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorStore(LlamaIndexVectorStore):
|
class QdrantVectorStore(LlamaIndexVectorStore):
|
||||||
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore
|
_li_class = None
|
||||||
|
|
||||||
|
def _get_li_class(self):
|
||||||
|
try:
|
||||||
|
from llama_index.vector_stores.qdrant import (
|
||||||
|
QdrantVectorStore as LIQdrantVectorStore,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install missing package: "
|
||||||
|
"'pip install llama-index-vector-stores-qdrant'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return LIQdrantVectorStore
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -29,6 +40,10 @@ class QdrantVectorStore(LlamaIndexVectorStore):
|
||||||
client_kwargs=client_kwargs,
|
client_kwargs=client_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
from llama_index.vector_stores.qdrant import (
|
||||||
|
QdrantVectorStore as LIQdrantVectorStore,
|
||||||
|
)
|
||||||
|
|
||||||
self._client = cast(LIQdrantVectorStore, self._client)
|
self._client = cast(LIQdrantVectorStore, self._client)
|
||||||
|
|
||||||
def delete(self, ids: List[str], **kwargs):
|
def delete(self, ids: List[str], **kwargs):
|
||||||
|
|
|
@ -30,16 +30,15 @@ dependencies = [
|
||||||
"fastapi<=0.112.1",
|
"fastapi<=0.112.1",
|
||||||
"gradio>=4.31.0,<4.40",
|
"gradio>=4.31.0,<4.40",
|
||||||
"html2text==2024.2.26",
|
"html2text==2024.2.26",
|
||||||
"langchain>=0.1.16,<0.2.0",
|
"langchain>=0.1.16,<0.2.16",
|
||||||
"langchain-anthropic",
|
"langchain-community>=0.0.34,<=0.2.11",
|
||||||
"langchain-community>=0.0.34,<0.1.0",
|
|
||||||
"langchain-openai>=0.1.4,<0.2.0",
|
"langchain-openai>=0.1.4,<0.2.0",
|
||||||
|
"langchain-anthropic",
|
||||||
|
"langchain-cohere>=0.2.4,<0.3.0",
|
||||||
"llama-hub>=0.0.79,<0.1.0",
|
"llama-hub>=0.0.79,<0.1.0",
|
||||||
"llama-index>=0.10.40,<0.11.0",
|
"llama-index>=0.10.40,<0.11.0",
|
||||||
"llama-index-vector-stores-chroma>=0.1.9",
|
"llama-index-vector-stores-chroma>=0.1.9",
|
||||||
"llama-index-vector-stores-lancedb",
|
"llama-index-vector-stores-lancedb",
|
||||||
"llama-index-vector-stores-milvus",
|
|
||||||
"llama-index-vector-stores-qdrant",
|
|
||||||
"openai>=1.23.6,<2",
|
"openai>=1.23.6,<2",
|
||||||
"openpyxl>=3.1.2,<3.2",
|
"openpyxl>=3.1.2,<3.2",
|
||||||
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571
|
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571
|
||||||
|
@ -75,6 +74,9 @@ adv = [
|
||||||
"llama-cpp-python<0.2.8",
|
"llama-cpp-python<0.2.8",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
"wikipedia>=1.4.0,<1.5",
|
"wikipedia>=1.4.0,<1.5",
|
||||||
|
"llama-index>=0.10.40,<0.11.0",
|
||||||
|
"llama-index-vector-stores-milvus",
|
||||||
|
"llama-index-vector-stores-qdrant",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"black",
|
"black",
|
||||||
|
|
|
@ -135,7 +135,7 @@ def test_lchuggingface_embeddings(
|
||||||
|
|
||||||
@skip_when_cohere_not_installed
|
@skip_when_cohere_not_installed
|
||||||
@patch(
|
@patch(
|
||||||
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
"langchain_cohere.CohereEmbeddings.embed_documents",
|
||||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||||
)
|
)
|
||||||
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
||||||
|
|
|
@ -354,7 +354,7 @@ class EmbeddingManagement(BasePage):
|
||||||
_ = emb("Hi")
|
_ = emb("Hi")
|
||||||
|
|
||||||
log_content += (
|
log_content += (
|
||||||
"<mark style='background: yellow; color: red'>- Connection success. "
|
"<mark style='background: green; color: white'>- Connection success. "
|
||||||
"</mark><br>"
|
"</mark><br>"
|
||||||
)
|
)
|
||||||
yield log_content
|
yield log_content
|
||||||
|
|
|
@ -285,7 +285,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
],
|
],
|
||||||
retrieval_mode=user_settings["retrieval_mode"],
|
retrieval_mode=user_settings["retrieval_mode"],
|
||||||
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
||||||
rerankers=[CohereReranking()],
|
rerankers=[CohereReranking(use_key_from_ktem=True)],
|
||||||
)
|
)
|
||||||
if not user_settings["use_reranking"]:
|
if not user_settings["use_reranking"]:
|
||||||
retriever.rerankers = [] # type: ignore
|
retriever.rerankers = [] # type: ignore
|
||||||
|
|
|
@ -828,7 +828,6 @@ class FileIndexPage(BasePage):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"{len(results)=}, {len(file_list)=}")
|
|
||||||
return results, file_list
|
return results, file_list
|
||||||
|
|
||||||
def interact_file_list(self, list_files, ev: gr.SelectData):
|
def interact_file_list(self, list_files, ev: gr.SelectData):
|
||||||
|
|
|
@ -58,6 +58,7 @@ class LLMManager:
|
||||||
AzureChatOpenAI,
|
AzureChatOpenAI,
|
||||||
ChatOpenAI,
|
ChatOpenAI,
|
||||||
LCAnthropicChat,
|
LCAnthropicChat,
|
||||||
|
LCCohereChat,
|
||||||
LCGeminiChat,
|
LCGeminiChat,
|
||||||
LlamaCppChat,
|
LlamaCppChat,
|
||||||
)
|
)
|
||||||
|
@ -67,6 +68,7 @@ class LLMManager:
|
||||||
AzureChatOpenAI,
|
AzureChatOpenAI,
|
||||||
LCAnthropicChat,
|
LCAnthropicChat,
|
||||||
LCGeminiChat,
|
LCGeminiChat,
|
||||||
|
LCCohereChat,
|
||||||
LlamaCppChat,
|
LlamaCppChat,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -353,7 +353,7 @@ class LLMManagement(BasePage):
|
||||||
respond = llm("Hi")
|
respond = llm("Hi")
|
||||||
|
|
||||||
log_content += (
|
log_content += (
|
||||||
f"<mark style='background: yellow; color: red'>- Connection success. "
|
f"<mark style='background: green; color: white'>- Connection success. "
|
||||||
f"Got response:\n {respond}</mark><br>"
|
f"Got response:\n {respond}</mark><br>"
|
||||||
)
|
)
|
||||||
yield log_content
|
yield log_content
|
||||||
|
|
|
@ -1,9 +1,27 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from decouple import config
|
||||||
from ktem.app import BaseApp
|
from ktem.app import BaseApp
|
||||||
from ktem.pages.chat import ChatPage
|
from ktem.pages.chat import ChatPage
|
||||||
from ktem.pages.help import HelpPage
|
from ktem.pages.help import HelpPage
|
||||||
from ktem.pages.resources import ResourcesTab
|
from ktem.pages.resources import ResourcesTab
|
||||||
from ktem.pages.settings import SettingsPage
|
from ktem.pages.settings import SettingsPage
|
||||||
|
from ktem.pages.setup import SetupPage
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||||
|
KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False)
|
||||||
|
KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True)
|
||||||
|
|
||||||
|
# override first setup setting
|
||||||
|
if config("KH_FIRST_SETUP", default=False, cast=bool):
|
||||||
|
KH_APP_DATA_EXISTS = False
|
||||||
|
|
||||||
|
|
||||||
|
def toggle_first_setup_visibility():
|
||||||
|
global KH_APP_DATA_EXISTS
|
||||||
|
is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS
|
||||||
|
KH_APP_DATA_EXISTS = True
|
||||||
|
return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup)
|
||||||
|
|
||||||
|
|
||||||
class App(BaseApp):
|
class App(BaseApp):
|
||||||
|
@ -99,13 +117,17 @@ class App(BaseApp):
|
||||||
) as self._tabs["help-tab"]:
|
) as self._tabs["help-tab"]:
|
||||||
self.help_page = HelpPage(self)
|
self.help_page = HelpPage(self)
|
||||||
|
|
||||||
|
if KH_ENABLE_FIRST_SETUP:
|
||||||
|
with gr.Column(visible=False) as self.setup_page_wrapper:
|
||||||
|
self.setup_page = SetupPage(self)
|
||||||
|
|
||||||
def on_subscribe_public_events(self):
|
def on_subscribe_public_events(self):
|
||||||
if self.f_user_management:
|
if self.f_user_management:
|
||||||
from ktem.db.engine import engine
|
from ktem.db.engine import engine
|
||||||
from ktem.db.models import User
|
from ktem.db.models import User
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
def signed_in_out(user_id):
|
def toggle_login_visibility(user_id):
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return list(
|
return list(
|
||||||
(
|
(
|
||||||
|
@ -146,7 +168,7 @@ class App(BaseApp):
|
||||||
self.subscribe_event(
|
self.subscribe_event(
|
||||||
name="onSignIn",
|
name="onSignIn",
|
||||||
definition={
|
definition={
|
||||||
"fn": signed_in_out,
|
"fn": toggle_login_visibility,
|
||||||
"inputs": [self.user_id],
|
"inputs": [self.user_id],
|
||||||
"outputs": list(self._tabs.values()) + [self.tabs],
|
"outputs": list(self._tabs.values()) + [self.tabs],
|
||||||
"show_progress": "hidden",
|
"show_progress": "hidden",
|
||||||
|
@ -156,9 +178,30 @@ class App(BaseApp):
|
||||||
self.subscribe_event(
|
self.subscribe_event(
|
||||||
name="onSignOut",
|
name="onSignOut",
|
||||||
definition={
|
definition={
|
||||||
"fn": signed_in_out,
|
"fn": toggle_login_visibility,
|
||||||
"inputs": [self.user_id],
|
"inputs": [self.user_id],
|
||||||
"outputs": list(self._tabs.values()) + [self.tabs],
|
"outputs": list(self._tabs.values()) + [self.tabs],
|
||||||
"show_progress": "hidden",
|
"show_progress": "hidden",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if KH_ENABLE_FIRST_SETUP:
|
||||||
|
self.subscribe_event(
|
||||||
|
name="onFirstSetupComplete",
|
||||||
|
definition={
|
||||||
|
"fn": toggle_first_setup_visibility,
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": [self.setup_page_wrapper, self.tabs],
|
||||||
|
"show_progress": "hidden",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_app_created(self):
|
||||||
|
"""Called when the app is created"""
|
||||||
|
|
||||||
|
if KH_ENABLE_FIRST_SETUP:
|
||||||
|
self.app.load(
|
||||||
|
toggle_first_setup_visibility,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.setup_page_wrapper, self.tabs],
|
||||||
|
)
|
||||||
|
|
|
@ -883,7 +883,8 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
# check if this is a newly created conversation
|
# check if this is a newly created conversation
|
||||||
if len(chat_history) == 1:
|
if len(chat_history) == 1:
|
||||||
suggested_name = suggest_pipeline(chat_history).text[:40]
|
suggested_name = suggest_pipeline(chat_history).text
|
||||||
|
suggested_name = suggested_name.replace('"', "").replace("'", "")[:40]
|
||||||
new_name = gr.update(value=suggested_name)
|
new_name = gr.update(value=suggested_name)
|
||||||
renamed = True
|
renamed = True
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,8 @@ class ChatPanel(BasePage):
|
||||||
self.chatbot = gr.Chatbot(
|
self.chatbot = gr.Chatbot(
|
||||||
label=self._app.app_name,
|
label=self._app.app_name,
|
||||||
placeholder=(
|
placeholder=(
|
||||||
"This is the beginning of a new conversation.\nMake sure to have added"
|
"This is the beginning of a new conversation.\nIf you are new, "
|
||||||
" a LLM by following the instructions in the Help tab."
|
"visit the Help tab for quick instructions."
|
||||||
),
|
),
|
||||||
show_label=False,
|
show_label=False,
|
||||||
elem_id="main-chat-bot",
|
elem_id="main-chat-bot",
|
||||||
|
|
347
libs/ktem/ktem/pages/setup.py
Normal file
347
libs/ktem/ktem/pages/setup.py
Normal file
|
@ -0,0 +1,347 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import requests
|
||||||
|
from ktem.app import BasePage
|
||||||
|
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||||
|
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
|
||||||
|
|
||||||
|
|
||||||
|
DEMO_MESSAGE = (
|
||||||
|
"This is a public space. Please use the "
|
||||||
|
'"Duplicate Space" function on the top right '
|
||||||
|
"corner to setup your own space."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pull_model(name: str, stream: bool = True):
|
||||||
|
payload = {"name": name}
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
DEFAULT_OLLAMA_URL + "/pull", json=payload, headers=headers, stream=stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the request was successful
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
data = json.loads(line.decode("utf-8"))
|
||||||
|
yield data
|
||||||
|
if data.get("status") == "success":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SetupPage(BasePage):
|
||||||
|
|
||||||
|
public_events = ["onFirstSetupComplete"]
|
||||||
|
|
||||||
|
def __init__(self, app):
|
||||||
|
self._app = app
|
||||||
|
self.on_building_ui()
|
||||||
|
|
||||||
|
def on_building_ui(self):
|
||||||
|
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
|
||||||
|
self.radio_model = gr.Radio(
|
||||||
|
[
|
||||||
|
("Cohere API (*free registration* available) - recommended", "cohere"),
|
||||||
|
("OpenAI API (for more advance models)", "openai"),
|
||||||
|
("Local LLM (for completely *private RAG*)", "ollama"),
|
||||||
|
],
|
||||||
|
label="Select your model provider",
|
||||||
|
value="cohere",
|
||||||
|
info=(
|
||||||
|
"Note: You can change this later. "
|
||||||
|
"If you are not sure, go with the first option "
|
||||||
|
"which fits most normal users."
|
||||||
|
),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(visible=False) as self.openai_option:
|
||||||
|
gr.Markdown(
|
||||||
|
(
|
||||||
|
"#### OpenAI API Key\n\n"
|
||||||
|
"(create at https://platform.openai.com/api-keys)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.openai_api_key = gr.Textbox(
|
||||||
|
show_label=False, placeholder="OpenAI API Key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(visible=True) as self.cohere_option:
|
||||||
|
gr.Markdown(
|
||||||
|
(
|
||||||
|
"#### Cohere API Key\n\n"
|
||||||
|
"(register your free API key "
|
||||||
|
"at https://dashboard.cohere.com/api-keys)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.cohere_api_key = gr.Textbox(
|
||||||
|
show_label=False, placeholder="Cohere API Key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(visible=False) as self.ollama_option:
|
||||||
|
gr.Markdown(
|
||||||
|
(
|
||||||
|
"#### Setup Ollama\n\n"
|
||||||
|
"Download and install Ollama from "
|
||||||
|
"https://ollama.com/"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.setup_log = gr.HTML(
|
||||||
|
show_label=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.btn_finish = gr.Button("Proceed", variant="primary")
|
||||||
|
self.btn_skip = gr.Button(
|
||||||
|
"I am an advance user. Skip this.", variant="stop"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_register_events(self):
|
||||||
|
onFirstSetupComplete = gr.on(
|
||||||
|
triggers=[
|
||||||
|
self.btn_finish.click,
|
||||||
|
self.cohere_api_key.submit,
|
||||||
|
self.openai_api_key.submit,
|
||||||
|
],
|
||||||
|
fn=self.update_model,
|
||||||
|
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
|
||||||
|
outputs=[self.setup_log],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
if not KH_DEMO_MODE:
|
||||||
|
onSkipSetup = gr.on(
|
||||||
|
triggers=[self.btn_skip.click],
|
||||||
|
fn=lambda: None,
|
||||||
|
inputs=[],
|
||||||
|
show_progress="hidden",
|
||||||
|
outputs=[self.radio_model],
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in self._app.get_event("onFirstSetupComplete"):
|
||||||
|
onSkipSetup = onSkipSetup.success(**event)
|
||||||
|
|
||||||
|
onFirstSetupComplete = onFirstSetupComplete.success(
|
||||||
|
fn=self.update_default_settings,
|
||||||
|
inputs=[self.radio_model, self._app.settings_state],
|
||||||
|
outputs=self._app.settings_state,
|
||||||
|
)
|
||||||
|
for event in self._app.get_event("onFirstSetupComplete"):
|
||||||
|
onFirstSetupComplete = onFirstSetupComplete.success(**event)
|
||||||
|
|
||||||
|
self.radio_model.change(
|
||||||
|
fn=self.switch_options_view,
|
||||||
|
inputs=[self.radio_model],
|
||||||
|
show_progress="hidden",
|
||||||
|
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
cohere_api_key,
|
||||||
|
openai_api_key,
|
||||||
|
radio_model_value,
|
||||||
|
):
|
||||||
|
# skip if KH_DEMO_MODE
|
||||||
|
if KH_DEMO_MODE:
|
||||||
|
raise gr.Error(DEMO_MESSAGE)
|
||||||
|
|
||||||
|
log_content = ""
|
||||||
|
if not radio_model_value:
|
||||||
|
gr.Info("Skip setup models.")
|
||||||
|
yield gr.value(visible=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
if radio_model_value == "cohere":
|
||||||
|
if cohere_api_key:
|
||||||
|
llms.update(
|
||||||
|
name="cohere",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
||||||
|
"model_name": "command-r-plus-08-2024",
|
||||||
|
"api_key": cohere_api_key,
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
embeddings.update(
|
||||||
|
name="cohere",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
||||||
|
"model": "embed-multilingual-v2.0",
|
||||||
|
"cohere_api_key": cohere_api_key,
|
||||||
|
"user_agent": "default",
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
elif radio_model_value == "openai":
|
||||||
|
if openai_api_key:
|
||||||
|
llms.update(
|
||||||
|
name="openai",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": openai_api_key,
|
||||||
|
"timeout": 20,
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
embeddings.update(
|
||||||
|
name="openai",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"model": "text-embedding-3-large",
|
||||||
|
"api_key": openai_api_key,
|
||||||
|
"timeout": 10,
|
||||||
|
"context_length": 8191,
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
elif radio_model_value == "ollama":
|
||||||
|
llms.update(
|
||||||
|
name="ollama",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||||
|
"base_url": "http://localhost:11434/v1/",
|
||||||
|
"model": "llama3.1:8b",
|
||||||
|
"api_key": "ollama",
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
embeddings.update(
|
||||||
|
name="ollama",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||||
|
"base_url": "http://localhost:11434/v1/",
|
||||||
|
"model": "nomic-embed-text",
|
||||||
|
"api_key": "ollama",
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# download required models through ollama
|
||||||
|
llm_model_name = llms.get("ollama").model # type: ignore
|
||||||
|
emb_model_name = embeddings.get("ollama").model # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
for model_name in [emb_model_name, llm_model_name]:
|
||||||
|
log_content += f"- Downloading model `{model_name}` from Ollama<br>"
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
pre_download_log = log_content
|
||||||
|
|
||||||
|
for response in pull_model(model_name):
|
||||||
|
complete = response.get("completed", 0)
|
||||||
|
total = response.get("total", 0)
|
||||||
|
if complete > 0 and total > 0:
|
||||||
|
ratio = int(complete / total * 100)
|
||||||
|
log_content = (
|
||||||
|
pre_download_log
|
||||||
|
+ f"- {response.get('status')}: {ratio}%<br>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "pulling" not in response.get("status", ""):
|
||||||
|
log_content += f"- {response.get('status')}<br>"
|
||||||
|
|
||||||
|
yield log_content
|
||||||
|
except Exception as e:
|
||||||
|
log_content += (
|
||||||
|
"Make sure you have download and installed Ollama correctly."
|
||||||
|
f"Got error: {str(e)}"
|
||||||
|
)
|
||||||
|
yield log_content
|
||||||
|
raise gr.Error("Failed to download model from Ollama.")
|
||||||
|
|
||||||
|
# test models connection
|
||||||
|
llm_output = emb_output = None
|
||||||
|
|
||||||
|
# LLM model
|
||||||
|
log_content += f"- Testing LLM model: {radio_model_value}<br>"
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
llm = llms.get(radio_model_value) # type: ignore
|
||||||
|
log_content += "- Sending a message `Hi`<br>"
|
||||||
|
yield log_content
|
||||||
|
try:
|
||||||
|
llm_output = llm("Hi")
|
||||||
|
except Exception as e:
|
||||||
|
log_content += (
|
||||||
|
f"<mark style='color: yellow; background: red'>- Connection failed. "
|
||||||
|
f"Got error:\n {str(e)}</mark>"
|
||||||
|
)
|
||||||
|
|
||||||
|
if llm_output:
|
||||||
|
log_content += (
|
||||||
|
"<mark style='background: green; color: white'>- Connection success. "
|
||||||
|
"</mark><br>"
|
||||||
|
)
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
if llm_output:
|
||||||
|
# embedding model
|
||||||
|
log_content += f"- Testing Embedding model: {radio_model_value}<br>"
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
emb = embeddings.get(radio_model_value)
|
||||||
|
assert emb, f"Embedding model {radio_model_value} not found."
|
||||||
|
|
||||||
|
log_content += "- Sending a message `Hi`<br>"
|
||||||
|
yield log_content
|
||||||
|
try:
|
||||||
|
emb_output = emb("Hi")
|
||||||
|
except Exception as e:
|
||||||
|
log_content += (
|
||||||
|
f"<mark style='color: yellow; background: red'>"
|
||||||
|
"- Connection failed. "
|
||||||
|
f"Got error:\n {str(e)}</mark>"
|
||||||
|
)
|
||||||
|
|
||||||
|
if emb_output:
|
||||||
|
log_content += (
|
||||||
|
"<mark style='background: green; color: white'>"
|
||||||
|
"- Connection success. "
|
||||||
|
"</mark><br>"
|
||||||
|
)
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
if llm_output and emb_output:
|
||||||
|
gr.Info("Setup models completed successfully!")
|
||||||
|
else:
|
||||||
|
raise gr.Error(
|
||||||
|
"Setup models failed. Please verify your connection and API key."
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_default_settings(self, radio_model_value, default_settings):
|
||||||
|
# revise default settings
|
||||||
|
# reranking llm
|
||||||
|
default_settings["index.options.1.reranking_llm"] = radio_model_value
|
||||||
|
if radio_model_value == "ollama":
|
||||||
|
default_settings["index.options.1.use_llm_reranking"] = False
|
||||||
|
|
||||||
|
return default_settings
|
||||||
|
|
||||||
|
def switch_options_view(self, radio_model_value):
|
||||||
|
components_visible = [gr.update(visible=False) for _ in range(3)]
|
||||||
|
|
||||||
|
values = ["cohere", "openai", "ollama", None]
|
||||||
|
assert radio_model_value in values, f"Invalid value {radio_model_value}"
|
||||||
|
|
||||||
|
if radio_model_value is not None:
|
||||||
|
idx = values.index(radio_model_value)
|
||||||
|
components_visible[idx] = gr.update(visible=True)
|
||||||
|
|
||||||
|
return components_visible
|
|
@ -52,6 +52,7 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline):
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
"tools": [{"type": "function", "function": function}],
|
"tools": [{"type": "function", "function": function}],
|
||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
|
"tools_pydantic": [SubQuery],
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
|
|
@ -7,6 +7,7 @@ DEFAULT_REWRITE_PROMPT = (
|
||||||
"Given the following question, rephrase and expand it "
|
"Given the following question, rephrase and expand it "
|
||||||
"to help you do better answering. Maintain all information "
|
"to help you do better answering. Maintain all information "
|
||||||
"in the original question. Keep the question as concise as possible. "
|
"in the original question. Keep the question as concise as possible. "
|
||||||
|
"Only output the rephrased question without additional information. "
|
||||||
"Give answer in {lang}\n"
|
"Give answer in {lang}\n"
|
||||||
"Original question: {question}\n"
|
"Original question: {question}\n"
|
||||||
"Rephrased question: "
|
"Rephrased question: "
|
||||||
|
|
|
@ -39,10 +39,13 @@ EVIDENCE_MODE_TABLE = 1
|
||||||
EVIDENCE_MODE_CHATBOT = 2
|
EVIDENCE_MODE_CHATBOT = 2
|
||||||
EVIDENCE_MODE_FIGURE = 3
|
EVIDENCE_MODE_FIGURE = 3
|
||||||
MAX_IMAGES = 10
|
MAX_IMAGES = 10
|
||||||
|
CITATION_TIMEOUT = 5.0
|
||||||
|
|
||||||
|
|
||||||
def find_text(search_span, context):
|
def find_text(search_span, context):
|
||||||
sentence_list = search_span.split("\n")
|
sentence_list = search_span.split("\n")
|
||||||
|
context = context.replace("\n", " ")
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
# don't search for small text
|
# don't search for small text
|
||||||
if len(search_span) > 5:
|
if len(search_span) > 5:
|
||||||
|
@ -50,7 +53,7 @@ def find_text(search_span, context):
|
||||||
match = SequenceMatcher(
|
match = SequenceMatcher(
|
||||||
None, sentence, context, autojunk=False
|
None, sentence, context, autojunk=False
|
||||||
).find_longest_match()
|
).find_longest_match()
|
||||||
if match.size > len(sentence) * 0.35:
|
if match.size > max(len(sentence) * 0.35, 5):
|
||||||
matches.append((match.b, match.b + match.size))
|
matches.append((match.b, match.b + match.size))
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
@ -200,15 +203,6 @@ DEFAULT_QA_FIGURE_PROMPT = (
|
||||||
"Answer: "
|
"Answer: "
|
||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
DEFAULT_REWRITE_PROMPT = (
|
|
||||||
"Given the following question, rephrase and expand it "
|
|
||||||
"to help you do better answering. Maintain all information "
|
|
||||||
"in the original question. Keep the question as concise as possible. "
|
|
||||||
"Give answer in {lang}\n"
|
|
||||||
"Original question: {question}\n"
|
|
||||||
"Rephrased question: "
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
|
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
|
||||||
|
|
||||||
|
|
||||||
|
@ -391,7 +385,8 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
qa_score = None
|
qa_score = None
|
||||||
|
|
||||||
if citation_thread:
|
if citation_thread:
|
||||||
citation_thread.join()
|
citation_thread.join(timeout=CITATION_TIMEOUT)
|
||||||
|
|
||||||
answer = Document(
|
answer = Document(
|
||||||
text=output,
|
text=output,
|
||||||
metadata={"citation": citation, "qa_score": qa_score},
|
metadata={"citation": citation, "qa_score": qa_score},
|
||||||
|
@ -525,9 +520,9 @@ class FullQAPipeline(BaseReasoning):
|
||||||
spans = defaultdict(list)
|
spans = defaultdict(list)
|
||||||
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
|
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
|
||||||
|
|
||||||
if answer.metadata["citation"] and answer.metadata["citation"].answer:
|
if answer.metadata["citation"]:
|
||||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
evidences = answer.metadata["citation"].evidences
|
||||||
for quote in fact_with_evidence.substring_quote:
|
for quote in evidences:
|
||||||
matched_excerpts = []
|
matched_excerpts = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
matches = find_text(quote, doc.text)
|
matches = find_text(quote, doc.text)
|
||||||
|
@ -542,7 +537,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
)
|
)
|
||||||
matched_excerpts.append(doc.text[start:end])
|
matched_excerpts.append(doc.text[start:end])
|
||||||
|
|
||||||
print("Matched citation:", quote, matched_excerpts),
|
# print("Matched citation:", quote, matched_excerpts),
|
||||||
|
|
||||||
id2docs = {doc.doc_id: doc for doc in docs}
|
id2docs = {doc.doc_id: doc for doc in docs}
|
||||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||||
|
|
|
@ -75,7 +75,6 @@ class Render:
|
||||||
if not highlight_text:
|
if not highlight_text:
|
||||||
try:
|
try:
|
||||||
lang = detect(text.replace("\n", " "))["lang"]
|
lang = detect(text.replace("\n", " "))["lang"]
|
||||||
print("lang", lang)
|
|
||||||
if lang not in ["ja", "cn"]:
|
if lang not in ["ja", "cn"]:
|
||||||
highlight_words = [
|
highlight_words = [
|
||||||
t[:-1] if t.endswith("-") else t for t in text.split("\n")
|
t[:-1] if t.endswith("-") else t for t in text.split("\n")
|
||||||
|
@ -83,10 +82,13 @@ class Render:
|
||||||
highlight_text = highlight_words[0]
|
highlight_text = highlight_words[0]
|
||||||
phrase = "true"
|
phrase = "true"
|
||||||
else:
|
else:
|
||||||
highlight_text = text.replace("\n", "")
|
|
||||||
phrase = "false"
|
phrase = "false"
|
||||||
|
|
||||||
print("highlight_text", highlight_text, phrase)
|
highlight_text = (
|
||||||
|
text.replace("\n", "").replace('"', "").replace("'", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# print("highlight_text", highlight_text, phrase, lang)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
highlight_text = text
|
highlight_text = text
|
||||||
|
@ -162,8 +164,15 @@ class Render:
|
||||||
if item_type_prefix:
|
if item_type_prefix:
|
||||||
item_type_prefix += " from "
|
item_type_prefix += " from "
|
||||||
|
|
||||||
|
if llm_reranking_score > 0:
|
||||||
|
relevant_score = llm_reranking_score
|
||||||
|
elif cohere_reranking_score > 0:
|
||||||
|
relevant_score = cohere_reranking_score
|
||||||
|
else:
|
||||||
|
relevant_score = 0.0
|
||||||
|
|
||||||
rendered_score = Render.collapsible(
|
rendered_score = Render.collapsible(
|
||||||
header=f"<b> Relevance score</b>: {llm_reranking_score}",
|
header=f"<b> Relevance score</b>: {relevant_score:.1f}",
|
||||||
content="<b>  Vectorstore score:</b>"
|
content="<b>  Vectorstore score:</b>"
|
||||||
f" {vectorstore_score}"
|
f" {vectorstore_score}"
|
||||||
f"{text_search_str}"
|
f"{text_search_str}"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user