feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)

* fix: utf-8 txt reader

* fix: revise vectorstore import and make it optional

* feat: add cohere chat model with tool call support

* fix: simplify citation pipeline

* fix: improve citation logic

* fix: improve decompose func call

* fix: revise question rewrite prompt

* fix: revise chat box default placeholder

* fix: add key from ktem to cohere rerank

* fix: conv name suggestion

* fix: ignore default key cohere rerank

* fix: improve test connection UI

* fix: reorder requirements

* feat: add first setup screen

* fix: update requirements

* fix: vectorstore tests

* fix: update cohere version

* fix: relax langchain core version

* fix: add demo mode

* fix: update flowsettings

* fix: typo

* fix: fix bool env passing
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2024-09-22 16:32:23 +07:00
committed by GitHub
parent 0bdb9a32f2
commit 88d577b0cc
27 changed files with 643 additions and 140 deletions

View File

@@ -354,7 +354,7 @@ class EmbeddingManagement(BasePage):
_ = emb("Hi")
log_content += (
"<mark style='background: yellow; color: red'>- Connection success. "
"<mark style='background: green; color: white'>- Connection success. "
"</mark><br>"
)
yield log_content

View File

@@ -285,7 +285,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
],
retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
rerankers=[CohereReranking()],
rerankers=[CohereReranking(use_key_from_ktem=True)],
)
if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore

View File

@@ -828,7 +828,6 @@ class FileIndexPage(BasePage):
]
)
print(f"{len(results)=}, {len(file_list)=}")
return results, file_list
def interact_file_list(self, list_files, ev: gr.SelectData):

View File

@@ -58,6 +58,7 @@ class LLMManager:
AzureChatOpenAI,
ChatOpenAI,
LCAnthropicChat,
LCCohereChat,
LCGeminiChat,
LlamaCppChat,
)
@@ -67,6 +68,7 @@ class LLMManager:
AzureChatOpenAI,
LCAnthropicChat,
LCGeminiChat,
LCCohereChat,
LlamaCppChat,
]

View File

@@ -353,7 +353,7 @@ class LLMManagement(BasePage):
respond = llm("Hi")
log_content += (
f"<mark style='background: yellow; color: red'>- Connection success. "
f"<mark style='background: green; color: white'>- Connection success. "
f"Got response:\n {respond}</mark><br>"
)
yield log_content

View File

