Make ktem official (#134)
* Move kotaemon and ktem into same folder * Update docs * Update CI * Resolve mypy, isorts * Re-allow test pdf files
This commit is contained in:
committed by
GitHub
parent
9c5b707010
commit
2dd531114f
0
libs/ktem/ktem/pages/__init__.py
Normal file
0
libs/ktem/ktem/pages/__init__.py
Normal file
125
libs/ktem/ktem/pages/chat/__init__.py
Normal file
125
libs/ktem/ktem/pages/chat/__init__.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
from .chat_panel import ChatPanel
|
||||
from .control import ConversationControl
|
||||
from .data_source import DataSource
|
||||
from .events import chat_fn, index_fn, is_liked, load_files, update_data_source
|
||||
from .report import ReportIssue
|
||||
from .upload import FileUpload
|
||||
|
||||
|
||||
class ChatPage(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
self.data_source = DataSource(self._app)
|
||||
self.file_upload = FileUpload(self._app)
|
||||
self.report_issue = ReportIssue(self._app)
|
||||
with gr.Column(scale=6):
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
|
||||
def on_register_events(self):
|
||||
self.chat_panel.submit_btn.click(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
).then(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_panel.text_input.submit(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
).then(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_panel.chatbot.like(
|
||||
fn=is_liked,
|
||||
inputs=[self.chat_control.conversation_id],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_control.conversation.change(
|
||||
self.chat_control.select_conv,
|
||||
inputs=[self.chat_control.conversation],
|
||||
outputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.report_issue.report_btn.click(
|
||||
self.report_issue.report,
|
||||
inputs=[
|
||||
self.report_issue.correctness,
|
||||
self.report_issue.issues,
|
||||
self.report_issue.more_detail,
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.data_source.files.input(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.file_upload.upload_button.click(
|
||||
fn=index_fn,
|
||||
inputs=[
|
||||
self.file_upload.files,
|
||||
self.file_upload.reindex,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.file_upload.file_output, self.data_source.files],
|
||||
)
|
||||
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=load_files()),
|
||||
inputs=None,
|
||||
outputs=[self.data_source.files],
|
||||
)
|
21
libs/ktem/ktem/pages/chat/chat_panel.py
Normal file
21
libs/ktem/ktem/pages/chat/chat_panel.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
|
||||
class ChatPanel(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
self.chatbot = gr.Chatbot(
|
||||
elem_id="main-chat-bot",
|
||||
show_copy_button=True,
|
||||
likeable=True,
|
||||
show_label=False,
|
||||
)
|
||||
with gr.Row():
|
||||
self.text_input = gr.Text(
|
||||
placeholder="Chat input", scale=15, container=False
|
||||
)
|
||||
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
|
193
libs/ktem/ktem/pages/chat/control.py
Normal file
193
libs/ktem/ktem/pages/chat/control.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import Conversation, engine
|
||||
from sqlmodel import Session, select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationControl(BasePage):
|
||||
"""Manage conversation"""
|
||||
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Conversation control", open=True):
|
||||
self.conversation_id = gr.State(value="")
|
||||
self.conversation = gr.Dropdown(
|
||||
label="Chat sessions",
|
||||
choices=[],
|
||||
container=False,
|
||||
filterable=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.conversation_new_btn = gr.Button(value="New", min_width=10)
|
||||
self.conversation_del_btn = gr.Button(value="Delete", min_width=10)
|
||||
|
||||
with gr.Row():
|
||||
self.conversation_rn = gr.Text(
|
||||
placeholder="Conversation name",
|
||||
container=False,
|
||||
scale=5,
|
||||
min_width=10,
|
||||
interactive=True,
|
||||
)
|
||||
self.conversation_rn_btn = gr.Button(
|
||||
value="Rename", scale=1, min_width=10
|
||||
)
|
||||
|
||||
# current_state = gr.Text()
|
||||
# show_current_state = gr.Button(value="Current")
|
||||
# show_current_state.click(
|
||||
# lambda a, b: "\n".join([a, b]),
|
||||
# inputs=[cid, self.conversation],
|
||||
# outputs=[current_state],
|
||||
# )
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
self._app.subscribe_event(
|
||||
name="onCreateUser",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
self.conversation_new_btn.click(
|
||||
self.new_conv,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.conversation_id, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.conversation_del_btn.click(
|
||||
self.delete_conv,
|
||||
inputs=[self.conversation_id, self._app.user_id],
|
||||
outputs=[self.conversation_id, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.conversation_rn_btn.click(
|
||||
self.rename_conv,
|
||||
inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
|
||||
outputs=[self.conversation, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def load_chat_history(self, user_id):
|
||||
"""Reload chat history"""
|
||||
options = []
|
||||
with Session(engine) as session:
|
||||
statement = (
|
||||
select(Conversation)
|
||||
.where(Conversation.user == user_id)
|
||||
.order_by(Conversation.date_created.desc()) # type: ignore
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
for result in results:
|
||||
options.append((result.name, result.id))
|
||||
|
||||
# return gr.update(choices=options)
|
||||
return options
|
||||
|
||||
def reload_conv(self, user_id):
|
||||
conv_list = self.load_chat_history(user_id)
|
||||
if conv_list:
|
||||
return gr.update(value=conv_list[0][1], choices=conv_list)
|
||||
else:
|
||||
return gr.update(value=None, choices=[])
|
||||
|
||||
def new_conv(self, user_id):
|
||||
"""Create new chat"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return None, gr.update()
|
||||
with Session(engine) as session:
|
||||
new_conv = Conversation(user=user_id)
|
||||
session.add(new_conv)
|
||||
session.commit()
|
||||
|
||||
id_ = new_conv.id
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
|
||||
return id_, gr.update(value=id_, choices=history)
|
||||
|
||||
def delete_conv(self, conversation_id, user_id):
|
||||
"""Create new chat"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return None, gr.update()
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
session.delete(result)
|
||||
session.commit()
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
if history:
|
||||
id_ = history[0][1]
|
||||
return id_, gr.update(value=id_, choices=history)
|
||||
else:
|
||||
return None, gr.update(value=None, choices=[])
|
||||
|
||||
def select_conv(self, conversation_id):
|
||||
"""Select the conversation"""
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
try:
|
||||
result = session.exec(statement).one()
|
||||
id_ = result.id
|
||||
name = result.name
|
||||
files = result.data_source.get("files", [])
|
||||
chats = result.data_source.get("messages", [])
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
id_ = ""
|
||||
name = ""
|
||||
files = []
|
||||
chats = []
|
||||
return id_, id_, name, files, chats
|
||||
|
||||
def rename_conv(self, conversation_id, new_name, user_id):
|
||||
"""Rename the conversation"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return gr.update(), ""
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
result.name = new_name
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
return gr.update(choices=history), conversation_id
|
18
libs/ktem/ktem/pages/chat/data_source.py
Normal file
18
libs/ktem/ktem/pages/chat/data_source.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
|
||||
class DataSource(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Data source", open=True):
|
||||
self.files = gr.Dropdown(
|
||||
label="Files",
|
||||
choices=[],
|
||||
multiselect=True,
|
||||
container=False,
|
||||
interactive=True,
|
||||
)
|
220
libs/ktem/ktem/pages/chat/events.py
Normal file
220
libs/ktem/ktem/pages/chat/events.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.components import llms, reasonings
|
||||
from ktem.db.models import Conversation, Source, engine
|
||||
from ktem.indexing.base import BaseIndex
|
||||
from ktem.reasoning.simple import DocumentRetrievalPipeline
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as app_settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||
"""Create the pipeline from settings
|
||||
|
||||
Args:
|
||||
settings: the settings of the app
|
||||
files: the list of file ids that will be served as context. If None, then
|
||||
consider using all files
|
||||
|
||||
Returns:
|
||||
the pipeline objects
|
||||
"""
|
||||
|
||||
reasoning_mode = settings["reasoning.use"]
|
||||
reasoning_cls = reasonings[reasoning_mode]
|
||||
pipeline = reasoning_cls.get_pipeline(settings, files=files)
|
||||
|
||||
if settings["reasoning.use"] in ["rewoo", "react"]:
|
||||
from kotaemon.agents import ReactAgent, RewooAgent
|
||||
|
||||
llm = (
|
||||
llms["gpt4"]
|
||||
if settings["answer_simple_llm_model"] == "gpt-4"
|
||||
else llms["gpt35"]
|
||||
)
|
||||
tools = []
|
||||
tools_keys = (
|
||||
"answer_rewoo_tools"
|
||||
if settings["reasoning.use"] == "rewoo"
|
||||
else "answer_react_tools"
|
||||
)
|
||||
for tool in settings[tools_keys]:
|
||||
if tool == "llm":
|
||||
from kotaemon.agents import LLMTool
|
||||
|
||||
tools.append(LLMTool(llm=llm))
|
||||
elif tool == "docsearch":
|
||||
from kotaemon.agents import ComponentTool
|
||||
|
||||
filenames = ""
|
||||
if files:
|
||||
with Session(engine) as session:
|
||||
statement = select(Source).where(
|
||||
Source.id.in_(files) # type: ignore
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
filenames = (
|
||||
"The file names are: "
|
||||
+ " ".join([result.name for result in results])
|
||||
+ ". "
|
||||
)
|
||||
|
||||
retrieval_pipeline = DocumentRetrievalPipeline()
|
||||
retrieval_pipeline.set_run(
|
||||
{
|
||||
".top_k": int(settings["retrieval_number"]),
|
||||
".mmr": settings["retrieval_mmr"],
|
||||
".doc_ids": files,
|
||||
},
|
||||
temp=True,
|
||||
)
|
||||
tool = ComponentTool(
|
||||
name="docsearch",
|
||||
description=(
|
||||
"A vector store that searches for similar and "
|
||||
"related content "
|
||||
f"in a document. {filenames}"
|
||||
"The result is a huge chunk of text related "
|
||||
"to your search but can also "
|
||||
"contain irrelevant info."
|
||||
),
|
||||
component=retrieval_pipeline,
|
||||
postprocessor=lambda docs: "\n\n".join(
|
||||
[doc.text.replace("\n", " ") for doc in docs]
|
||||
),
|
||||
)
|
||||
tools.append(tool)
|
||||
elif tool == "google":
|
||||
from kotaemon.agents import GoogleSearchTool
|
||||
|
||||
tools.append(GoogleSearchTool())
|
||||
elif tool == "wikipedia":
|
||||
from kotaemon.agents import WikipediaTool
|
||||
|
||||
tools.append(WikipediaTool())
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown tool: {tool}")
|
||||
|
||||
if settings["reasoning.use"] == "rewoo":
|
||||
pipeline = RewooAgent(
|
||||
planner_llm=llm,
|
||||
solver_llm=llm,
|
||||
plugins=tools,
|
||||
)
|
||||
pipeline.set_run({".use_citation": True})
|
||||
else:
|
||||
pipeline = ReactAgent(
|
||||
llm=llm,
|
||||
plugins=tools,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def chat_fn(chat_input, chat_history, files, settings):
|
||||
pipeline = create_pipeline(settings, files)
|
||||
|
||||
text = ""
|
||||
refs = []
|
||||
for response in pipeline(chat_input):
|
||||
if response.metadata.get("citation", None):
|
||||
citation = response.metadata["citation"]
|
||||
for idx, fact_with_evidence in enumerate(citation.answer):
|
||||
quotes = fact_with_evidence.substring_quote
|
||||
if quotes:
|
||||
refs.append(
|
||||
(None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
|
||||
)
|
||||
else:
|
||||
text += response.text
|
||||
|
||||
yield "", chat_history + [(chat_input, text)] + refs
|
||||
|
||||
|
||||
def is_liked(convo_id, liked: gr.LikeData):
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
data_source = deepcopy(result.data_source)
|
||||
likes = data_source.get("likes", [])
|
||||
likes.append([liked.index, liked.value, liked.liked])
|
||||
data_source["likes"] = likes
|
||||
|
||||
result.data_source = data_source
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_data_source(convo_id, selected_files, messages):
|
||||
"""Update the data source"""
|
||||
if not convo_id:
|
||||
gr.Warning("No conversation selected")
|
||||
return
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
data_source = result.data_source
|
||||
result.data_source = {
|
||||
"files": selected_files,
|
||||
"messages": messages,
|
||||
"likes": deepcopy(data_source.get("likes", [])),
|
||||
}
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
|
||||
def load_files():
|
||||
options = []
|
||||
with Session(engine) as session:
|
||||
statement = select(Source)
|
||||
results = session.exec(statement).all()
|
||||
for result in results:
|
||||
options.append((result.name, result.id))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def index_fn(files, reindex: bool, selected_files, settings):
|
||||
"""Upload and index the files
|
||||
|
||||
Args:
|
||||
files: the list of files to be uploaded
|
||||
reindex: whether to reindex the files
|
||||
selected_files: the list of files already selected
|
||||
settings: the settings of the app
|
||||
"""
|
||||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
||||
|
||||
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
||||
gr.Info(f"Finish indexing into {len(output_nodes)} chunks")
|
||||
|
||||
# download the file
|
||||
text = "\n\n".join([each.text for each in output_nodes])
|
||||
handler, file_path = tempfile.mkstemp(suffix=".txt")
|
||||
with open(file_path, "w") as f:
|
||||
f.write(text)
|
||||
os.close(handler)
|
||||
|
||||
if isinstance(selected_files, list):
|
||||
output = selected_files + file_ids
|
||||
else:
|
||||
output = file_ids
|
||||
|
||||
file_list = load_files()
|
||||
|
||||
return (
|
||||
gr.update(value=file_path, visible=True),
|
||||
gr.update(value=output, choices=file_list),
|
||||
)
|
70
libs/ktem/ktem/pages/chat/report.py
Normal file
70
libs/ktem/ktem/pages/chat/report.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import IssueReport, engine
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class ReportIssue(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Report", open=False):
|
||||
self.correctness = gr.Radio(
|
||||
choices=[
|
||||
("The answer is correct", "correct"),
|
||||
("The answer is incorrect", "incorrect"),
|
||||
],
|
||||
label="Correctness:",
|
||||
)
|
||||
self.issues = gr.CheckboxGroup(
|
||||
choices=[
|
||||
("The answer is offensive", "offensive"),
|
||||
("The evidence is incorrect", "wrong-evidence"),
|
||||
],
|
||||
label="Other issue:",
|
||||
)
|
||||
self.more_detail = gr.Textbox(
|
||||
placeholder="More detail (e.g. how wrong is it, what is the "
|
||||
"correct answer, etc...)",
|
||||
container=False,
|
||||
lines=3,
|
||||
)
|
||||
gr.Markdown(
|
||||
"This will send the current chat and the user settings to "
|
||||
"help with investigation"
|
||||
)
|
||||
self.report_btn = gr.Button("Report")
|
||||
|
||||
def report(
|
||||
self,
|
||||
correctness: str,
|
||||
issues: list[str],
|
||||
more_detail: str,
|
||||
conv_id: str,
|
||||
chat_history: list,
|
||||
files: list,
|
||||
settings: dict,
|
||||
user_id: Optional[int],
|
||||
):
|
||||
with Session(engine) as session:
|
||||
issue = IssueReport(
|
||||
issues={
|
||||
"correctness": correctness,
|
||||
"issues": issues,
|
||||
"more_detail": more_detail,
|
||||
},
|
||||
chat={
|
||||
"conv_id": conv_id,
|
||||
"chat_history": chat_history,
|
||||
"files": files,
|
||||
},
|
||||
settings=settings,
|
||||
user=user_id,
|
||||
)
|
||||
session.add(issue)
|
||||
session.commit()
|
||||
gr.Info("Thank you for your feedback")
|
43
libs/ktem/ktem/pages/chat/upload.py
Normal file
43
libs/ktem/ktem/pages/chat/upload.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
|
||||
class FileUpload(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="File upload", open=False):
|
||||
gr.Markdown(
|
||||
"Supported file types: image, pdf, txt, csv, xlsx, docx.",
|
||||
)
|
||||
self.files = gr.File(
|
||||
file_types=["image", ".pdf", ".txt", ".csv", ".xlsx", ".docx"],
|
||||
file_count="multiple",
|
||||
container=False,
|
||||
height=50,
|
||||
)
|
||||
with gr.Accordion("Advanced indexing options", open=False):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.reindex = gr.Checkbox(
|
||||
value=False, label="Force reindex file", container=False
|
||||
)
|
||||
with gr.Column():
|
||||
self.parser = gr.Dropdown(
|
||||
choices=[
|
||||
("PDF text parser", "normal"),
|
||||
("lib-table", "table"),
|
||||
("lib-table + OCR", "ocr"),
|
||||
("MathPix", "mathpix"),
|
||||
],
|
||||
value="normal",
|
||||
label="Use advance PDF parser (table+layout preserving)",
|
||||
container=True,
|
||||
)
|
||||
|
||||
self.upload_button = gr.Button("Upload and Index")
|
||||
self.file_output = gr.File(
|
||||
visible=False, label="Output files (debug purpose)"
|
||||
)
|
24
libs/ktem/ktem/pages/help.py
Normal file
24
libs/ktem/ktem/pages/help.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class HelpPage:
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.dir_md = Path(__file__).parent.parent / "assets" / "md"
|
||||
|
||||
with gr.Accordion("Changelogs"):
|
||||
gr.Markdown(self.get_changelogs())
|
||||
|
||||
with gr.Accordion("About Kotaemon (temporary)"):
|
||||
with (self.dir_md / "about_kotaemon.md").open() as fi:
|
||||
gr.Markdown(fi.read())
|
||||
|
||||
with gr.Accordion("About Cinnamon AI (temporary)", open=False):
|
||||
with (self.dir_md / "about_cinnamon.md").open() as fi:
|
||||
gr.Markdown(fi.read())
|
||||
|
||||
def get_changelogs(self):
|
||||
with (self.dir_md / "changelogs.md").open() as fi:
|
||||
return fi.read()
|
414
libs/ktem/ktem/pages/settings.py
Normal file
414
libs/ktem/ktem/pages/settings.py
Normal file
@@ -0,0 +1,414 @@
|
||||
import hashlib
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import Settings, User, engine
|
||||
from sqlmodel import Session, select
|
||||
|
||||
gr_cls_single_value = {
|
||||
"text": gr.Textbox,
|
||||
"number": gr.Number,
|
||||
"checkbox": gr.Checkbox,
|
||||
}
|
||||
|
||||
|
||||
gr_cls_choices = {
|
||||
"dropdown": gr.Dropdown,
|
||||
"radio": gr.Radio,
|
||||
"checkboxgroup": gr.CheckboxGroup,
|
||||
}
|
||||
|
||||
|
||||
def render_setting_item(setting_item, value):
|
||||
"""Render the setting component into corresponding Gradio UI component"""
|
||||
kwargs = {
|
||||
"label": setting_item.name,
|
||||
"value": value,
|
||||
"interactive": True,
|
||||
}
|
||||
|
||||
if setting_item.component in gr_cls_single_value:
|
||||
return gr_cls_single_value[setting_item.component](**kwargs)
|
||||
|
||||
kwargs["choices"] = setting_item.choices
|
||||
|
||||
if setting_item.component in gr_cls_choices:
|
||||
return gr_cls_choices[setting_item.component](**kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown component {setting_item.component}, allowed are: "
|
||||
f"{list(gr_cls_single_value.keys()) + list(gr_cls_choices.keys())}.\n"
|
||||
f"Setting item: {setting_item}"
|
||||
)
|
||||
|
||||
|
||||
class SettingsPage(BasePage):
|
||||
"""Responsible for allowing the users to customize the application
|
||||
|
||||
**IMPORTANT**: the name and id of the UI setting components should match the
|
||||
name of the setting in the `app.default_settings`
|
||||
"""
|
||||
|
||||
public_events = ["onSignIn", "onSignOut", "onCreateUser"]
|
||||
|
||||
def __init__(self, app):
|
||||
"""Initiate the page and render the UI"""
|
||||
self._app = app
|
||||
|
||||
self._settings_state = app.settings_state
|
||||
self._user_id = app.user_id
|
||||
self._default_settings = app.default_settings
|
||||
self._settings_dict = self._default_settings.flatten()
|
||||
self._settings_keys = list(self._settings_dict.keys())
|
||||
|
||||
self._components = {}
|
||||
self._reasoning_mode = {}
|
||||
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
self.setting_save_btn = gr.Button("Save settings")
|
||||
with gr.Tab("User settings"):
|
||||
self.user_tab()
|
||||
with gr.Tab("General application settings"):
|
||||
self.app_tab()
|
||||
with gr.Tab("Index settings"):
|
||||
self.index_tab()
|
||||
with gr.Tab("Reasoning settings"):
|
||||
self.reasoning_tab()
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
pass
|
||||
|
||||
def on_register_events(self):
|
||||
self.setting_save_btn.click(
|
||||
self.save_setting,
|
||||
inputs=[self._user_id] + self.components(),
|
||||
outputs=self._settings_state,
|
||||
)
|
||||
self.password_change_btn.click(
|
||||
self.change_password,
|
||||
inputs=[
|
||||
self._user_id,
|
||||
self.password_change,
|
||||
self.password_change_confirm,
|
||||
],
|
||||
outputs=None,
|
||||
show_progress="hidden",
|
||||
)
|
||||
self._components["reasoning.use"].change(
|
||||
self.change_reasoning_mode,
|
||||
inputs=[self._components["reasoning.use"]],
|
||||
outputs=list(self._reasoning_mode.values()),
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
onSignInClick = self.signin.click(
|
||||
self.sign_in,
|
||||
inputs=[self.username, self.password],
|
||||
outputs=[self._user_id, self.username, self.password]
|
||||
+ self.signed_in_state()
|
||||
+ [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onSignIn"):
|
||||
onSignInClick = onSignInClick.then(**event)
|
||||
|
||||
onSignInSubmit = self.password.submit(
|
||||
self.sign_in,
|
||||
inputs=[self.username, self.password],
|
||||
outputs=[self._user_id, self.username, self.password]
|
||||
+ self.signed_in_state()
|
||||
+ [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onSignIn"):
|
||||
onSignInSubmit = onSignInSubmit.then(**event)
|
||||
|
||||
onCreateUserClick = self.create_btn.click(
|
||||
self.create_user,
|
||||
inputs=[
|
||||
self.username_new,
|
||||
self.password_new,
|
||||
self.password_new_confirm,
|
||||
],
|
||||
outputs=[
|
||||
self._user_id,
|
||||
self.username_new,
|
||||
self.password_new,
|
||||
self.password_new_confirm,
|
||||
]
|
||||
+ self.signed_in_state()
|
||||
+ [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onCreateUser"):
|
||||
onCreateUserClick = onCreateUserClick.then(**event)
|
||||
|
||||
onSignOutClick = self.signout.click(
|
||||
self.sign_out,
|
||||
inputs=None,
|
||||
outputs=[self._user_id] + self.signed_in_state() + [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onSignOut"):
|
||||
onSignOutClick = onSignOutClick.then(**event)
|
||||
|
||||
def user_tab(self):
|
||||
with gr.Row() as self.user_out_state:
|
||||
with gr.Column():
|
||||
gr.Markdown("Sign in")
|
||||
self.username = gr.Textbox(label="Username", interactive=True)
|
||||
self.password = gr.Textbox(
|
||||
label="Password", type="password", interactive=True
|
||||
)
|
||||
self.signin = gr.Button("Login")
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("Create new account")
|
||||
self.username_new = gr.Textbox(label="Username", interactive=True)
|
||||
self.password_new = gr.Textbox(
|
||||
label="Password", type="password", interactive=True
|
||||
)
|
||||
self.password_new_confirm = gr.Textbox(
|
||||
label="Confirm password", type="password", interactive=True
|
||||
)
|
||||
self.create_btn = gr.Button("Create account")
|
||||
|
||||
# user management
|
||||
self.current_name = gr.Markdown("Current user: ___", visible=False)
|
||||
self.signout = gr.Button("Logout", visible=False)
|
||||
|
||||
self.password_change = gr.Textbox(
|
||||
label="New password", interactive=True, type="password", visible=False
|
||||
)
|
||||
self.password_change_confirm = gr.Textbox(
|
||||
label="Confirm password", interactive=True, type="password", visible=False
|
||||
)
|
||||
self.password_change_btn = gr.Button(
|
||||
"Change password", interactive=True, visible=False
|
||||
)
|
||||
|
||||
def signed_out_state(self):
|
||||
return [
|
||||
self.username,
|
||||
self.password,
|
||||
self.signin,
|
||||
self.username_new,
|
||||
self.password_new,
|
||||
self.password_new_confirm,
|
||||
self.create_btn,
|
||||
]
|
||||
|
||||
def signed_in_state(self):
|
||||
return [
|
||||
self.current_name, # always the first one
|
||||
self.signout,
|
||||
self.password_change,
|
||||
self.password_change_confirm,
|
||||
self.password_change_btn,
|
||||
]
|
||||
|
||||
def sign_in(self, username: str, password: str):
|
||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||||
user_id, clear_username, clear_password = None, username, password
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(
|
||||
User.username == username,
|
||||
User.password == hashed_password,
|
||||
)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
user_id = result[0].id
|
||||
clear_username, clear_password = "", ""
|
||||
else:
|
||||
gr.Warning("Username or password is incorrect")
|
||||
|
||||
output: list = [user_id, clear_username, clear_password]
|
||||
if user_id is None:
|
||||
output += [
|
||||
gr.update(visible=False) for _ in range(len(self.signed_in_state()))
|
||||
]
|
||||
output.append(gr.update(visible=True))
|
||||
else:
|
||||
output.append(gr.update(visible=True, value=f"Current user: {username}"))
|
||||
output += [
|
||||
gr.update(visible=True) for _ in range(len(self.signed_in_state()) - 1)
|
||||
]
|
||||
output.append(gr.update(visible=False))
|
||||
|
||||
return output
|
||||
|
||||
def create_user(self, username, password, password_confirm):
|
||||
user_id, usn, pwd, pwdc = None, username, password, password_confirm
|
||||
if password != password_confirm:
|
||||
gr.Warning("Password does not match")
|
||||
else:
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(
|
||||
User.username == username,
|
||||
)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
gr.Warning(f'Username "{username}" already exists')
|
||||
else:
|
||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||||
user = User(username=username, password=hashed_password)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
usn, pwd, pwdc = "", "", ""
|
||||
print(user_id)
|
||||
|
||||
output: list = [user_id, usn, pwd, pwdc]
|
||||
if user_id is not None:
|
||||
output.append(gr.update(visible=True, value=f"Current user: {username}"))
|
||||
output += [
|
||||
gr.update(visible=True) for _ in range(len(self.signed_in_state()) - 1)
|
||||
]
|
||||
output.append(gr.update(visible=False))
|
||||
else:
|
||||
output += [
|
||||
gr.update(visible=False) for _ in range(len(self.signed_in_state()))
|
||||
]
|
||||
output.append(gr.update(visible=True))
|
||||
|
||||
return output
|
||||
|
||||
def sign_out(self):
|
||||
output = [None]
|
||||
output += [gr.update(visible=False) for _ in range(len(self.signed_in_state()))]
|
||||
output.append(gr.update(visible=True))
|
||||
return output
|
||||
|
||||
def change_password(self, user_id, password, password_confirm):
|
||||
if password != password_confirm:
|
||||
gr.Warning("Password does not match")
|
||||
return
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.id == user_id)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
user = result[0]
|
||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||||
user.password = hashed_password
|
||||
session.add(user)
|
||||
session.commit()
|
||||
gr.Info("Password changed")
|
||||
else:
|
||||
gr.Warning("User not found")
|
||||
|
||||
def app_tab(self):
|
||||
for n, si in self._default_settings.application.settings.items():
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"application.{n}"] = obj
|
||||
|
||||
def index_tab(self):
|
||||
for n, si in self._default_settings.index.settings.items():
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"index.{n}"] = obj
|
||||
|
||||
def reasoning_tab(self):
|
||||
with gr.Group():
|
||||
for n, si in self._default_settings.reasoning.settings.items():
|
||||
if n == "use":
|
||||
continue
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"reasoning.{n}"] = obj
|
||||
|
||||
gr.Markdown("### Reasoning-specific settings")
|
||||
self._components["reasoning.use"] = render_setting_item(
|
||||
self._default_settings.reasoning.settings["use"],
|
||||
self._default_settings.reasoning.settings["use"].value,
|
||||
)
|
||||
|
||||
for idx, (pn, sig) in enumerate(
|
||||
self._default_settings.reasoning.options.items()
|
||||
):
|
||||
with gr.Group(
|
||||
visible=idx == 0,
|
||||
elem_id=pn,
|
||||
) as self._reasoning_mode[pn]:
|
||||
gr.Markdown("**Name**: Description")
|
||||
for n, si in sig.settings.items():
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"reasoning.options.{pn}.{n}"] = obj
|
||||
|
||||
def change_reasoning_mode(self, value):
|
||||
output = []
|
||||
for each in self._reasoning_mode.values():
|
||||
if value == each.elem_id:
|
||||
output.append(gr.update(visible=True))
|
||||
else:
|
||||
output.append(gr.update(visible=False))
|
||||
return output
|
||||
|
||||
def load_setting(self, user_id=None):
|
||||
settings = self._settings_dict
|
||||
with Session(engine) as session:
|
||||
statement = select(Settings).where(Settings.user == user_id)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
settings = result[0].setting
|
||||
|
||||
output = [settings]
|
||||
output += tuple(settings[name] for name in self.component_names())
|
||||
return output
|
||||
|
||||
def save_setting(self, user_id: int, *args):
|
||||
"""Save the setting to disk and persist the setting to session state
|
||||
|
||||
Args:
|
||||
user_id: the user id
|
||||
args: all the values from the settings
|
||||
"""
|
||||
setting = {key: value for key, value in zip(self.component_names(), args)}
|
||||
if user_id is None:
|
||||
gr.Warning("Need to login before saving settings")
|
||||
return setting
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Settings).where(Settings.user == user_id)
|
||||
try:
|
||||
user_setting = session.exec(statement).one()
|
||||
except Exception:
|
||||
user_setting = Settings()
|
||||
user_setting.user = user_id
|
||||
user_setting.setting = setting
|
||||
session.add(user_setting)
|
||||
session.commit()
|
||||
|
||||
gr.Info("Setting saved")
|
||||
return setting
|
||||
|
||||
def components(self) -> list:
|
||||
"""Get the setting components"""
|
||||
output = []
|
||||
for name in self._settings_keys:
|
||||
output.append(self._components[name])
|
||||
return output
|
||||
|
||||
def component_names(self):
|
||||
"""Get the setting components"""
|
||||
return self._settings_keys
|
Reference in New Issue
Block a user