Make ktem official (#134)
* Move kotaemon and ktem into same folder * Update docs * Update CI * Resolve mypy, isorts * Re-allow test pdf files
This commit is contained in:
committed by
GitHub
parent
9c5b707010
commit
2dd531114f
2
libs/ktem/.gitignore
vendored
Normal file
2
libs/ktem/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
14-1_抜粋-1.pdf
|
||||
_example_.db
|
24
libs/ktem/README.md
Normal file
24
libs/ktem/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Example of MVP pipeline for _example_
|
||||
|
||||
## Prerequisite
|
||||
|
||||
To run the system out-of-the-box, please supply the following environment
|
||||
variables:
|
||||
|
||||
```
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
SERPAPI_API_KEY=
|
||||
COHERE_API_KEY=
|
||||
OPENAI_API_KEY_EMBEDDING=
|
||||
|
||||
# optional
|
||||
KH_APP_NAME=
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
```
|
||||
gradio launch.py
|
||||
```
|
101
libs/ktem/flowsettings.py
Normal file
101
libs/ktem/flowsettings.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from pathlib import Path
|
||||
|
||||
from decouple import config
|
||||
from platformdirs import user_cache_dir
|
||||
from theflow.settings.default import * # noqa
|
||||
|
||||
user_cache_dir = Path(
|
||||
user_cache_dir(str(config("KH_APP_NAME", default="ktem")), "Cinnamon")
|
||||
)
|
||||
user_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
COHERE_API_KEY = config("COHERE_API_KEY", default="")
|
||||
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
|
||||
KH_DOCSTORE = {
|
||||
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
||||
"path": str(user_cache_dir / "docstore"),
|
||||
}
|
||||
KH_VECTORSTORE = {
|
||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||
"path": str(user_cache_dir / "vectorstore"),
|
||||
}
|
||||
KH_FILESTORAGE_PATH = str(user_cache_dir / "files")
|
||||
KH_LLMS = {
|
||||
"gpt4": {
|
||||
"def": {
|
||||
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
"temperature": 0,
|
||||
"azure_endpoint": config("OPENAI_API_BASE", default=""),
|
||||
"openai_api_key": config("OPENAI_API_KEY", default=""),
|
||||
"openai_api_version": config("OPENAI_API_VERSION", default=""),
|
||||
"deployment_name": "dummy-q2",
|
||||
"stream": True,
|
||||
},
|
||||
"accuracy": 10,
|
||||
"cost": 10,
|
||||
"default": False,
|
||||
},
|
||||
"gpt35": {
|
||||
"def": {
|
||||
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
"temperature": 0,
|
||||
"azure_endpoint": config("OPENAI_API_BASE", default=""),
|
||||
"openai_api_key": config("OPENAI_API_KEY", default=""),
|
||||
"openai_api_version": config("OPENAI_API_VERSION", default=""),
|
||||
"deployment_name": "dummy-q2",
|
||||
"request_timeout": 10,
|
||||
"stream": False,
|
||||
},
|
||||
"accuracy": 5,
|
||||
"cost": 5,
|
||||
"default": True,
|
||||
},
|
||||
}
|
||||
KH_EMBEDDINGS = {
|
||||
"ada": {
|
||||
"def": {
|
||||
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||
"model": "text-embedding-ada-002",
|
||||
"azure_endpoint": config("OPENAI_API_BASE", default=""),
|
||||
"openai_api_key": config("OPENAI_API_KEY", default=""),
|
||||
"deployment": "dummy-q2-text-embedding",
|
||||
"chunk_size": 16,
|
||||
},
|
||||
"accuracy": 5,
|
||||
"cost": 5,
|
||||
"default": True,
|
||||
},
|
||||
}
|
||||
KH_REASONINGS = {
|
||||
"simple": "ktem.reasoning.simple.FullQAPipeline",
|
||||
}
|
||||
|
||||
|
||||
SETTINGS_APP = {
|
||||
"lang": {
|
||||
"name": "Language",
|
||||
"value": "en",
|
||||
"choices": [("English", "en"), ("Japanese", "ja")],
|
||||
"component": "dropdown",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
SETTINGS_REASONING = {
|
||||
"use": {
|
||||
"name": "Reasoning options",
|
||||
"value": None,
|
||||
"choices": [],
|
||||
"component": "radio",
|
||||
},
|
||||
"lang": {
|
||||
"name": "Language",
|
||||
"value": "en",
|
||||
"choices": [("English", "en"), ("Japanese", "ja")],
|
||||
"component": "dropdown",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
KH_INDEX = "ktem.indexing.file.IndexDocumentPipeline"
|
0
libs/ktem/khapptests/__init__.py
Normal file
0
libs/ktem/khapptests/__init__.py
Normal file
1552
libs/ktem/khapptests/resources/embedding_openai.json
Normal file
1552
libs/ktem/khapptests/resources/embedding_openai.json
Normal file
File diff suppressed because it is too large
Load Diff
72
libs/ktem/khapptests/test_qa.py
Normal file
72
libs/ktem/khapptests/test_qa.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from index import ReaderIndexingPipeline
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from openai.resources.embeddings import Embeddings
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||
openai_embedding = json.load(f)
|
||||
|
||||
|
||||
_openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||
{
|
||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||
"object": "chat.completion",
|
||||
"created": 1692338378,
|
||||
"model": "gpt-35-turbo",
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_openai_embedding(monkeypatch):
|
||||
monkeypatch.setattr(Embeddings, "create", lambda *args, **kwargs: openai_embedding)
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.chat.completions.Completions.create",
|
||||
side_effect=lambda *args, **kwargs: _openai_chat_completion_response,
|
||||
)
|
||||
def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
|
||||
indexing_pipeline = ReaderIndexingPipeline(
|
||||
storage_path=tmp_path,
|
||||
)
|
||||
indexing_pipeline.indexing_vector_pipeline.embedding.openai_api_key = "some-key"
|
||||
input_file_path = Path(__file__).parent / "resources/dummy.pdf"
|
||||
|
||||
# call ingestion pipeline
|
||||
indexing_pipeline(input_file_path, force_reindex=True)
|
||||
retrieving_pipeline = indexing_pipeline.to_retrieving_pipeline()
|
||||
|
||||
results = retrieving_pipeline("This is a query")
|
||||
assert len(results) == 1
|
||||
|
||||
# create llm
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="gpt35turbo",
|
||||
temperature=0,
|
||||
)
|
||||
qa_pipeline = indexing_pipeline.to_qa_pipeline(llm=llm, openai_api_key="some-key")
|
||||
response = qa_pipeline("Summarize this document.")
|
||||
assert response
|
0
libs/ktem/ktem/__init__.py
Normal file
0
libs/ktem/ktem/__init__.py
Normal file
201
libs/ktem/ktem/app.py
Normal file
201
libs/ktem/ktem/app.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import pluggy
|
||||
from ktem import extension_protocol
|
||||
from ktem.components import reasonings
|
||||
from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared
|
||||
from ktem.settings import (
|
||||
BaseSettingGroup,
|
||||
SettingGroup,
|
||||
SettingItem,
|
||||
SettingReasoningGroup,
|
||||
)
|
||||
from theflow.settings import settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
class BaseApp:
|
||||
"""The main app of Kotaemon
|
||||
|
||||
The main application contains app-level information:
|
||||
- setting state
|
||||
- user id
|
||||
|
||||
Also contains registering methods for:
|
||||
- reasoning pipelines
|
||||
- indexing & retrieval pipelines
|
||||
|
||||
App life-cycle:
|
||||
- Render
|
||||
- Declare public events
|
||||
- Subscribe public events
|
||||
- Register events
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
dir_assets = Path(__file__).parent / "assets"
|
||||
with (dir_assets / "css" / "main.css").open() as fi:
|
||||
self._css = fi.read()
|
||||
with (dir_assets / "js" / "main.js").open() as fi:
|
||||
self._js = fi.read()
|
||||
|
||||
self.default_settings = SettingGroup(
|
||||
application=BaseSettingGroup(settings=settings.SETTINGS_APP),
|
||||
reasoning=SettingReasoningGroup(settings=settings.SETTINGS_REASONING),
|
||||
)
|
||||
|
||||
self._callbacks: dict[str, list] = {}
|
||||
self._events: dict[str, list] = {}
|
||||
|
||||
self.register_indices()
|
||||
self.register_reasonings()
|
||||
self.register_extensions()
|
||||
|
||||
self.default_settings.reasoning.finalize()
|
||||
self.default_settings.index.finalize()
|
||||
|
||||
self.settings_state = gr.State(self.default_settings.flatten())
|
||||
self.user_id = gr.State(None)
|
||||
|
||||
def register_indices(self):
|
||||
"""Register the index components from app settings"""
|
||||
index = import_dotted_string(settings.KH_INDEX, safe=False)
|
||||
user_settings = index().get_user_settings()
|
||||
for key, value in user_settings.items():
|
||||
self.default_settings.index.settings[key] = SettingItem(**value)
|
||||
|
||||
def register_reasonings(self):
|
||||
"""Register the reasoning components from app settings"""
|
||||
if getattr(settings, "KH_REASONINGS", None) is None:
|
||||
return
|
||||
|
||||
for name, value in settings.KH_REASONINGS.items():
|
||||
reasoning_cls = import_dotted_string(value, safe=False)
|
||||
reasonings[name] = reasoning_cls
|
||||
options = reasoning_cls().get_user_settings()
|
||||
self.default_settings.reasoning.options[name] = BaseSettingGroup(
|
||||
settings=options
|
||||
)
|
||||
|
||||
def register_extensions(self):
|
||||
"""Register installed extensions"""
|
||||
self.exman = pluggy.PluginManager("ktem")
|
||||
self.exman.add_hookspecs(extension_protocol)
|
||||
self.exman.load_setuptools_entrypoints("ktem")
|
||||
|
||||
# retrieve and register extension declarations
|
||||
extension_declarations = self.exman.hook.ktem_declare_extensions()
|
||||
for extension_declaration in extension_declarations:
|
||||
# if already in database, with the same version: skip
|
||||
|
||||
# otherwise,
|
||||
# remove the old information from the database if it exists
|
||||
# store the information into the database
|
||||
|
||||
functionality = extension_declaration["functionality"]
|
||||
|
||||
# update the reasoning information
|
||||
if "reasoning" in functionality:
|
||||
for rid, rdec in functionality["reasoning"].items():
|
||||
unique_rid = f"{extension_declaration['id']}/{rid}"
|
||||
self.default_settings.reasoning.options[
|
||||
unique_rid
|
||||
] = BaseSettingGroup(
|
||||
settings=rdec["settings"],
|
||||
)
|
||||
|
||||
def declare_event(self, name: str):
|
||||
"""Declare a public gradio event for other components to subscribe to
|
||||
|
||||
Args:
|
||||
name: The name of the event
|
||||
"""
|
||||
if name in self._events:
|
||||
raise HookAlreadyDeclared(f"Hook {name} is already declared")
|
||||
self._events[name] = []
|
||||
|
||||
def subscribe_event(self, name: str, definition: dict):
|
||||
"""Register a hook for the app
|
||||
|
||||
Args:
|
||||
name: The name of the hook
|
||||
hook: The hook to be registered
|
||||
"""
|
||||
if name not in self._events:
|
||||
raise HookNotDeclared(f"Hook {name} is not declared")
|
||||
self._events[name].append(definition)
|
||||
|
||||
def get_event(self, name) -> list[dict]:
|
||||
if name not in self._events:
|
||||
raise HookNotDeclared(f"Hook {name} is not declared")
|
||||
|
||||
return self._events[name]
|
||||
|
||||
def ui(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def make(self):
|
||||
with gr.Blocks(css=self._css) as demo:
|
||||
self.app = demo
|
||||
self.settings_state.render()
|
||||
self.user_id.render()
|
||||
|
||||
self.ui()
|
||||
|
||||
for value in self.__dict__.values():
|
||||
if isinstance(value, BasePage):
|
||||
value.declare_public_events()
|
||||
|
||||
for value in self.__dict__.values():
|
||||
if isinstance(value, BasePage):
|
||||
value.subscribe_public_events()
|
||||
|
||||
for value in self.__dict__.values():
|
||||
if isinstance(value, BasePage):
|
||||
value.register_events()
|
||||
|
||||
demo.load(lambda: None, None, None, js=f"() => {{{self._js}}}")
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
class BasePage:
|
||||
"""The logic of the Kotaemon app"""
|
||||
|
||||
public_events: list[str] = []
|
||||
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
|
||||
def on_building_ui(self):
|
||||
"""Build the UI of the app"""
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
"""Subscribe to the declared public event of the app"""
|
||||
|
||||
def on_register_events(self):
|
||||
"""Register all events to the app"""
|
||||
|
||||
def declare_public_events(self):
|
||||
"""Declare an event for the app"""
|
||||
for event in self.public_events:
|
||||
self._app.declare_event(event)
|
||||
|
||||
for value in self.__dict__.values():
|
||||
if isinstance(value, BasePage):
|
||||
value.declare_public_events()
|
||||
|
||||
def subscribe_public_events(self):
|
||||
"""Subscribe to an event"""
|
||||
self.on_subscribe_public_events()
|
||||
for value in self.__dict__.values():
|
||||
if isinstance(value, BasePage):
|
||||
value.subscribe_public_events()
|
||||
|
||||
def register_events(self):
|
||||
"""Register all events"""
|
||||
self.on_register_events()
|
||||
for value in self.__dict__.values():
|
||||
if isinstance(value, BasePage):
|
||||
value.register_events()
|
37
libs/ktem/ktem/assets/css/main.css
Normal file
37
libs/ktem/ktem/assets/css/main.css
Normal file
@@ -0,0 +1,37 @@
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.gradio-container {
|
||||
max-width: 100% !important;
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
.header-bar {
|
||||
background-color: #f7f7f7;
|
||||
margin: 0px 0px 20px;
|
||||
overflow-x: scroll;
|
||||
display: block !important;
|
||||
text-wrap: nowrap;
|
||||
}
|
||||
|
||||
.dark .header-bar {
|
||||
border: none !important;
|
||||
background-color: #8080802b !important;
|
||||
}
|
||||
|
||||
.header-bar button.selected {
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
#chat-tab, #settings-tab, #help-tab {
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
#main-chat-bot {
|
||||
height: calc(100vh - 140px) !important;
|
||||
}
|
||||
|
||||
.setting-answer-mode-description {
|
||||
margin: 5px 5px 2px !important
|
||||
}
|
6
libs/ktem/ktem/assets/js/main.js
Normal file
6
libs/ktem/ktem/assets/js/main.js
Normal file
@@ -0,0 +1,6 @@
|
||||
let main_parent = document.getElementById("chat-tab").parentNode;
|
||||
|
||||
main_parent.childNodes[0].classList.add("header-bar");
|
||||
main_parent.style = "padding: 0; margin: 0";
|
||||
main_parent.parentNode.style = "gap: 0";
|
||||
main_parent.parentNode.parentNode.style = "padding: 0";
|
25
libs/ktem/ktem/assets/md/about_cinnamon.md
Normal file
25
libs/ktem/ktem/assets/md/about_cinnamon.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# About Cinnamon AI
|
||||
|
||||
Welcome to **Cinnamon AI**, a pioneering force in the field of artificial intelligence and document processing. At Cinnamon AI, we are committed to revolutionizing the way businesses handle information, leveraging cutting-edge technologies to streamline and automate data extraction processes.
|
||||
|
||||
## Our Mission
|
||||
|
||||
At the core of our mission is the pursuit of innovation that simplifies complex tasks. We strive to empower organizations with transformative AI solutions that enhance efficiency, accuracy, and productivity. Cinnamon AI is dedicated to bridging the gap between human intelligence and machine capabilities, making data extraction and analysis seamless and intuitive.
|
||||
|
||||
## Key Highlights
|
||||
|
||||
- **Advanced Technology:** Cinnamon AI specializes in harnessing the power of natural language processing (NLP) and machine learning to develop sophisticated solutions for document understanding and data extraction.
|
||||
|
||||
- **Industry Impact:** We cater to diverse industries, providing tailor-made AI solutions that address the unique challenges and opportunities within each sector. From finance to healthcare, our technology is designed to make a meaningful impact.
|
||||
|
||||
- **Global Presence:** With a global perspective, Cinnamon AI operates on an international scale, collaborating with businesses and enterprises around the world to elevate their data processing capabilities.
|
||||
|
||||
## Why Choose Cinnamon AI
|
||||
|
||||
- **Innovation:** Our commitment to innovation is evident in our continual pursuit of technological excellence. We stay ahead of the curve to deliver solutions that meet the evolving needs of the digital landscape.
|
||||
|
||||
- **Reliability:** Clients trust Cinnamon AI for reliable, accurate, and scalable AI solutions. Our track record speaks to our dedication to quality and customer satisfaction.
|
||||
|
||||
- **Collaboration:** We believe in the power of collaboration. By working closely with our clients, we tailor our solutions to their specific requirements, fostering long-term partnerships built on mutual success.
|
||||
|
||||
Explore the future of data processing with Cinnamon AI – where intelligence meets innovation.
|
19
libs/ktem/ktem/assets/md/about_kotaemon.md
Normal file
19
libs/ktem/ktem/assets/md/about_kotaemon.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# About Kotaemon
|
||||
|
||||
Welcome to the future of language technology – Cinnamon AI proudly presents our latest innovation, **Kotaemon**. At Cinnamon AI, we believe in pushing the boundaries of what's possible with natural language processing, and Kotaemon embodies the pinnacle of our endeavors. Designed to empower businesses and developers alike, Kotaemon is not just a product; it's a manifestation of our commitment to enhancing human-machine interaction.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Cognitive Understanding:** Kotaemon boasts advanced cognitive understanding capabilities, allowing it to interpret and respond to natural language queries with unprecedented accuracy. Whether you're building chatbots, virtual assistants, or language-driven applications, Kotaemon ensures a nuanced and contextually rich user experience.
|
||||
|
||||
- **Versatility:** From analyzing vast textual datasets to generating coherent and contextually relevant responses, Kotaemon adapts seamlessly to diverse use cases. Whether you're in customer support, content creation, or data analysis, Kotaemon is your versatile companion in navigating the linguistic landscape.
|
||||
|
||||
- **Scalability:** Built with scalability in mind, Kotaemon is designed to meet the evolving needs of your business. As your language-related tasks grow in complexity, Kotaemon scales with you, providing a robust foundation for future innovation and expansion.
|
||||
|
||||
- **Ethical AI:** Cinnamon AI is committed to responsible and ethical AI development. Kotaemon reflects our dedication to fairness, transparency, and unbiased language processing, ensuring that your applications uphold the highest ethical standards.
|
||||
|
||||
## Why Kotaemon?
|
||||
|
||||
Kotaemon is not just a tool; it's a catalyst for unlocking the true potential of natural language understanding. Whether you're a developer aiming to enhance user experiences or a business leader seeking to leverage language technology, Kotaemon is your partner in navigating the intricacies of human communication.
|
||||
|
||||
Join us on this transformative journey with Kotaemon – where language meets innovation, and understanding becomes seamless. Cinnamon AI: Redefining the future of natural language processing.
|
11
libs/ktem/ktem/assets/md/changelogs.md
Normal file
11
libs/ktem/ktem/assets/md/changelogs.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# Changelogs
|
||||
|
||||
## v1.0.0
|
||||
|
||||
- Chat: interact with chatbot with simple pipeline, rewoo and react agents
|
||||
- Chat: conversation management: create, delete, rename conversations
|
||||
- Files: upload files
|
||||
- Files: select files as context for chatbot
|
||||
- User management: create, sign-in, sign-out, change password
|
||||
- Setting: common settings and pipeline-based settings
|
||||
- Info panel: show Cinnamon AI and Kotaemon information
|
162
libs/ktem/ktem/components.py
Normal file
162
libs/ktem/ktem/components.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Common components, some kind of config"""
|
||||
import logging
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||
from theflow.settings import settings
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
filestorage_path = Path(settings.KH_FILESTORAGE_PATH)
|
||||
filestorage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@cache
|
||||
def get_docstore() -> BaseDocumentStore:
|
||||
return deserialize(settings.KH_DOCSTORE, safe=False)
|
||||
|
||||
|
||||
@cache
|
||||
def get_vectorstore() -> BaseVectorStore:
|
||||
return deserialize(settings.KH_VECTORSTORE, safe=False)
|
||||
|
||||
|
||||
class ModelPool:
|
||||
"""Represent a pool of models"""
|
||||
|
||||
def __init__(self, category: str, conf: dict):
|
||||
self._category = category
|
||||
self._conf = conf
|
||||
|
||||
self._models: dict[str, BaseComponent] = {}
|
||||
self._accuracy: list[str] = []
|
||||
self._cost: list[str] = []
|
||||
self._default: list[str] = []
|
||||
|
||||
for name, model in conf.items():
|
||||
self._models[name] = deserialize(model["def"], safe=False)
|
||||
if model.get("default", False):
|
||||
self._default.append(name)
|
||||
|
||||
self._accuracy = list(
|
||||
sorted(conf, key=lambda x: conf[x].get("accuracy", float("-inf")))
|
||||
)
|
||||
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
|
||||
|
||||
def __getitem__(self, key: str) -> BaseComponent:
|
||||
return self._models[key]
|
||||
|
||||
def __setitem__(self, key: str, value: BaseComponent):
|
||||
self._models[key] = value
|
||||
|
||||
def settings(self) -> dict:
|
||||
"""Present model pools option for gradio"""
|
||||
return {
|
||||
"label": self._category,
|
||||
"choices": list(self._models.keys()),
|
||||
"value": self.get_default_name(),
|
||||
}
|
||||
|
||||
def options(self) -> dict:
|
||||
"""Present a list of models"""
|
||||
return self._models
|
||||
|
||||
def get_random_name(self) -> str:
|
||||
"""Get the name of random model
|
||||
|
||||
Returns:
|
||||
str: random model name in the pool
|
||||
"""
|
||||
import random
|
||||
|
||||
if not self._conf:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
return random.choice(list(self._conf.keys()))
|
||||
|
||||
def get_default_name(self) -> str:
|
||||
"""Get the name of default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._conf:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
if self._default:
|
||||
import random
|
||||
|
||||
return random.choice(self._default)
|
||||
|
||||
return self.get_random_name()
|
||||
|
||||
def get_random(self) -> BaseComponent:
|
||||
"""Get random model"""
|
||||
return self._models[self.get_random_name()]
|
||||
|
||||
def get_default(self) -> BaseComponent:
|
||||
"""Get default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
BaseComponent: model
|
||||
"""
|
||||
return self._models[self.get_default_name()]
|
||||
|
||||
def get_highest_accuracy_name(self) -> str:
|
||||
"""Get the name of model with highest accuracy
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._conf:
|
||||
raise ValueError("No models in pool")
|
||||
return self._accuracy[-1]
|
||||
|
||||
def get_highest_accuracy(self) -> BaseComponent:
|
||||
"""Get model with highest accuracy
|
||||
|
||||
Returns:
|
||||
BaseComponent: model
|
||||
"""
|
||||
if not self._conf:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
return self._models[self._accuracy[-1]]
|
||||
|
||||
def get_lowest_cost_name(self) -> str:
|
||||
"""Get the name of model with lowest cost
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._conf:
|
||||
raise ValueError("No models in pool")
|
||||
return self._cost[0]
|
||||
|
||||
def get_lowest_cost(self) -> BaseComponent:
|
||||
"""Get model with lowest cost
|
||||
|
||||
Returns:
|
||||
BaseComponent: model
|
||||
"""
|
||||
if not self._conf:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
return self._models[self._cost[0]]
|
||||
|
||||
|
||||
llms = ModelPool("LLMs", settings.KH_LLMS)
|
||||
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
|
||||
reasonings: dict = {}
|
||||
tools = ModelPool("Tools", {})
|
||||
indices = ModelPool("Indices", {})
|
0
libs/ktem/ktem/db/__init__.py
Normal file
0
libs/ktem/ktem/db/__init__.py
Normal file
4
libs/ktem/ktem/db/engine.py
Normal file
4
libs/ktem/ktem/db/engine.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from sqlmodel import create_engine
|
||||
from theflow.settings import settings
|
||||
|
||||
engine = create_engine(settings.KH_DATABASE)
|
97
libs/ktem/ktem/db/models.py
Normal file
97
libs/ktem/ktem/db/models.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Source(SQLModel, table=True):
|
||||
"""The source of the document
|
||||
|
||||
Attributes:
|
||||
id: id of the source
|
||||
name: name of the source
|
||||
path: path to the source
|
||||
"""
|
||||
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
|
||||
)
|
||||
name: str
|
||||
path: str
|
||||
|
||||
|
||||
class SourceTargetRelation(str, Enum):
|
||||
DOCUMENT = "document"
|
||||
VECTOR = "vector"
|
||||
|
||||
|
||||
class Index(SQLModel, table=True):
|
||||
"""The index pointing from the original id to the target id"""
|
||||
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True, index=True)
|
||||
source_id: str
|
||||
target_id: str
|
||||
relation_type: Optional[SourceTargetRelation] = Field(default=None)
|
||||
|
||||
|
||||
class Conversation(SQLModel, table=True):
|
||||
"""Conversation record"""
|
||||
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
|
||||
)
|
||||
name: str = Field(
|
||||
default_factory=lambda: datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
user: int = Field(default=0) # For now we only have one user
|
||||
|
||||
# contains messages + current files
|
||||
data_source: dict = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
date_created: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
date_updated: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
username: str = Field(unique=True)
|
||||
password: str
|
||||
|
||||
|
||||
class Settings(SQLModel, table=True):
|
||||
"""Record of settings"""
|
||||
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
|
||||
)
|
||||
user: int = Field(default=0)
|
||||
setting: dict = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
|
||||
class IssueReport(SQLModel, table=True):
|
||||
"""Record of issues"""
|
||||
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
issues: dict = Field(default={}, sa_column=Column(JSON))
|
||||
chat: Optional[dict] = Field(default=None, sa_column=Column(JSON))
|
||||
settings: Optional[dict] = Field(default=None, sa_column=Column(JSON))
|
||||
user: Optional[int] = Field(default=None)
|
||||
|
||||
|
||||
SQLModel.metadata.create_all(engine)
|
10
libs/ktem/ktem/exceptions.py
Normal file
10
libs/ktem/ktem/exceptions.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class KHException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HookNotDeclared(KHException):
|
||||
pass
|
||||
|
||||
|
||||
class HookAlreadyDeclared(KHException):
|
||||
pass
|
39
libs/ktem/ktem/extension_protocol.py
Normal file
39
libs/ktem/ktem/extension_protocol.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pluggy
|
||||
|
||||
hookspec = pluggy.HookspecMarker("ktem")
|
||||
hookimpl = pluggy.HookimplMarker("ktem")
|
||||
|
||||
|
||||
@hookspec
|
||||
def ktem_declare_extensions() -> dict: # type: ignore
|
||||
"""Called before the run() function is executed.
|
||||
|
||||
This hook is called without any arguments, and should return a dictionary.
|
||||
The dictionary has the following structure:
|
||||
|
||||
```
|
||||
{
|
||||
"id": str, # cannot contain . or /
|
||||
"name": str, # human-friendly name of the plugin
|
||||
"version": str,
|
||||
"support_host": str,
|
||||
"functionality": {
|
||||
"reasoning": {
|
||||
id: { # cannot contain . or /
|
||||
"name": str,
|
||||
"callbacks": {},
|
||||
"settings": {},
|
||||
},
|
||||
},
|
||||
"index": {
|
||||
"name": str,
|
||||
"callbacks": {
|
||||
"get_index_pipeline": callable,
|
||||
"get_retrievers": {name: callable}
|
||||
},
|
||||
"settings": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
0
libs/ktem/ktem/indexing/__init__.py
Normal file
0
libs/ktem/ktem/indexing/__init__.py
Normal file
16
libs/ktem/ktem/indexing/base.py
Normal file
16
libs/ktem/ktem/indexing/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from kotaemon.base import BaseComponent
|
||||
|
||||
|
||||
class BaseIndex(BaseComponent):
|
||||
def get_user_settings(self) -> dict:
|
||||
"""Get the user settings for indexing
|
||||
|
||||
Returns:
|
||||
dict: user settings in the dictionary format of
|
||||
`ktem.settings.SettingItem`
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, setting: dict) -> "BaseIndex":
|
||||
raise NotImplementedError
|
5
libs/ktem/ktem/indexing/exceptions.py
Normal file
5
libs/ktem/ktem/indexing/exceptions.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ktem.exceptions import KHException
|
||||
|
||||
|
||||
class FileExistsError(KHException):
|
||||
pass
|
182
libs/ktem/ktem/indexing/file.py
Normal file
182
libs/ktem/ktem/indexing/file.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
|
||||
from ktem.components import embeddings, filestorage_path, get_docstore, get_vectorstore
|
||||
from ktem.db.models import Index, Source, SourceTargetRelation, engine
|
||||
from ktem.indexing.base import BaseIndex
|
||||
from ktem.indexing.exceptions import FileExistsError
|
||||
from kotaemon.indices import VectorIndexing
|
||||
from kotaemon.indices.ingests import DocumentIngestor
|
||||
from sqlmodel import Session, select
|
||||
|
||||
USER_SETTINGS = {
|
||||
"index_parser": {
|
||||
"name": "Index parser",
|
||||
"value": "normal",
|
||||
"choices": [
|
||||
("PDF text parser", "normal"),
|
||||
("Mathpix", "mathpix"),
|
||||
("Advanced ocr", "ocr"),
|
||||
],
|
||||
"component": "dropdown",
|
||||
},
|
||||
"separate_embedding": {
|
||||
"name": "Use separate embedding",
|
||||
"value": False,
|
||||
"choices": [("Yes", True), ("No", False)],
|
||||
"component": "dropdown",
|
||||
},
|
||||
"num_retrieval": {
|
||||
"name": "Number of documents to retrieve",
|
||||
"value": 3,
|
||||
"component": "number",
|
||||
},
|
||||
"retrieval_mode": {
|
||||
"name": "Retrieval mode",
|
||||
"value": "vector",
|
||||
"choices": ["vector", "text", "hybrid"],
|
||||
"component": "dropdown",
|
||||
},
|
||||
"prioritize_table": {
|
||||
"name": "Prioritize table",
|
||||
"value": True,
|
||||
"choices": [True, False],
|
||||
"component": "checkbox",
|
||||
},
|
||||
"mmr": {
|
||||
"name": "Use MMR",
|
||||
"value": True,
|
||||
"choices": [True, False],
|
||||
"component": "checkbox",
|
||||
},
|
||||
"use_reranking": {
|
||||
"name": "Use reranking",
|
||||
"value": True,
|
||||
"choices": [True, False],
|
||||
"component": "checkbox",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class IndexDocumentPipeline(BaseIndex):
|
||||
"""Store the documents and index the content into vector store and doc store
|
||||
|
||||
Args:
|
||||
indexing_vector_pipeline: pipeline to index the documents
|
||||
file_ingestor: ingestor to ingest the documents
|
||||
"""
|
||||
|
||||
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx(
|
||||
doc_store=get_docstore(),
|
||||
vector_store=get_vectorstore(),
|
||||
embedding=embeddings.get_default(),
|
||||
)
|
||||
file_ingestor: DocumentIngestor = DocumentIngestor.withx()
|
||||
|
||||
def run(
|
||||
self,
|
||||
file_paths: str | Path | list[str | Path],
|
||||
reindex: bool = False,
|
||||
**kwargs, # type: ignore
|
||||
):
|
||||
"""Index the list of documents
|
||||
|
||||
This function will extract the files, persist the files to storage,
|
||||
index the files.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths to index
|
||||
reindex: whether to force reindexing the files if they exist
|
||||
|
||||
Returns:
|
||||
list of split nodes
|
||||
"""
|
||||
if not isinstance(file_paths, list):
|
||||
file_paths = [file_paths]
|
||||
|
||||
to_index: list[str] = []
|
||||
file_to_hash: dict[str, str] = {}
|
||||
errors = []
|
||||
|
||||
for file_path in file_paths:
|
||||
abs_path = str(Path(file_path).resolve())
|
||||
with open(abs_path, "rb") as fi:
|
||||
file_hash = sha256(fi.read()).hexdigest()
|
||||
|
||||
file_to_hash[abs_path] = file_hash
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Source).where(Source.name == Path(abs_path).name)
|
||||
item = session.exec(statement).first()
|
||||
|
||||
if item and not reindex:
|
||||
errors.append(Path(abs_path).name)
|
||||
continue
|
||||
|
||||
to_index.append(abs_path)
|
||||
|
||||
if errors:
|
||||
raise FileExistsError(
|
||||
"Files already exist. Please rename/remove them or enable reindex.\n"
|
||||
f"{errors}"
|
||||
)
|
||||
|
||||
# persist the files to storage
|
||||
for path in to_index:
|
||||
shutil.copy(path, filestorage_path / file_to_hash[path])
|
||||
|
||||
# prepare record info
|
||||
file_to_source: dict[str, Source] = {}
|
||||
for file_path, file_hash in file_to_hash.items():
|
||||
source = Source(path=file_hash, name=Path(file_path).name)
|
||||
file_to_source[file_path] = source
|
||||
|
||||
# extract the files
|
||||
nodes = self.file_ingestor(to_index)
|
||||
for node in nodes:
|
||||
file_path = str(node.metadata["file_path"])
|
||||
node.source = file_to_source[file_path].id
|
||||
|
||||
# index the files
|
||||
self.indexing_vector_pipeline(nodes)
|
||||
|
||||
# persist to the index
|
||||
file_ids = []
|
||||
with Session(engine) as session:
|
||||
for source in file_to_source.values():
|
||||
session.add(source)
|
||||
session.commit()
|
||||
for source in file_to_source.values():
|
||||
file_ids.append(source.id)
|
||||
|
||||
with Session(engine) as session:
|
||||
for node in nodes:
|
||||
index = Index(
|
||||
source_id=node.source,
|
||||
target_id=node.doc_id,
|
||||
relation_type=SourceTargetRelation.DOCUMENT,
|
||||
)
|
||||
session.add(index)
|
||||
for node in nodes:
|
||||
index = Index(
|
||||
source_id=node.source,
|
||||
target_id=node.doc_id,
|
||||
relation_type=SourceTargetRelation.VECTOR,
|
||||
)
|
||||
session.add(index)
|
||||
session.commit()
|
||||
|
||||
return nodes, file_ids
|
||||
|
||||
def get_user_settings(self) -> dict:
|
||||
return USER_SETTINGS
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, setting) -> "IndexDocumentPipeline":
|
||||
"""Get the pipeline based on the setting"""
|
||||
obj = cls()
|
||||
obj.file_ingestor.pdf_mode = setting["index.index_parser"]
|
||||
return obj
|
31
libs/ktem/ktem/main.py
Normal file
31
libs/ktem/ktem/main.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BaseApp
|
||||
from ktem.pages.chat import ChatPage
|
||||
from ktem.pages.help import HelpPage
|
||||
from ktem.pages.settings import SettingsPage
|
||||
|
||||
|
||||
class App(BaseApp):
|
||||
"""The main app of Kotaemon
|
||||
|
||||
The main application contains app-level information:
|
||||
- setting state
|
||||
- user id
|
||||
|
||||
App life-cycle:
|
||||
- Render
|
||||
- Declare public events
|
||||
- Subscribe public events
|
||||
- Register events
|
||||
"""
|
||||
|
||||
def ui(self):
|
||||
"""Render the UI"""
|
||||
with gr.Tab("Chat", elem_id="chat-tab"):
|
||||
self.chat_page = ChatPage(self)
|
||||
|
||||
with gr.Tab("Settings", elem_id="settings-tab"):
|
||||
self.settings_page = SettingsPage(self)
|
||||
|
||||
with gr.Tab("Help", elem_id="help-tab"):
|
||||
self.help_page = HelpPage(self)
|
0
libs/ktem/ktem/pages/__init__.py
Normal file
0
libs/ktem/ktem/pages/__init__.py
Normal file
125
libs/ktem/ktem/pages/chat/__init__.py
Normal file
125
libs/ktem/ktem/pages/chat/__init__.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
from .chat_panel import ChatPanel
|
||||
from .control import ConversationControl
|
||||
from .data_source import DataSource
|
||||
from .events import chat_fn, index_fn, is_liked, load_files, update_data_source
|
||||
from .report import ReportIssue
|
||||
from .upload import FileUpload
|
||||
|
||||
|
||||
class ChatPage(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
self.data_source = DataSource(self._app)
|
||||
self.file_upload = FileUpload(self._app)
|
||||
self.report_issue = ReportIssue(self._app)
|
||||
with gr.Column(scale=6):
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
|
||||
def on_register_events(self):
|
||||
self.chat_panel.submit_btn.click(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
).then(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_panel.text_input.submit(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
).then(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_panel.chatbot.like(
|
||||
fn=is_liked,
|
||||
inputs=[self.chat_control.conversation_id],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_control.conversation.change(
|
||||
self.chat_control.select_conv,
|
||||
inputs=[self.chat_control.conversation],
|
||||
outputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.report_issue.report_btn.click(
|
||||
self.report_issue.report,
|
||||
inputs=[
|
||||
self.report_issue.correctness,
|
||||
self.report_issue.issues,
|
||||
self.report_issue.more_detail,
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.data_source.files.input(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.data_source.files,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.file_upload.upload_button.click(
|
||||
fn=index_fn,
|
||||
inputs=[
|
||||
self.file_upload.files,
|
||||
self.file_upload.reindex,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.file_upload.file_output, self.data_source.files],
|
||||
)
|
||||
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=load_files()),
|
||||
inputs=None,
|
||||
outputs=[self.data_source.files],
|
||||
)
|
21
libs/ktem/ktem/pages/chat/chat_panel.py
Normal file
21
libs/ktem/ktem/pages/chat/chat_panel.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
|
||||
class ChatPanel(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
self.chatbot = gr.Chatbot(
|
||||
elem_id="main-chat-bot",
|
||||
show_copy_button=True,
|
||||
likeable=True,
|
||||
show_label=False,
|
||||
)
|
||||
with gr.Row():
|
||||
self.text_input = gr.Text(
|
||||
placeholder="Chat input", scale=15, container=False
|
||||
)
|
||||
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
|
193
libs/ktem/ktem/pages/chat/control.py
Normal file
193
libs/ktem/ktem/pages/chat/control.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import Conversation, engine
|
||||
from sqlmodel import Session, select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationControl(BasePage):
|
||||
"""Manage conversation"""
|
||||
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Conversation control", open=True):
|
||||
self.conversation_id = gr.State(value="")
|
||||
self.conversation = gr.Dropdown(
|
||||
label="Chat sessions",
|
||||
choices=[],
|
||||
container=False,
|
||||
filterable=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.conversation_new_btn = gr.Button(value="New", min_width=10)
|
||||
self.conversation_del_btn = gr.Button(value="Delete", min_width=10)
|
||||
|
||||
with gr.Row():
|
||||
self.conversation_rn = gr.Text(
|
||||
placeholder="Conversation name",
|
||||
container=False,
|
||||
scale=5,
|
||||
min_width=10,
|
||||
interactive=True,
|
||||
)
|
||||
self.conversation_rn_btn = gr.Button(
|
||||
value="Rename", scale=1, min_width=10
|
||||
)
|
||||
|
||||
# current_state = gr.Text()
|
||||
# show_current_state = gr.Button(value="Current")
|
||||
# show_current_state.click(
|
||||
# lambda a, b: "\n".join([a, b]),
|
||||
# inputs=[cid, self.conversation],
|
||||
# outputs=[current_state],
|
||||
# )
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
self._app.subscribe_event(
|
||||
name="onCreateUser",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
self.conversation_new_btn.click(
|
||||
self.new_conv,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.conversation_id, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.conversation_del_btn.click(
|
||||
self.delete_conv,
|
||||
inputs=[self.conversation_id, self._app.user_id],
|
||||
outputs=[self.conversation_id, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.conversation_rn_btn.click(
|
||||
self.rename_conv,
|
||||
inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
|
||||
outputs=[self.conversation, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def load_chat_history(self, user_id):
|
||||
"""Reload chat history"""
|
||||
options = []
|
||||
with Session(engine) as session:
|
||||
statement = (
|
||||
select(Conversation)
|
||||
.where(Conversation.user == user_id)
|
||||
.order_by(Conversation.date_created.desc()) # type: ignore
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
for result in results:
|
||||
options.append((result.name, result.id))
|
||||
|
||||
# return gr.update(choices=options)
|
||||
return options
|
||||
|
||||
def reload_conv(self, user_id):
|
||||
conv_list = self.load_chat_history(user_id)
|
||||
if conv_list:
|
||||
return gr.update(value=conv_list[0][1], choices=conv_list)
|
||||
else:
|
||||
return gr.update(value=None, choices=[])
|
||||
|
||||
def new_conv(self, user_id):
|
||||
"""Create new chat"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return None, gr.update()
|
||||
with Session(engine) as session:
|
||||
new_conv = Conversation(user=user_id)
|
||||
session.add(new_conv)
|
||||
session.commit()
|
||||
|
||||
id_ = new_conv.id
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
|
||||
return id_, gr.update(value=id_, choices=history)
|
||||
|
||||
def delete_conv(self, conversation_id, user_id):
|
||||
"""Create new chat"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return None, gr.update()
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
session.delete(result)
|
||||
session.commit()
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
if history:
|
||||
id_ = history[0][1]
|
||||
return id_, gr.update(value=id_, choices=history)
|
||||
else:
|
||||
return None, gr.update(value=None, choices=[])
|
||||
|
||||
def select_conv(self, conversation_id):
|
||||
"""Select the conversation"""
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
try:
|
||||
result = session.exec(statement).one()
|
||||
id_ = result.id
|
||||
name = result.name
|
||||
files = result.data_source.get("files", [])
|
||||
chats = result.data_source.get("messages", [])
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
id_ = ""
|
||||
name = ""
|
||||
files = []
|
||||
chats = []
|
||||
return id_, id_, name, files, chats
|
||||
|
||||
def rename_conv(self, conversation_id, new_name, user_id):
|
||||
"""Rename the conversation"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return gr.update(), ""
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
result.name = new_name
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
return gr.update(choices=history), conversation_id
|
18
libs/ktem/ktem/pages/chat/data_source.py
Normal file
18
libs/ktem/ktem/pages/chat/data_source.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
|
||||
class DataSource(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Data source", open=True):
|
||||
self.files = gr.Dropdown(
|
||||
label="Files",
|
||||
choices=[],
|
||||
multiselect=True,
|
||||
container=False,
|
||||
interactive=True,
|
||||
)
|
220
libs/ktem/ktem/pages/chat/events.py
Normal file
220
libs/ktem/ktem/pages/chat/events.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.components import llms, reasonings
|
||||
from ktem.db.models import Conversation, Source, engine
|
||||
from ktem.indexing.base import BaseIndex
|
||||
from ktem.reasoning.simple import DocumentRetrievalPipeline
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as app_settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||
"""Create the pipeline from settings
|
||||
|
||||
Args:
|
||||
settings: the settings of the app
|
||||
files: the list of file ids that will be served as context. If None, then
|
||||
consider using all files
|
||||
|
||||
Returns:
|
||||
the pipeline objects
|
||||
"""
|
||||
|
||||
reasoning_mode = settings["reasoning.use"]
|
||||
reasoning_cls = reasonings[reasoning_mode]
|
||||
pipeline = reasoning_cls.get_pipeline(settings, files=files)
|
||||
|
||||
if settings["reasoning.use"] in ["rewoo", "react"]:
|
||||
from kotaemon.agents import ReactAgent, RewooAgent
|
||||
|
||||
llm = (
|
||||
llms["gpt4"]
|
||||
if settings["answer_simple_llm_model"] == "gpt-4"
|
||||
else llms["gpt35"]
|
||||
)
|
||||
tools = []
|
||||
tools_keys = (
|
||||
"answer_rewoo_tools"
|
||||
if settings["reasoning.use"] == "rewoo"
|
||||
else "answer_react_tools"
|
||||
)
|
||||
for tool in settings[tools_keys]:
|
||||
if tool == "llm":
|
||||
from kotaemon.agents import LLMTool
|
||||
|
||||
tools.append(LLMTool(llm=llm))
|
||||
elif tool == "docsearch":
|
||||
from kotaemon.agents import ComponentTool
|
||||
|
||||
filenames = ""
|
||||
if files:
|
||||
with Session(engine) as session:
|
||||
statement = select(Source).where(
|
||||
Source.id.in_(files) # type: ignore
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
filenames = (
|
||||
"The file names are: "
|
||||
+ " ".join([result.name for result in results])
|
||||
+ ". "
|
||||
)
|
||||
|
||||
retrieval_pipeline = DocumentRetrievalPipeline()
|
||||
retrieval_pipeline.set_run(
|
||||
{
|
||||
".top_k": int(settings["retrieval_number"]),
|
||||
".mmr": settings["retrieval_mmr"],
|
||||
".doc_ids": files,
|
||||
},
|
||||
temp=True,
|
||||
)
|
||||
tool = ComponentTool(
|
||||
name="docsearch",
|
||||
description=(
|
||||
"A vector store that searches for similar and "
|
||||
"related content "
|
||||
f"in a document. {filenames}"
|
||||
"The result is a huge chunk of text related "
|
||||
"to your search but can also "
|
||||
"contain irrelevant info."
|
||||
),
|
||||
component=retrieval_pipeline,
|
||||
postprocessor=lambda docs: "\n\n".join(
|
||||
[doc.text.replace("\n", " ") for doc in docs]
|
||||
),
|
||||
)
|
||||
tools.append(tool)
|
||||
elif tool == "google":
|
||||
from kotaemon.agents import GoogleSearchTool
|
||||
|
||||
tools.append(GoogleSearchTool())
|
||||
elif tool == "wikipedia":
|
||||
from kotaemon.agents import WikipediaTool
|
||||
|
||||
tools.append(WikipediaTool())
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown tool: {tool}")
|
||||
|
||||
if settings["reasoning.use"] == "rewoo":
|
||||
pipeline = RewooAgent(
|
||||
planner_llm=llm,
|
||||
solver_llm=llm,
|
||||
plugins=tools,
|
||||
)
|
||||
pipeline.set_run({".use_citation": True})
|
||||
else:
|
||||
pipeline = ReactAgent(
|
||||
llm=llm,
|
||||
plugins=tools,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def chat_fn(chat_input, chat_history, files, settings):
|
||||
pipeline = create_pipeline(settings, files)
|
||||
|
||||
text = ""
|
||||
refs = []
|
||||
for response in pipeline(chat_input):
|
||||
if response.metadata.get("citation", None):
|
||||
citation = response.metadata["citation"]
|
||||
for idx, fact_with_evidence in enumerate(citation.answer):
|
||||
quotes = fact_with_evidence.substring_quote
|
||||
if quotes:
|
||||
refs.append(
|
||||
(None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
|
||||
)
|
||||
else:
|
||||
text += response.text
|
||||
|
||||
yield "", chat_history + [(chat_input, text)] + refs
|
||||
|
||||
|
||||
def is_liked(convo_id, liked: gr.LikeData):
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
data_source = deepcopy(result.data_source)
|
||||
likes = data_source.get("likes", [])
|
||||
likes.append([liked.index, liked.value, liked.liked])
|
||||
data_source["likes"] = likes
|
||||
|
||||
result.data_source = data_source
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_data_source(convo_id, selected_files, messages):
|
||||
"""Update the data source"""
|
||||
if not convo_id:
|
||||
gr.Warning("No conversation selected")
|
||||
return
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
data_source = result.data_source
|
||||
result.data_source = {
|
||||
"files": selected_files,
|
||||
"messages": messages,
|
||||
"likes": deepcopy(data_source.get("likes", [])),
|
||||
}
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
|
||||
def load_files():
|
||||
options = []
|
||||
with Session(engine) as session:
|
||||
statement = select(Source)
|
||||
results = session.exec(statement).all()
|
||||
for result in results:
|
||||
options.append((result.name, result.id))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def index_fn(files, reindex: bool, selected_files, settings):
|
||||
"""Upload and index the files
|
||||
|
||||
Args:
|
||||
files: the list of files to be uploaded
|
||||
reindex: whether to reindex the files
|
||||
selected_files: the list of files already selected
|
||||
settings: the settings of the app
|
||||
"""
|
||||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
||||
|
||||
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
||||
gr.Info(f"Finish indexing into {len(output_nodes)} chunks")
|
||||
|
||||
# download the file
|
||||
text = "\n\n".join([each.text for each in output_nodes])
|
||||
handler, file_path = tempfile.mkstemp(suffix=".txt")
|
||||
with open(file_path, "w") as f:
|
||||
f.write(text)
|
||||
os.close(handler)
|
||||
|
||||
if isinstance(selected_files, list):
|
||||
output = selected_files + file_ids
|
||||
else:
|
||||
output = file_ids
|
||||
|
||||
file_list = load_files()
|
||||
|
||||
return (
|
||||
gr.update(value=file_path, visible=True),
|
||||
gr.update(value=output, choices=file_list),
|
||||
)
|
70
libs/ktem/ktem/pages/chat/report.py
Normal file
70
libs/ktem/ktem/pages/chat/report.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import IssueReport, engine
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class ReportIssue(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Report", open=False):
|
||||
self.correctness = gr.Radio(
|
||||
choices=[
|
||||
("The answer is correct", "correct"),
|
||||
("The answer is incorrect", "incorrect"),
|
||||
],
|
||||
label="Correctness:",
|
||||
)
|
||||
self.issues = gr.CheckboxGroup(
|
||||
choices=[
|
||||
("The answer is offensive", "offensive"),
|
||||
("The evidence is incorrect", "wrong-evidence"),
|
||||
],
|
||||
label="Other issue:",
|
||||
)
|
||||
self.more_detail = gr.Textbox(
|
||||
placeholder="More detail (e.g. how wrong is it, what is the "
|
||||
"correct answer, etc...)",
|
||||
container=False,
|
||||
lines=3,
|
||||
)
|
||||
gr.Markdown(
|
||||
"This will send the current chat and the user settings to "
|
||||
"help with investigation"
|
||||
)
|
||||
self.report_btn = gr.Button("Report")
|
||||
|
||||
def report(
|
||||
self,
|
||||
correctness: str,
|
||||
issues: list[str],
|
||||
more_detail: str,
|
||||
conv_id: str,
|
||||
chat_history: list,
|
||||
files: list,
|
||||
settings: dict,
|
||||
user_id: Optional[int],
|
||||
):
|
||||
with Session(engine) as session:
|
||||
issue = IssueReport(
|
||||
issues={
|
||||
"correctness": correctness,
|
||||
"issues": issues,
|
||||
"more_detail": more_detail,
|
||||
},
|
||||
chat={
|
||||
"conv_id": conv_id,
|
||||
"chat_history": chat_history,
|
||||
"files": files,
|
||||
},
|
||||
settings=settings,
|
||||
user=user_id,
|
||||
)
|
||||
session.add(issue)
|
||||
session.commit()
|
||||
gr.Info("Thank you for your feedback")
|
43
libs/ktem/ktem/pages/chat/upload.py
Normal file
43
libs/ktem/ktem/pages/chat/upload.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
||||
|
||||
class FileUpload(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="File upload", open=False):
|
||||
gr.Markdown(
|
||||
"Supported file types: image, pdf, txt, csv, xlsx, docx.",
|
||||
)
|
||||
self.files = gr.File(
|
||||
file_types=["image", ".pdf", ".txt", ".csv", ".xlsx", ".docx"],
|
||||
file_count="multiple",
|
||||
container=False,
|
||||
height=50,
|
||||
)
|
||||
with gr.Accordion("Advanced indexing options", open=False):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.reindex = gr.Checkbox(
|
||||
value=False, label="Force reindex file", container=False
|
||||
)
|
||||
with gr.Column():
|
||||
self.parser = gr.Dropdown(
|
||||
choices=[
|
||||
("PDF text parser", "normal"),
|
||||
("lib-table", "table"),
|
||||
("lib-table + OCR", "ocr"),
|
||||
("MathPix", "mathpix"),
|
||||
],
|
||||
value="normal",
|
||||
label="Use advance PDF parser (table+layout preserving)",
|
||||
container=True,
|
||||
)
|
||||
|
||||
self.upload_button = gr.Button("Upload and Index")
|
||||
self.file_output = gr.File(
|
||||
visible=False, label="Output files (debug purpose)"
|
||||
)
|
24
libs/ktem/ktem/pages/help.py
Normal file
24
libs/ktem/ktem/pages/help.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class HelpPage:
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.dir_md = Path(__file__).parent.parent / "assets" / "md"
|
||||
|
||||
with gr.Accordion("Changelogs"):
|
||||
gr.Markdown(self.get_changelogs())
|
||||
|
||||
with gr.Accordion("About Kotaemon (temporary)"):
|
||||
with (self.dir_md / "about_kotaemon.md").open() as fi:
|
||||
gr.Markdown(fi.read())
|
||||
|
||||
with gr.Accordion("About Cinnamon AI (temporary)", open=False):
|
||||
with (self.dir_md / "about_cinnamon.md").open() as fi:
|
||||
gr.Markdown(fi.read())
|
||||
|
||||
def get_changelogs(self):
|
||||
with (self.dir_md / "changelogs.md").open() as fi:
|
||||
return fi.read()
|
414
libs/ktem/ktem/pages/settings.py
Normal file
414
libs/ktem/ktem/pages/settings.py
Normal file
@@ -0,0 +1,414 @@
|
||||
import hashlib
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import Settings, User, engine
|
||||
from sqlmodel import Session, select
|
||||
|
||||
gr_cls_single_value = {
|
||||
"text": gr.Textbox,
|
||||
"number": gr.Number,
|
||||
"checkbox": gr.Checkbox,
|
||||
}
|
||||
|
||||
|
||||
gr_cls_choices = {
|
||||
"dropdown": gr.Dropdown,
|
||||
"radio": gr.Radio,
|
||||
"checkboxgroup": gr.CheckboxGroup,
|
||||
}
|
||||
|
||||
|
||||
def render_setting_item(setting_item, value):
|
||||
"""Render the setting component into corresponding Gradio UI component"""
|
||||
kwargs = {
|
||||
"label": setting_item.name,
|
||||
"value": value,
|
||||
"interactive": True,
|
||||
}
|
||||
|
||||
if setting_item.component in gr_cls_single_value:
|
||||
return gr_cls_single_value[setting_item.component](**kwargs)
|
||||
|
||||
kwargs["choices"] = setting_item.choices
|
||||
|
||||
if setting_item.component in gr_cls_choices:
|
||||
return gr_cls_choices[setting_item.component](**kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown component {setting_item.component}, allowed are: "
|
||||
f"{list(gr_cls_single_value.keys()) + list(gr_cls_choices.keys())}.\n"
|
||||
f"Setting item: {setting_item}"
|
||||
)
|
||||
|
||||
|
||||
class SettingsPage(BasePage):
|
||||
"""Responsible for allowing the users to customize the application
|
||||
|
||||
**IMPORTANT**: the name and id of the UI setting components should match the
|
||||
name of the setting in the `app.default_settings`
|
||||
"""
|
||||
|
||||
public_events = ["onSignIn", "onSignOut", "onCreateUser"]
|
||||
|
||||
def __init__(self, app):
|
||||
"""Initiate the page and render the UI"""
|
||||
self._app = app
|
||||
|
||||
self._settings_state = app.settings_state
|
||||
self._user_id = app.user_id
|
||||
self._default_settings = app.default_settings
|
||||
self._settings_dict = self._default_settings.flatten()
|
||||
self._settings_keys = list(self._settings_dict.keys())
|
||||
|
||||
self._components = {}
|
||||
self._reasoning_mode = {}
|
||||
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
self.setting_save_btn = gr.Button("Save settings")
|
||||
with gr.Tab("User settings"):
|
||||
self.user_tab()
|
||||
with gr.Tab("General application settings"):
|
||||
self.app_tab()
|
||||
with gr.Tab("Index settings"):
|
||||
self.index_tab()
|
||||
with gr.Tab("Reasoning settings"):
|
||||
self.reasoning_tab()
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
pass
|
||||
|
||||
def on_register_events(self):
|
||||
self.setting_save_btn.click(
|
||||
self.save_setting,
|
||||
inputs=[self._user_id] + self.components(),
|
||||
outputs=self._settings_state,
|
||||
)
|
||||
self.password_change_btn.click(
|
||||
self.change_password,
|
||||
inputs=[
|
||||
self._user_id,
|
||||
self.password_change,
|
||||
self.password_change_confirm,
|
||||
],
|
||||
outputs=None,
|
||||
show_progress="hidden",
|
||||
)
|
||||
self._components["reasoning.use"].change(
|
||||
self.change_reasoning_mode,
|
||||
inputs=[self._components["reasoning.use"]],
|
||||
outputs=list(self._reasoning_mode.values()),
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
onSignInClick = self.signin.click(
|
||||
self.sign_in,
|
||||
inputs=[self.username, self.password],
|
||||
outputs=[self._user_id, self.username, self.password]
|
||||
+ self.signed_in_state()
|
||||
+ [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onSignIn"):
|
||||
onSignInClick = onSignInClick.then(**event)
|
||||
|
||||
onSignInSubmit = self.password.submit(
|
||||
self.sign_in,
|
||||
inputs=[self.username, self.password],
|
||||
outputs=[self._user_id, self.username, self.password]
|
||||
+ self.signed_in_state()
|
||||
+ [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onSignIn"):
|
||||
onSignInSubmit = onSignInSubmit.then(**event)
|
||||
|
||||
onCreateUserClick = self.create_btn.click(
|
||||
self.create_user,
|
||||
inputs=[
|
||||
self.username_new,
|
||||
self.password_new,
|
||||
self.password_new_confirm,
|
||||
],
|
||||
outputs=[
|
||||
self._user_id,
|
||||
self.username_new,
|
||||
self.password_new,
|
||||
self.password_new_confirm,
|
||||
]
|
||||
+ self.signed_in_state()
|
||||
+ [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onCreateUser"):
|
||||
onCreateUserClick = onCreateUserClick.then(**event)
|
||||
|
||||
onSignOutClick = self.signout.click(
|
||||
self.sign_out,
|
||||
inputs=None,
|
||||
outputs=[self._user_id] + self.signed_in_state() + [self.user_out_state],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.load_setting,
|
||||
inputs=self._user_id,
|
||||
outputs=[self._settings_state] + self.components(),
|
||||
show_progress="hidden",
|
||||
)
|
||||
for event in self._app.get_event("onSignOut"):
|
||||
onSignOutClick = onSignOutClick.then(**event)
|
||||
|
||||
def user_tab(self):
|
||||
with gr.Row() as self.user_out_state:
|
||||
with gr.Column():
|
||||
gr.Markdown("Sign in")
|
||||
self.username = gr.Textbox(label="Username", interactive=True)
|
||||
self.password = gr.Textbox(
|
||||
label="Password", type="password", interactive=True
|
||||
)
|
||||
self.signin = gr.Button("Login")
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("Create new account")
|
||||
self.username_new = gr.Textbox(label="Username", interactive=True)
|
||||
self.password_new = gr.Textbox(
|
||||
label="Password", type="password", interactive=True
|
||||
)
|
||||
self.password_new_confirm = gr.Textbox(
|
||||
label="Confirm password", type="password", interactive=True
|
||||
)
|
||||
self.create_btn = gr.Button("Create account")
|
||||
|
||||
# user management
|
||||
self.current_name = gr.Markdown("Current user: ___", visible=False)
|
||||
self.signout = gr.Button("Logout", visible=False)
|
||||
|
||||
self.password_change = gr.Textbox(
|
||||
label="New password", interactive=True, type="password", visible=False
|
||||
)
|
||||
self.password_change_confirm = gr.Textbox(
|
||||
label="Confirm password", interactive=True, type="password", visible=False
|
||||
)
|
||||
self.password_change_btn = gr.Button(
|
||||
"Change password", interactive=True, visible=False
|
||||
)
|
||||
|
||||
def signed_out_state(self):
|
||||
return [
|
||||
self.username,
|
||||
self.password,
|
||||
self.signin,
|
||||
self.username_new,
|
||||
self.password_new,
|
||||
self.password_new_confirm,
|
||||
self.create_btn,
|
||||
]
|
||||
|
||||
def signed_in_state(self):
|
||||
return [
|
||||
self.current_name, # always the first one
|
||||
self.signout,
|
||||
self.password_change,
|
||||
self.password_change_confirm,
|
||||
self.password_change_btn,
|
||||
]
|
||||
|
||||
def sign_in(self, username: str, password: str):
|
||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||||
user_id, clear_username, clear_password = None, username, password
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(
|
||||
User.username == username,
|
||||
User.password == hashed_password,
|
||||
)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
user_id = result[0].id
|
||||
clear_username, clear_password = "", ""
|
||||
else:
|
||||
gr.Warning("Username or password is incorrect")
|
||||
|
||||
output: list = [user_id, clear_username, clear_password]
|
||||
if user_id is None:
|
||||
output += [
|
||||
gr.update(visible=False) for _ in range(len(self.signed_in_state()))
|
||||
]
|
||||
output.append(gr.update(visible=True))
|
||||
else:
|
||||
output.append(gr.update(visible=True, value=f"Current user: {username}"))
|
||||
output += [
|
||||
gr.update(visible=True) for _ in range(len(self.signed_in_state()) - 1)
|
||||
]
|
||||
output.append(gr.update(visible=False))
|
||||
|
||||
return output
|
||||
|
||||
def create_user(self, username, password, password_confirm):
|
||||
user_id, usn, pwd, pwdc = None, username, password, password_confirm
|
||||
if password != password_confirm:
|
||||
gr.Warning("Password does not match")
|
||||
else:
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(
|
||||
User.username == username,
|
||||
)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
gr.Warning(f'Username "{username}" already exists')
|
||||
else:
|
||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||||
user = User(username=username, password=hashed_password)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
usn, pwd, pwdc = "", "", ""
|
||||
print(user_id)
|
||||
|
||||
output: list = [user_id, usn, pwd, pwdc]
|
||||
if user_id is not None:
|
||||
output.append(gr.update(visible=True, value=f"Current user: {username}"))
|
||||
output += [
|
||||
gr.update(visible=True) for _ in range(len(self.signed_in_state()) - 1)
|
||||
]
|
||||
output.append(gr.update(visible=False))
|
||||
else:
|
||||
output += [
|
||||
gr.update(visible=False) for _ in range(len(self.signed_in_state()))
|
||||
]
|
||||
output.append(gr.update(visible=True))
|
||||
|
||||
return output
|
||||
|
||||
def sign_out(self):
|
||||
output = [None]
|
||||
output += [gr.update(visible=False) for _ in range(len(self.signed_in_state()))]
|
||||
output.append(gr.update(visible=True))
|
||||
return output
|
||||
|
||||
def change_password(self, user_id, password, password_confirm):
|
||||
if password != password_confirm:
|
||||
gr.Warning("Password does not match")
|
||||
return
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.id == user_id)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
user = result[0]
|
||||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||||
user.password = hashed_password
|
||||
session.add(user)
|
||||
session.commit()
|
||||
gr.Info("Password changed")
|
||||
else:
|
||||
gr.Warning("User not found")
|
||||
|
||||
def app_tab(self):
|
||||
for n, si in self._default_settings.application.settings.items():
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"application.{n}"] = obj
|
||||
|
||||
def index_tab(self):
|
||||
for n, si in self._default_settings.index.settings.items():
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"index.{n}"] = obj
|
||||
|
||||
def reasoning_tab(self):
|
||||
with gr.Group():
|
||||
for n, si in self._default_settings.reasoning.settings.items():
|
||||
if n == "use":
|
||||
continue
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"reasoning.{n}"] = obj
|
||||
|
||||
gr.Markdown("### Reasoning-specific settings")
|
||||
self._components["reasoning.use"] = render_setting_item(
|
||||
self._default_settings.reasoning.settings["use"],
|
||||
self._default_settings.reasoning.settings["use"].value,
|
||||
)
|
||||
|
||||
for idx, (pn, sig) in enumerate(
|
||||
self._default_settings.reasoning.options.items()
|
||||
):
|
||||
with gr.Group(
|
||||
visible=idx == 0,
|
||||
elem_id=pn,
|
||||
) as self._reasoning_mode[pn]:
|
||||
gr.Markdown("**Name**: Description")
|
||||
for n, si in sig.settings.items():
|
||||
obj = render_setting_item(si, si.value)
|
||||
self._components[f"reasoning.options.{pn}.{n}"] = obj
|
||||
|
||||
def change_reasoning_mode(self, value):
|
||||
output = []
|
||||
for each in self._reasoning_mode.values():
|
||||
if value == each.elem_id:
|
||||
output.append(gr.update(visible=True))
|
||||
else:
|
||||
output.append(gr.update(visible=False))
|
||||
return output
|
||||
|
||||
def load_setting(self, user_id=None):
|
||||
settings = self._settings_dict
|
||||
with Session(engine) as session:
|
||||
statement = select(Settings).where(Settings.user == user_id)
|
||||
result = session.exec(statement).all()
|
||||
if result:
|
||||
settings = result[0].setting
|
||||
|
||||
output = [settings]
|
||||
output += tuple(settings[name] for name in self.component_names())
|
||||
return output
|
||||
|
||||
def save_setting(self, user_id: int, *args):
|
||||
"""Save the setting to disk and persist the setting to session state
|
||||
|
||||
Args:
|
||||
user_id: the user id
|
||||
args: all the values from the settings
|
||||
"""
|
||||
setting = {key: value for key, value in zip(self.component_names(), args)}
|
||||
if user_id is None:
|
||||
gr.Warning("Need to login before saving settings")
|
||||
return setting
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Settings).where(Settings.user == user_id)
|
||||
try:
|
||||
user_setting = session.exec(statement).one()
|
||||
except Exception:
|
||||
user_setting = Settings()
|
||||
user_setting.user = user_id
|
||||
user_setting.setting = setting
|
||||
session.add(user_setting)
|
||||
session.commit()
|
||||
|
||||
gr.Info("Setting saved")
|
||||
return setting
|
||||
|
||||
def components(self) -> list:
|
||||
"""Get the setting components"""
|
||||
output = []
|
||||
for name in self._settings_keys:
|
||||
output.append(self._components[name])
|
||||
return output
|
||||
|
||||
def component_names(self):
|
||||
"""Get the setting components"""
|
||||
return self._settings_keys
|
0
libs/ktem/ktem/reasoning/__init__.py
Normal file
0
libs/ktem/ktem/reasoning/__init__.py
Normal file
0
libs/ktem/ktem/reasoning/base.py
Normal file
0
libs/ktem/ktem/reasoning/base.py
Normal file
409
libs/ktem/ktem/reasoning/simple.py
Normal file
409
libs/ktem/ktem/reasoning/simple.py
Normal file
@@ -0,0 +1,409 @@
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import tiktoken
|
||||
from ktem.components import embeddings, get_docstore, get_vectorstore, llms
|
||||
from ktem.db.models import Index, SourceTargetRelation, engine
|
||||
from kotaemon.base import (
|
||||
BaseComponent,
|
||||
Document,
|
||||
HumanMessage,
|
||||
Node,
|
||||
RetrievedDocument,
|
||||
SystemMessage,
|
||||
)
|
||||
from kotaemon.indices import VectorRetrieval
|
||||
from kotaemon.indices.qa.citation import CitationPipeline
|
||||
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
)
|
||||
from llama_index.vector_stores.types import VectorStoreQueryMode
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentRetrievalPipeline(BaseComponent):
|
||||
"""Retrieve relevant document
|
||||
|
||||
Args:
|
||||
vector_retrieval: the retrieval pipeline that return the relevant documents
|
||||
given a text query
|
||||
reranker: the reranking pipeline that re-rank and filter the retrieved
|
||||
documents
|
||||
get_extra_table: if True, for each retrieved document, the pipeline will look
|
||||
for surrounding tables (e.g. within the page)
|
||||
"""
|
||||
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
||||
doc_store=get_docstore(),
|
||||
vector_store=get_vectorstore(),
|
||||
embedding=embeddings.get_default(),
|
||||
)
|
||||
reranker: BaseReranking = CohereReranking.withx(
|
||||
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
|
||||
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
|
||||
get_extra_table: bool = False
|
||||
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
top_k: int = 5,
|
||||
mmr: bool = False,
|
||||
doc_ids: Optional[list[str]] = None,
|
||||
) -> list[RetrievedDocument]:
|
||||
"""Retrieve document excerpts similar to the text
|
||||
|
||||
Args:
|
||||
text: the text to retrieve similar documents
|
||||
top_k: number of documents to retrieve
|
||||
mmr: whether to use mmr to re-rank the documents
|
||||
doc_ids: list of document ids to constraint the retrieval
|
||||
"""
|
||||
kwargs = {}
|
||||
if doc_ids:
|
||||
with Session(engine) as session:
|
||||
stmt = select(Index).where(
|
||||
Index.relation_type == SourceTargetRelation.VECTOR,
|
||||
Index.source_id.in_(doc_ids), # type: ignore
|
||||
)
|
||||
results = session.exec(stmt)
|
||||
vs_ids = [r.target_id for r in results.all()]
|
||||
|
||||
kwargs["filters"] = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=vs_id,
|
||||
operator=FilterOperator.EQ,
|
||||
)
|
||||
for vs_id in vs_ids
|
||||
],
|
||||
condition=FilterCondition.OR,
|
||||
)
|
||||
|
||||
if mmr:
|
||||
# TODO: double check that llama-index MMR works correctly
|
||||
kwargs["mode"] = VectorStoreQueryMode.MMR
|
||||
kwargs["mmr_threshold"] = 0.5
|
||||
|
||||
# rerank
|
||||
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
||||
if self.get_from_path("reranker"):
|
||||
docs = self.reranker(docs, query=text)
|
||||
|
||||
if not self.get_extra_table:
|
||||
return docs
|
||||
|
||||
# retrieve extra nodes relate to table
|
||||
table_pages = defaultdict(list)
|
||||
retrieved_id = set([doc.doc_id for doc in docs])
|
||||
for doc in docs:
|
||||
if "page_label" not in doc.metadata:
|
||||
continue
|
||||
if "file_name" not in doc.metadata:
|
||||
warnings.warn(
|
||||
"file_name not in metadata while page_label is in metadata: "
|
||||
f"{doc.metadata}"
|
||||
)
|
||||
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
|
||||
|
||||
queries = [
|
||||
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
|
||||
for fn, pls in table_pages.items()
|
||||
]
|
||||
if queries:
|
||||
extra_docs = self.vector_retrieval(
|
||||
text="",
|
||||
top_k=50,
|
||||
where={"$or": queries},
|
||||
)
|
||||
for doc in extra_docs:
|
||||
if doc.doc_id not in retrieved_id:
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class PrepareEvidencePipeline(BaseComponent):
|
||||
"""Prepare the evidence text from the list of retrieved documents
|
||||
|
||||
This step usually happens after `DocumentRetrievalPipeline`.
|
||||
|
||||
Args:
|
||||
trim_func: a callback function or a BaseComponent, that splits a large
|
||||
chunk of text into smaller ones. The first one will be retained.
|
||||
"""
|
||||
|
||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||
chunk_size=7600,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, docs: list[RetrievedDocument]) -> Document:
|
||||
evidence = ""
|
||||
table_found = 0
|
||||
evidence_mode = 0
|
||||
|
||||
for _id, retrieved_item in enumerate(docs):
|
||||
retrieved_content = ""
|
||||
page = retrieved_item.metadata.get("page_label", None)
|
||||
source = filename = retrieved_item.metadata.get("file_name", "-")
|
||||
if page:
|
||||
source += f" (Page {page})"
|
||||
if retrieved_item.metadata.get("type", "") == "table":
|
||||
evidence_mode = 1 # table
|
||||
if table_found < 5:
|
||||
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
||||
if retrieved_content not in evidence:
|
||||
table_found += 1
|
||||
evidence += (
|
||||
f"<br><b>Table from {source}</b>\n"
|
||||
+ retrieved_content
|
||||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||
evidence_mode = 2 # chatbot
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
evidence += (
|
||||
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||
+ retrieved_content
|
||||
+ "\n<br>"
|
||||
)
|
||||
else:
|
||||
if "window" in retrieved_item.metadata:
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
else:
|
||||
retrieved_content = retrieved_item.text
|
||||
retrieved_content = retrieved_content.replace("\n", " ")
|
||||
if retrieved_content not in evidence:
|
||||
evidence += (
|
||||
f"<br><b>Content from {source}: </b> "
|
||||
+ retrieved_content
|
||||
+ " \n<br>"
|
||||
)
|
||||
|
||||
print("Retrieved #{}: {}".format(_id, retrieved_content))
|
||||
print(retrieved_item.metadata)
|
||||
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
||||
|
||||
# trim context by trim_len
|
||||
print("len (original)", len(evidence))
|
||||
if evidence:
|
||||
texts = self.trim_func([Document(text=evidence)])
|
||||
evidence = texts[0].text
|
||||
print("len (trimmed)", len(evidence))
|
||||
|
||||
print(f"PrepareEvidence with input {input}\nOutput: {evidence}\n")
|
||||
|
||||
return Document(content=(evidence_mode, evidence))
|
||||
|
||||
|
||||
DEFAULT_QA_TEXT_PROMPT = (
|
||||
"Use the following pieces of context to answer the question at the end. "
|
||||
"If you don't know the answer, just say that you don't know, don't try to "
|
||||
"make up an answer. Keep the answer as concise as possible. Give answer in "
|
||||
"{lang}. {system}\n\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Helpful Answer:"
|
||||
)
|
||||
|
||||
DEFAULT_QA_TABLE_PROMPT = (
|
||||
"List all rows (row number) from the table context that related to the question, "
|
||||
"then provide detail answer with clear explanation and citations. "
|
||||
"If you don't know the answer, just say that you don't know, "
|
||||
"don't try to make up an answer. Give answer in {lang}. {system}\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Helpful Answer:"
|
||||
)
|
||||
|
||||
DEFAULT_QA_CHATBOT_PROMPT = (
|
||||
"Pick the most suitable chatbot scenarios to answer the question at the end, "
|
||||
"output the provided answer text. If you don't know the answer, "
|
||||
"just say that you don't know. Keep the answer as concise as possible. "
|
||||
"Give answer in {lang}. {system}\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer:"
|
||||
)
|
||||
|
||||
|
||||
class AnswerWithContextPipeline(BaseComponent):
|
||||
"""Answer the question based on the evidence
|
||||
|
||||
Args:
|
||||
llm: the language model to generate the answer
|
||||
citation_pipeline: generates citation from the evidence
|
||||
qa_template: the prompt template for LLM to generate answer (refer to
|
||||
evidence_mode)
|
||||
qa_table_template: the prompt template for LLM to generate answer for table
|
||||
(refer to evidence_mode)
|
||||
qa_chatbot_template: the prompt template for LLM to generate answer for
|
||||
pre-made scenarios (refer to evidence_mode)
|
||||
lang: the language of the answer. Currently support English and Japanese
|
||||
"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
||||
)
|
||||
|
||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
||||
|
||||
system_prompt: str = ""
|
||||
lang: str = "English" # support English and Japanese
|
||||
|
||||
def run(
|
||||
self, question: str, evidence: str, evidence_mode: int = 0
|
||||
) -> Document | Iterator[Document]:
|
||||
"""Answer the question based on the evidence
|
||||
|
||||
In addition to the question and the evidence, this method also take into
|
||||
account evidence_mode. The evidence_mode tells which kind of evidence is.
|
||||
The kind of evidence affects:
|
||||
1. How the evidence is represented.
|
||||
2. The prompt to generate the answer.
|
||||
|
||||
By default, the evidence_mode is 0, which means the evidence is plain text with
|
||||
no particular semantic representation. The evidence_mode can be:
|
||||
1. "table": There will be HTML markup telling that there is a table
|
||||
within the evidence.
|
||||
2. "chatbot": There will be HTML markup telling that there is a chatbot.
|
||||
This chatbot is a scenario, extracted from an Excel file, where each
|
||||
row corresponds to an interaction.
|
||||
|
||||
Args:
|
||||
question: the original question posed by user
|
||||
evidence: the text that contain relevant information to answer the question
|
||||
(determined by retrieval pipeline)
|
||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||
"""
|
||||
if evidence_mode == 0:
|
||||
prompt_template = PromptTemplate(self.qa_template)
|
||||
elif evidence_mode == 1:
|
||||
prompt_template = PromptTemplate(self.qa_table_template)
|
||||
else:
|
||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
system=self.system_prompt,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
HumanMessage(content=prompt),
|
||||
]
|
||||
# output = self.llm(messages).text
|
||||
yield from self.llm(messages)
|
||||
|
||||
citation = self.citation_pipeline(context=evidence, question=question)
|
||||
answer = Document(text="", metadata={"citation": citation})
|
||||
yield answer
|
||||
|
||||
|
||||
class FullQAPipeline(BaseComponent):
|
||||
"""Question answering pipeline. Handle from question to answer"""
|
||||
|
||||
class Config:
|
||||
allow_extra = True
|
||||
params_publish = True
|
||||
|
||||
retrieval_pipeline: DocumentRetrievalPipeline = DocumentRetrievalPipeline.withx()
|
||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||
|
||||
def run(self, question: str, **kwargs) -> Iterator[Document]:
|
||||
docs = self.retrieval_pipeline(text=question)
|
||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||
answer = self.answering_pipeline(
|
||||
question=question, evidence=evidence, evidence_mode=evidence_mode
|
||||
)
|
||||
yield from answer # should be a generator
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, settings, **kwargs):
|
||||
"""Get the reasoning pipeline
|
||||
|
||||
Need a base pipeline implementation. Currently the drawback is that we want to
|
||||
treat the retrievers as tools. Hence, the reasoning pipelie should just take
|
||||
the already initiated tools (retrievers), and do not need to set such logic
|
||||
here.
|
||||
"""
|
||||
pipeline = FullQAPipeline(get_extra_table=settings["index.prioritize_table"])
|
||||
if not settings["index.use_reranking"]:
|
||||
pipeline.retrieval_pipeline.reranker = None # type: ignore
|
||||
|
||||
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
|
||||
kwargs = {
|
||||
".retrieval_pipeline.top_k": int(settings["index.num_retrieval"]),
|
||||
".retrieval_pipeline.mmr": settings["index.mmr"],
|
||||
".retrieval_pipeline.doc_ids": kwargs.get("files", None),
|
||||
}
|
||||
pipeline.set_run(kwargs, temp=True)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
from ktem.components import llms
|
||||
|
||||
try:
|
||||
citation_llm = llms.get_lowest_cost_name()
|
||||
citation_llm_choices = list(llms.options().keys())
|
||||
main_llm = llms.get_highest_accuracy_name()
|
||||
main_llm_choices = list(llms.options().keys())
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
citation_llm = None
|
||||
citation_llm_choices = []
|
||||
main_llm = None
|
||||
main_llm_choices = []
|
||||
|
||||
return {
|
||||
"highlight_citation": {
|
||||
"name": "Highlight Citation",
|
||||
"value": True,
|
||||
"component": "checkbox",
|
||||
},
|
||||
"system_prompt": {
|
||||
"name": "System Prompt",
|
||||
"value": "This is a question answering system",
|
||||
},
|
||||
"citation_llm": {
|
||||
"name": "LLM for citation",
|
||||
"value": citation_llm,
|
||||
"component": "dropdown",
|
||||
"choices": citation_llm_choices,
|
||||
},
|
||||
"main_llm": {
|
||||
"name": "LLM for main generation",
|
||||
"value": main_llm,
|
||||
"component": "dropdown",
|
||||
"choices": main_llm_choices,
|
||||
},
|
||||
}
|
156
libs/ktem/ktem/settings.py
Normal file
156
libs/ktem/ktem/settings.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SettingItem(BaseModel):
|
||||
"""Represent a setting item
|
||||
|
||||
Args:
|
||||
name: the name of the setting item
|
||||
value: the default value of the setting item
|
||||
choices: the list of choices of the setting item, if any
|
||||
metadata: the metadata of the setting item
|
||||
component: the expected UI component to render the setting
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: Any
|
||||
choices: list = Field(default_factory=list)
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
component: str = "text"
|
||||
|
||||
|
||||
class BaseSettingGroup(BaseModel):
|
||||
settings: dict[str, "SettingItem"] = Field(default_factory=dict)
|
||||
options: dict[str, "BaseSettingGroup"] = Field(default_factory=dict)
|
||||
|
||||
def _get_options(self) -> dict:
|
||||
return {}
|
||||
|
||||
def finalize(self):
|
||||
"""Finalize the setting group"""
|
||||
|
||||
def flatten(self) -> dict:
|
||||
"""Render the setting group into value"""
|
||||
output = {}
|
||||
for key, value in self.settings.items():
|
||||
output[key] = value.value
|
||||
|
||||
output.update({f"options.{k}": v for k, v in self._get_options().items()})
|
||||
|
||||
return output
|
||||
|
||||
def get_setting_item(self, path: str) -> SettingItem:
|
||||
"""Get the item based on dot notation"""
|
||||
path = path.strip(".")
|
||||
if "." not in path:
|
||||
return self.settings[path]
|
||||
|
||||
key, sub_path = path.split(".", 1)
|
||||
if key != "options":
|
||||
raise ValueError(f"Invalid key {path}. Should starts with `options.*`")
|
||||
|
||||
option_id, sub_path = sub_path.split(".", 1)
|
||||
option = self.options[option_id]
|
||||
return option.get_setting_item(sub_path)
|
||||
|
||||
|
||||
class SettingReasoningGroup(BaseSettingGroup):
|
||||
def _get_options(self) -> dict:
|
||||
output = {}
|
||||
for ex_name, ex_setting in self.options.items():
|
||||
for key, value in ex_setting.flatten().items():
|
||||
output[f"{ex_name}.{key}"] = value
|
||||
|
||||
return output
|
||||
|
||||
def finalize(self):
|
||||
"""Finalize the setting"""
|
||||
options = list(self.options.keys())
|
||||
if options:
|
||||
self.settings["use"].choices = [(x, x) for x in options]
|
||||
self.settings["use"].value = options[0]
|
||||
|
||||
|
||||
class SettingIndexOption(BaseSettingGroup):
|
||||
"""Temporarily keep it here to see if we need this setting template
|
||||
for the index component
|
||||
"""
|
||||
|
||||
indexing: BaseSettingGroup
|
||||
retrieval: BaseSettingGroup
|
||||
|
||||
def flatten(self) -> dict:
|
||||
"""Render the setting group into value"""
|
||||
output = {}
|
||||
for key, value in self.indexing.flatten():
|
||||
output[f"indexing.{key}"] = value
|
||||
|
||||
for key, value in self.retrieval.flatten():
|
||||
output[f"retrieval.{key}"] = value
|
||||
|
||||
return output
|
||||
|
||||
def get_setting_item(self, path: str) -> SettingItem:
|
||||
"""Get the item based on dot notation"""
|
||||
path = path.strip(".")
|
||||
|
||||
key, sub_path = path.split(".", 1)
|
||||
if key not in ["indexing", "retrieval"]:
|
||||
raise ValueError(
|
||||
f"Invalid key {path}. Should starts with `indexing.*` or `retrieval.*`"
|
||||
)
|
||||
|
||||
value = getattr(self, key)
|
||||
return value.get_setting_item(sub_path)
|
||||
|
||||
|
||||
class SettingIndexGroup(BaseSettingGroup):
|
||||
def _get_options(self) -> dict:
|
||||
output = {}
|
||||
for name, setting in self.options.items():
|
||||
for key, value in setting.flatten().items():
|
||||
output[f"{name}.{key}"] = value
|
||||
|
||||
return output
|
||||
|
||||
def finalize(self):
|
||||
"""Finalize the setting"""
|
||||
options = list(self.options.keys())
|
||||
if options:
|
||||
self.settings["use"].choices = [(x, x) for x in options]
|
||||
self.settings["use"].value = options
|
||||
|
||||
|
||||
class SettingGroup(BaseModel):
|
||||
application: BaseSettingGroup = Field(default_factory=BaseSettingGroup)
|
||||
index: SettingIndexGroup = Field(default_factory=SettingIndexGroup)
|
||||
reasoning: SettingReasoningGroup = Field(default_factory=SettingReasoningGroup)
|
||||
|
||||
def flatten(self) -> dict:
|
||||
"""Render the setting group into value"""
|
||||
output = {}
|
||||
for key, value in self.application.flatten().items():
|
||||
output[f"application.{key}"] = value
|
||||
|
||||
for key, value in self.index.flatten().items():
|
||||
output[f"index.{key}"] = value
|
||||
|
||||
for key, value in self.reasoning.flatten().items():
|
||||
output[f"reasoning.{key}"] = value
|
||||
|
||||
return output
|
||||
|
||||
def get_setting_item(self, path: str) -> SettingItem:
|
||||
"""Get the item based on dot notation"""
|
||||
path = path.strip(".")
|
||||
|
||||
key, sub_path = path.split(".", 1)
|
||||
if key not in ["application", "index", "reasoning"]:
|
||||
raise ValueError(
|
||||
f"Invalid key {path}. Should starts with `indexing.*` or `retrieval.*`"
|
||||
)
|
||||
|
||||
value = getattr(self, key)
|
||||
return value.get_setting_item(sub_path)
|
5
libs/ktem/launch.py
Normal file
5
libs/ktem/launch.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ktem.main import App
|
||||
|
||||
app = App()
|
||||
demo = app.make()
|
||||
demo.queue().launch()
|
39
libs/ktem/pyproject.toml
Normal file
39
libs/ktem/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[build-system]
|
||||
requires = ["setuptools >= 61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = false
|
||||
packages.find.include = ["ktem*"]
|
||||
packages.find.exclude = ["tests*", "env*"]
|
||||
|
||||
[project]
|
||||
name = "ktem"
|
||||
version = "0.0.1"
|
||||
requires-python = ">= 3.10"
|
||||
description = "RAG-based Question and Answering Application"
|
||||
dependencies = [
|
||||
"chromadb",
|
||||
"click",
|
||||
"cohere",
|
||||
"platformdirs",
|
||||
"pluggy",
|
||||
"python-decouple",
|
||||
"python-dotenv",
|
||||
"sqlalchemy",
|
||||
"sqlmodel",
|
||||
"tiktoken",
|
||||
"unstructured[pdf]",
|
||||
]
|
||||
readme = "README.md"
|
||||
license = { text = "MIT License" }
|
||||
authors = [
|
||||
{ name = "john", email = "john@cinnamon.is" },
|
||||
{ name = "ian", email = "ian@cinnamon.is" },
|
||||
{ name = "tadashi", email = "tadashi@cinnamon.is" },
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
1
libs/ktem/requirements.txt
Normal file
1
libs/ktem/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
platformdirs
|
29
libs/ktem/scripts/mock.py
Normal file
29
libs/ktem/scripts/mock.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import time
|
||||
|
||||
from ktem.db.models import Conversation, Source, engine
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
def add_conversation():
|
||||
"""Add conversation to the manager."""
|
||||
with Session(engine) as session:
|
||||
c1 = Conversation(name="Conversation 1")
|
||||
c2 = Conversation()
|
||||
session.add(c1)
|
||||
time.sleep(1)
|
||||
session.add(c2)
|
||||
time.sleep(1)
|
||||
session.commit()
|
||||
|
||||
|
||||
def add_files():
|
||||
with Session(engine) as session:
|
||||
s1 = Source(name="Source 1", path="Path 1")
|
||||
s2 = Source(name="Source 2", path="Path 2")
|
||||
session.add(s1)
|
||||
session.add(s2)
|
||||
session.commit()
|
||||
|
||||
|
||||
# add_conversation()
|
||||
add_files()
|
Reference in New Issue
Block a user