@@ -1,9 +1,27 @@
import gradio as gr
from decouple import config
from ktem.app import BaseApp
from ktem.pages.chat import ChatPage
from ktem.pages.help import HelpPage
from ktem.pages.resources import ResourcesTab
from ktem.pages.settings import SettingsPage
from ktem.pages.setup import SetupPage
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False)
KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True)
# override first setup setting
if config("KH_FIRST_SETUP", default=False, cast=bool):
KH_APP_DATA_EXISTS = False
def toggle_first_setup_visibility():
global KH_APP_DATA_EXISTS
is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS
KH_APP_DATA_EXISTS = True
return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup)
class App(BaseApp):
@@ -99,13 +117,17 @@ class App(BaseApp):
) as self._tabs["help-tab"]:
self.help_page = HelpPage(self)
if KH_ENABLE_FIRST_SETUP:
with gr.Column(visible=False) as self.setup_page_wrapper:
self.setup_page = SetupPage(self)
def on_subscribe_public_events(self):
if self.f_user_management:
from ktem.db.engine import engine
from ktem.db.models import User
from sqlmodel import Session, select
def signed_in_out(user_id):
def toggle_login_visibility(user_id):
if not user_id:
return list(
(
@@ -146,7 +168,7 @@ class App(BaseApp):
self.subscribe_event(
name="onSignIn",
definition={
"fn": signed_in_out,
"fn": toggle_login_visibility,
"inputs": [self.user_id],
"outputs": list(self._tabs.values()) + [self.tabs],
"show_progress": "hidden",
@@ -156,9 +178,30 @@ class App(BaseApp):
self.subscribe_event(
name="onSignOut",
definition={
"fn": signed_in_out,
"fn": toggle_login_visibility,
"inputs": [self.user_id],
"outputs": list(self._tabs.values()) + [self.tabs],
"show_progress": "hidden",
},
)
if KH_ENABLE_FIRST_SETUP:
self.subscribe_event(
name="onFirstSetupComplete",
definition={
"fn": toggle_first_setup_visibility,
"inputs": [],
"outputs": [self.setup_page_wrapper, self.tabs],
"show_progress": "hidden",
},
)
def _on_app_created(self):
"""Called when the app is created"""
if KH_ENABLE_FIRST_SETUP:
self.app.load(
toggle_first_setup_visibility,
inputs=[],
outputs=[self.setup_page_wrapper, self.tabs],
)

View File

@@ -883,7 +883,8 @@ class ChatPage(BasePage):
# check if this is a newly created conversation
if len(chat_history) == 1:
suggested_name = suggest_pipeline(chat_history).text[:40]
suggested_name = suggest_pipeline(chat_history).text
suggested_name = suggested_name.replace('"', "").replace("'", "")[:40]
new_name = gr.update(value=suggested_name)
renamed = True

View File

@@ -11,8 +11,8 @@ class ChatPanel(BasePage):
self.chatbot = gr.Chatbot(
label=self._app.app_name,
placeholder=(
"This is the beginning of a new conversation.\nMake sure to have added"
" a LLM by following the instructions in the Help tab."
"This is the beginning of a new conversation.\nIf you are new, "
"visit the Help tab for quick instructions."
),
show_label=False,
elem_id="main-chat-bot",

View File

@@ -0,0 +1,347 @@
import json
import gradio as gr
import requests
from ktem.app import BasePage
from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
DEMO_MESSAGE = (
"This is a public space. Please use the "
'"Duplicate Space" function on the top right '
"corner to setup your own space."
)
def pull_model(name: str, stream: bool = True):
payload = {"name": name}
headers = {"Content-Type": "application/json"}
response = requests.post(
DEFAULT_OLLAMA_URL + "/pull", json=payload, headers=headers, stream=stream
)
# Check if the request was successful
response.raise_for_status()
if stream:
for line in response.iter_lines():
if line:
data = json.loads(line.decode("utf-8"))
yield data
if data.get("status") == "success":
break
else:
data = response.json()
return data
class SetupPage(BasePage):
public_events = ["onFirstSetupComplete"]
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
self.radio_model = gr.Radio(
[
("Cohere API (*free registration* available) - recommended", "cohere"),
("OpenAI API (for more advance models)", "openai"),
("Local LLM (for completely *private RAG*)", "ollama"),
],
label="Select your model provider",
value="cohere",
info=(
"Note: You can change this later. "
"If you are not sure, go with the first option "
"which fits most normal users."
),
interactive=True,
)
with gr.Column(visible=False) as self.openai_option:
gr.Markdown(
(
"#### OpenAI API Key\n\n"
"(create at https://platform.openai.com/api-keys)"
)
)
self.openai_api_key = gr.Textbox(
show_label=False, placeholder="OpenAI API Key"
)
with gr.Column(visible=True) as self.cohere_option:
gr.Markdown(
(
"#### Cohere API Key\n\n"
"(register your free API key "
"at https://dashboard.cohere.com/api-keys)"
)
)
self.cohere_api_key = gr.Textbox(
show_label=False, placeholder="Cohere API Key"
)
with gr.Column(visible=False) as self.ollama_option:
gr.Markdown(
(
"#### Setup Ollama\n\n"
"Download and install Ollama from "
"https://ollama.com/"
)
)
self.setup_log = gr.HTML(
show_label=False,
)
with gr.Row():
self.btn_finish = gr.Button("Proceed", variant="primary")
self.btn_skip = gr.Button(
"I am an advance user. Skip this.", variant="stop"
)
def on_register_events(self):
onFirstSetupComplete = gr.on(
triggers=[
self.btn_finish.click,
self.cohere_api_key.submit,
self.openai_api_key.submit,
],
fn=self.update_model,
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
outputs=[self.setup_log],
show_progress="hidden",
)
if not KH_DEMO_MODE:
onSkipSetup = gr.on(
triggers=[self.btn_skip.click],
fn=lambda: None,
inputs=[],
show_progress="hidden",
outputs=[self.radio_model],
)
for event in self._app.get_event("onFirstSetupComplete"):
onSkipSetup = onSkipSetup.success(**event)
onFirstSetupComplete = onFirstSetupComplete.success(
fn=self.update_default_settings,
inputs=[self.radio_model, self._app.settings_state],
outputs=self._app.settings_state,
)
for event in self._app.get_event("onFirstSetupComplete"):
onFirstSetupComplete = onFirstSetupComplete.success(**event)
self.radio_model.change(
fn=self.switch_options_view,
inputs=[self.radio_model],
show_progress="hidden",
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
)
def update_model(
self,
cohere_api_key,
openai_api_key,
radio_model_value,
):
# skip if KH_DEMO_MODE
if KH_DEMO_MODE:
raise gr.Error(DEMO_MESSAGE)
log_content = ""
if not radio_model_value:
gr.Info("Skip setup models.")
yield gr.value(visible=False)
return
if radio_model_value == "cohere":
if cohere_api_key:
llms.update(
name="cohere",
spec={
"__type__": "kotaemon.llms.chats.LCCohereChat",
"model_name": "command-r-plus-08-2024",
"api_key": cohere_api_key,
},
default=True,
)
embeddings.update(
name="cohere",
spec={
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v2.0",
"cohere_api_key": cohere_api_key,
"user_agent": "default",
},
default=True,
)
elif radio_model_value == "openai":
if openai_api_key:
llms.update(
name="openai",
spec={
"__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "https://api.openai.com/v1",
"model": "gpt-4o",
"api_key": openai_api_key,
"timeout": 20,
},
default=True,
)
embeddings.update(
name="openai",
spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "https://api.openai.com/v1",
"model": "text-embedding-3-large",
"api_key": openai_api_key,
"timeout": 10,
"context_length": 8191,
},
default=True,
)
elif radio_model_value == "ollama":
llms.update(
name="ollama",
spec={
"__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "http://localhost:11434/v1/",
"model": "llama3.1:8b",
"api_key": "ollama",
},
default=True,
)
embeddings.update(
name="ollama",
spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "http://localhost:11434/v1/",
"model": "nomic-embed-text",
"api_key": "ollama",
},
default=True,
)
# download required models through ollama
llm_model_name = llms.get("ollama").model # type: ignore
emb_model_name = embeddings.get("ollama").model # type: ignore
try:
for model_name in [emb_model_name, llm_model_name]:
log_content += f"- Downloading model `{model_name}` from Ollama<br>"
yield log_content
pre_download_log = log_content
for response in pull_model(model_name):
complete = response.get("completed", 0)
total = response.get("total", 0)
if complete > 0 and total > 0:
ratio = int(complete / total * 100)
log_content = (
pre_download_log
+ f"- {response.get('status')}: {ratio}%<br>"
)
else:
if "pulling" not in response.get("status", ""):
log_content += f"- {response.get('status')}<br>"
yield log_content
except Exception as e:
log_content += (
"Make sure you have download and installed Ollama correctly."
f"Got error: {str(e)}"
)
yield log_content
raise gr.Error("Failed to download model from Ollama.")
# test models connection
llm_output = emb_output = None
# LLM model
log_content += f"- Testing LLM model: {radio_model_value}<br>"
yield log_content
llm = llms.get(radio_model_value) # type: ignore
log_content += "- Sending a message `Hi`<br>"
yield log_content
try:
llm_output = llm("Hi")
except Exception as e:
log_content += (
f"<mark style='color: yellow; background: red'>- Connection failed. "
f"Got error:\n {str(e)}</mark>"
)
if llm_output:
log_content += (
"<mark style='background: green; color: white'>- Connection success. "
"</mark><br>"
)
yield log_content
if llm_output:
# embedding model
log_content += f"- Testing Embedding model: {radio_model_value}<br>"
yield log_content
emb = embeddings.get(radio_model_value)
assert emb, f"Embedding model {radio_model_value} not found."
log_content += "- Sending a message `Hi`<br>"
yield log_content
try:
emb_output = emb("Hi")
except Exception as e:
log_content += (
f"<mark style='color: yellow; background: red'>"
"- Connection failed. "
f"Got error:\n {str(e)}</mark>"
)
if emb_output:
log_content += (
"<mark style='background: green; color: white'>"
"- Connection success. "
"</mark><br>"
)
yield log_content
if llm_output and emb_output:
gr.Info("Setup models completed successfully!")
else:
raise gr.Error(
"Setup models failed. Please verify your connection and API key."
)
def update_default_settings(self, radio_model_value, default_settings):
# revise default settings
# reranking llm
default_settings["index.options.1.reranking_llm"] = radio_model_value
if radio_model_value == "ollama":
default_settings["index.options.1.use_llm_reranking"] = False
return default_settings
def switch_options_view(self, radio_model_value):
components_visible = [gr.update(visible=False) for _ in range(3)]
values = ["cohere", "openai", "ollama", None]
assert radio_model_value in values, f"Invalid value {radio_model_value}"
if radio_model_value is not None:
idx = values.index(radio_model_value)
components_visible[idx] = gr.update(visible=True)
return components_visible

View File

@@ -52,6 +52,7 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline):
llm_kwargs = {
"tools": [{"type": "function", "function": function}],
"tool_choice": "auto",
"tools_pydantic": [SubQuery],
}
messages = [

View File

@@ -7,6 +7,7 @@ DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Only output the rephrased question without additional information. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "

View File

@@ -39,10 +39,13 @@ EVIDENCE_MODE_TABLE = 1
EVIDENCE_MODE_CHATBOT = 2
EVIDENCE_MODE_FIGURE = 3
MAX_IMAGES = 10
CITATION_TIMEOUT = 5.0
def find_text(search_span, context):
sentence_list = search_span.split("\n")
context = context.replace("\n", " ")
matches = []
# don't search for small text
if len(search_span) > 5:
@@ -50,7 +53,7 @@ def find_text(search_span, context):
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
if match.size > len(sentence) * 0.35:
if match.size > max(len(sentence) * 0.35, 5):
matches.append((match.b, match.b + match.size))
return matches
@@ -200,15 +203,6 @@ DEFAULT_QA_FIGURE_PROMPT = (
"Answer: "
) # noqa
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
) # noqa
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
@@ -391,7 +385,8 @@ class AnswerWithContextPipeline(BaseComponent):
qa_score = None
if citation_thread:
citation_thread.join()
citation_thread.join(timeout=CITATION_TIMEOUT)
answer = Document(
text=output,
metadata={"citation": citation, "qa_score": qa_score},
@@ -525,24 +520,24 @@ class FullQAPipeline(BaseReasoning):
spans = defaultdict(list)
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
if answer.metadata["citation"] and answer.metadata["citation"].answer:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
matched_excerpts = []
for doc in docs:
matches = find_text(quote, doc.text)
if answer.metadata["citation"]:
evidences = answer.metadata["citation"].evidences
for quote in evidences:
matched_excerpts = []
for doc in docs:
matches = find_text(quote, doc.text)
for start, end in matches:
if "|" not in doc.text[start:end]:
spans[doc.doc_id].append(
{
"start": start,
"end": end,
}
)
matched_excerpts.append(doc.text[start:end])
for start, end in matches:
if "|" not in doc.text[start:end]:
spans[doc.doc_id].append(
{
"start": start,
"end": end,
}
)
matched_excerpts.append(doc.text[start:end])
print("Matched citation:", quote, matched_excerpts),
# print("Matched citation:", quote, matched_excerpts),
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())

View File

@@ -75,7 +75,6 @@ class Render:
if not highlight_text:
try:
lang = detect(text.replace("\n", " "))["lang"]
print("lang", lang)
if lang not in ["ja", "cn"]:
highlight_words = [
t[:-1] if t.endswith("-") else t for t in text.split("\n")
@@ -83,10 +82,13 @@ class Render:
highlight_text = highlight_words[0]
phrase = "true"
else:
highlight_text = text.replace("\n", "")
phrase = "false"
print("highlight_text", highlight_text, phrase)
highlight_text = (
text.replace("\n", "").replace('"', "").replace("'", "")
)
# print("highlight_text", highlight_text, phrase, lang)
except Exception as e:
print(e)
highlight_text = text
@@ -162,8 +164,15 @@ class Render:
if item_type_prefix:
item_type_prefix += " from "
if llm_reranking_score > 0:
relevant_score = llm_reranking_score
elif cohere_reranking_score > 0:
relevant_score = cohere_reranking_score
else:
relevant_score = 0.0
rendered_score = Render.collapsible(
header=f"<b>&emsp;Relevance score</b>: {llm_reranking_score}",
header=f"<b>&emsp;Relevance score</b>: {relevant_score:.1f}",
content="<b>&emsp;&emsp;Vectorstore score:</b>"
f" {vectorstore_score}"
f"{text_search_str}"