feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)

* fix: utf-8 txt reader

* fix: revise vectorstore import and make it optional

* feat: add cohere chat model with tool call support

* fix: simplify citation pipeline

* fix: improve citation logic

* fix: improve decompose func call

* fix: revise question rewrite prompt

* fix: revise chat box default placeholder

* fix: add key from ktem to cohere rerank

* fix: conv name suggestion

* fix: ignore default key cohere rerank

* fix: improve test connection UI

* fix: reorder requirements

* feat: add first setup screen

* fix: update requirements

* fix: vectorstore tests

* fix: update cohere version

* fix: relax langchain core version

* fix: add demo mode

* fix: update flowsettings

* fix: typo

* fix: fix bool env passing
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-09-22 16:32:23 +07:00 committed by GitHub
parent 0bdb9a32f2
commit 88d577b0cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 643 additions and 140 deletions

View File

@ -24,9 +24,13 @@ if not KH_APP_VERSION:
except Exception: except Exception:
KH_APP_VERSION = "local" KH_APP_VERSION = "local"
KH_ENABLE_FIRST_SETUP = True
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
# App can be ran from anywhere and it's not trivial to decide where to store app data. # App can be ran from anywhere and it's not trivial to decide where to store app data.
# So let's use the same directory as the flowsetting.py file. # So let's use the same directory as the flowsetting.py file.
KH_APP_DATA_DIR = this_dir / "ktem_app_data" KH_APP_DATA_DIR = this_dir / "ktem_app_data"
KH_APP_DATA_EXISTS = KH_APP_DATA_DIR.exists()
KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True) KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True)
# User data directory # User data directory
@ -59,7 +63,9 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
KH_DOC_DIR = this_dir / "docs" KH_DOC_DIR = this_dir / "docs"
KH_MODE = "dev" KH_MODE = "dev"
KH_FEATURE_USER_MANAGEMENT = True KH_FEATURE_USER_MANAGEMENT = config(
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
)
KH_USER_CAN_SEE_PUBLIC = None KH_USER_CAN_SEE_PUBLIC = None
KH_FEATURE_USER_MANAGEMENT_ADMIN = str( KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin") config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
@ -202,6 +208,14 @@ KH_LLMS["groq"] = {
}, },
"default": False, "default": False,
} }
KH_LLMS["cohere"] = {
"spec": {
"__type__": "kotaemon.llms.chats.LCCohereChat",
"model_name": "command-r-plus-08-2024",
"api_key": "your-key",
},
"default": False,
}
# additional embeddings configurations # additional embeddings configurations
KH_EMBEDDINGS["cohere"] = { KH_EMBEDDINGS["cohere"] = {

View File

@ -183,7 +183,7 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
def _get_lc_class(self): def _get_lc_class(self):
try: try:
from langchain_community.embeddings import CohereEmbeddings from langchain_cohere import CohereEmbeddings
except ImportError: except ImportError:
from langchain.embeddings import CohereEmbeddings from langchain.embeddings import CohereEmbeddings

View File

@ -1,4 +1,4 @@
from typing import Iterator, List from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -7,53 +7,14 @@ from kotaemon.base.schema import HumanMessage, SystemMessage
from kotaemon.llms import BaseLLM from kotaemon.llms import BaseLLM
class FactWithEvidence(BaseModel): class CiteEvidence(BaseModel):
"""Class representing a single statement. """List of evidences (maximum 5) to support the answer."""
Each fact has a body and a list of sources. evidences: List[str] = Field(
If there are multiple facts make sure to break them apart
such that each one only uses a set of sources that are relevant to it.
"""
fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field(
..., ...,
description=( description=(
"Each source should be a direct quote from the context, " "Each source should be a direct quote from the context, "
"as a substring of the original content" "as a substring of the original content (max 15 words)."
),
)
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
import regex
minor = quote
major = context
errs_ = 0
s = regex.search(f"({minor}){{e<={errs_}}}", major)
while s is None and errs_ <= errs:
errs_ += 1
s = regex.search(f"({minor}){{e<={errs_}}}", major)
if s is not None:
yield from s.spans()
def get_spans(self, context: str) -> Iterator[str]:
for quote in self.substring_quote:
yield from self._get_span(quote, context)
class QuestionAnswer(BaseModel):
"""A question and its answer as a list of facts each one should have a source.
each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field(
...,
description=(
"Body of the answer, each fact should be "
"its separate object with a body and a list of sources"
), ),
) )
@ -68,7 +29,7 @@ class CitationPipeline(BaseComponent):
return self.invoke(context, question) return self.invoke(context, question)
def prepare_llm(self, context: str, question: str): def prepare_llm(self, context: str, question: str):
schema = QuestionAnswer.schema() schema = CiteEvidence.schema()
function = { function = {
"name": schema["title"], "name": schema["title"],
"description": schema["description"], "description": schema["description"],
@ -76,7 +37,8 @@ class CitationPipeline(BaseComponent):
} }
llm_kwargs = { llm_kwargs = {
"tools": [{"type": "function", "function": function}], "tools": [{"type": "function", "function": function}],
"tool_choice": "auto", "tool_choice": "required",
"tools_pydantic": [CiteEvidence],
} }
messages = [ messages = [
SystemMessage( SystemMessage(
@ -85,7 +47,12 @@ class CitationPipeline(BaseComponent):
"questions with correct and exact citations." "questions with correct and exact citations."
) )
), ),
HumanMessage(content="Answer question using the following context"), HumanMessage(
content=(
"Answer question using the following context. "
"Use the provided function CiteEvidence() to cite your sources."
)
),
HumanMessage(content=context), HumanMessage(content=context),
HumanMessage(content=f"Question: {question}"), HumanMessage(content=f"Question: {question}"),
HumanMessage( HumanMessage(
@ -103,14 +70,24 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: invoking LLM") print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM") print("CitationPipeline: finish invoking LLM")
if not llm_output.messages or not llm_output.additional_kwargs.get( if not llm_output.additional_kwargs.get("tool_calls"):
"tool_calls"
):
return None return None
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments" first_func = llm_output.additional_kwargs["tool_calls"][0]
]
output = QuestionAnswer.parse_raw(function_output) if "function" in first_func:
# openai and cohere format
function_output = first_func["function"]["arguments"]
else:
# anthropic format
function_output = first_func["args"]
print("CitationPipeline:", function_output)
if isinstance(function_output, str):
output = CiteEvidence.parse_raw(function_output)
else:
output = CiteEvidence.parse_obj(function_output)
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
@ -118,18 +95,4 @@ class CitationPipeline(BaseComponent):
return output return output
async def ainvoke(self, context: str, question: str): async def ainvoke(self, context: str, question: str):
messages, llm_kwargs = self.prepare_llm(context, question) raise NotImplementedError()
try:
print("CitationPipeline: async invoking LLM")
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
print("CitationPipeline: finish async invoking LLM")
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None
return output

View File

@ -10,6 +10,7 @@ from .base import BaseReranking
class CohereReranking(BaseReranking): class CohereReranking(BaseReranking):
model_name: str = "rerank-multilingual-v2.0" model_name: str = "rerank-multilingual-v2.0"
cohere_api_key: str = config("COHERE_API_KEY", "") cohere_api_key: str = config("COHERE_API_KEY", "")
use_key_from_ktem: bool = False
def run(self, documents: list[Document], query: str) -> list[Document]: def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents """Use Cohere Reranker model to re-order documents
@ -18,9 +19,25 @@ class CohereReranking(BaseReranking):
import cohere import cohere
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking" "Please install Cohere `pip install cohere` to use Cohere Reranking"
) )
# try to get COHERE_API_KEY from embeddings
if not self.cohere_api_key and self.use_key_from_ktem:
try:
from ktem.embeddings.manager import (
embedding_models_manager as embeddings,
)
cohere_model = embeddings.get("cohere")
ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore
"cohere_api_key"
)
if ktem_cohere_api_key != "your-key":
self.cohere_api_key = ktem_cohere_api_key
except Exception as e:
print("Cannot get Cohere API key from `ktem`", e)
if not self.cohere_api_key: if not self.cohere_api_key:
print("Cohere API key not found. Skipping reranking.") print("Cohere API key not found. Skipping reranking.")
return documents return documents
@ -35,7 +52,7 @@ class CohereReranking(BaseReranking):
response = cohere_client.rerank( response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs model=self.model_name, query=query, documents=_docs
) )
print("Cohere score", [r.relevance_score for r in response.results]) # print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results: for r in response.results:
doc = documents[r.index] doc = documents[r.index]
doc.metadata["cohere_reranking_score"] = r.relevance_score doc.metadata["cohere_reranking_score"] = r.relevance_score

