From 88d577b0ccf0ce9e882453705c9f464cd3952719 Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Sun, 22 Sep 2024 16:32:23 +0700 Subject: [PATCH] feat: add first setup screen for LLM & Embedding models (#314) (bump:minor) * fix: utf-8 txt reader * fix: revise vectorstore import and make it optional * feat: add cohere chat model with tool call support * fix: simplify citation pipeline * fix: improve citation logic * fix: improve decompose func call * fix: revise question rewrite prompt * fix: revise chat box default placeholder * fix: add key from ktem to cohere rerank * fix: conv name suggestion * fix: ignore default key cohere rerank * fix: improve test connection UI * fix: reorder requirements * feat: add first setup screen * fix: update requirements * fix: vectorstore tests * fix: update cohere version * fix: relax langchain core version * fix: add demo mode * fix: update flowsettings * fix: typo * fix: fix bool env passing --- flowsettings.py | 16 +- .../kotaemon/embeddings/langchain_based.py | 2 +- libs/kotaemon/kotaemon/indices/qa/citation.py | 101 ++--- .../kotaemon/indices/rankings/cohere.py | 21 +- libs/kotaemon/kotaemon/llms/__init__.py | 2 + libs/kotaemon/kotaemon/llms/chats/__init__.py | 2 + .../kotaemon/llms/chats/langchain_based.py | 68 +++- libs/kotaemon/kotaemon/llms/chats/openai.py | 6 + libs/kotaemon/kotaemon/loaders/txt_loader.py | 2 +- .../kotaemon/storages/vectorstores/base.py | 23 +- .../kotaemon/storages/vectorstores/milvus.py | 23 +- .../kotaemon/storages/vectorstores/qdrant.py | 23 +- libs/kotaemon/pyproject.toml | 12 +- libs/kotaemon/tests/test_embedding_models.py | 2 +- libs/ktem/ktem/embeddings/ui.py | 2 +- libs/ktem/ktem/index/file/pipelines.py | 2 +- libs/ktem/ktem/index/file/ui.py | 1 - libs/ktem/ktem/llms/manager.py | 2 + libs/ktem/ktem/llms/ui.py | 2 +- libs/ktem/ktem/main.py | 49 ++- libs/ktem/ktem/pages/chat/__init__.py | 3 +- libs/ktem/ktem/pages/chat/chat_panel.py | 4 +- libs/ktem/ktem/pages/setup.py | 347 ++++++++++++++++++ .../prompt_optimization/decompose_question.py | 1 + .../prompt_optimization/rewrite_question.py | 1 + libs/ktem/ktem/reasoning/simple.py | 49 ++- libs/ktem/ktem/utils/render.py | 17 +- 27 files changed, 643 insertions(+), 140 deletions(-) create mode 100644 libs/ktem/ktem/pages/setup.py diff --git a/flowsettings.py b/flowsettings.py index b4f9d76..b976ab2 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -24,9 +24,13 @@ if not KH_APP_VERSION: except Exception: KH_APP_VERSION = "local" +KH_ENABLE_FIRST_SETUP = True +KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool) + # App can be ran from anywhere and it's not trivial to decide where to store app data. # So let's use the same directory as the flowsetting.py file. KH_APP_DATA_DIR = this_dir / "ktem_app_data" +KH_APP_DATA_EXISTS = KH_APP_DATA_DIR.exists() KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True) # User data directory @@ -59,7 +63,9 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface") KH_DOC_DIR = this_dir / "docs" KH_MODE = "dev" -KH_FEATURE_USER_MANAGEMENT = True +KH_FEATURE_USER_MANAGEMENT = config( + "KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool +) KH_USER_CAN_SEE_PUBLIC = None KH_FEATURE_USER_MANAGEMENT_ADMIN = str( config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin") @@ -202,6 +208,14 @@ KH_LLMS["groq"] = { }, "default": False, } +KH_LLMS["cohere"] = { + "spec": { + "__type__": "kotaemon.llms.chats.LCCohereChat", + "model_name": "command-r-plus-08-2024", + "api_key": "your-key", + }, + "default": False, +} # additional embeddings configurations KH_EMBEDDINGS["cohere"] = { diff --git a/libs/kotaemon/kotaemon/embeddings/langchain_based.py b/libs/kotaemon/kotaemon/embeddings/langchain_based.py index d415b2e..03ff9c6 100644 --- a/libs/kotaemon/kotaemon/embeddings/langchain_based.py +++ b/libs/kotaemon/kotaemon/embeddings/langchain_based.py @@ -183,7 +183,7 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings): def _get_lc_class(self): try: - from langchain_community.embeddings import CohereEmbeddings + from langchain_cohere import CohereEmbeddings except ImportError: from langchain.embeddings import CohereEmbeddings diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py index 30eceaa..6358760 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation.py @@ -1,4 +1,4 @@ -from typing import Iterator, List +from typing import List from pydantic import BaseModel, Field @@ -7,53 +7,14 @@ from kotaemon.base.schema import HumanMessage, SystemMessage from kotaemon.llms import BaseLLM -class FactWithEvidence(BaseModel): - """Class representing a single statement. +class CiteEvidence(BaseModel): + """List of evidences (maximum 5) to support the answer.""" - Each fact has a body and a list of sources. - If there are multiple facts make sure to break them apart - such that each one only uses a set of sources that are relevant to it. - """ - - fact: str = Field(..., description="Body of the sentence, as part of a response") - substring_quote: List[str] = Field( + evidences: List[str] = Field( ..., description=( "Each source should be a direct quote from the context, " - "as a substring of the original content" - ), - ) - - def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]: - import regex - - minor = quote - major = context - - errs_ = 0 - s = regex.search(f"({minor}){{e<={errs_}}}", major) - while s is None and errs_ <= errs: - errs_ += 1 - s = regex.search(f"({minor}){{e<={errs_}}}", major) - - if s is not None: - yield from s.spans() - - def get_spans(self, context: str) -> Iterator[str]: - for quote in self.substring_quote: - yield from self._get_span(quote, context) - - -class QuestionAnswer(BaseModel): - """A question and its answer as a list of facts each one should have a source. - each sentence contains a body and a list of sources.""" - - question: str = Field(..., description="Question that was asked") - answer: List[FactWithEvidence] = Field( - ..., - description=( - "Body of the answer, each fact should be " - "its separate object with a body and a list of sources" + "as a substring of the original content (max 15 words)." ), ) @@ -68,7 +29,7 @@ class CitationPipeline(BaseComponent): return self.invoke(context, question) def prepare_llm(self, context: str, question: str): - schema = QuestionAnswer.schema() + schema = CiteEvidence.schema() function = { "name": schema["title"], "description": schema["description"], @@ -76,7 +37,8 @@ class CitationPipeline(BaseComponent): } llm_kwargs = { "tools": [{"type": "function", "function": function}], - "tool_choice": "auto", + "tool_choice": "required", + "tools_pydantic": [CiteEvidence], } messages = [ SystemMessage( @@ -85,7 +47,12 @@ class CitationPipeline(BaseComponent): "questions with correct and exact citations." ) ), - HumanMessage(content="Answer question using the following context"), + HumanMessage( + content=( + "Answer question using the following context. " + "Use the provided function CiteEvidence() to cite your sources." + ) + ), HumanMessage(content=context), HumanMessage(content=f"Question: {question}"), HumanMessage( @@ -103,14 +70,24 @@ class CitationPipeline(BaseComponent): print("CitationPipeline: invoking LLM") llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) print("CitationPipeline: finish invoking LLM") - if not llm_output.messages or not llm_output.additional_kwargs.get( - "tool_calls" - ): + if not llm_output.additional_kwargs.get("tool_calls"): return None - function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][ - "arguments" - ] - output = QuestionAnswer.parse_raw(function_output) + + first_func = llm_output.additional_kwargs["tool_calls"][0] + + if "function" in first_func: + # openai and cohere format + function_output = first_func["function"]["arguments"] + else: + # anthropic format + function_output = first_func["args"] + + print("CitationPipeline:", function_output) + + if isinstance(function_output, str): + output = CiteEvidence.parse_raw(function_output) + else: + output = CiteEvidence.parse_obj(function_output) except Exception as e: print(e) return None @@ -118,18 +95,4 @@ class CitationPipeline(BaseComponent): return output async def ainvoke(self, context: str, question: str): - messages, llm_kwargs = self.prepare_llm(context, question) - - try: - print("CitationPipeline: async invoking LLM") - llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs) - print("CitationPipeline: finish async invoking LLM") - function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][ - "arguments" - ] - output = QuestionAnswer.parse_raw(function_output) - except Exception as e: - print(e) - return None - - return output + raise NotImplementedError() diff --git a/libs/kotaemon/kotaemon/indices/rankings/cohere.py b/libs/kotaemon/kotaemon/indices/rankings/cohere.py index e759d6c..b4ce97e 100644 --- a/libs/kotaemon/kotaemon/indices/rankings/cohere.py +++ b/libs/kotaemon/kotaemon/indices/rankings/cohere.py @@ -10,6 +10,7 @@ from .base import BaseReranking class CohereReranking(BaseReranking): model_name: str = "rerank-multilingual-v2.0" cohere_api_key: str = config("COHERE_API_KEY", "") + use_key_from_ktem: bool = False def run(self, documents: list[Document], query: str) -> list[Document]: """Use Cohere Reranker model to re-order documents @@ -18,9 +19,25 @@ class CohereReranking(BaseReranking): import cohere except ImportError: raise ImportError( - "Please install Cohere " "`pip install cohere` to use Cohere Reranking" + "Please install Cohere `pip install cohere` to use Cohere Reranking" ) + # try to get COHERE_API_KEY from embeddings + if not self.cohere_api_key and self.use_key_from_ktem: + try: + from ktem.embeddings.manager import ( + embedding_models_manager as embeddings, + ) + + cohere_model = embeddings.get("cohere") + ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore + "cohere_api_key" + ) + if ktem_cohere_api_key != "your-key": + self.cohere_api_key = ktem_cohere_api_key + except Exception as e: + print("Cannot get Cohere API key from `ktem`", e) + if not self.cohere_api_key: print("Cohere API key not found. Skipping reranking.") return documents @@ -35,7 +52,7 @@ class CohereReranking(BaseReranking): response = cohere_client.rerank( model=self.model_name, query=query, documents=_docs ) - print("Cohere score", [r.relevance_score for r in response.results]) + # print("Cohere score", [r.relevance_score for r in response.results]) for r in response.results: doc = documents[r.index] doc.metadata["cohere_reranking_score"] = r.relevance_score diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py index 03d517d..d48e418 100644 --- a/libs/kotaemon/kotaemon/llms/__init__.py +++ b/libs/kotaemon/kotaemon/llms/__init__.py @@ -10,6 +10,7 @@ from .chats import ( LCAnthropicChat, LCAzureChatOpenAI, LCChatOpenAI, + LCCohereChat, LCGeminiChat, LlamaCppChat, ) @@ -31,6 +32,7 @@ __all__ = [ "ChatOpenAI", "LCAnthropicChat", "LCGeminiChat", + "LCCohereChat", "LCAzureChatOpenAI", "LCChatOpenAI", "LlamaCppChat", diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py index 7f82c9a..5585fba 100644 --- a/libs/kotaemon/kotaemon/llms/chats/__init__.py +++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py @@ -5,6 +5,7 @@ from .langchain_based import ( LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI, + LCCohereChat, LCGeminiChat, ) from .llamacpp import LlamaCppChat @@ -18,6 +19,7 @@ __all__ = [ "ChatOpenAI", "LCAnthropicChat", "LCGeminiChat", + "LCCohereChat", "LCChatOpenAI", "LCAzureChatOpenAI", "LCChatMixin", diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index a2d3409..e28116f 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -18,6 +18,9 @@ class LCChatMixin: "Please return the relevant Langchain class in in _get_lc_class" ) + def _get_tool_call_kwargs(self): + return {} + def __init__(self, stream: bool = False, **params): self._lc_class = self._get_lc_class() self._obj = self._lc_class(**params) @@ -56,9 +59,7 @@ class LCChatMixin: total_tokens = pred.llm_output["token_usage"]["total_tokens"] prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"] except Exception: - logger.warning( - f"Cannot get token usage from LLM output for {self._lc_class.__name__}" - ) + pass return LLMInterface( text=all_text[0] if len(all_text) > 0 else "", @@ -83,8 +84,30 @@ class LCChatMixin: LLMInterface: generated response """ input_ = self.prepare_message(messages) - pred = self._obj.generate(messages=[input_], **kwargs) - return self.prepare_response(pred) + + if "tools_pydantic" in kwargs: + tools = kwargs.pop( + "tools_pydantic", + ) + lc_tool_call = self._obj.bind_tools(tools) + pred = lc_tool_call.invoke( + input_, + **self._get_tool_call_kwargs(), + ) + if pred.tool_calls: + tool_calls = pred.tool_calls + else: + tool_calls = pred.additional_kwargs.get("tool_calls", []) + + output = LLMInterface( + content="", + additional_kwargs={"tool_calls": tool_calls}, + ) + else: + pred = self._obj.generate(messages=[input_], **kwargs) + output = self.prepare_response(pred) + + return output async def ainvoke( self, messages: str | BaseMessage | list[BaseMessage], **kwargs @@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore required=True, ) + def _get_tool_call_kwargs(self): + return {"tool_choice": {"type": "any"}} + def __init__( self, api_key: str | None = None, @@ -291,3 +317,35 @@ class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore raise ImportError("Please install langchain-google-genai") return ChatGoogleGenerativeAI + + +class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore + api_key: str = Param( + help="API key (https://dashboard.cohere.com/api-keys)", required=True + ) + model_name: str = Param( + help=("Model name to use (https://dashboard.cohere.com/playground/chat)"), + required=True, + ) + + def __init__( + self, + api_key: str | None = None, + model_name: str | None = None, + temperature: float = 0.7, + **params, + ): + super().__init__( + cohere_api_key=api_key, + model_name=model_name, + temperature=temperature, + **params, + ) + + def _get_lc_class(self): + try: + from langchain_cohere import ChatCohere + except ImportError: + raise ImportError("Please install langchain-cohere") + + return ChatCohere diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py index 6a605f6..b1ae872 100644 --- a/libs/kotaemon/kotaemon/llms/chats/openai.py +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -292,6 +292,9 @@ class ChatOpenAI(BaseChatOpenAI): def openai_response(self, client, **kwargs): """Get the openai response""" + if "tools_pydantic" in kwargs: + kwargs.pop("tools_pydantic") + params_ = { "model": self.model, "temperature": self.temperature, @@ -360,6 +363,9 @@ class AzureChatOpenAI(BaseChatOpenAI): def openai_response(self, client, **kwargs): """Get the openai response""" + if "tools_pydantic" in kwargs: + kwargs.pop("tools_pydantic") + params_ = { "model": self.azure_deployment, "temperature": self.temperature, diff --git a/libs/kotaemon/kotaemon/loaders/txt_loader.py b/libs/kotaemon/kotaemon/loaders/txt_loader.py index 6484029..5a0d48c 100644 --- a/libs/kotaemon/kotaemon/loaders/txt_loader.py +++ b/libs/kotaemon/kotaemon/loaders/txt_loader.py @@ -15,7 +15,7 @@ class TxtReader(BaseReader): def load_data( self, file_path: Path, extra_info: Optional[dict] = None, **kwargs ) -> list[Document]: - with open(file_path, "r") as f: + with open(file_path, "r", encoding="utf-8") as f: text = f.read() metadata = extra_info or {} diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/base.py b/libs/kotaemon/kotaemon/storages/vectorstores/base.py index e6f2518..7d13c9d 100644 --- a/libs/kotaemon/kotaemon/storages/vectorstores/base.py +++ b/libs/kotaemon/kotaemon/storages/vectorstores/base.py @@ -73,17 +73,25 @@ class BaseVectorStore(ABC): class LlamaIndexVectorStore(BaseVectorStore): - _li_class: type[LIVectorStore | BasePydanticVectorStore] + """Mixin for LlamaIndex based vectorstores""" + + _li_class: type[LIVectorStore | BasePydanticVectorStore] | None + + def _get_li_class(self): + raise NotImplementedError( + "Please return the relevant LlamaIndex class in in _get_li_class" + ) def __init__(self, *args, **kwargs): - if self._li_class is None: - raise AttributeError( - "Require `_li_class` to set a VectorStore class from LlamarIndex" - ) + # get li_class from the method if not set + if not self._li_class: + LIClass = self._get_li_class() + else: + LIClass = self._li_class from dataclasses import fields - self._client = self._li_class(*args, **kwargs) + self._client = LIClass(*args, **kwargs) self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)} for key in ["query_embedding", "similarity_top_k", "node_ids"]: @@ -97,6 +105,9 @@ class LlamaIndexVectorStore(BaseVectorStore): return setattr(self._client, name, value) def __getattr__(self, name: str) -> Any: + if name == "_li_class": + return super().__getattribute__(name) + return getattr(self._client, name) def add( diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py b/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py index 974200d..1725743 100644 --- a/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py +++ b/libs/kotaemon/kotaemon/storages/vectorstores/milvus.py @@ -1,7 +1,5 @@ import os -from typing import Any, Optional, Type, cast - -from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore +from typing import Any, Optional, cast from kotaemon.base import DocumentWithEmbedding @@ -9,7 +7,20 @@ from .base import LlamaIndexVectorStore class MilvusVectorStore(LlamaIndexVectorStore): - _li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore + _li_class = None + + def _get_li_class(self): + try: + from llama_index.vector_stores.milvus import ( + MilvusVectorStore as LIMilvusVectorStore, + ) + except ImportError: + raise ImportError( + "Please install missing package: " + "'pip install llama-index-vector-stores-milvus'" + ) + + return LIMilvusVectorStore def __init__( self, @@ -46,6 +57,10 @@ class MilvusVectorStore(LlamaIndexVectorStore): dim=dim, **self._kwargs, ) + from llama_index.vector_stores.milvus import ( + MilvusVectorStore as LIMilvusVectorStore, + ) + self._client = cast(LIMilvusVectorStore, self._client) self._inited = True diff --git a/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py b/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py index f3b421c..ea9811a 100644 --- a/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py +++ b/libs/kotaemon/kotaemon/storages/vectorstores/qdrant.py @@ -1,12 +1,23 @@ -from typing import Any, List, Optional, Type, cast - -from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore +from typing import Any, List, Optional, cast from .base import LlamaIndexVectorStore class QdrantVectorStore(LlamaIndexVectorStore): - _li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore + _li_class = None + + def _get_li_class(self): + try: + from llama_index.vector_stores.qdrant import ( + QdrantVectorStore as LIQdrantVectorStore, + ) + except ImportError: + raise ImportError( + "Please install missing package: " + "'pip install llama-index-vector-stores-qdrant'" + ) + + return LIQdrantVectorStore def __init__( self, @@ -29,6 +40,10 @@ class QdrantVectorStore(LlamaIndexVectorStore): client_kwargs=client_kwargs, **kwargs, ) + from llama_index.vector_stores.qdrant import ( + QdrantVectorStore as LIQdrantVectorStore, + ) + self._client = cast(LIQdrantVectorStore, self._client) def delete(self, ids: List[str], **kwargs): diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index d39395e..fb3fbcf 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -30,16 +30,15 @@ dependencies = [ "fastapi<=0.112.1", "gradio>=4.31.0,<4.40", "html2text==2024.2.26", - "langchain>=0.1.16,<0.2.0", - "langchain-anthropic", - "langchain-community>=0.0.34,<0.1.0", + "langchain>=0.1.16,<0.2.16", + "langchain-community>=0.0.34,<=0.2.11", "langchain-openai>=0.1.4,<0.2.0", + "langchain-anthropic", + "langchain-cohere>=0.2.4,<0.3.0", "llama-hub>=0.0.79,<0.1.0", "llama-index>=0.10.40,<0.11.0", "llama-index-vector-stores-chroma>=0.1.9", "llama-index-vector-stores-lancedb", - "llama-index-vector-stores-milvus", - "llama-index-vector-stores-qdrant", "openai>=1.23.6,<2", "openpyxl>=3.1.2,<3.2", "opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571 @@ -75,6 +74,9 @@ adv = [ "llama-cpp-python<0.2.8", "sentence-transformers", "wikipedia>=1.4.0,<1.5", + "llama-index>=0.10.40,<0.11.0", + "llama-index-vector-stores-milvus", + "llama-index-vector-stores-qdrant", ] dev = [ "black", diff --git a/libs/kotaemon/tests/test_embedding_models.py b/libs/kotaemon/tests/test_embedding_models.py index 93d3cc5..c365a8b 100644 --- a/libs/kotaemon/tests/test_embedding_models.py +++ b/libs/kotaemon/tests/test_embedding_models.py @@ -135,7 +135,7 @@ def test_lchuggingface_embeddings( @skip_when_cohere_not_installed @patch( - "langchain.embeddings.cohere.CohereEmbeddings.embed_documents", + "langchain_cohere.CohereEmbeddings.embed_documents", side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], ) def test_lccohere_embeddings(langchain_cohere_embedding_call): diff --git a/libs/ktem/ktem/embeddings/ui.py b/libs/ktem/ktem/embeddings/ui.py index c464005..c97e75b 100644 --- a/libs/ktem/ktem/embeddings/ui.py +++ b/libs/ktem/ktem/embeddings/ui.py @@ -354,7 +354,7 @@ class EmbeddingManagement(BasePage): _ = emb("Hi") log_content += ( - "- Connection success. " + "- Connection success. " "
" ) yield log_content diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index d664e7f..598064a 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -285,7 +285,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): ], retrieval_mode=user_settings["retrieval_mode"], llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None), - rerankers=[CohereReranking()], + rerankers=[CohereReranking(use_key_from_ktem=True)], ) if not user_settings["use_reranking"]: retriever.rerankers = [] # type: ignore diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 3315a22..eeac953 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -828,7 +828,6 @@ class FileIndexPage(BasePage): ] ) - print(f"{len(results)=}, {len(file_list)=}") return results, file_list def interact_file_list(self, list_files, ev: gr.SelectData): diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index 3ac90a5..829bcae 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -58,6 +58,7 @@ class LLMManager: AzureChatOpenAI, ChatOpenAI, LCAnthropicChat, + LCCohereChat, LCGeminiChat, LlamaCppChat, ) @@ -67,6 +68,7 @@ class LLMManager: AzureChatOpenAI, LCAnthropicChat, LCGeminiChat, + LCCohereChat, LlamaCppChat, ] diff --git a/libs/ktem/ktem/llms/ui.py b/libs/ktem/ktem/llms/ui.py index e7cd6a9..d29babb 100644 --- a/libs/ktem/ktem/llms/ui.py +++ b/libs/ktem/ktem/llms/ui.py @@ -353,7 +353,7 @@ class LLMManagement(BasePage): respond = llm("Hi") log_content += ( - f"- Connection success. " + f"- Connection success. " f"Got response:\n {respond}
" ) yield log_content diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py index ba305e6..00d23f2 100644 --- a/libs/ktem/ktem/main.py +++ b/libs/ktem/ktem/main.py @@ -1,9 +1,27 @@ import gradio as gr +from decouple import config from ktem.app import BaseApp from ktem.pages.chat import ChatPage from ktem.pages.help import HelpPage from ktem.pages.resources import ResourcesTab from ktem.pages.settings import SettingsPage +from ktem.pages.setup import SetupPage +from theflow.settings import settings as flowsettings + +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False) +KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True) + +# override first setup setting +if config("KH_FIRST_SETUP", default=False, cast=bool): + KH_APP_DATA_EXISTS = False + + +def toggle_first_setup_visibility(): + global KH_APP_DATA_EXISTS + is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS + KH_APP_DATA_EXISTS = True + return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup) class App(BaseApp): @@ -99,13 +117,17 @@ class App(BaseApp): ) as self._tabs["help-tab"]: self.help_page = HelpPage(self) + if KH_ENABLE_FIRST_SETUP: + with gr.Column(visible=False) as self.setup_page_wrapper: + self.setup_page = SetupPage(self) + def on_subscribe_public_events(self): if self.f_user_management: from ktem.db.engine import engine from ktem.db.models import User from sqlmodel import Session, select - def signed_in_out(user_id): + def toggle_login_visibility(user_id): if not user_id: return list( ( @@ -146,7 +168,7 @@ class App(BaseApp): self.subscribe_event( name="onSignIn", definition={ - "fn": signed_in_out, + "fn": toggle_login_visibility, "inputs": [self.user_id], "outputs": list(self._tabs.values()) + [self.tabs], "show_progress": "hidden", @@ -156,9 +178,30 @@ class App(BaseApp): self.subscribe_event( name="onSignOut", definition={ - "fn": signed_in_out, + "fn": toggle_login_visibility, "inputs": [self.user_id], "outputs": list(self._tabs.values()) + [self.tabs], "show_progress": "hidden", }, ) + + if KH_ENABLE_FIRST_SETUP: + self.subscribe_event( + name="onFirstSetupComplete", + definition={ + "fn": toggle_first_setup_visibility, + "inputs": [], + "outputs": [self.setup_page_wrapper, self.tabs], + "show_progress": "hidden", + }, + ) + + def _on_app_created(self): + """Called when the app is created""" + + if KH_ENABLE_FIRST_SETUP: + self.app.load( + toggle_first_setup_visibility, + inputs=[], + outputs=[self.setup_page_wrapper, self.tabs], + ) diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index a21c5f2..248b17e 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -883,7 +883,8 @@ class ChatPage(BasePage): # check if this is a newly created conversation if len(chat_history) == 1: - suggested_name = suggest_pipeline(chat_history).text[:40] + suggested_name = suggest_pipeline(chat_history).text + suggested_name = suggested_name.replace('"', "").replace("'", "")[:40] new_name = gr.update(value=suggested_name) renamed = True diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index 80700b0..51258d0 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -11,8 +11,8 @@ class ChatPanel(BasePage): self.chatbot = gr.Chatbot( label=self._app.app_name, placeholder=( - "This is the beginning of a new conversation.\nMake sure to have added" - " a LLM by following the instructions in the Help tab." + "This is the beginning of a new conversation.\nIf you are new, " + "visit the Help tab for quick instructions." ), show_label=False, elem_id="main-chat-bot", diff --git a/libs/ktem/ktem/pages/setup.py b/libs/ktem/ktem/pages/setup.py new file mode 100644 index 0000000..edd86f8 --- /dev/null +++ b/libs/ktem/ktem/pages/setup.py @@ -0,0 +1,347 @@ +import json + +import gradio as gr +import requests +from ktem.app import BasePage +from ktem.embeddings.manager import embedding_models_manager as embeddings +from ktem.llms.manager import llms +from theflow.settings import settings as flowsettings + +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +DEFAULT_OLLAMA_URL = "http://localhost:11434/api" + + +DEMO_MESSAGE = ( + "This is a public space. Please use the " + '"Duplicate Space" function on the top right ' + "corner to setup your own space." +) + + +def pull_model(name: str, stream: bool = True): + payload = {"name": name} + headers = {"Content-Type": "application/json"} + + response = requests.post( + DEFAULT_OLLAMA_URL + "/pull", json=payload, headers=headers, stream=stream + ) + + # Check if the request was successful + response.raise_for_status() + + if stream: + for line in response.iter_lines(): + if line: + data = json.loads(line.decode("utf-8")) + yield data + if data.get("status") == "success": + break + else: + data = response.json() + + return data + + +class SetupPage(BasePage): + + public_events = ["onFirstSetupComplete"] + + def __init__(self, app): + self._app = app + self.on_building_ui() + + def on_building_ui(self): + gr.Markdown(f"# Welcome to {self._app.app_name} first setup!") + self.radio_model = gr.Radio( + [ + ("Cohere API (*free registration* available) - recommended", "cohere"), + ("OpenAI API (for more advance models)", "openai"), + ("Local LLM (for completely *private RAG*)", "ollama"), + ], + label="Select your model provider", + value="cohere", + info=( + "Note: You can change this later. " + "If you are not sure, go with the first option " + "which fits most normal users." + ), + interactive=True, + ) + + with gr.Column(visible=False) as self.openai_option: + gr.Markdown( + ( + "#### OpenAI API Key\n\n" + "(create at https://platform.openai.com/api-keys)" + ) + ) + self.openai_api_key = gr.Textbox( + show_label=False, placeholder="OpenAI API Key" + ) + + with gr.Column(visible=True) as self.cohere_option: + gr.Markdown( + ( + "#### Cohere API Key\n\n" + "(register your free API key " + "at https://dashboard.cohere.com/api-keys)" + ) + ) + self.cohere_api_key = gr.Textbox( + show_label=False, placeholder="Cohere API Key" + ) + + with gr.Column(visible=False) as self.ollama_option: + gr.Markdown( + ( + "#### Setup Ollama\n\n" + "Download and install Ollama from " + "https://ollama.com/" + ) + ) + + self.setup_log = gr.HTML( + show_label=False, + ) + + with gr.Row(): + self.btn_finish = gr.Button("Proceed", variant="primary") + self.btn_skip = gr.Button( + "I am an advance user. Skip this.", variant="stop" + ) + + def on_register_events(self): + onFirstSetupComplete = gr.on( + triggers=[ + self.btn_finish.click, + self.cohere_api_key.submit, + self.openai_api_key.submit, + ], + fn=self.update_model, + inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model], + outputs=[self.setup_log], + show_progress="hidden", + ) + if not KH_DEMO_MODE: + onSkipSetup = gr.on( + triggers=[self.btn_skip.click], + fn=lambda: None, + inputs=[], + show_progress="hidden", + outputs=[self.radio_model], + ) + + for event in self._app.get_event("onFirstSetupComplete"): + onSkipSetup = onSkipSetup.success(**event) + + onFirstSetupComplete = onFirstSetupComplete.success( + fn=self.update_default_settings, + inputs=[self.radio_model, self._app.settings_state], + outputs=self._app.settings_state, + ) + for event in self._app.get_event("onFirstSetupComplete"): + onFirstSetupComplete = onFirstSetupComplete.success(**event) + + self.radio_model.change( + fn=self.switch_options_view, + inputs=[self.radio_model], + show_progress="hidden", + outputs=[self.cohere_option, self.openai_option, self.ollama_option], + ) + + def update_model( + self, + cohere_api_key, + openai_api_key, + radio_model_value, + ): + # skip if KH_DEMO_MODE + if KH_DEMO_MODE: + raise gr.Error(DEMO_MESSAGE) + + log_content = "" + if not radio_model_value: + gr.Info("Skip setup models.") + yield gr.value(visible=False) + return + + if radio_model_value == "cohere": + if cohere_api_key: + llms.update( + name="cohere", + spec={ + "__type__": "kotaemon.llms.chats.LCCohereChat", + "model_name": "command-r-plus-08-2024", + "api_key": cohere_api_key, + }, + default=True, + ) + embeddings.update( + name="cohere", + spec={ + "__type__": "kotaemon.embeddings.LCCohereEmbeddings", + "model": "embed-multilingual-v2.0", + "cohere_api_key": cohere_api_key, + "user_agent": "default", + }, + default=True, + ) + elif radio_model_value == "openai": + if openai_api_key: + llms.update( + name="openai", + spec={ + "__type__": "kotaemon.llms.ChatOpenAI", + "base_url": "https://api.openai.com/v1", + "model": "gpt-4o", + "api_key": openai_api_key, + "timeout": 20, + }, + default=True, + ) + embeddings.update( + name="openai", + spec={ + "__type__": "kotaemon.embeddings.OpenAIEmbeddings", + "base_url": "https://api.openai.com/v1", + "model": "text-embedding-3-large", + "api_key": openai_api_key, + "timeout": 10, + "context_length": 8191, + }, + default=True, + ) + elif radio_model_value == "ollama": + llms.update( + name="ollama", + spec={ + "__type__": "kotaemon.llms.ChatOpenAI", + "base_url": "http://localhost:11434/v1/", + "model": "llama3.1:8b", + "api_key": "ollama", + }, + default=True, + ) + embeddings.update( + name="ollama", + spec={ + "__type__": "kotaemon.embeddings.OpenAIEmbeddings", + "base_url": "http://localhost:11434/v1/", + "model": "nomic-embed-text", + "api_key": "ollama", + }, + default=True, + ) + + # download required models through ollama + llm_model_name = llms.get("ollama").model # type: ignore + emb_model_name = embeddings.get("ollama").model # type: ignore + + try: + for model_name in [emb_model_name, llm_model_name]: + log_content += f"- Downloading model `{model_name}` from Ollama
" + yield log_content + + pre_download_log = log_content + + for response in pull_model(model_name): + complete = response.get("completed", 0) + total = response.get("total", 0) + if complete > 0 and total > 0: + ratio = int(complete / total * 100) + log_content = ( + pre_download_log + + f"- {response.get('status')}: {ratio}%
" + ) + else: + if "pulling" not in response.get("status", ""): + log_content += f"- {response.get('status')}
" + + yield log_content + except Exception as e: + log_content += ( + "Make sure you have download and installed Ollama correctly." + f"Got error: {str(e)}" + ) + yield log_content + raise gr.Error("Failed to download model from Ollama.") + + # test models connection + llm_output = emb_output = None + + # LLM model + log_content += f"- Testing LLM model: {radio_model_value}
" + yield log_content + + llm = llms.get(radio_model_value) # type: ignore + log_content += "- Sending a message `Hi`
" + yield log_content + try: + llm_output = llm("Hi") + except Exception as e: + log_content += ( + f"- Connection failed. " + f"Got error:\n {str(e)}" + ) + + if llm_output: + log_content += ( + "- Connection success. " + "
" + ) + yield log_content + + if llm_output: + # embedding model + log_content += f"- Testing Embedding model: {radio_model_value}
" + yield log_content + + emb = embeddings.get(radio_model_value) + assert emb, f"Embedding model {radio_model_value} not found." + + log_content += "- Sending a message `Hi`
" + yield log_content + try: + emb_output = emb("Hi") + except Exception as e: + log_content += ( + f"" + "- Connection failed. " + f"Got error:\n {str(e)}" + ) + + if emb_output: + log_content += ( + "" + "- Connection success. " + "
" + ) + yield log_content + + if llm_output and emb_output: + gr.Info("Setup models completed successfully!") + else: + raise gr.Error( + "Setup models failed. Please verify your connection and API key." + ) + + def update_default_settings(self, radio_model_value, default_settings): + # revise default settings + # reranking llm + default_settings["index.options.1.reranking_llm"] = radio_model_value + if radio_model_value == "ollama": + default_settings["index.options.1.use_llm_reranking"] = False + + return default_settings + + def switch_options_view(self, radio_model_value): + components_visible = [gr.update(visible=False) for _ in range(3)] + + values = ["cohere", "openai", "ollama", None] + assert radio_model_value in values, f"Invalid value {radio_model_value}" + + if radio_model_value is not None: + idx = values.index(radio_model_value) + components_visible[idx] = gr.update(visible=True) + + return components_visible diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py b/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py index 7fdc473..8108b3c 100644 --- a/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py +++ b/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py @@ -52,6 +52,7 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline): llm_kwargs = { "tools": [{"type": "function", "function": function}], "tool_choice": "auto", + "tools_pydantic": [SubQuery], } messages = [ diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/rewrite_question.py b/libs/ktem/ktem/reasoning/prompt_optimization/rewrite_question.py index 3891f54..2ec8788 100644 --- a/libs/ktem/ktem/reasoning/prompt_optimization/rewrite_question.py +++ b/libs/ktem/ktem/reasoning/prompt_optimization/rewrite_question.py @@ -7,6 +7,7 @@ DEFAULT_REWRITE_PROMPT = ( "Given the following question, rephrase and expand it " "to help you do better answering. Maintain all information " "in the original question. Keep the question as concise as possible. " + "Only output the rephrased question without additional information. " "Give answer in {lang}\n" "Original question: {question}\n" "Rephrased question: " diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 8244690..e6d84ca 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -39,10 +39,13 @@ EVIDENCE_MODE_TABLE = 1 EVIDENCE_MODE_CHATBOT = 2 EVIDENCE_MODE_FIGURE = 3 MAX_IMAGES = 10 +CITATION_TIMEOUT = 5.0 def find_text(search_span, context): sentence_list = search_span.split("\n") + context = context.replace("\n", " ") + matches = [] # don't search for small text if len(search_span) > 5: @@ -50,7 +53,7 @@ def find_text(search_span, context): match = SequenceMatcher( None, sentence, context, autojunk=False ).find_longest_match() - if match.size > len(sentence) * 0.35: + if match.size > max(len(sentence) * 0.35, 5): matches.append((match.b, match.b + match.size)) return matches @@ -200,15 +203,6 @@ DEFAULT_QA_FIGURE_PROMPT = ( "Answer: " ) # noqa -DEFAULT_REWRITE_PROMPT = ( - "Given the following question, rephrase and expand it " - "to help you do better answering. Maintain all information " - "in the original question. Keep the question as concise as possible. " - "Give answer in {lang}\n" - "Original question: {question}\n" - "Rephrased question: " -) # noqa - CONTEXT_RELEVANT_WARNING_SCORE = 0.7 @@ -391,7 +385,8 @@ class AnswerWithContextPipeline(BaseComponent): qa_score = None if citation_thread: - citation_thread.join() + citation_thread.join(timeout=CITATION_TIMEOUT) + answer = Document( text=output, metadata={"citation": citation, "qa_score": qa_score}, @@ -525,24 +520,24 @@ class FullQAPipeline(BaseReasoning): spans = defaultdict(list) has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs) - if answer.metadata["citation"] and answer.metadata["citation"].answer: - for fact_with_evidence in answer.metadata["citation"].answer: - for quote in fact_with_evidence.substring_quote: - matched_excerpts = [] - for doc in docs: - matches = find_text(quote, doc.text) + if answer.metadata["citation"]: + evidences = answer.metadata["citation"].evidences + for quote in evidences: + matched_excerpts = [] + for doc in docs: + matches = find_text(quote, doc.text) - for start, end in matches: - if "|" not in doc.text[start:end]: - spans[doc.doc_id].append( - { - "start": start, - "end": end, - } - ) - matched_excerpts.append(doc.text[start:end]) + for start, end in matches: + if "|" not in doc.text[start:end]: + spans[doc.doc_id].append( + { + "start": start, + "end": end, + } + ) + matched_excerpts.append(doc.text[start:end]) - print("Matched citation:", quote, matched_excerpts), + # print("Matched citation:", quote, matched_excerpts), id2docs = {doc.doc_id: doc for doc in docs} not_detected = set(id2docs.keys()) - set(spans.keys()) diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py index b1695aa..9176627 100644 --- a/libs/ktem/ktem/utils/render.py +++ b/libs/ktem/ktem/utils/render.py @@ -75,7 +75,6 @@ class Render: if not highlight_text: try: lang = detect(text.replace("\n", " "))["lang"] - print("lang", lang) if lang not in ["ja", "cn"]: highlight_words = [ t[:-1] if t.endswith("-") else t for t in text.split("\n") @@ -83,10 +82,13 @@ class Render: highlight_text = highlight_words[0] phrase = "true" else: - highlight_text = text.replace("\n", "") phrase = "false" - print("highlight_text", highlight_text, phrase) + highlight_text = ( + text.replace("\n", "").replace('"', "").replace("'", "") + ) + + # print("highlight_text", highlight_text, phrase, lang) except Exception as e: print(e) highlight_text = text @@ -162,8 +164,15 @@ class Render: if item_type_prefix: item_type_prefix += " from " + if llm_reranking_score > 0: + relevant_score = llm_reranking_score + elif cohere_reranking_score > 0: + relevant_score = cohere_reranking_score + else: + relevant_score = 0.0 + rendered_score = Render.collapsible( - header=f" Relevance score: {llm_reranking_score}", + header=f" Relevance score: {relevant_score:.1f}", content="  Vectorstore score:" f" {vectorstore_score}" f"{text_search_str}"