View File

@ -10,6 +10,7 @@ from .chats import (
LCAnthropicChat, LCAnthropicChat,
LCAzureChatOpenAI, LCAzureChatOpenAI,
LCChatOpenAI, LCChatOpenAI,
LCCohereChat,
LCGeminiChat, LCGeminiChat,
LlamaCppChat, LlamaCppChat,
) )
@ -31,6 +32,7 @@ __all__ = [
"ChatOpenAI", "ChatOpenAI",
"LCAnthropicChat", "LCAnthropicChat",
"LCGeminiChat", "LCGeminiChat",
"LCCohereChat",
"LCAzureChatOpenAI", "LCAzureChatOpenAI",
"LCChatOpenAI", "LCChatOpenAI",
"LlamaCppChat", "LlamaCppChat",

View File

@ -5,6 +5,7 @@ from .langchain_based import (
LCAzureChatOpenAI, LCAzureChatOpenAI,
LCChatMixin, LCChatMixin,
LCChatOpenAI, LCChatOpenAI,
LCCohereChat,
LCGeminiChat, LCGeminiChat,
) )
from .llamacpp import LlamaCppChat from .llamacpp import LlamaCppChat
@ -18,6 +19,7 @@ __all__ = [
"ChatOpenAI", "ChatOpenAI",
"LCAnthropicChat", "LCAnthropicChat",
"LCGeminiChat", "LCGeminiChat",
"LCCohereChat",
"LCChatOpenAI", "LCChatOpenAI",
"LCAzureChatOpenAI", "LCAzureChatOpenAI",
"LCChatMixin", "LCChatMixin",

View File

@ -18,6 +18,9 @@ class LCChatMixin:
"Please return the relevant Langchain class in in _get_lc_class" "Please return the relevant Langchain class in in _get_lc_class"
) )
def _get_tool_call_kwargs(self):
return {}
def __init__(self, stream: bool = False, **params): def __init__(self, stream: bool = False, **params):
self._lc_class = self._get_lc_class() self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params) self._obj = self._lc_class(**params)
@ -56,9 +59,7 @@ class LCChatMixin:
total_tokens = pred.llm_output["token_usage"]["total_tokens"] total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception: except Exception:
logger.warning( pass
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface( return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "", text=all_text[0] if len(all_text) > 0 else "",
@ -83,8 +84,30 @@ class LCChatMixin:
LLMInterface: generated response LLMInterface: generated response
""" """
input_ = self.prepare_message(messages) input_ = self.prepare_message(messages)
pred = self._obj.generate(messages=[input_], **kwargs)
return self.prepare_response(pred) if "tools_pydantic" in kwargs:
tools = kwargs.pop(
"tools_pydantic",
)
lc_tool_call = self._obj.bind_tools(tools)
pred = lc_tool_call.invoke(
input_,
**self._get_tool_call_kwargs(),
)
if pred.tool_calls:
tool_calls = pred.tool_calls
else:
tool_calls = pred.additional_kwargs.get("tool_calls", [])
output = LLMInterface(
content="",
additional_kwargs={"tool_calls": tool_calls},
)
else:
pred = self._obj.generate(messages=[input_], **kwargs)
output = self.prepare_response(pred)
return output
async def ainvoke( async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs self, messages: str | BaseMessage | list[BaseMessage], **kwargs
@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
required=True, required=True,
) )
def _get_tool_call_kwargs(self):
return {"tool_choice": {"type": "any"}}
def __init__( def __init__(
self, self,
api_key: str | None = None, api_key: str | None = None,
@ -291,3 +317,35 @@ class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
raise ImportError("Please install langchain-google-genai") raise ImportError("Please install langchain-google-genai")
return ChatGoogleGenerativeAI return ChatGoogleGenerativeAI
class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://dashboard.cohere.com/api-keys)", required=True
)
model_name: str = Param(
help=("Model name to use (https://dashboard.cohere.com/playground/chat)"),
required=True,
)
def __init__(
self,
api_key: str | None = None,
model_name: str | None = None,
temperature: float = 0.7,
**params,
):
super().__init__(
cohere_api_key=api_key,
model_name=model_name,
temperature=temperature,
**params,
)
def _get_lc_class(self):
try:
from langchain_cohere import ChatCohere
except ImportError:
raise ImportError("Please install langchain-cohere")
return ChatCohere

View File

@ -292,6 +292,9 @@ class ChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs): def openai_response(self, client, **kwargs):
"""Get the openai response""" """Get the openai response"""
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")
params_ = { params_ = {
"model": self.model, "model": self.model,
"temperature": self.temperature, "temperature": self.temperature,
@ -360,6 +363,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs): def openai_response(self, client, **kwargs):
"""Get the openai response""" """Get the openai response"""
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")
params_ = { params_ = {
"model": self.azure_deployment, "model": self.azure_deployment,
"temperature": self.temperature, "temperature": self.temperature,

View File

@ -15,7 +15,7 @@ class TxtReader(BaseReader):
def load_data( def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]: ) -> list[Document]:
with open(file_path, "r") as f: with open(file_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
metadata = extra_info or {} metadata = extra_info or {}

View File

@ -73,17 +73,25 @@ class BaseVectorStore(ABC):
class LlamaIndexVectorStore(BaseVectorStore): class LlamaIndexVectorStore(BaseVectorStore):
_li_class: type[LIVectorStore | BasePydanticVectorStore] """Mixin for LlamaIndex based vectorstores"""
_li_class: type[LIVectorStore | BasePydanticVectorStore] | None
def _get_li_class(self):
raise NotImplementedError(
"Please return the relevant LlamaIndex class in in _get_li_class"
)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if self._li_class is None: # get li_class from the method if not set
raise AttributeError( if not self._li_class:
"Require `_li_class` to set a VectorStore class from LlamarIndex" LIClass = self._get_li_class()
) else:
LIClass = self._li_class
from dataclasses import fields from dataclasses import fields
self._client = self._li_class(*args, **kwargs) self._client = LIClass(*args, **kwargs)
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)} self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
for key in ["query_embedding", "similarity_top_k", "node_ids"]: for key in ["query_embedding", "similarity_top_k", "node_ids"]:
@ -97,6 +105,9 @@ class LlamaIndexVectorStore(BaseVectorStore):
return setattr(self._client, name, value) return setattr(self._client, name, value)
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
if name == "_li_class":
return super().__getattribute__(name)
return getattr(self._client, name) return getattr(self._client, name)
def add( def add(

View File

@ -1,7 +1,5 @@
import os import os
from typing import Any, Optional, Type, cast from typing import Any, Optional, cast
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
from kotaemon.base import DocumentWithEmbedding from kotaemon.base import DocumentWithEmbedding
@ -9,7 +7,20 @@ from .base import LlamaIndexVectorStore
class MilvusVectorStore(LlamaIndexVectorStore): class MilvusVectorStore(LlamaIndexVectorStore):
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore _li_class = None
def _get_li_class(self):
try:
from llama_index.vector_stores.milvus import (
MilvusVectorStore as LIMilvusVectorStore,
)
except ImportError:
raise ImportError(
"Please install missing package: "
"'pip install llama-index-vector-stores-milvus'"
)
return LIMilvusVectorStore
def __init__( def __init__(
self, self,
@ -46,6 +57,10 @@ class MilvusVectorStore(LlamaIndexVectorStore):
dim=dim, dim=dim,
**self._kwargs, **self._kwargs,
) )
from llama_index.vector_stores.milvus import (
MilvusVectorStore as LIMilvusVectorStore,
)
self._client = cast(LIMilvusVectorStore, self._client) self._client = cast(LIMilvusVectorStore, self._client)
self._inited = True self._inited = True

View File

@ -1,12 +1,23 @@
from typing import Any, List, Optional, Type, cast from typing import Any, List, Optional, cast
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
from .base import LlamaIndexVectorStore from .base import LlamaIndexVectorStore
class QdrantVectorStore(LlamaIndexVectorStore): class QdrantVectorStore(LlamaIndexVectorStore):
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore _li_class = None
def _get_li_class(self):
try:
from llama_index.vector_stores.qdrant import (
QdrantVectorStore as LIQdrantVectorStore,
)
except ImportError:
raise ImportError(
"Please install missing package: "
"'pip install llama-index-vector-stores-qdrant'"
)
return LIQdrantVectorStore
def __init__( def __init__(
self, self,
@ -29,6 +40,10 @@ class QdrantVectorStore(LlamaIndexVectorStore):
client_kwargs=client_kwargs, client_kwargs=client_kwargs,
**kwargs, **kwargs,
) )
from llama_index.vector_stores.qdrant import (
QdrantVectorStore as LIQdrantVectorStore,
)
self._client = cast(LIQdrantVectorStore, self._client) self._client = cast(LIQdrantVectorStore, self._client)
def delete(self, ids: List[str], **kwargs): def delete(self, ids: List[str], **kwargs):

View File

@ -30,16 +30,15 @@ dependencies = [
"fastapi<=0.112.1", "fastapi<=0.112.1",
"gradio>=4.31.0,<4.40", "gradio>=4.31.0,<4.40",
"html2text==2024.2.26", "html2text==2024.2.26",
"langchain>=0.1.16,<0.2.0", "langchain>=0.1.16,<0.2.16",
"langchain-anthropic", "langchain-community>=0.0.34,<=0.2.11",
"langchain-community>=0.0.34,<0.1.0",
"langchain-openai>=0.1.4,<0.2.0", "langchain-openai>=0.1.4,<0.2.0",
"langchain-anthropic",
"langchain-cohere>=0.2.4,<0.3.0",
"llama-hub>=0.0.79,<0.1.0", "llama-hub>=0.0.79,<0.1.0",
"llama-index>=0.10.40,<0.11.0", "llama-index>=0.10.40,<0.11.0",
"llama-index-vector-stores-chroma>=0.1.9", "llama-index-vector-stores-chroma>=0.1.9",
"llama-index-vector-stores-lancedb", "llama-index-vector-stores-lancedb",
"llama-index-vector-stores-milvus",
"llama-index-vector-stores-qdrant",
"openai>=1.23.6,<2", "openai>=1.23.6,<2",
"openpyxl>=3.1.2,<3.2", "openpyxl>=3.1.2,<3.2",
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571 "opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571
@ -75,6 +74,9 @@ adv = [
"llama-cpp-python<0.2.8", "llama-cpp-python<0.2.8",
"sentence-transformers", "sentence-transformers",
"wikipedia>=1.4.0,<1.5", "wikipedia>=1.4.0,<1.5",
"llama-index>=0.10.40,<0.11.0",
"llama-index-vector-stores-milvus",
"llama-index-vector-stores-qdrant",
] ]
dev = [ dev = [
"black", "black",

View File

@ -135,7 +135,7 @@ def test_lchuggingface_embeddings(
@skip_when_cohere_not_installed @skip_when_cohere_not_installed
@patch( @patch(
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents", "langchain_cohere.CohereEmbeddings.embed_documents",
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
) )
def test_lccohere_embeddings(langchain_cohere_embedding_call): def test_lccohere_embeddings(langchain_cohere_embedding_call):

View File

@ -354,7 +354,7 @@ class EmbeddingManagement(BasePage):
_ = emb("Hi") _ = emb("Hi")
log_content += ( log_content += (
"<mark style='background: yellow; color: red'>- Connection success. " "<mark style='background: green; color: white'>- Connection success. "
"</mark><br>" "</mark><br>"
) )
yield log_content yield log_content

View File

@ -285,7 +285,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
], ],
retrieval_mode=user_settings["retrieval_mode"], retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None), llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
rerankers=[CohereReranking()], rerankers=[CohereReranking(use_key_from_ktem=True)],
) )
if not user_settings["use_reranking"]: if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore retriever.rerankers = [] # type: ignore

View File

@ -828,7 +828,6 @@ class FileIndexPage(BasePage):
] ]
) )
print(f"{len(results)=}, {len(file_list)=}")
return results, file_list return results, file_list
def interact_file_list(self, list_files, ev: gr.SelectData): def interact_file_list(self, list_files, ev: gr.SelectData):

View File

@ -58,6 +58,7 @@ class LLMManager:
AzureChatOpenAI, AzureChatOpenAI,
ChatOpenAI, ChatOpenAI,
LCAnthropicChat, LCAnthropicChat,
LCCohereChat,
LCGeminiChat, LCGeminiChat,
LlamaCppChat, LlamaCppChat,
) )
@ -67,6 +68,7 @@ class LLMManager:
AzureChatOpenAI, AzureChatOpenAI,
LCAnthropicChat, LCAnthropicChat,
LCGeminiChat, LCGeminiChat,
LCCohereChat,
LlamaCppChat, LlamaCppChat,
] ]

View File

@ -353,7 +353,7 @@ class LLMManagement(BasePage):
respond = llm("Hi") respond = llm("Hi")
log_content += ( log_content += (
f"<mark style='background: yellow; color: red'>- Connection success. " f"<mark style='background: green; color: white'>- Connection success. "
f"Got response:\n {respond}</mark><br>" f"Got response:\n {respond}</mark><br>"
) )
yield log_content yield log_content

View File

@ -1,9 +1,27 @@
import gradio as gr import gradio as gr
from decouple import config
from ktem.app import BaseApp from ktem.app import BaseApp
from ktem.pages.chat import ChatPage from ktem.pages.chat import ChatPage
from ktem.pages.help import HelpPage from ktem.pages.help import HelpPage
from ktem.pages.resources import ResourcesTab from ktem.pages.resources import ResourcesTab
from ktem.pages.settings import SettingsPage from ktem.pages.settings import SettingsPage
from ktem.pages.setup import SetupPage
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False)
KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True)
# override first setup setting
if config("KH_FIRST_SETUP", default=False, cast=bool):
KH_APP_DATA_EXISTS = False
def toggle_first_setup_visibility():
global KH_APP_DATA_EXISTS
is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS
KH_APP_DATA_EXISTS = True
return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup)
class App(BaseApp): class App(BaseApp):
@ -99,13 +117,17 @@ class App(BaseApp):
) as self._tabs["help-tab"]: ) as self._tabs["help-tab"]:
self.help_page = HelpPage(self) self.help_page = HelpPage(self)
if KH_ENABLE_FIRST_SETUP:
with gr.Column(visible=False) as self.setup_page_wrapper:
self.setup_page = SetupPage(self)
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
if self.f_user_management: if self.f_user_management:
from ktem.db.engine import engine from ktem.db.engine import engine
from ktem.db.models import User from ktem.db.models import User
from sqlmodel import Session, select from sqlmodel import Session, select
def signed_in_out(user_id): def toggle_login_visibility(user_id):
if not user_id: if not user_id:
return list( return list(
( (
@ -146,7 +168,7 @@ class App(BaseApp):
self.subscribe_event( self.subscribe_event(
name="onSignIn", name="onSignIn",
definition={ definition={
"fn": signed_in_out, "fn": toggle_login_visibility,
"inputs": [self.user_id], "inputs": [self.user_id],
"outputs": list(self._tabs.values()) + [self.tabs], "outputs": list(self._tabs.values()) + [self.tabs],
"show_progress": "hidden", "show_progress": "hidden",
@ -156,9 +178,30 @@ class App(BaseApp):
self.subscribe_event( self.subscribe_event(
name="onSignOut", name="onSignOut",
definition={ definition={
"fn": signed_in_out, "fn": toggle_login_visibility,
"inputs": [self.user_id], "inputs": [self.user_id],
"outputs": list(self._tabs.values()) + [self.tabs], "outputs": list(self._tabs.values()) + [self.tabs],
"show_progress": "hidden", "show_progress": "hidden",
}, },
) )
if KH_ENABLE_FIRST_SETUP:
self.subscribe_event(
name="onFirstSetupComplete",
definition={
"fn": toggle_first_setup_visibility,
"inputs": [],
"outputs": [self.setup_page_wrapper, self.tabs],
"show_progress": "hidden",
},
)
def _on_app_created(self):
"""Called when the app is created"""
if KH_ENABLE_FIRST_SETUP:
self.app.load(
toggle_first_setup_visibility,
inputs=[],
outputs=[self.setup_page_wrapper, self.tabs],
)

View File

@ -883,7 +883,8 @@ class ChatPage(BasePage):
# check if this is a newly created conversation # check if this is a newly created conversation
if len(chat_history) == 1: if len(chat_history) == 1:
suggested_name = suggest_pipeline(chat_history).text[:40] suggested_name = suggest_pipeline(chat_history).text
suggested_name = suggested_name.replace('"', "").replace("'", "")[:40]
new_name = gr.update(value=suggested_name) new_name = gr.update(value=suggested_name)
renamed = True renamed = True

View File

@ -11,8 +11,8 @@ class ChatPanel(BasePage):
self.chatbot = gr.Chatbot( self.chatbot = gr.Chatbot(
label=self._app.app_name, label=self._app.app_name,
placeholder=( placeholder=(
"This is the beginning of a new conversation.\nMake sure to have added" "This is the beginning of a new conversation.\nIf you are new, "
" a LLM by following the instructions in the Help tab." "visit the Help tab for quick instructions."
), ),
show_label=False, show_label=False,
elem_id="main-chat-bot", elem_id="main-chat-bot",

View File

@ -0,0 +1,347 @@
import json
import gradio as gr
import requests
from ktem.app import BasePage
from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
DEMO_MESSAGE = (
"This is a public space. Please use the "
'"Duplicate Space" function on the top right '
"corner to setup your own space."
)
def pull_model(name: str, stream: bool = True):
payload = {"name": name}
headers = {"Content-Type": "application/json"}
response = requests.post(
DEFAULT_OLLAMA_URL + "/pull", json=payload, headers=headers, stream=stream
)
# Check if the request was successful
response.raise_for_status()
if stream:
for line in response.iter_lines():
if line:
data = json.loads(line.decode("utf-8"))
yield data
if data.get("status") == "success":
break
else:
data = response.json()
return data
class SetupPage(BasePage):
public_events = ["onFirstSetupComplete"]
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
self.radio_model = gr.Radio(
[
("Cohere API (*free registration* available) - recommended", "cohere"),
("OpenAI API (for more advance models)", "openai"),
("Local LLM (for completely *private RAG*)", "ollama"),
],
label="Select your model provider",
value="cohere",
info=(
"Note: You can change this later. "
"If you are not sure, go with the first option "
"which fits most normal users."
),
interactive=True,
)
with gr.Column(visible=False) as self.openai_option:
gr.Markdown(
(
"#### OpenAI API Key\n\n"
"(create at https://platform.openai.com/api-keys)"
)
)
self.openai_api_key = gr.Textbox(
show_label=False, placeholder="OpenAI API Key"
)
with gr.Column(visible=True) as self.cohere_option:
gr.Markdown(
(
"#### Cohere API Key\n\n"
"(register your free API key "
"at https://dashboard.cohere.com/api-keys)"
)
)
self.cohere_api_key = gr.Textbox(
show_label=False, placeholder="Cohere API Key"
)
with gr.Column(visible=False) as self.ollama_option:
gr.Markdown(
(
"#### Setup Ollama\n\n"
"Download and install Ollama from "
"https://ollama.com/"
)
)
self.setup_log = gr.HTML(
show_label=False,
)
with gr.Row():
self.btn_finish = gr.Button("Proceed", variant="primary")
self.btn_skip = gr.Button(
"I am an advance user. Skip this.", variant="stop"
)
def on_register_events(self):
onFirstSetupComplete = gr.on(
triggers=[
self.btn_finish.click,
self.cohere_api_key.submit,
self.openai_api_key.submit,
],
fn=self.update_model,
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
outputs=[self.setup_log],
show_progress="hidden",
)
if not KH_DEMO_MODE:
onSkipSetup = gr.on(
triggers=[self.btn_skip.click],
fn=lambda: None,
inputs=[],
show_progress="hidden",
outputs=[self.radio_model],
)
for event in self._app.get_event("onFirstSetupComplete"):
onSkipSetup = onSkipSetup.success(**event)
onFirstSetupComplete = onFirstSetupComplete.success(
fn=self.update_default_settings,
inputs=[self.radio_model, self._app.settings_state],
outputs=self._app.settings_state,
)
for event in self._app.get_event("onFirstSetupComplete"):
onFirstSetupComplete = onFirstSetupComplete.success(**event)
self.radio_model.change(
fn=self.switch_options_view,
inputs=[self.radio_model],
show_progress="hidden",
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
)
def update_model(
self,
cohere_api_key,
openai_api_key,
radio_model_value,
):
# skip if KH_DEMO_MODE
if KH_DEMO_MODE:
raise gr.Error(DEMO_MESSAGE)
log_content = ""
if not radio_model_value:
gr.Info("Skip setup models.")
yield gr.value(visible=False)
return
if radio_model_value == "cohere":
if cohere_api_key:
llms.update(
name="cohere",
spec={
"__type__": "kotaemon.llms.chats.LCCohereChat",
"model_name": "command-r-plus-08-2024",
"api_key": cohere_api_key,
},
default=True,
)
embeddings.update(
name="cohere",
spec={
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v2.0",
"cohere_api_key": cohere_api_key,
"user_agent": "default",
},
default=True,
)
elif radio_model_value == "openai":
if openai_api_key:
llms.update(
name="openai",
spec={
"__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "https://api.openai.com/v1",
"model": "gpt-4o",
"api_key": openai_api_key,
"timeout": 20,
},
default=True,
)
embeddings.update(
name="openai",
spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "https://api.openai.com/v1",
"model": "text-embedding-3-large",
"api_key": openai_api_key,
"timeout": 10,
"context_length": 8191,
},
default=True,
)
elif radio_model_value == "ollama":
llms.update(
name="ollama",
spec={
"__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "http://localhost:11434/v1/",
"model": "llama3.1:8b",
"api_key": "ollama",
},
default=True,
)
embeddings.update(
name="ollama",
spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "http://localhost:11434/v1/",
"model": "nomic-embed-text",
"api_key": "ollama",
},
default=True,
)
# download required models through ollama
llm_model_name = llms.get("ollama").model # type: ignore
emb_model_name = embeddings.get("ollama").model # type: ignore
try:
for model_name in [emb_model_name, llm_model_name]:
log_content += f"- Downloading model `{model_name}` from Ollama<br>"
yield log_content
pre_download_log = log_content
for response in pull_model(model_name):
complete = response.get("completed", 0)
total = response.get("total", 0)
if complete > 0 and total > 0:
ratio = int(complete / total * 100)
log_content = (
pre_download_log
+ f"- {response.get('status')}: {ratio}%<br>"
)
else:
if "pulling" not in response.get("status", ""):
log_content += f"- {response.get('status')}<br>"
yield log_content
except Exception as e:
log_content += (
"Make sure you have download and installed Ollama correctly."
f"Got error: {str(e)}"
)
yield log_content
raise gr.Error("Failed to download model from Ollama.")
# test models connection
llm_output = emb_output = None
# LLM model
log_content += f"- Testing LLM model: {radio_model_value}<br>"
yield log_content
llm = llms.get(radio_model_value) # type: ignore
log_content += "- Sending a message `Hi`<br>"
yield log_content
try:
llm_output = llm("Hi")
except Exception as e:
log_content += (
f"<mark style='color: yellow; background: red'>- Connection failed. "
f"Got error:\n {str(e)}</mark>"
)
if llm_output:
log_content += (
"<mark style='background: green; color: white'>- Connection success. "
"</mark><br>"
)
yield log_content
if llm_output:
# embedding model
log_content += f"- Testing Embedding model: {radio_model_value}<br>"
yield log_content
emb = embeddings.get(radio_model_value)
assert emb, f"Embedding model {radio_model_value} not found."
log_content += "- Sending a message `Hi`<br>"
yield log_content
try:
emb_output = emb("Hi")
except Exception as e:
log_content += (
f"<mark style='color: yellow; background: red'>"
"- Connection failed. "
f"Got error:\n {str(e)}</mark>"
)
if emb_output:
log_content += (
"<mark style='background: green; color: white'>"
"- Connection success. "
"</mark><br>"
)
yield log_content
if llm_output and emb_output:
gr.Info("Setup models completed successfully!")
else:
raise gr.Error(
"Setup models failed. Please verify your connection and API key."
)
def update_default_settings(self, radio_model_value, default_settings):
# revise default settings
# reranking llm
default_settings["index.options.1.reranking_llm"] = radio_model_value
if radio_model_value == "ollama":
default_settings["index.options.1.use_llm_reranking"] = False
return default_settings
def switch_options_view(self, radio_model_value):
components_visible = [gr.update(visible=False) for _ in range(3)]
values = ["cohere", "openai", "ollama", None]
assert radio_model_value in values, f"Invalid value {radio_model_value}"
if radio_model_value is not None:
idx = values.index(radio_model_value)
components_visible[idx] = gr.update(visible=True)
return components_visible

View File

@ -52,6 +52,7 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline):
llm_kwargs = { llm_kwargs = {
"tools": [{"type": "function", "function": function}], "tools": [{"type": "function", "function": function}],
"tool_choice": "auto", "tool_choice": "auto",
"tools_pydantic": [SubQuery],
} }
messages = [ messages = [

View File

@ -7,6 +7,7 @@ DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it " "Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information " "to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. " "in the original question. Keep the question as concise as possible. "
"Only output the rephrased question without additional information. "
"Give answer in {lang}\n" "Give answer in {lang}\n"
"Original question: {question}\n" "Original question: {question}\n"
"Rephrased question: " "Rephrased question: "

View File

@ -39,10 +39,13 @@ EVIDENCE_MODE_TABLE = 1
EVIDENCE_MODE_CHATBOT = 2 EVIDENCE_MODE_CHATBOT = 2
EVIDENCE_MODE_FIGURE = 3 EVIDENCE_MODE_FIGURE = 3
MAX_IMAGES = 10 MAX_IMAGES = 10
CITATION_TIMEOUT = 5.0
def find_text(search_span, context): def find_text(search_span, context):
sentence_list = search_span.split("\n") sentence_list = search_span.split("\n")
context = context.replace("\n", " ")
matches = [] matches = []
# don't search for small text # don't search for small text
if len(search_span) > 5: if len(search_span) > 5:
@ -50,7 +53,7 @@ def find_text(search_span, context):
match = SequenceMatcher( match = SequenceMatcher(
None, sentence, context, autojunk=False None, sentence, context, autojunk=False
).find_longest_match() ).find_longest_match()
if match.size > len(sentence) * 0.35: if match.size > max(len(sentence) * 0.35, 5):
matches.append((match.b, match.b + match.size)) matches.append((match.b, match.b + match.size))
return matches return matches
@ -200,15 +203,6 @@ DEFAULT_QA_FIGURE_PROMPT = (
"Answer: " "Answer: "
) # noqa ) # noqa
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
) # noqa
CONTEXT_RELEVANT_WARNING_SCORE = 0.7 CONTEXT_RELEVANT_WARNING_SCORE = 0.7
@ -391,7 +385,8 @@ class AnswerWithContextPipeline(BaseComponent):
qa_score = None qa_score = None
if citation_thread: if citation_thread:
citation_thread.join() citation_thread.join(timeout=CITATION_TIMEOUT)
answer = Document( answer = Document(
text=output, text=output,
metadata={"citation": citation, "qa_score": qa_score}, metadata={"citation": citation, "qa_score": qa_score},
@ -525,24 +520,24 @@ class FullQAPipeline(BaseReasoning):
spans = defaultdict(list) spans = defaultdict(list)
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs) has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
if answer.metadata["citation"] and answer.metadata["citation"].answer: if answer.metadata["citation"]:
for fact_with_evidence in answer.metadata["citation"].answer: evidences = answer.metadata["citation"].evidences
for quote in fact_with_evidence.substring_quote: for quote in evidences:
matched_excerpts = [] matched_excerpts = []
for doc in docs: for doc in docs:
matches = find_text(quote, doc.text) matches = find_text(quote, doc.text)
for start, end in matches: for start, end in matches:
if "|" not in doc.text[start:end]: if "|" not in doc.text[start:end]:
spans[doc.doc_id].append( spans[doc.doc_id].append(
{ {
"start": start, "start": start,
"end": end, "end": end,
} }
) )
matched_excerpts.append(doc.text[start:end]) matched_excerpts.append(doc.text[start:end])
print("Matched citation:", quote, matched_excerpts), # print("Matched citation:", quote, matched_excerpts),
id2docs = {doc.doc_id: doc for doc in docs} id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys()) not_detected = set(id2docs.keys()) - set(spans.keys())

View File

@ -75,7 +75,6 @@ class Render:
if not highlight_text: if not highlight_text:
try: try:
lang = detect(text.replace("\n", " "))["lang"] lang = detect(text.replace("\n", " "))["lang"]
print("lang", lang)
if lang not in ["ja", "cn"]: if lang not in ["ja", "cn"]:
highlight_words = [ highlight_words = [
t[:-1] if t.endswith("-") else t for t in text.split("\n") t[:-1] if t.endswith("-") else t for t in text.split("\n")
@ -83,10 +82,13 @@ class Render:
highlight_text = highlight_words[0] highlight_text = highlight_words[0]
phrase = "true" phrase = "true"
else: else:
highlight_text = text.replace("\n", "")
phrase = "false" phrase = "false"
print("highlight_text", highlight_text, phrase) highlight_text = (
text.replace("\n", "").replace('"', "").replace("'", "")
)
# print("highlight_text", highlight_text, phrase, lang)
except Exception as e: except Exception as e:
print(e) print(e)
highlight_text = text highlight_text = text
@ -162,8 +164,15 @@ class Render:
if item_type_prefix: if item_type_prefix:
item_type_prefix += " from " item_type_prefix += " from "
if llm_reranking_score > 0:
relevant_score = llm_reranking_score
elif cohere_reranking_score > 0:
relevant_score = cohere_reranking_score
else:
relevant_score = 0.0
rendered_score = Render.collapsible( rendered_score = Render.collapsible(
header=f"<b>&emsp;Relevance score</b>: {llm_reranking_score}", header=f"<b>&emsp;Relevance score</b>: {relevant_score:.1f}",
content="<b>&emsp;&emsp;Vectorstore score:</b>" content="<b>&emsp;&emsp;Vectorstore score:</b>"
f" {vectorstore_score}" f" {vectorstore_score}"
f"{text_search_str}" f"{text_search_str}"