feat: merge develop (#123)

* Support hybrid vector retrieval

* Enable figures and table reading in Azure DI

* Retrieve with multi-modal

* Fix mixing up table

* Add txt loader

* Add Anthropic Chat

* Raising error when retrieving help file

* Allow same filename for different people if private is True

* Allow declaring extra LLM vendors

* Show chunks on the File page

* Allow elasticsearch to get more docs

* Fix Cohere response (#86)

* Fix Cohere response

* Remove Adobe pdfservice from dependency

kotaemon doesn't rely more pdfservice for its core functionality,
and pdfservice uses very out-dated dependency that causes conflict.

---------

Co-authored-by: trducng <trungduc1992@gmail.com>

* Add confidence score (#87)

* Save question answering data as a log file

* Save the original information besides the rewritten info

* Export Cohere relevance score as confidence score

* Fix style check

* Upgrade the confidence score appearance (#90)

* Highlight the relevance score

* Round relevance score. Get key from config instead of env

* Cohere return all scores

* Display relevance score for image

* Remove columns and rows in Excel loader which contains all NaN (#91)

* remove columns and rows which contains all NaN

* back to multiple joiner options

* Fix style

---------

Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local>
Co-authored-by: trducng <trungduc1992@gmail.com>

* Track retriever state

* Bump llama-index version 0.10

* feat/save-azuredi-mhtml-to-markdown (#93)

* feat/save-azuredi-mhtml-to-markdown

* fix: replace os.path to pathlib change theflow.settings

* refactor: base on pre-commit

* chore: move the func of saving content markdown above removed_spans

---------

Co-authored-by: jacky0218 <jacky0218@github.com>

* fix: losing first chunk (#94)

* fix: losing first chunk.

* fix: update the method of preventing losing chunks

---------

Co-authored-by: jacky0218 <jacky0218@github.com>

* fix: adding the base64 image in markdown (#95)

* feat: more chunk info on UI

* fix: error when reindexing files

* refactor: allow more information exception trace when using gpt4v

* feat: add excel reader that treats each worksheet as a document

* Persist loader information when indexing file

* feat: allow hiding unneeded setting panels

* feat: allow specific timezone when creating conversation

* feat: add more confidence score (#96)

* Allow a list of rerankers

* Export llm reranking score instead of filter with boolean

* Get logprobs from LLMs

* Rename cohere reranking score

* Call 2 rerankers at once

* Run QA pipeline for each chunk to get qa_score

* Display more relevance scores

* Define another LLMScoring instead of editing the original one

* Export logprobs instead of probs

* Call LLMScoring

* Get qa_score only in the final answer

* feat: replace text length with token in file list

* ui: show index name instead of id in the settings

* feat(ai): restrict the vision temperature

* fix(ui): remove the misleading message about non-retrieved evidences

* feat(ui): show the reasoning name and description in the reasoning setting page

* feat(ui): show version on the main windows

* feat(ui): show default llm name in the setting page

* fix(conf): append the result of doc in llm_scoring (#97)

* fix: constraint maximum number of images

* feat(ui): allow filter file by name in file list page

* Fix exceeding token length error for OpenAI embeddings by chunking then averaging (#99)

* Average embeddings in case the text exceeds max size

* Add docstring

* fix: Allow empty string when calling embedding

* fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98)

* Round when displaying not by default

* Add LLMTrulens reranking model

* Use llmtrulensscoring in pipeline

* fix: update UI display for trulen score

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* feat: add question decomposition & few-shot rewrite pipeline (#89)

* Create few-shot query-rewriting. Run and display the result in info_panel

* Fix style check

* Put the functions to separate modules

* Add zero-shot question decomposition

* Fix fewshot rewriting

* Add default few-shot examples

* Fix decompose question

* Fix importing rewriting pipelines

* fix: update decompose logic in fullQA pipeline

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* fix: add encoding utf-8 when save temporal markdown in vectorIndex (#101)

* fix: improve retrieval pipeline and relevant score display (#102)

* fix: improve retrieval pipeline by extending first round top_k with multiplier

* fix: minor fix

* feat: improve UI default settings and add quick switch option for pipeline

* fix: improve agent logics (#103)

* fix: improve agent progres display

* fix: update retrieval logic

* fix: UI display

* fix: less verbose debug log

* feat: add warning message for low confidence

* fix: LLM scoring enabled by default

* fix: minor update logics

* fix: hotfix image citation

* feat: update docx loader for handle merged table cells + handle zip file upload (#104)

* feat: update docx loader for handle merged table cells

* feat: handle zip file

* refactor: pre-commit

* fix: escape text in download UI

* feat: optimize vector store query db (#105)

* feat: optimize vector store query db

* feat: add file_id to chroma metadatas

* feat: remove unnecessary logs and update migrate script

* feat: iterate through file index

* fix: remove unused code

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* fix: add openai embedidng exponential back-off

* fix: update import download_loader

* refactor: codespell

* fix: update some default settings

* fix: update installation instruction

* fix: default chunk length in simple QA

* feat: add share converstation feature and enable retrieval history (#108)

* feat: add share converstation feature and enable retrieval history

* fix: update share conversation UI

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* fix: allow exponential backoff for failed OCR call (#109)

* fix: update default prompt when no retrieval is used

* fix: create embedding for long image chunks

* fix: add exception handling for additional table retriever

* fix: clean conversation & file selection UI

* fix: elastic search with empty doc_ids

* feat: add thumbnail PDF reader for quick multimodal QA

* feat: add thumbnail handling logic in indexing

* fix: UI text update

* fix: PDF thumb loader page number logic

* feat: add quick indexing pipeline and update UI

* feat: add conv name suggestion

* fix: minor UI change

* feat: citation in thread

* fix: add conv name suggestion in regen

* chore: add assets for usage doc

* chore: update usage doc

* feat: pdf viewer (#110)

* feat: update pdfviewer

* feat: update missing files

* fix: update rendering logic of infor panel

* fix: improve thumbnail retrieval logic

* fix: update PDF evidence rendering logic

* fix: remove pdfjs built dist

* fix: reduce thumbnail evidence count

* chore: update gitignore

* fix: add js event on chat msg select

* fix: update css for viewer

* fix: add env var for PDFJS prebuilt

* fix: move language setting to reasoning utils

---------

Co-authored-by: phv2312 <kat87yb@gmail.com>
Co-authored-by: trducng <trungduc1992@gmail.com>

* feat: graph rag (#116)

* fix: reload server when add/delete index

* fix: rework indexing pipeline to be able to disable vectorstore and splitter if needed

* feat: add graphRAG index with plot view

* fix: update requirement for graphRAG and lighten unnecessary packages

* feat: add knowledge network index (#118)

* feat: add Knowledge Network index

* fix: update reader mode setting for knet

* fix: update init knet

* fix: update collection name to index pipeline

* fix: missing req

---------

Co-authored-by: jeff52415 <jeff.yang@cinnamon.is>

* fix: update info panel return for graphrag

* fix: retriever setting graphrag

* feat: local llm settings (#122)

* feat: expose context length as reasoning setting to better fit local models

* fix: update context length setting for agents

* fix: rework threadpool llm call

* fix: fix improve indexing logic

* fix: fix improve UI

* feat: add lancedb

* fix: improve lancedb logic

* feat: add lancedb vectorstore

* fix: lighten requirement

* fix: improve lanceDB vs

* fix: improve UI

* fix: openai retry

* fix: update reqs

* fix: update launch command

* feat: update Dockerfile

* feat: add plot history

* fix: update default config

* fix: remove verbose print

* fix: update default setting

* fix: update gradio plot return

* fix: default gradio tmp

* fix: improve lancedb docstore

* fix: fix question decompose pipeline

* feat: add multimodal reader in UI

* fix: udpate docs

* fix: update default settings & docker build

* fix: update app startup

* chore: update documentation

* chore: update README

* chore: update README

---------

Co-authored-by: trducng <trungduc1992@gmail.com>

* chore: update README

* chore: update README

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
Co-authored-by: cin-ace <ace@cinnamon.is>
Co-authored-by: Linh Nguyen <70562198+linhnguyen-cinnamon@users.noreply.github.com>
Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local>
Co-authored-by: cin-jacky <101088014+jacky0218@users.noreply.github.com>
Co-authored-by: jacky0218 <jacky0218@github.com>
Co-authored-by: kan_cin <kan@cinnamon.is>
Co-authored-by: phv2312 <kat87yb@gmail.com>
Co-authored-by: jeff52415 <jeff.yang@cinnamon.is>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2024-08-26 08:50:37 +07:00
committed by GitHub
parent 86d60e1649
commit 2570e11501
121 changed files with 14748 additions and 1063 deletions

View File

@@ -4,6 +4,7 @@ from typing import Optional
import gradio as gr
import pluggy
from ktem import extension_protocol
from ktem.assets import PDFJS_PREBUILT_DIR
from ktem.components import reasonings
from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared
from ktem.index import IndexManager
@@ -36,6 +37,7 @@ class BaseApp:
def __init__(self):
self.dev_mode = getattr(settings, "KH_MODE", "") == "dev"
self.app_name = getattr(settings, "KH_APP_NAME", "Kotaemon")
self.app_version = getattr(settings, "KH_APP_VERSION", "")
self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False)
self._theme = gr.Theme.from_hub("lone17/kotaemon")
@@ -44,6 +46,13 @@ class BaseApp:
self._css = fi.read()
with (dir_assets / "js" / "main.js").open() as fi:
self._js = fi.read()
self._js = self._js.replace("KH_APP_VERSION", self.app_version)
with (dir_assets / "js" / "pdf_viewer.js").open() as fi:
self._pdf_view_js = fi.read()
self._pdf_view_js = self._pdf_view_js.replace(
"PDFJS_PREBUILT_DIR", str(PDFJS_PREBUILT_DIR)
)
self._favicon = str(dir_assets / "img" / "favicon.svg")
self.default_settings = SettingGroup(
@@ -156,11 +165,17 @@ class BaseApp:
"""Called when the app is created"""
def make(self):
external_js = """
<script type="module" src="https://cdn.skypack.dev/pdfjs-viewer-element"></script>
"""
with gr.Blocks(
theme=self._theme,
css=self._css,
title=self.app_name,
analytics_enabled=False,
js=self._js,
head=external_js,
) as demo:
self.app = demo
self.settings_state.render()
@@ -173,6 +188,8 @@ class BaseApp:
self.register_events()
self.on_app_created()
demo.load(None, None, None, js=self._pdf_view_js)
return demo
def declare_public_events(self):
@@ -200,7 +217,6 @@ class BaseApp:
def on_app_created(self):
"""Execute on app created callbacks"""
self.app.load(lambda: None, None, None, js=f"() => {{{self._js}}}")
self._on_app_created()
for value in self.__dict__.values():
if isinstance(value, BasePage):

View File

@@ -0,0 +1,6 @@
from pathlib import Path
from decouple import config
PDFJS_VERSION_DIST: str = config("PDFJS_VERSION_DIST", "pdfjs-4.0.379-dist")
PDFJS_PREBUILT_DIR: Path = Path(__file__).parent / "prebuilt" / PDFJS_VERSION_DIST

View File

@@ -147,6 +147,16 @@ mark {
max-height: 42px;
}
/* Hide sort buttons at gr.DataFrame */
.sort-button {
display: none !important;
}
/* Show sort button only in File list*/
#file_list_view .sort-button {
display: block !important;
}
.scrollable {
overflow-y: auto;
}
@@ -158,3 +168,58 @@ mark {
.unset-overflow {
overflow: unset !important;
}
/*body {*/
/* margin: 0;*/
/* font-family: Arial, sans-serif;*/
/*}*/
pdfjs-viewer-element {
height: 100vh;
height: 100dvh;
}
/* Modal styles */
.modal {
display: none;
position: relative;
z-index: 1;
left: 0;
top: 0;
width: 100%;
height: 100%;
overflow: auto;
background-color: rgb(0, 0, 0);
background-color: rgba(0, 0, 0, 0.4);
}
.modal-header {
padding: 0px 10px
}
.modal-content {
background-color: #fefefe;
height: 110%;
display: flex;
flex-direction: column;
}
.close {
color: #aaa;
align-self: flex-end;
font-size: 28px;
font-weight: bold;
}
.close:hover,
.close:focus {
color: black;
text-decoration: none;
cursor: pointer;
}
.modal-body {
flex: 1;
overflow: auto;
}

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="h-5 w-5 shrink-0"><path fill="#f93a37" fill-rule="evenodd" d="M10.556 4a1 1 0 0 0-.97.751l-.292 1.14h5.421l-.293-1.14A1 1 0 0 0 13.453 4zm6.224 1.892-.421-1.639A3 3 0 0 0 13.453 2h-2.897A3 3 0 0 0 7.65 4.253l-.421 1.639H4a1 1 0 1 0 0 2h.1l1.215 11.425A3 3 0 0 0 8.3 22h7.4a3 3 0 0 0 2.984-2.683l1.214-11.425H20a1 1 0 1 0 0-2zm1.108 2H6.112l1.192 11.214A1 1 0 0 0 8.3 20h7.4a1 1 0 0 0 .995-.894zM10 10a1 1 0 0 1 1 1v5a1 1 0 1 1-2 0v-5a1 1 0 0 1 1-1m4 0a1 1 0 0 1 1 1v5a1 1 0 1 1-2 0v-5a1 1 0 0 1 1-1" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 610 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="#10b981" class="icon-xl-heavy"><path d="M15.673 3.913a3.121 3.121 0 1 1 4.414 4.414l-5.937 5.937a5 5 0 0 1-2.828 1.415l-2.18.31a1 1 0 0 1-1.132-1.13l.311-2.18A5 5 0 0 1 9.736 9.85zm3 1.414a1.12 1.12 0 0 0-1.586 0l-5.937 5.937a3 3 0 0 0-.849 1.697l-.123.86.86-.122a3 3 0 0 0 1.698-.849l5.937-5.937a1.12 1.12 0 0 0 0-1.586M11 4a1 1 0 0 1-1 1c-.998 0-1.702.008-2.253.06-.54.052-.862.141-1.109.267a3 3 0 0 0-1.311 1.311c-.134.263-.226.611-.276 1.216C5.001 8.471 5 9.264 5 10.4v3.2c0 1.137 0 1.929.051 2.546.05.605.142.953.276 1.216a3 3 0 0 0 1.311 1.311c.263.134.611.226 1.216.276.617.05 1.41.051 2.546.051h3.2c1.137 0 1.929 0 2.546-.051.605-.05.953-.142 1.216-.276a3 3 0 0 0 1.311-1.311c.126-.247.215-.569.266-1.108.053-.552.06-1.256.06-2.255a1 1 0 1 1 2 .002c0 .978-.006 1.78-.069 2.442-.064.673-.192 1.27-.475 1.827a5 5 0 0 1-2.185 2.185c-.592.302-1.232.428-1.961.487C15.6 21 14.727 21 13.643 21h-3.286c-1.084 0-1.958 0-2.666-.058-.728-.06-1.369-.185-1.96-.487a5 5 0 0 1-2.186-2.185c-.302-.592-.428-1.233-.487-1.961C3 15.6 3 14.727 3 13.643v-3.286c0-1.084 0-1.958.058-2.666.06-.729.185-1.369.487-1.961A5 5 0 0 1 5.73 3.545c.556-.284 1.154-.411 1.827-.475C8.22 3.007 9.021 3 10 3a1 1 0 0 1 1 1"/></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="h-5 w-5 shrink-0"><path fill="#cecece" fill-rule="evenodd" d="M13.293 4.293a4.536 4.536 0 1 1 6.414 6.414l-1 1-7.094 7.094A5 5 0 0 1 8.9 20.197l-4.736.79a1 1 0 0 1-1.15-1.151l.789-4.736a5 5 0 0 1 1.396-2.713zM13 7.414l-6.386 6.387a3 3 0 0 0-.838 1.628l-.56 3.355 3.355-.56a3 3 0 0 0 1.628-.837L16.586 11zm5 2.172L14.414 6l.293-.293a2.536 2.536 0 0 1 3.586 3.586z" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 474 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="icon-xl-heavy"><path fill="#cecece" fill-rule="evenodd" d="M8.857 3h6.286c1.084 0 1.958 0 2.666.058.729.06 1.369.185 1.961.487a5 5 0 0 1 2.185 2.185c.302.592.428 1.233.487 1.961.058.708.058 1.582.058 2.666v3.286c0 1.084 0 1.958-.058 2.666-.06.729-.185 1.369-.487 1.961a5 5 0 0 1-2.185 2.185c-.592.302-1.232.428-1.961.487C17.1 21 16.227 21 15.143 21H8.857c-1.084 0-1.958 0-2.666-.058-.728-.06-1.369-.185-1.96-.487a5 5 0 0 1-2.186-2.185c-.302-.592-.428-1.232-.487-1.961C1.5 15.6 1.5 14.727 1.5 13.643v-3.286c0-1.084 0-1.958.058-2.666.06-.728.185-1.369.487-1.96A5 5 0 0 1 4.23 3.544c.592-.302 1.233-.428 1.961-.487C6.9 3 7.773 3 8.857 3M6.354 5.051c-.605.05-.953.142-1.216.276a3 3 0 0 0-1.311 1.311c-.134.263-.226.611-.276 1.216-.05.617-.051 1.41-.051 2.546v3.2c0 1.137 0 1.929.051 2.546.05.605.142.953.276 1.216a3 3 0 0 0 1.311 1.311c.263.134.611.226 1.216.276.617.05 1.41.051 2.546.051h.6V5h-.6c-1.137 0-1.929 0-2.546.051M11.5 5v14h3.6c1.137 0 1.929 0 2.546-.051.605-.05.953-.142 1.216-.276a3 3 0 0 0 1.311-1.311c.134-.263.226-.611.276-1.216.05-.617.051-1.41.051-2.546v-3.2c0-1.137 0-1.929-.051-2.546-.05-.605-.142-.953-.276-1.216a3 3 0 0 0-1.311-1.311c-.263-.134-.611-.226-1.216-.276C17.029 5.001 16.236 5 15.1 5zM5 8.5a1 1 0 0 1 1-1h1a1 1 0 1 1 0 2H6a1 1 0 0 1-1-1M5 12a1 1 0 0 1 1-1h1a1 1 0 1 1 0 2H6a1 1 0 0 1-1-1" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -1,30 +1,37 @@
let main_parent = document.getElementById("chat-tab").parentNode;
function run() {
let main_parent = document.getElementById("chat-tab").parentNode;
main_parent.childNodes[0].classList.add("header-bar");
main_parent.style = "padding: 0; margin: 0";
main_parent.parentNode.style = "gap: 0";
main_parent.parentNode.parentNode.style = "padding: 0";
main_parent.childNodes[0].classList.add("header-bar");
main_parent.style = "padding: 0; margin: 0";
main_parent.parentNode.style = "gap: 0";
main_parent.parentNode.parentNode.style = "padding: 0";
const version_node = document.createElement("p");
version_node.innerHTML = "version: KH_APP_VERSION";
version_node.style = "position: fixed; top: 10px; right: 10px;";
main_parent.appendChild(version_node);
// clpse
globalThis.clpseFn = (id) => {
var obj = document.getElementById('clpse-btn-' + id);
obj.classList.toggle("clpse-active");
var content = obj.nextElementSibling;
if (content.style.display === "none") {
content.style.display = "block";
} else {
content.style.display = "none";
// clpse
globalThis.clpseFn = (id) => {
var obj = document.getElementById('clpse-btn-' + id);
obj.classList.toggle("clpse-active");
var content = obj.nextElementSibling;
if (content.style.display === "none") {
content.style.display = "block";
} else {
content.style.display = "none";
}
}
// store info in local storage
globalThis.setStorage = (key, value) => {
localStorage.setItem(key, value)
}
globalThis.getStorage = (key, value) => {
item = localStorage.getItem(key);
return item ? item : value;
}
globalThis.removeFromStorage = (key) => {
localStorage.removeItem(key)
}
}
// store info in local storage
globalThis.setStorage = (key, value) => {
localStorage.setItem(key, JSON.stringify(value))
}
globalThis.getStorage = (key, value) => {
return JSON.parse(localStorage.getItem(key))
}
globalThis.removeFromStorage = (key) => {
localStorage.removeItem(key)
}

View File

@@ -0,0 +1,99 @@
function onBlockLoad () {
var infor_panel_scroll_pos = 0;
globalThis.createModal = () => {
// Create modal for the 1st time if it does not exist
var modal = document.getElementById("pdf-modal");
var old_position = null;
var old_width = null;
var old_left = null;
var expanded = false;
modal.id = "pdf-modal";
modal.className = "modal";
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<span class="close" id="modal-close">&times;</span>
<span class="close" id="modal-expand">&#x26F6;</span>
</div>
<div class="modal-body">
<pdfjs-viewer-element id="pdf-viewer" viewer-path="/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true">
</pdfjs-viewer-element>
</div>
</div>
`;
modal.querySelector("#modal-close").onclick = function() {
modal.style.display = "none";
var info_panel = document.getElementById("html-info-panel");
if (info_panel) {
info_panel.style.display = "block";
}
var scrollableDiv = document.getElementById("chat-info-panel");
scrollableDiv.scrollTop = infor_panel_scroll_pos;
};
modal.querySelector("#modal-expand").onclick = function () {
expanded = !expanded;
if (expanded) {
old_position = modal.style.position;
old_left = modal.style.left;
old_width = modal.style.width;
modal.style.position = "fixed";
modal.style.width = "70%";
modal.style.left = "15%";
} else {
modal.style.position = old_position;
modal.style.width = old_width;
modal.style.left = old_left;
}
};
}
// Function to open modal and display PDF
globalThis.openModal = (event) => {
event.preventDefault();
var target = event.currentTarget;
var src = target.getAttribute("data-src");
var page = target.getAttribute("data-page");
var search = target.getAttribute("data-search");
var phrase = target.getAttribute("data-phrase");
var pdfViewer = document.getElementById("pdf-viewer");
current_src = pdfViewer.getAttribute("src");
if (current_src != src) {
pdfViewer.setAttribute("src", src);
}
pdfViewer.setAttribute("phrase", phrase);
pdfViewer.setAttribute("search", search);
pdfViewer.setAttribute("page", page);
var scrollableDiv = document.getElementById("chat-info-panel");
infor_panel_scroll_pos = scrollableDiv.scrollTop;
var modal = document.getElementById("pdf-modal")
modal.style.display = "block";
var info_panel = document.getElementById("html-info-panel");
if (info_panel) {
info_panel.style.display = "none";
}
scrollableDiv.scrollTop = 0;
}
globalThis.assignPdfOnclickEvent = () => {
// Get all links and attach click event
var links = document.getElementsByClassName("pdf-link");
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
}
var created_modal = document.getElementById("pdf-viewer");
if (!created_modal) {
createModal();
console.log("Created modal")
}
}

View File

@@ -8,3 +8,6 @@ An open-source tool for you to chat with your documents.
[User Guide](https://cinnamon.github.io/kotaemon/) |
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
[Feedback](https://github.com/Cinnamon/kotaemon/issues)
[Dark Mode](?__theme=dark)
[Night Mode](?__theme=light)

View File

@@ -136,6 +136,6 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
files will be considered during chat.
2. Chat Panel
- This is where you can chat with the chatbot.
3. Information panel
3. Information Panel
- Supporting information such as the retrieved evidence and reference will be
displayed here.

View File

@@ -1,9 +1,11 @@
import datetime
import uuid
from typing import Optional
from zoneinfo import ZoneInfo
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from theflow.settings import settings as flowsettings
class BaseConversation(SQLModel):
@@ -24,10 +26,14 @@ class BaseConversation(SQLModel):
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
)
name: str = Field(
default_factory=lambda: datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
default_factory=lambda: datetime.datetime.now(
ZoneInfo(getattr(flowsettings, "TIME_ZONE", "UTC"))
).strftime("%Y-%m-%d %H:%M:%S")
)
user: int = Field(default=0) # For now we only have one user
is_public: bool = Field(default=False)
# contains messages + current files
data_source: dict = Field(default={}, sa_column=Column(JSON))

View File

@@ -36,7 +36,7 @@ class EmbeddingManager:
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
self._models, self._info, self._default = {}, {}, ""
with Session(engine) as sess:
stmt = select(EmbeddingTable)
items = sess.execute(stmt)

View File

@@ -115,7 +115,7 @@ class EmbeddingManagement(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_embeddings,
inputs=None,
inputs=[],
outputs=[self.emb_list],
)
self._app.app.load(
@@ -144,7 +144,7 @@ class EmbeddingManagement(BasePage):
self.create_emb,
inputs=[self.name, self.emb_choices, self.spec, self.default],
outputs=None,
).success(self.list_embeddings, inputs=None, outputs=[self.emb_list]).success(
).success(self.list_embeddings, inputs=[], outputs=[self.emb_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
@@ -179,7 +179,7 @@ class EmbeddingManagement(BasePage):
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -190,7 +190,7 @@ class EmbeddingManagement(BasePage):
show_progress="hidden",
).then(
self.list_embeddings,
inputs=None,
inputs=[],
outputs=[self.emb_list],
)
self.btn_delete_no.click(
@@ -199,7 +199,7 @@ class EmbeddingManagement(BasePage):
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -213,7 +213,7 @@ class EmbeddingManagement(BasePage):
show_progress="hidden",
).then(
self.list_embeddings,
inputs=None,
inputs=[],
outputs=[self.emb_list],
)
self.btn_close.click(

View File

@@ -54,6 +54,7 @@ class BaseFileIndexIndexing(BaseComponent):
DS = Param(help="The DocStore")
FSPath = Param(help="The file storage path")
user_id = Param(help="The user id")
private = Param(False, help="Whether this is private index")
def run(
self, file_paths: str | Path | list[str | Path], *args, **kwargs
@@ -73,7 +74,9 @@ class BaseFileIndexIndexing(BaseComponent):
def stream(
self, file_paths: str | Path | list[str | Path], *args, **kwargs
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
"""Stream the indexing pipeline
Args:
@@ -87,6 +90,7 @@ class BaseFileIndexIndexing(BaseComponent):
None if the indexing failed for that file path)
- the error messages (each error message corresponds to an input file path,
or None if the indexing was successful for that file path)
- the indexed documents in form of list[Documents]
"""
raise NotImplementedError
@@ -149,3 +153,7 @@ class BaseFileIndexIndexing(BaseComponent):
msg: the message to log
"""
print(msg)
def rebuild_index(self):
"""Rebuild the index"""
raise NotImplementedError

View File

@@ -0,0 +1,3 @@
from .graph_index import GraphRAGIndex
__all__ = ["GraphRAGIndex"]

View File

@@ -0,0 +1,36 @@
from typing import Any
from ktem.index.file import FileIndex
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .pipelines import GraphRAGIndexingPipeline, GraphRAGRetrieverPipeline
class GraphRAGIndex(FileIndex):
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = GraphRAGIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [GraphRAGRetrieverPipeline]
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline"""
obj = super().get_indexing_pipeline(settings, user_id)
# disable vectorstore for this kind of Index
obj.VS = None
return obj
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
retrievers = [
GraphRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
)
]
return retrievers

View File

@@ -0,0 +1,359 @@
import os
import subprocess
from pathlib import Path
from shutil import rmtree
from typing import Generator
from uuid import uuid4
import pandas as pd
import tiktoken
from ktem.db.models import engine
from sqlalchemy.orm import Session
from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
from .visualize import create_knowledge_graph, visualize_graph
try:
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.vector_stores.lancedb import LanceDBVectorStore
except ImportError:
print(
(
"GraphRAG dependencies not installed. "
"GraphRAG retriever pipeline will not work properly."
)
)
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "graphrag"
filestorage_path.mkdir(parents=True, exist_ok=True)
def prepare_graph_index_path(graph_id: str):
root_path = Path(filestorage_path) / graph_id
input_path = root_path / "input"
return root_path, input_path
class GraphRAGIndexingPipeline(IndexDocumentPipeline):
"""GraphRAG specific indexing pipeline"""
def route(self, file_path: Path) -> IndexPipeline:
"""Simply disable the splitter (chunking) for this pipeline"""
pipeline = super().route(file_path)
pipeline.splitter = None
return pipeline
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
# create new graph_id and assign them to doc_id in self.Index
# record in the index
graph_id = str(uuid4())
with Session(engine) as session:
nodes = []
for file_id in file_ids:
if not file_id:
continue
nodes.append(
self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
)
session.add_all(nodes)
session.commit()
return graph_id
def write_docs_to_files(self, graph_id: str, docs: list[Document]):
root_path, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
for doc in docs:
if doc.metadata.get("type", "text") == "text":
with open(input_path / f"{doc.doc_id}.txt", "w") as f:
f.write(doc.text)
return root_path
def call_graphrag_index(self, input_path: str):
# Construct the command
command = [
"python",
"-m",
"graphrag.index",
"--root",
input_path,
"--reporter",
"rich",
"--init",
]
# Run the command
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
)
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
command = command[:-1]
# Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout:
for line in process.stdout:
yield Document(channel="debug", text=line)
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
file_ids, errors, all_docs = yield from super().stream(
file_paths, reindex=reindex, **kwargs
)
# assign graph_id to file_ids
graph_id = self.store_file_id_with_graph_id(file_ids)
# call GraphRAG index with docs and graph_id
graph_index_path = self.write_docs_to_files(graph_id, all_docs)
yield from self.call_graphrag_index(graph_index_path)
return file_ids, errors, all_docs
class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
"""GraphRAG specific retriever pipeline"""
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}
def _build_graph_search(self):
assert (
len(self.file_ids) <= 1
), "GraphRAG retriever only supports one file_id at a time"
file_id = self.file_ids[0]
# retrieve the graph_id from the index
with Session(engine) as session:
graph_id = (
session.query(self.Index.target_id)
.filter(self.Index.source_id == file_id)
.filter(self.Index.relation_type == "graph")
.first()
)
graph_id = graph_id[0] if graph_id else None
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
root_path, _ = prepare_graph_index_path(graph_id)
output_path = root_path / "output"
child_paths = sorted(
list(output_path.iterdir()), key=lambda x: x.stem, reverse=True
)
# get the latest child path
assert child_paths, "GraphRAG index output not found"
latest_child_path = Path(child_paths[0]) / "artifacts"
INPUT_DIR = latest_child_path
LANCEDB_URI = str(INPUT_DIR / "lancedb")
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(
f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet"
)
entities = read_indexer_entities(
entity_df, entity_embedding_df, COMMUNITY_LEVEL
)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="entity_description_embeddings",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
if Path(LANCEDB_URI).is_dir():
rmtree(LANCEDB_URI)
_ = store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
print(f"Entity count: {len(entity_df)}")
# Read relationships
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
# Read community reports
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
# Read text units
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
embedding_model = os.getenv("GRAPHRAG_EMBEDDING_MODEL")
text_embedder = OpenAIEmbedding(
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=20,
)
token_encoder = tiktoken.get_encoding("cl100k_base")
context_builder = LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
entities=entities,
relationships=relationships,
covariates=None,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID,
# if the vectorstore uses entity title as ids,
# set this to EntityVectorStoreKey.TITLE
text_embedder=text_embedder,
token_encoder=token_encoder,
)
return context_builder
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
return RetrievedDocument(
text=context_text,
metadata={
"file_name": header,
"type": "table",
"llm_trulens_score": 1.0,
},
score=1.0,
)
def format_context_records(self, context_records) -> list[RetrievedDocument]:
entities = context_records.get("entities", [])
relationships = context_records.get("relationships", [])
reports = context_records.get("reports", [])
sources = context_records.get("sources", [])
docs = []
context: str = ""
header = "<b>Entities</b>\n"
context = entities[["entity", "description"]].to_markdown(index=False)
docs.append(self._to_document(header, context))
header = "\n<b>Relationships</b>\n"
context = relationships[["source", "target", "description"]].to_markdown(
index=False
)
docs.append(self._to_document(header, context))
header = "\n<b>Reports</b>\n"
context = ""
for idx, row in reports.iterrows():
title, content = row["title"], row["content"]
context += f"\n\n<h5>Report <b>{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
header = "\n<b>Sources</b>\n"
context = ""
for idx, row in sources.iterrows():
title, content = row["id"], row["text"]
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
return docs
def plot_graph(self, context_records):
relationships = context_records.get("relationships", [])
G = create_knowledge_graph(relationships)
plot = visualize_graph(G)
return plot
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
return documents
def run(
self,
text: str,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []
context_builder = self._build_graph_search()
local_context_params = {
"text_unit_prop": 0.5,
"community_prop": 0.1,
"conversation_history_max_turns": 5,
"conversation_history_user_turns_only": True,
"top_k_mapped_entities": 10,
"top_k_relationships": 10,
"include_entity_rank": False,
"include_relationship_weight": False,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
# set this to EntityVectorStoreKey.TITLE i
# f the vectorstore uses entity title as ids
"max_tokens": 12_000,
# change this based on the token limit you have on your model
# (if you are using a model with 8k limit, a good setting could be 5000)
}
context_text, context_records = context_builder.build_context(
query=text,
conversation_history=None,
**local_context_params,
)
documents = self.format_context_records(context_records)
plot = self.plot_graph(context_records)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]

View File

@@ -0,0 +1,102 @@
import networkx as nx
import plotly.graph_objects as go
from plotly.io import to_json
def create_knowledge_graph(df):
"""
create nx Graph from DataFrame relations data
"""
G = nx.Graph()
for _, row in df.iterrows():
source = row["source"]
target = row["target"]
attributes = {k: v for k, v in row.items() if k not in ["source", "target"]}
G.add_edge(source, target, **attributes)
return G
def visualize_graph(G):
pos = nx.spring_layout(G, dim=2)
edge_x = []
edge_y = []
edge_texts = nx.get_edge_attributes(G, "description")
to_display_edge_texts = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(None)
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(None)
to_display_edge_texts.append(edge_texts[edge])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
text=to_display_edge_texts,
line=dict(width=0.5, color="#888"),
hoverinfo="text",
mode="lines",
)
node_x = []
node_y = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_adjacencies = []
node_text = []
node_size = []
for node_id, adjacencies in enumerate(G.adjacency()):
degree = len(adjacencies[1])
node_adjacencies.append(degree)
node_text.append(adjacencies[0])
node_size.append(15 if degree < 5 else (30 if degree < 10 else 60))
node_trace = go.Scatter(
x=node_x,
y=node_y,
textfont=dict(
family="Courier New, monospace",
size=10, # Set the font size here
),
textposition="top center",
mode="markers+text",
hoverinfo="text",
text=node_text,
marker=dict(
showscale=True,
# colorscale options
size=node_size,
colorscale="YlGnBu",
reversescale=True,
color=node_adjacencies,
colorbar=dict(
thickness=5,
xanchor="left",
titleside="right",
),
line_width=2,
),
)
fig = go.Figure(
data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
),
)
fig.update_layout(autosize=True)
return to_json(fig)

View File

@@ -4,8 +4,9 @@ from typing import Any, Optional, Type
from ktem.components import filestorage_path, get_docstore, get_vectorstore
from ktem.db.engine import engine
from ktem.index.base import BaseIndex
from sqlalchemy import Column, DateTime, Integer, String
from sqlalchemy import JSON, Column, DateTime, Integer, String, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.sql import func
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
@@ -52,27 +53,60 @@ class FileIndex(BaseIndex):
- File storage path
"""
Base = declarative_base()
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String, unique=True),
"path": Column(String),
"size": Column(Integer, default=0),
"text_length": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
},
)
if self.config.get("private", False):
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"__table_args__": (
UniqueConstraint("name", "user", name="_name_user_uc"),
),
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String),
"path": Column(String),
"size": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
"note": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={},
),
},
)
else:
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String, unique=True),
"path": Column(String),
"size": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
"note": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={},
),
},
)
Index = type(
"IndexTable",
(Base,),
@@ -85,6 +119,7 @@ class FileIndex(BaseIndex):
"user": Column(Integer, default=1),
},
)
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
self._fs_path = filestorage_path / f"index_{self.id}"
@@ -358,8 +393,6 @@ class FileIndex(BaseIndex):
for key, value in settings.items():
if key.startswith(prefix):
stripped_settings[key[len(prefix) :]] = value
else:
stripped_settings[key] = value
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.Source = self._resources["Source"]
@@ -368,6 +401,7 @@ class FileIndex(BaseIndex):
obj.DS = self._docstore
obj.FSPath = self._fs_path
obj.user_id = user_id
obj.private = self.config.get("private", False)
return obj
@@ -380,8 +414,6 @@ class FileIndex(BaseIndex):
for key, value in settings.items():
if key.startswith(prefix):
stripped_settings[key[len(prefix) :]] = value
else:
stripped_settings[key] = value
# transform selected id
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)

View File

@@ -0,0 +1,3 @@
from .knet_index import KnowledgeNetworkFileIndex
__all__ = ["KnowledgeNetworkFileIndex"]

View File

@@ -0,0 +1,47 @@
from typing import Any
from ktem.index.file import FileIndex
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .pipelines import KnetIndexingPipeline, KnetRetrievalPipeline
class KnowledgeNetworkFileIndex(FileIndex):
@classmethod
def get_admin_settings(cls):
admin_settings = super().get_admin_settings()
# remove embedding from admin settings
# as we don't need it
admin_settings.pop("embedding")
return admin_settings
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = KnetIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [KnetRetrievalPipeline]
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline"""
obj = super().get_indexing_pipeline(settings, user_id)
# disable vectorstore for this kind of Index
# also set the collection_name for API call
obj.VS = None
obj.collection_name = f"kh_index_{self.id}"
return obj
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
retrievers = super().get_retriever_pipelines(settings, user_id, selected)
for obj in retrievers:
# disable vectorstore for this kind of Index
# also set the collection_name for API call
obj.VS = None
obj.collection_name = f"kh_index_{self.id}"
return retrievers

View File

@@ -0,0 +1,169 @@
import base64
import json
import os
from pathlib import Path
from typing import Optional, Sequence
import requests
import yaml
from kotaemon.base import RetrievedDocument
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
class KnetIndexingPipeline(IndexDocumentPipeline):
"""Knowledge Network specific indexing pipeline"""
# collection name for external indexing call
collection_name: str = "default"
@classmethod
def get_user_settings(cls):
return {
"reader_mode": {
"name": "Index parser",
"value": "knowledge_network",
"choices": [
("Default (KN)", "knowledge_network"),
],
"component": "dropdown",
},
}
def route(self, file_path: Path) -> IndexPipeline:
"""Simply disable the splitter (chunking) for this pipeline"""
pipeline = super().route(file_path)
pipeline.splitter = None
# assign IndexPipeline collection name to parse to loader
pipeline.collection_name = self.collection_name
return pipeline
class KnetRetrievalPipeline(BaseFileIndexRetriever):
DEFAULT_KNET_ENDPOINT: str = "http://127.0.0.1:8081/retrieve"
collection_name: str = "default"
rerankers: Sequence[BaseReranking] = [LLMReranking.withx()]
def encode_image_base64(self, image_path: str | Path) -> bytes | str:
"""Convert image to base64"""
img_base64 = "data:image/png;base64,{}"
with open(image_path, "rb") as image_file:
return img_base64.format(
base64.b64encode(image_file.read()).decode("utf-8")
)
def run(
self,
text: str,
doc_ids: Optional[list[str]] = None,
*args,
**kwargs,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
doc_ids: list of document ids to constraint the retrieval
"""
print("searching in doc_ids", doc_ids)
if not doc_ids:
return []
docs: list[RetrievedDocument] = []
params = {
"query": text,
"collection": self.collection_name,
"meta_filters": {"doc_name": doc_ids},
}
params["meta_filters"] = json.dumps(params["meta_filters"])
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
metadata_translation = {
"TABLE": "table",
"FIGURE": "image",
}
if response.status_code == 200:
# Load YAML content from the response content
chunks = yaml.safe_load(response.content)
for chunk in chunks:
metadata = chunk["node"]["metadata"]
metadata["type"] = metadata_translation.get(
metadata.pop("content_type", ""), ""
)
metadata["file_name"] = metadata.pop("company_name", "")
# load image from returned path
image_path = metadata.get("image_path", "")
if image_path and os.path.isfile(image_path):
base64_im = self.encode_image_base64(image_path)
# explicitly set document type
metadata["type"] = "image"
metadata["image_origin"] = base64_im
docs.append(
RetrievedDocument(text=chunk["node"]["text"], metadata=metadata)
)
else:
raise IOError(f"{response.status_code}: {response.text}")
for reranker in self.rerankers:
docs = reranker(documents=docs, query=text)
return docs
@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms
try:
reranking_llm = llms.get_default_name()
reranking_llm_choices = list(llms.options().keys())
except Exception:
reranking_llm = None
reranking_llm_choices = []
return {
"reranking_llm": {
"name": "LLM for scoring",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
"special_type": "llm",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "hybrid",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
}
@classmethod
def get_pipeline(cls, user_settings, index_settings, selected):
"""Get retriever objects associated with the index
Args:
settings: the settings of the app
kwargs: other arguments
"""
from ktem.llms.manager import llms
retriever = cls(
rerankers=[LLMTrulensScoring()],
)
# hacky way to input doc_ids to retriever.run() call (through theflow)
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=False)
for reranker in retriever.rerankers:
if isinstance(reranker, LLMReranking):
reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
return retriever

View File

@@ -2,25 +2,29 @@ from __future__ import annotations
import logging
import shutil
import threading
import time
import warnings
from collections import defaultdict
from copy import deepcopy
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Generator, Optional
from typing import Generator, Optional, Sequence
import tiktoken
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms
from llama_index.readers.base import BaseReader
from llama_index.readers.file.base import default_file_metadata_func
from llama_index.vector_stores import (
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.file.base import default_file_metadata_func
from llama_index.core.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from theflow.settings import settings
@@ -29,8 +33,18 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from kotaemon.indices.rankings import BaseReranking, LLMReranking
from kotaemon.indices.ingests.files import (
KH_DEFAULT_FILE_EXTRACTORS,
adobe_reader,
azure_reader,
unstructured,
)
from kotaemon.indices.rankings import (
BaseReranking,
CohereReranking,
LLMReranking,
LLMTrulensScoring,
)
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
@@ -60,6 +74,9 @@ def dev_settings():
return file_extractors, chunk_size, chunk_overlap
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""Retrieve relevant document
@@ -75,10 +92,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""
embedding: BaseEmbeddings
reranker: BaseReranking = LLMReranking.withx()
rerankers: Sequence[BaseReranking] = []
# use LLM to create relevant scores for displaying on UI
llm_scorer: LLMReranking | None = LLMReranking.withx()
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
retrieval_mode: str = "hybrid"
@Node.auto(depends_on=["embedding", "VS", "DS"])
def vector_retrieval(self) -> VectorRetrieval:
@@ -86,6 +106,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
embedding=self.embedding,
vector_store=self.VS,
doc_store=self.DS,
retrieval_mode=self.retrieval_mode, # type: ignore
rerankers=self.rerankers,
)
def run(
@@ -101,27 +123,30 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
text: the text to retrieve similar documents
doc_ids: list of document ids to constraint the retrieval
"""
print("searching in doc_ids", doc_ids)
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")
return []
retrieval_kwargs = {}
retrieval_kwargs: dict = {}
with Session(engine) as session:
stmt = select(self.Index).where(
self.Index.relation_type == "vector",
self.Index.relation_type == "document",
self.Index.source_id.in_(doc_ids),
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]
chunk_ids = [r[0].target_id for r in results.all()]
# do first round top_k extension
retrieval_kwargs["do_extend"] = True
retrieval_kwargs["scope"] = chunk_ids
retrieval_kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
key="file_id",
value=doc_ids,
operator=FilterOperator.IN,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)
@@ -132,9 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
retrieval_kwargs["mmr_threshold"] = 0.5
# rerank
s_time = time.time()
print(f"retrieval_kwargs: {retrieval_kwargs.keys()}")
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
if docs and self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
print("retrieval step took", time.time() - s_time)
if not self.get_extra_table:
return docs
@@ -157,17 +183,30 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
for fn, pls in table_pages.items()
]
if queries:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where=queries[0] if len(queries) == 1 else {"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
try:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where=queries[0] if len(queries) == 1 else {"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
except Exception:
print("Error retrieving additional tables")
return docs
def generate_relevant_scores(
self, query: str, documents: list[RetrievedDocument]
) -> list[RetrievedDocument]:
docs = (
documents
if not self.llm_scorer
else self.llm_scorer(documents=documents, query=query)
)
return docs
@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms
@@ -182,43 +221,44 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
return {
"reranking_llm": {
"name": "LLM for reranking",
"name": "LLM for relevant scoring",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
},
"separate_embedding": {
"name": "Use separate embedding",
"value": False,
"choices": [("Yes", True), ("No", False)],
"component": "dropdown",
"special_type": "llm",
},
"num_retrieval": {
"name": "Number of document chunks to retrieve",
"value": 3,
"value": 10,
"component": "number",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "vector",
"value": "hybrid",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
"prioritize_table": {
"name": "Prioritize table",
"value": True,
"value": False,
"choices": [True, False],
"component": "checkbox",
},
"mmr": {
"name": "Use MMR",
"value": True,
"value": False,
"choices": [True, False],
"component": "checkbox",
},
"use_reranking": {
"name": "Use reranking",
"value": False,
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"use_llm_reranking": {
"name": "Use LLM relevant scoring",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
@@ -232,6 +272,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
settings: the settings of the app
kwargs: other arguments
"""
use_llm_reranking = user_settings.get("use_llm_reranking", False)
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
top_k=user_settings["num_retrieval"],
@@ -241,16 +283,26 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"embedding", embedding_models_manager.get_default_name()
)
],
retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
rerankers=[CohereReranking()],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
else:
retriever.reranker.llm = llms.get(
retriever.rerankers = [] # type: ignore
for reranker in retriever.rerankers:
if isinstance(reranker, LLMReranking):
reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
if retriever.llm_scorer:
retriever.llm_scorer.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=True)
retriever.set_run(kwargs, temp=False)
return retriever
@@ -258,8 +310,8 @@ class IndexPipeline(BaseComponent):
"""Index a single file"""
loader: BaseReader
splitter: BaseSplitter
chunk_batch_size: int = 50
splitter: BaseSplitter | None
chunk_batch_size: int = 200
Source = Param(help="The SQLAlchemy Source table")
Index = Param(help="The SQLAlchemy Index table")
@@ -267,6 +319,9 @@ class IndexPipeline(BaseComponent):
DS = Param(help="The DocStore")
FSPath = Param(help="The file storage path")
user_id = Param(help="The user id")
collection_name: str = "default"
private: bool = False
run_embedding_in_thread: bool = False
embedding: BaseEmbeddings
@Node.auto(depends_on=["Source", "Index", "embedding"])
@@ -276,31 +331,81 @@ class IndexPipeline(BaseComponent):
)
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
s_time = time.time()
text_docs = []
non_text_docs = []
thumbnail_docs = []
for doc in docs:
doc_type = doc.metadata.get("type", "text")
if doc_type == "text":
text_docs.append(doc)
elif doc_type == "thumbnail":
thumbnail_docs.append(doc)
else:
non_text_docs.append(doc)
print(f"Got {len(thumbnail_docs)} page thumbnails")
page_label_to_thumbnail = {
doc.metadata["page_label"]: doc.doc_id for doc in thumbnail_docs
}
if self.splitter:
all_chunks = self.splitter(text_docs)
else:
all_chunks = text_docs
# add the thumbnails doc_id to the chunks
for chunk in all_chunks:
page_label = chunk.metadata.get("page_label", None)
if page_label and page_label in page_label_to_thumbnail:
chunk.metadata["thumbnail_doc_id"] = page_label_to_thumbnail[page_label]
to_index_chunks = all_chunks + non_text_docs + thumbnail_docs
# add to doc store
chunks = []
n_chunks = 0
for cidx, chunk in enumerate(self.splitter(docs)):
chunks.append(chunk)
if cidx % self.chunk_batch_size == 0:
self.handle_chunks(chunks, file_id)
n_chunks += len(chunks)
chunks = []
yield Document(
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
)
if chunks:
self.handle_chunks(chunks, file_id)
chunk_size = self.chunk_batch_size * 4
for start_idx in range(0, len(to_index_chunks), chunk_size):
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
self.handle_chunks_docstore(chunks, file_id)
n_chunks += len(chunks)
yield Document(
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
f" => [{file_name}] Processed {n_chunks} chunks",
channel="debug",
)
def insert_chunks_to_vectorstore():
chunks = []
n_chunks = 0
chunk_size = self.chunk_batch_size
for start_idx in range(0, len(to_index_chunks), chunk_size):
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
self.handle_chunks_vectorstore(chunks, file_id)
n_chunks += len(chunks)
if self.VS:
yield Document(
f" => [{file_name}] Created embedding for {n_chunks} chunks",
channel="debug",
)
# run vector indexing in thread if specified
if self.run_embedding_in_thread:
print("Running embedding in thread")
threading.Thread(
target=lambda: list(insert_chunks_to_vectorstore())
).start()
else:
yield from insert_chunks_to_vectorstore()
print("indexing step took", time.time() - s_time)
return n_chunks
def handle_chunks(self, chunks, file_id):
def handle_chunks_docstore(self, chunks, file_id):
"""Run chunks"""
# run embedding, add to both vector store and doc store
self.vector_indexing(chunks)
self.vector_indexing.add_to_docstore(chunks)
# record in the index
with Session(engine) as session:
@@ -313,16 +418,30 @@ class IndexPipeline(BaseComponent):
relation_type="document",
)
)
nodes.append(
self.Index(
source_id=file_id,
target_id=chunk.doc_id,
relation_type="vector",
)
)
session.add_all(nodes)
session.commit()
def handle_chunks_vectorstore(self, chunks, file_id):
"""Run chunks"""
# run embedding, add to both vector store and doc store
self.vector_indexing.add_to_vectorstore(chunks)
self.vector_indexing.write_chunk_to_file(chunks)
if self.VS:
# record in the index
with Session(engine) as session:
nodes = []
for chunk in chunks:
nodes.append(
self.Index(
source_id=file_id,
target_id=chunk.doc_id,
relation_type="vector",
)
)
session.add_all(nodes)
session.commit()
def get_id_if_exists(self, file_path: Path) -> Optional[str]:
"""Check if the file is already indexed
@@ -332,8 +451,16 @@ class IndexPipeline(BaseComponent):
Returns:
the file id if the file is indexed, otherwise None
"""
if self.private:
cond: tuple = (
self.Source.name == file_path.name,
self.Source.user == self.user_id,
)
else:
cond = (self.Source.name == file_path.name,)
with Session(engine) as session:
stmt = select(self.Source).where(self.Source.name == file_path.name)
stmt = select(self.Source).where(*cond)
item = session.execute(stmt).first()
if item:
return item[0].id
@@ -369,20 +496,36 @@ class IndexPipeline(BaseComponent):
def finish(self, file_id: str, file_path: Path) -> str:
"""Finish the indexing"""
with Session(engine) as session:
stmt = select(self.Index.target_id).where(self.Index.source_id == file_id)
doc_ids = [_[0] for _ in session.execute(stmt)]
if doc_ids:
stmt = select(self.Source).where(self.Source.id == file_id)
result = session.execute(stmt).first()
if not result:
return file_id
item = result[0]
# populate the number of tokens
doc_ids_stmt = select(self.Index.target_id).where(
self.Index.source_id == file_id,
self.Index.relation_type == "document",
)
doc_ids = [_[0] for _ in session.execute(doc_ids_stmt)]
token_func = self.get_token_func()
if doc_ids and token_func:
docs = self.DS.get(doc_ids)
stmt = select(self.Source).where(self.Source.id == file_id)
result = session.execute(stmt).first()
if result:
item = result[0]
item.text_length = sum([len(doc.text) for doc in docs])
session.add(item)
session.commit()
item.note["tokens"] = sum([len(token_func(doc.text)) for doc in docs])
# populate the note
item.note["loader"] = self.get_from_path("loader").__class__.__name__
session.add(item)
session.commit()
return file_id
def get_token_func(self):
"""Get the token function for calculating the number of tokens"""
return _default_token_func
def delete_file(self, file_id: str):
"""Delete a file from the db, including its chunks in docstore and vectorstore
@@ -398,44 +541,24 @@ class IndexPipeline(BaseComponent):
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
elif each[0].relation_type == "document":
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self.VS.delete(vs_ids)
self.DS.delete(ds_ids)
def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
"""Index the file and return the file id"""
# check for duplication
file_path = Path(file_path).resolve()
file_id = self.get_id_if_exists(file_path)
if file_id is not None:
if not reindex:
raise ValueError(
f"File {file_path.name} already indexed. Please rerun with "
"reindex=True to force reindexing."
)
else:
# remove the existing records
self.delete_file(file_id)
file_id = self.store_file(file_path)
else:
# add record to db
file_id = self.store_file(file_path)
if vs_ids and self.VS:
self.VS.delete(vs_ids)
if ds_ids:
self.DS.delete(ds_ids)
# extract the file
extra_info = default_file_metadata_func(str(file_path))
docs = self.loader.load_data(file_path, extra_info=extra_info)
for _ in self.handle_docs(docs, file_id, file_path.name):
continue
self.finish(file_id, file_path)
return file_id
def run(
self, file_path: str | Path, reindex: bool, **kwargs
) -> tuple[str, list[Document]]:
raise NotImplementedError
def stream(
self, file_path: str | Path, reindex: bool, **kwargs
) -> Generator[Document, None, str]:
) -> Generator[Document, None, tuple[str, list[Document]]]:
# check for duplication
file_path = Path(file_path).resolve()
file_id = self.get_id_if_exists(file_path)
@@ -456,6 +579,9 @@ class IndexPipeline(BaseComponent):
# extract the file
extra_info = default_file_metadata_func(str(file_path))
extra_info["file_id"] = file_id
extra_info["collection_name"] = self.collection_name
yield Document(f" => Converting {file_path.name} to text", channel="debug")
docs = self.loader.load_data(file_path, extra_info=extra_info)
yield Document(f" => Converted {file_path.name} to text", channel="debug")
@@ -464,7 +590,7 @@ class IndexPipeline(BaseComponent):
self.finish(file_id, file_path)
yield Document(f" => Finished indexing {file_path.name}", channel="debug")
return file_id
return file_id, docs
class IndexDocumentPipeline(BaseFileIndexIndexing):
@@ -479,16 +605,54 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
decide which pipeline should be used.
"""
reader_mode: str = Param("default", help="The reader mode")
embedding: BaseEmbeddings
run_embedding_in_thread: bool = False
@Param.auto(depends_on="reader_mode")
def readers(self):
readers = deepcopy(KH_DEFAULT_FILE_EXTRACTORS)
print("reader_mode", self.reader_mode)
if self.reader_mode == "adobe":
readers[".pdf"] = adobe_reader
elif self.reader_mode == "azure-di":
readers[".pdf"] = azure_reader
dev_readers, _, _ = dev_settings()
readers.update(dev_readers)
return readers
@classmethod
def get_user_settings(cls):
return {
"reader_mode": {
"name": "File loader",
"value": "default",
"choices": [
("Default (open-source)", "default"),
("Adobe API (figure+table extraction)", "adobe"),
(
"Azure AI Document Intelligence (figure+table extraction)",
"azure-di",
),
],
"component": "dropdown",
},
}
@classmethod
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
use_quick_index_mode = user_settings.get("quick_index_mode", False)
print("use_quick_index_mode", use_quick_index_mode)
obj = cls(
embedding=embedding_models_manager[
index_settings.get(
"embedding", embedding_models_manager.get_default_name()
)
]
],
run_embedding_in_thread=use_quick_index_mode,
reader_mode=user_settings.get("reader_mode", "default"),
)
return obj
@@ -497,16 +661,17 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
Can subclass this method for a more elaborate pipeline routing strategy.
"""
readers, chunk_size, chunk_overlap = dev_settings()
_, chunk_size, chunk_overlap = dev_settings()
ext = file_path.suffix
reader = readers.get(ext, KH_DEFAULT_FILE_EXTRACTORS.get(ext, None))
ext = file_path.suffix.lower()
reader = self.readers.get(ext, unstructured)
if reader is None:
raise NotImplementedError(
f"No supported pipeline to index {file_path.name}. Please specify "
"the suitable pipeline for this file type in the settings."
)
print("Using reader", reader)
pipeline: IndexPipeline = IndexPipeline(
loader=reader,
splitter=TokenSplitter(
@@ -515,50 +680,37 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
separator="\n\n",
backup_separators=["\n", ".", "\u200B"],
),
run_embedding_in_thread=self.run_embedding_in_thread,
Source=self.Source,
Index=self.Index,
VS=self.VS,
DS=self.DS,
FSPath=self.FSPath,
user_id=self.user_id,
private=self.private,
embedding=self.embedding,
)
return pipeline
def run(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
self, file_paths: str | Path | list[str | Path], *args, **kwargs
) -> tuple[list[str | None], list[str | None]]:
"""Return a list of indexed file ids, and a list of errors"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
file_ids: list[str | None] = []
errors: list[str | None] = []
for file_path in file_paths:
file_path = Path(file_path)
try:
pipeline = self.route(file_path)
file_id = pipeline.run(file_path, reindex=reindex, **kwargs)
file_ids.append(file_id)
errors.append(None)
except Exception as e:
logger.error(e)
file_ids.append(None)
errors.append(str(e))
return file_ids, errors
raise NotImplementedError
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
"""Return a list of indexed file ids, and a list of errors"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
file_ids: list[str | None] = []
errors: list[str | None] = []
all_docs = []
n_files = len(file_paths)
for idx, file_path in enumerate(file_paths):
file_path = Path(file_path)
@@ -569,9 +721,10 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
try:
pipeline = self.route(file_path)
file_id = yield from pipeline.stream(
file_id, docs = yield from pipeline.stream(
file_path, reindex=reindex, **kwargs
)
all_docs.extend(docs)
file_ids.append(file_id)
errors.append(None)
yield Document(
@@ -579,7 +732,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
channel="index",
)
except Exception as e:
logger.error(e)
logger.exception(e)
file_ids.append(None)
errors.append(str(e))
yield Document(
@@ -591,4 +744,4 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
channel="index",
)
return file_ids, errors
return file_ids, errors, all_docs

View File

@@ -1,5 +1,9 @@
import html
import os
import shutil
import tempfile
import zipfile
from copy import deepcopy
from pathlib import Path
from typing import Generator
@@ -9,8 +13,12 @@ from gradio.data_classes import FileData
from gradio.utils import NamedString
from ktem.app import BasePage
from ktem.db.engine import engine
from ktem.utils.render import Render
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
DOWNLOAD_MESSAGE = "Press again to download"
class File(gr.File):
@@ -143,28 +151,57 @@ class FileIndexPage(BasePage):
)
gr.Markdown("## File List")
self.filter = gr.Textbox(
value="",
label="Filter by name:",
info=(
"(1) Case-insensitive. "
"(2) Search with empty string to show all files."
),
)
self.file_list_state = gr.State(value=None)
self.file_list = gr.DataFrame(
headers=["id", "name", "size", "text_length", "date_created"],
headers=[
"id",
"name",
"size",
"tokens",
"loader",
"date_created",
],
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
interactive=False,
wrap=False,
elem_id="file_list_view",
)
with gr.Row():
self.deselect_button = gr.Button(
"Close",
visible=False,
)
self.delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
)
with gr.Row():
self.is_zipped_state = gr.State(value=False)
self.download_all_button = gr.DownloadButton(
"Download all files",
visible=True,
)
self.download_single_button = gr.DownloadButton(
"Download file",
visible=False,
)
with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
with gr.Column(scale=2):
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button(
"Deselect",
visible=False,
elem_classes=["right-button"],
)
self.delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
elem_classes=["right-button"],
)
self.chunks = gr.HTML(visible=False)
def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app"""
@@ -189,12 +226,58 @@ class FileIndexPage(BasePage):
)
def file_selected(self, file_id):
chunks = []
if file_id is not None:
# get the chunks
Index = self._index._resources["Index"]
with Session(engine) as session:
matches = session.execute(
select(Index).where(
Index.source_id == file_id,
Index.relation_type == "document",
)
)
doc_ids = [doc.target_id for (doc,) in matches]
docs = self._index._docstore.get(doc_ids)
docs = sorted(
docs, key=lambda x: x.metadata.get("page_label", float("inf"))
)
for idx, doc in enumerate(docs):
title = html.escape(
f"{doc.text[:50]}..." if len(doc.text) > 50 else doc.text
)
doc_type = doc.metadata.get("type", "text")
content = ""
if doc_type == "text":
content = html.escape(doc.text)
elif doc_type == "table":
content = Render.table(doc.text)
elif doc_type == "image":
content = Render.image(
url=doc.metadata.get("image_origin", ""), text=doc.text
)
header_prefix = f"[{idx+1}/{len(docs)}]"
if doc.metadata.get("page_label"):
header_prefix += f" [Page {doc.metadata['page_label']}]"
chunks.append(
Render.collapsible(
header=f"{header_prefix} {title}",
content=content,
)
)
return (
gr.update(value="".join(chunks), visible=file_id is not None),
gr.update(visible=file_id is not None),
gr.update(visible=file_id is not None),
gr.update(visible=file_id is not None),
)
def delete_event(self, file_id):
file_name = ""
with Session(engine) as session:
source = session.execute(
select(self._index._resources["Source"]).where(
@@ -202,6 +285,7 @@ class FileIndexPage(BasePage):
)
).first()
if source:
file_name = source[0].name
session.delete(source[0])
vs_ids, ds_ids = [], []
@@ -213,15 +297,16 @@ class FileIndexPage(BasePage):
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
elif each[0].relation_type == "document":
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self._index._vs.delete(vs_ids)
if vs_ids:
self._index._vs.delete(vs_ids)
self._index._docstore.delete(ds_ids)
gr.Info(f"File {file_id} has been deleted")
gr.Info(f"File {file_name} has been deleted")
return None, self.selected_panel_false
@@ -231,6 +316,57 @@ class FileIndexPage(BasePage):
gr.update(visible=False),
)
def download_single_file(self, is_zipped_state, file_id):
with Session(engine) as session:
source = session.execute(
select(self._index._resources["Source"]).where(
self._index._resources["Source"].id == file_id
)
).first()
if source:
target_file_name = Path(source[0].name)
zip_files = []
for file_name in os.listdir(flowsettings.KH_CHUNKS_OUTPUT_DIR):
if target_file_name.stem in file_name:
zip_files.append(
os.path.join(flowsettings.KH_CHUNKS_OUTPUT_DIR, file_name)
)
for file_name in os.listdir(flowsettings.KH_MARKDOWN_OUTPUT_DIR):
if target_file_name.stem in file_name:
zip_files.append(
os.path.join(flowsettings.KH_MARKDOWN_OUTPUT_DIR, file_name)
)
zip_file_path = os.path.join(
flowsettings.KH_ZIP_OUTPUT_DIR, target_file_name.stem
)
with zipfile.ZipFile(f"{zip_file_path}.zip", "w") as zipMe:
for file in zip_files:
zipMe.write(file, arcname=os.path.basename(file))
if is_zipped_state:
new_button = gr.DownloadButton(label="Download", value=None)
else:
new_button = gr.DownloadButton(
label=DOWNLOAD_MESSAGE, value=f"{zip_file_path}.zip"
)
return not is_zipped_state, new_button
def download_all_files(self):
zip_files = []
for file_name in os.listdir(flowsettings.KH_CHUNKS_OUTPUT_DIR):
zip_files.append(os.path.join(flowsettings.KH_CHUNKS_OUTPUT_DIR, file_name))
for file_name in os.listdir(flowsettings.KH_MARKDOWN_OUTPUT_DIR):
zip_files.append(
os.path.join(flowsettings.KH_MARKDOWN_OUTPUT_DIR, file_name)
)
zip_file_path = os.path.join(flowsettings.KH_ZIP_OUTPUT_DIR, "all")
with zipfile.ZipFile(f"{zip_file_path}.zip", "w") as zipMe:
for file in zip_files:
arcname = Path(file)
zipMe.write(file, arcname=arcname.name)
return gr.DownloadButton(label=DOWNLOAD_MESSAGE, value=f"{zip_file_path}.zip")
def on_register_events(self):
"""Register all events to the app"""
onDeleted = (
@@ -241,35 +377,61 @@ class FileIndexPage(BasePage):
)
.then(
fn=lambda: (None, self.selected_panel_false),
inputs=None,
inputs=[],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
.then(
fn=self.list_file,
inputs=[self._app.user_id],
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
)
.then(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.chunks,
self.deselect_button,
self.delete_button,
self.download_single_button,
],
show_progress="hidden",
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onDeleted = onDeleted.then(**event)
self.deselect_button.click(
fn=lambda: (None, self.selected_panel_false),
inputs=None,
inputs=[],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
self.selected_panel.change(
).then(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.chunks,
self.deselect_button,
self.delete_button,
self.download_single_button,
],
show_progress="hidden",
)
self.download_all_button.click(
fn=self.download_all_files,
inputs=[],
outputs=self.download_all_button,
show_progress="hidden",
)
self.download_single_button.click(
fn=self.download_single_file,
inputs=[self.is_zipped_state, self.selected_file_id],
outputs=[self.is_zipped_state, self.download_single_button],
show_progress="hidden",
)
onUploaded = self.upload_button.click(
fn=lambda: gr.update(visible=True),
outputs=[self.upload_progress_panel],
@@ -285,9 +447,63 @@ class FileIndexPage(BasePage):
concurrency_limit=20,
)
try:
# quick file upload event registration of first Index only
if self._index.id == 1:
self.quick_upload_state = gr.State(value=[])
print("Setting up quick upload event")
quickUploadedEvent = (
self._app.chat_page.quick_file_upload.upload(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_with_default_loaders,
inputs=[
self._app.chat_page.quick_file_upload,
gr.State(value=False),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_file_upload,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickUploadedEvent = quickUploadedEvent.then(**event)
quickUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
).then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
).then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
except Exception as e:
print(e)
uploadedEvent = onUploaded.then(
fn=self.list_file,
inputs=[self._app.user_id],
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
@@ -309,16 +525,64 @@ class FileIndexPage(BasePage):
inputs=[self.file_list],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
).then(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.chunks,
self.deselect_button,
self.delete_button,
self.download_single_button,
],
show_progress="hidden",
)
self.filter.submit(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
show_progress="hidden",
)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_file,
inputs=[self._app.user_id],
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
)
def _may_extract_zip(self, files, zip_dir: str):
"""Handle zip files"""
zip_files = [file for file in files if file.endswith(".zip")]
remaining_files = [file for file in files if not file.endswith("zip")]
# Clean-up <zip_dir> before unzip to remove old files
shutil.rmtree(zip_dir, ignore_errors=True)
for zip_file in zip_files:
# Prepare new zip output dir, separated for each files
basename = os.path.splitext(os.path.basename(zip_file))[0]
zip_out_dir = os.path.join(zip_dir, basename)
os.makedirs(zip_out_dir, exist_ok=True)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(zip_out_dir)
n_zip_file = 0
for root, dirs, files in os.walk(zip_dir):
for file in files:
ext = os.path.splitext(file)[1]
# only allow supported file-types ( not zip )
if ext not in [".zip"] and ext in self._supported_file_types:
remaining_files += [os.path.join(root, file)]
n_zip_file += 1
if n_zip_file > 0:
print(f"Update zip files: {n_zip_file}")
return remaining_files
def index_fn(
self, files, reindex: bool, settings, user_id
) -> Generator[tuple[str, str], None, None]:
@@ -335,6 +599,8 @@ class FileIndexPage(BasePage):
yield "", ""
return
files = self._may_extract_zip(files, flowsettings.KH_ZIP_INPUT_DIR)
errors = self.validate(files)
if errors:
gr.Warning(", ".join(errors))
@@ -366,19 +632,61 @@ class FileIndexPage(BasePage):
debugs.append(response.text)
yield "\n".join(outputs), "\n".join(debugs)
except StopIteration as e:
result, errors = e.value
results, index_errors, docs = e.value
except Exception as e:
debugs.append(f"Error: {e}")
yield "\n".join(outputs), "\n".join(debugs)
return
n_successes = len([_ for _ in result if _])
n_successes = len([_ for _ in results if _])
if n_successes:
gr.Info(f"Successfully index {n_successes} files")
n_errors = len([_ for _ in errors if _])
if n_errors:
gr.Warning(f"Have errors for {n_errors} files")
return results
def index_fn_with_default_loaders(
self, files, reindex: bool, settings, user_id
) -> list["str"]:
"""Function for quick upload with default loaders
Args:
files: the list of files to be uploaded
reindex: whether to reindex the files
selected_files: the list of files already selected
settings: the settings of the app
"""
print("Overriding with default loaders")
exist_ids = []
to_process_files = []
for str_file_path in files:
file_path = Path(str(str_file_path))
exist_id = (
self._index.get_indexing_pipeline(settings, user_id)
.route(file_path)
.get_id_if_exists(file_path)
)
if exist_id:
exist_ids.append(exist_id)
else:
to_process_files.append(str_file_path)
returned_ids = []
settings = deepcopy(settings)
settings[f"index.options.{self._index.id}.reader_mode"] = "default"
settings[f"index.options.{self._index.id}.quick_index_mode"] = True
if to_process_files:
_iter = self.index_fn(to_process_files, reindex, settings, user_id)
try:
while next(_iter):
pass
except StopIteration as e:
returned_ids = e.value
return exist_ids + returned_ids
def index_files_from_dir(
self, folder_path, reindex, settings, user_id
) -> Generator[tuple[str, str], None, None]:
@@ -452,7 +760,19 @@ class FileIndexPage(BasePage):
yield from self.index_fn(files, reindex, settings, user_id)
def list_file(self, user_id):
def format_size_human_readable(self, num: float | str, suffix="B"):
try:
num = float(num)
except ValueError:
return num
for unit in ("", "K", "M", "G", "T", "P", "E", "Z"):
if abs(num) < 1024.0:
return f"{num:3.0f}{unit}{suffix}"
num /= 1024.0
return f"{num:.0f}Yi{suffix}"
def list_file(self, user_id, name_pattern=""):
if user_id is None:
# not signed in
return [], pd.DataFrame.from_records(
@@ -461,7 +781,8 @@ class FileIndexPage(BasePage):
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"tokens": "-",
"loader": "-",
"date_created": "-",
}
]
@@ -472,12 +793,17 @@ class FileIndexPage(BasePage):
statement = select(Source)
if self._index.config.get("private", False):
statement = statement.where(Source.user == user_id)
if name_pattern:
statement = statement.where(Source.name.ilike(f"%{name_pattern}%"))
results = [
{
"id": each[0].id,
"name": each[0].name,
"size": each[0].size,
"text_length": each[0].text_length,
"size": self.format_size_human_readable(each[0].size),
"tokens": self.format_size_human_readable(
each[0].note.get("tokens", "-"), suffix=""
),
"loader": each[0].note.get("loader", "-"),
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
}
for each in session.execute(statement).all()
@@ -492,12 +818,14 @@ class FileIndexPage(BasePage):
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"tokens": "-",
"loader": "-",
"date_created": "-",
}
]
)
print(f"{len(results)=}, {len(file_list)=}")
return results, file_list
def interact_file_list(self, list_files, ev: gr.SelectData):
@@ -561,9 +889,8 @@ class FileSelector(BasePage):
self.mode = gr.Radio(
value=default_mode,
choices=[
("Disabled", "disabled"),
("Search All", "all"),
("Select", "select"),
("Search In File(s)", "select"),
],
container=False,
)

View File

@@ -123,8 +123,11 @@ class IndexManager:
)
try:
# clean up
index.on_delete()
try:
# clean up
index.on_delete()
except Exception as e:
print(f"Error while deleting index {index.name}: {e}")
# remove from database
with Session(engine) as sess:

View File

@@ -7,6 +7,21 @@ from ktem.utils.file import YAMLNoDateSafeLoader
from .manager import IndexManager
# UGLY way to restart gradio server by updating atime
def update_current_module_atime():
import os
import time
# Define the file path
file_path = __file__
print("Updating atime for", file_path)
# Get the current time
current_time = time.time()
# Set the modified time (and access time) to the current time
os.utime(file_path, (current_time, current_time))
def format_description(cls):
user_settings = cls.get_admin_settings()
params_lines = ["| Name | Default | Description |", "| --- | --- | --- |"]
@@ -29,7 +44,7 @@ class IndexManagement(BasePage):
def on_building_ui(self):
with gr.Tab(label="View"):
self.index_list = gr.DataFrame(
headers=["ID", "Name", "Index Type"],
headers=["id", "name", "index type"],
interactive=False,
)
@@ -95,7 +110,7 @@ class IndexManagement(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_indices,
inputs=None,
inputs=[],
outputs=[self.index_list],
)
self._app.app.load(
@@ -117,7 +132,7 @@ class IndexManagement(BasePage):
self.create_index,
inputs=[self.name, self.index_type, self.spec],
outputs=None,
).success(self.list_indices, inputs=None, outputs=[self.index_list]).success(
).success(self.list_indices, inputs=[], outputs=[self.index_list]).success(
lambda: ("", None, "", self.spec_desc_default),
outputs=[
self.name,
@@ -125,6 +140,8 @@ class IndexManagement(BasePage):
self.spec,
self.spec_desc,
],
).success(
update_current_module_atime
)
self.index_list.select(
self.select_index,
@@ -152,7 +169,7 @@ class IndexManagement(BasePage):
gr.update(visible=False),
gr.update(visible=True),
),
inputs=None,
inputs=[],
outputs=[
self.btn_edit_save,
self.btn_delete,
@@ -166,10 +183,8 @@ class IndexManagement(BasePage):
inputs=[self.selected_index_id],
outputs=[self.selected_index_id],
show_progress="hidden",
).then(
self.list_indices,
inputs=None,
outputs=[self.index_list],
).then(self.list_indices, inputs=[], outputs=[self.index_list],).success(
update_current_module_atime
)
self.btn_delete_no.click(
lambda: (
@@ -178,7 +193,7 @@ class IndexManagement(BasePage):
gr.update(visible=True),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[
self.btn_edit_save,
self.btn_delete,
@@ -197,7 +212,7 @@ class IndexManagement(BasePage):
show_progress="hidden",
).then(
self.list_indices,
inputs=None,
inputs=[],
outputs=[self.index_list],
)
self.btn_close.click(
@@ -245,16 +260,16 @@ class IndexManagement(BasePage):
items = []
for item in self.manager.indices:
record = {}
record["ID"] = item.id
record["Name"] = item.name
record["Index Type"] = item.__class__.__name__
record["id"] = item.id
record["name"] = item.name
record["index type"] = item.__class__.__name__
items.append(record)
if items:
indices_list = pd.DataFrame.from_records(items)
else:
indices_list = pd.DataFrame.from_records(
[{"ID": "-", "Name": "-", "Index Type": "-"}]
[{"id": "-", "name": "-", "index type": "-"}]
)
return indices_list
@@ -268,7 +283,7 @@ class IndexManagement(BasePage):
if not ev.selected:
return -1
return int(index_list["ID"][ev.index[0]])
return int(index_list["id"][ev.index[0]])
def on_selected_index_change(self, selected_index_id: int):
"""Show the relevant index as user selects it on the UI

View File

@@ -3,7 +3,7 @@ from typing import Optional, Type, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from theflow.utils.modules import deserialize, import_dotted_string
from kotaemon.llms import ChatLLM
@@ -38,7 +38,7 @@ class LLMManager:
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
self._models, self._info, self._default = {}, {}, ""
with Session(engine) as session:
stmt = select(LLMTable)
items = session.execute(stmt)
@@ -54,14 +54,12 @@ class LLMManager:
self._default = item.name
def load_vendors(self):
from kotaemon.llms import (
AzureChatOpenAI,
ChatOpenAI,
EndpointChatLLM,
LlamaCppChat,
)
from kotaemon.llms import AzureChatOpenAI, ChatOpenAI, LlamaCppChat
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat]
for extra_vendor in getattr(flowsettings, "KH_LLM_EXTRA_VENDORS", []):
self._vendors.append(import_dotted_string(extra_vendor, safe=False))
def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name"""

View File

@@ -112,7 +112,7 @@ class LLMManagement(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_llms,
inputs=None,
inputs=[],
outputs=[self.llm_list],
)
self._app.app.load(
@@ -140,8 +140,8 @@ class LLMManagement(BasePage):
self.btn_new.click(
self.create_llm,
inputs=[self.name, self.llm_choices, self.spec, self.default],
outputs=None,
).success(self.list_llms, inputs=None, outputs=[self.llm_list]).success(
outputs=[],
).success(self.list_llms, inputs=[], outputs=[self.llm_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
@@ -176,7 +176,7 @@ class LLMManagement(BasePage):
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -187,7 +187,7 @@ class LLMManagement(BasePage):
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
inputs=[],
outputs=[self.llm_list],
)
self.btn_delete_no.click(
@@ -196,7 +196,7 @@ class LLMManagement(BasePage):
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -210,7 +210,7 @@ class LLMManagement(BasePage):
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
inputs=[],
outputs=[self.llm_list],
)
self.btn_close.click(

View File

@@ -44,7 +44,7 @@ class App(BaseApp):
if len(self.index_manager.indices) == 1:
for index in self.index_manager.indices:
with gr.Tab(
f"{index.name} Index",
f"{index.name}",
elem_id="indices-tab",
elem_classes=[
"fill-main-area-height",
@@ -58,7 +58,7 @@ class App(BaseApp):
setattr(self, f"_index_{index.id}", page)
elif len(self.index_manager.indices) > 1:
with gr.Tab(
"Indices",
"Files",
elem_id="indices-tab",
elem_classes=["fill-main-area-height", "scrollable", "indices-tab"],
id="indices-tab",
@@ -66,7 +66,7 @@ class App(BaseApp):
) as self._tabs["indices-tab"]:
for index in self.index_manager.indices:
with gr.Tab(
f"{index.name}",
f"{index.name} Collection",
elem_id=f"{index.id}-tab",
) as self._tabs[f"{index.id}-tab"]:
page = index.get_index_page_ui()

View File

@@ -1,15 +1,25 @@
import asyncio
import csv
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional
import gradio as gr
from filelock import FileLock
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline,
)
from plotly.io import from_json
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
@@ -17,23 +27,49 @@ from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
DEFAULT_SETTING = "(default)"
INFO_PANEL_SCALES = {True: 8, False: 4}
pdfview_js = """
function() {
// Get all links and attach click event
var links = document.getElementsByClassName("pdf-link");
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
}
"""
class ChatPage(BasePage):
def __init__(self, app):
self._app = app
self._indices_input = []
self.on_building_ui()
self._reasoning_type = gr.State(value=None)
self._llm_type = gr.State(value=None)
self._conversation_renamed = gr.State(value=False)
self.info_panel_expanded = gr.State(value=True)
def on_building_ui(self):
with gr.Row():
self.chat_state = gr.State(STATE)
with gr.Column(scale=1, elem_id="conv-settings-panel"):
self.state_chat = gr.State(STATE)
self.state_retrieval_history = gr.State([])
self.state_chat_history = gr.State([])
self.state_plot_history = gr.State([])
self.state_settings = gr.State({})
self.state_info_panel = gr.State("")
self.state_plot_panel = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion = ChatSuggestion(self._app)
for index in self._app.index_manager.indices:
for index_id, index in enumerate(self._app.index_manager.indices):
index.selector = None
index_ui = index.get_selector_component_ui()
if not index_ui:
@@ -41,7 +77,9 @@ class ChatPage(BasePage):
continue
index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=f"{index.name} Index", open=True):
with gr.Accordion(
label=f"{index.name} Collection", open=index_id < 1
):
index_ui.render()
gr_index = index_ui.as_gradio_component()
if gr_index:
@@ -60,14 +98,66 @@ class ChatPage(BasePage):
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
if len(self._app.index_manager.indices) > 0:
with gr.Accordion(label="Quick Upload") as _:
self.quick_file_upload = File(
file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()),
file_count="multiple",
container=True,
show_label=False,
)
self.quick_file_upload_status = gr.Markdown()
self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6, elem_id="chat-area"):
self.chat_panel = ChatPanel(self._app)
with gr.Column(scale=3, elem_id="chat-info-panel"):
with gr.Row():
with gr.Accordion(label="Chat settings", open=False):
# a quick switch for reasoning type option
with gr.Row():
gr.HTML("Reasoning method")
gr.HTML("Model")
with gr.Row():
reasoning_type_values = [
(DEFAULT_SETTING, DEFAULT_SETTING)
] + self._app.default_settings.reasoning.settings[
"use"
].choices
self.reasoning_types = gr.Dropdown(
choices=reasoning_type_values,
value=DEFAULT_SETTING,
container=False,
show_label=False,
)
self.model_types = gr.Dropdown(
choices=self._app.default_settings.reasoning.options[
"simple"
]
.settings["llm"]
.choices,
value="",
container=False,
show_label=False,
)
with gr.Column(
scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel"
) as self.info_column:
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.HTML()
self.modal = gr.HTML("<div id='pdf-modal'></div>")
self.plot_panel = gr.Plot(visible=False)
self.info_panel = gr.HTML(elem_id="html-info-panel")
def _json_to_plot(self, json_dict: dict | None):
if json_dict:
plot = from_json(json_dict)
plot = gr.update(visible=True, value=plot)
else:
plot = gr.update(visible=False)
return plot
def on_register_events(self):
gr.on(
@@ -98,27 +188,75 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
self._reasoning_type,
self._llm_type,
self.state_chat,
self._app.user_id,
]
+ self._indices_input,
outputs=[
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.plot_panel,
self.state_plot_panel,
self.state_chat,
],
concurrency_limit=20,
show_progress="minimal",
).success(
fn=self.backup_original_info,
inputs=[
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
],
outputs=[
self.state_chat_history,
self.state_settings,
self.state_info_panel,
],
).then(
fn=self.update_data_source,
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.chat_state,
self.state_chat,
]
+ self._indices_input,
outputs=None,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
outputs=[
self.chat_control.conversation_rn,
self._conversation_renamed,
],
).success(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self._conversation_renamed,
self._app.user_id,
],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
self.chat_panel.regen_btn.click(
@@ -127,33 +265,90 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
self._reasoning_type,
self._llm_type,
self.state_chat,
self._app.user_id,
]
+ self._indices_input,
outputs=[
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.plot_panel,
self.state_plot_panel,
self.state_chat,
],
concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.chat_state,
self.state_chat,
]
+ self._indices_input,
outputs=None,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
outputs=[
self.chat_control.conversation_rn,
self._conversation_renamed,
],
).success(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self._conversation_renamed,
self._app.user_id,
],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
self.chat_control.btn_info_expand.click(
fn=lambda is_expanded: (
gr.update(scale=INFO_PANEL_SCALES[is_expanded]),
not is_expanded,
),
inputs=self.info_panel_expanded,
outputs=[self.info_column, self.info_panel_expanded],
)
self.chat_panel.chatbot.like(
fn=self.is_liked,
inputs=[self.chat_control.conversation_id],
outputs=None,
).success(
self.save_log,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
self.state_settings,
self.state_info_panel,
gr.State(getattr(flowsettings, "KH_APP_DATA_DIR", "logs")),
],
outputs=None,
)
self.chat_control.btn_new.click(
@@ -163,17 +358,25 @@ class ChatPage(BasePage):
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
inputs=[self.chat_control.conversation, self._app.user_id],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
show_progress="hidden",
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
)
self.chat_control.btn_del.click(
@@ -188,17 +391,25 @@ class ChatPage(BasePage):
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
inputs=[self.chat_control.conversation, self._app.user_id],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
show_progress="hidden",
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
@@ -207,33 +418,80 @@ class ChatPage(BasePage):
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.conversation_rn_btn.click(
self.chat_control.btn_conversation_rn.click(
lambda: gr.update(visible=True),
outputs=[
self.chat_control.conversation_rn,
],
)
self.chat_control.conversation_rn.submit(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
gr.State(value=True),
self._app.user_id,
],
outputs=[self.chat_control.conversation, self.chat_control.conversation],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
show_progress="hidden",
)
self.chat_control.conversation.select(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
inputs=[self.chat_control.conversation, self._app.user_id],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
show_progress="hidden",
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
# evidence display on message selection
self.chat_panel.chatbot.select(
self.message_selected,
inputs=[
self.state_retrieval_history,
self.state_plot_history,
],
outputs=[
self.info_panel,
self.state_plot_panel,
],
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
self.chat_control.cb_is_public.change(
self.on_set_public_conversation,
inputs=[self.chat_control.cb_is_public, self.chat_control.conversation],
outputs=None,
show_progress="hidden",
)
self.report_issue.report_btn.click(
@@ -247,11 +505,26 @@ class ChatPage(BasePage):
self._app.settings_state,
self._app.user_id,
self.info_panel,
self.chat_state,
self.state_chat,
]
+ self._indices_input,
outputs=None,
)
self.reasoning_types.change(
self.reasoning_changed,
inputs=[self.reasoning_types],
outputs=[self._reasoning_type],
)
self.model_types.change(
lambda x: x,
inputs=[self.model_types],
outputs=[self._llm_type],
)
self.chat_control.conversation_id.change(
lambda: gr.update(visible=False),
outputs=self.plot_panel,
)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion.example.select(
self.chat_suggestion.select_example,
@@ -291,6 +564,28 @@ class ChatPage(BasePage):
else:
return gr.update(visible=True), gr.update(visible=False)
def on_set_public_conversation(self, is_public, convo_id):
if not convo_id:
gr.Warning("No conversation selected")
return
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
result = session.exec(statement).one()
name = result.name
if result.is_public != is_public:
# Only trigger updating when user
# select different value from the current
result.is_public = is_public
session.add(result)
session.commit()
gr.Info(
f"Conversation: {name} is {'public' if is_public else 'private'}."
)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
@@ -306,25 +601,53 @@ class ChatPage(BasePage):
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": lambda: self.chat_control.select_conv(""),
"fn": lambda: self.chat_control.select_conv("", None),
"outputs": [
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
]
+ self._indices_input,
"show_progress": "hidden",
},
)
def update_data_source(self, convo_id, messages, state, *selecteds):
def persist_data_source(
self,
convo_id,
user_id,
retrieval_msg,
plot_data,
retrival_history,
plot_history,
messages,
state,
*selecteds,
):
"""Update the data source"""
if not convo_id:
gr.Warning("No conversation selected")
return
# if not regen, then append the new message
if not state["app"].get("regen", False):
retrival_history = retrival_history + [retrieval_msg]
plot_history = plot_history + [plot_data]
else:
if retrival_history:
print("Updating retrieval history (regen=True)")
retrival_history[-1] = retrieval_msg
plot_history[-1] = plot_data
# reset regen state
state["app"]["regen"] = False
selecteds_ = {}
for index in self._app.index_manager.indices:
if index.selector is None:
@@ -339,15 +662,29 @@ class ChatPage(BasePage):
result = session.exec(statement).one()
data_source = result.data_source
old_selecteds = data_source.get("selected", {})
is_owner = result.user == user_id
# Write down to db
result.data_source = {
"selected": selecteds_,
"selected": selecteds_ if is_owner else old_selecteds,
"messages": messages,
"retrieval_messages": retrival_history,
"plot_history": plot_history,
"state": state,
"likes": deepcopy(data_source.get("likes", [])),
}
session.add(result)
session.commit()
return retrival_history, plot_history
def reasoning_changed(self, reasoning_type):
if reasoning_type != DEFAULT_SETTING:
# override app settings state (temporary)
gr.Info("Reasoning type changed to `{}`".format(reasoning_type))
return reasoning_type
def is_liked(self, convo_id, liked: gr.LikeData):
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
@@ -362,7 +699,19 @@ class ChatPage(BasePage):
session.add(result)
session.commit()
def create_pipeline(self, settings: dict, state: dict, user_id: int, *selecteds):
def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData):
index = msg.index[0]
return retrieval_history[index], plot_history[index]
def create_pipeline(
self,
settings: dict,
session_reasoning_type: str,
session_llm: str,
state: dict,
user_id: int,
*selecteds,
):
"""Create the pipeline from settings
Args:
@@ -374,10 +723,23 @@ class ChatPage(BasePage):
Returns:
- the pipeline objects
"""
reasoning_mode = settings["reasoning.use"]
# override reasoning_mode by temporary chat page state
print("Session reasoning type", session_reasoning_type)
print("Session LLM", session_llm)
reasoning_mode = (
settings["reasoning.use"]
if session_reasoning_type in (DEFAULT_SETTING, None)
else session_reasoning_type
)
reasoning_cls = reasonings[reasoning_mode]
print("Reasoning class", reasoning_cls)
reasoning_id = reasoning_cls.get_info()["id"]
settings = deepcopy(settings)
llm_setting_key = f"reasoning.options.{reasoning_id}.llm"
if llm_setting_key in settings and session_llm not in (DEFAULT_SETTING, None):
settings[llm_setting_key] = session_llm
# get retrievers
retrievers = []
for index in self._app.index_manager.indices:
@@ -403,7 +765,15 @@ class ChatPage(BasePage):
return pipeline, reasoning_state
def chat_fn(
self, conversation_id, chat_history, settings, state, user_id, *selecteds
self,
conversation_id,
chat_history,
settings,
reasoning_type,
llm_type,
state,
user_id,
*selecteds,
):
"""Chat function"""
chat_input = chat_history[-1][0]
@@ -413,18 +783,23 @@ class ChatPage(BasePage):
# construct the pipeline
pipeline, reasoning_state = self.create_pipeline(
settings, state, user_id, *selecteds
settings, reasoning_type, llm_type, state, user_id, *selecteds
)
print("Reasoning state", reasoning_state)
pipeline.set_output_queue(queue)
text, refs = "", ""
text, refs, plot, plot_gr = "", "", None, gr.update(visible=False)
msg_placeholder = getattr(
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
)
print(msg_placeholder)
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
len_ref = -1 # for logging purpose
yield (
chat_history + [(chat_input, text or msg_placeholder)],
refs,
plot_gr,
plot,
state,
)
for response in pipeline.stream(chat_input, conversation_id, chat_history):
@@ -446,22 +821,42 @@ class ChatPage(BasePage):
else:
refs += response.content
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
if response.channel == "plot":
plot = response.content
plot_gr = self._json_to_plot(plot)
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
yield (
chat_history + [(chat_input, text or msg_placeholder)],
refs,
plot_gr,
plot,
state,
)
if not text:
empty_msg = getattr(
flowsettings, "KH_CHAT_EMPTY_MSG_PLACEHOLDER", "(Sorry, I don't know)"
)
print(f"Generate nothing: {empty_msg}")
yield chat_history + [(chat_input, text or empty_msg)], refs, state
yield (
chat_history + [(chat_input, text or empty_msg)],
refs,
plot_gr,
plot,
state,
)
def regen_fn(
self, conversation_id, chat_history, settings, state, user_id, *selecteds
self,
conversation_id,
chat_history,
settings,
reasoning_type,
llm_type,
state,
user_id,
*selecteds,
):
"""Regen function"""
if not chat_history:
@@ -470,11 +865,119 @@ class ChatPage(BasePage):
return
state["app"]["regen"] = True
for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, user_id, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
yield from self.chat_fn(
conversation_id,
chat_history,
settings,
reasoning_type,
llm_type,
state,
user_id,
*selecteds,
)
state["app"]["regen"] = False
def check_and_suggest_name_conv(self, chat_history):
suggest_pipeline = SuggestConvNamePipeline()
new_name = gr.update()
renamed = False
# check if this is a newly created conversation
if len(chat_history) == 1:
suggested_name = suggest_pipeline(chat_history).text[:40]
new_name = gr.update(value=suggested_name)
renamed = True
return new_name, renamed
def backup_original_info(
self, chat_history, settings, info_pannel, original_chat_history
):
original_chat_history.append(chat_history[-1])
return original_chat_history, settings, info_pannel
def save_log(
self,
conversation_id,
chat_history,
settings,
info_panel,
original_chat_history,
original_settings,
original_info_panel,
log_dir,
):
if not Path(log_dir).exists():
Path(log_dir).mkdir(parents=True)
lock = FileLock(Path(log_dir) / ".lock")
# get current date
today = datetime.now()
formatted_date = today.strftime("%d%m%Y_%H")
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
if not likes:
return
feedback = likes[-1][-1]
message_index = likes[-1][0]
current_message = chat_history[message_index[0]]
original_message = original_chat_history[message_index[0]]
is_original = all(
[
current_item == original_item
for current_item, original_item in zip(
current_message, original_message
)
]
)
dataframe = [
[
conversation_id,
message_index,
current_message[0],
current_message[1],
chat_history,
settings,
info_panel,
feedback,
is_original,
original_message[1],
original_chat_history,
original_settings,
original_info_panel,
]
]
with lock:
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
is_log_file_exist = log_file.is_file()
with open(log_file, "a") as f:
writer = csv.writer(f)
# write headers
if not is_log_file_exist:
writer.writerow(
[
"Conversation ID",
"Message ID",
"Question",
"Answer",
"Chat History",
"Settings",
"Evidences",
"Feedback",
"Original/ Rewritten",
"Original Answer",
"Original Chat History",
"Original Settings",
"Original Evidences",
]
)
writer.writerows(dataframe)

View File

@@ -1,13 +1,20 @@
import logging
import os
import gradio as gr
from ktem.app import BasePage
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from ktem.db.models import Conversation, User, engine
from sqlmodel import Session, or_, select
import flowsettings
from ...utils.conversation import sync_retrieval_n_message
from .common import STATE
logger = logging.getLogger(__name__)
ASSETS_DIR = "assets/icons"
if not os.path.isdir(ASSETS_DIR):
ASSETS_DIR = "libs/ktem/ktem/assets/icons"
def is_conv_name_valid(name):
@@ -35,14 +42,47 @@ class ConversationControl(BasePage):
label="Chat sessions",
choices=[],
container=False,
filterable=False,
filterable=True,
interactive=True,
elem_classes=["unset-overflow"],
)
with gr.Row() as self._new_delete:
self.btn_new = gr.Button(value="New", min_width=10, variant="primary")
self.btn_del = gr.Button(value="Delete", min_width=10, variant="stop")
self.btn_new = gr.Button(
value="",
icon=f"{ASSETS_DIR}/new.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.btn_del = gr.Button(
value="",
icon=f"{ASSETS_DIR}/delete.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.btn_conversation_rn = gr.Button(
value="",
icon=f"{ASSETS_DIR}/rename.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.btn_info_expand = gr.Button(
value="",
icon=f"{ASSETS_DIR}/sidebar.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.cb_is_public = gr.Checkbox(
value=False, label="Shared", min_width=10, scale=4
)
with gr.Row(visible=False) as self._delete_confirm:
self.btn_del_conf = gr.Button(
@@ -54,28 +94,60 @@ class ConversationControl(BasePage):
with gr.Row():
self.conversation_rn = gr.Text(
label="(Enter) to save",
placeholder="Conversation name",
container=False,
container=True,
scale=5,
min_width=10,
interactive=True,
)
self.conversation_rn_btn = gr.Button(
value="Rename",
scale=1,
min_width=10,
elem_classes=["no-background", "body-text-color", "bold-text"],
visible=False,
)
def load_chat_history(self, user_id):
"""Reload chat history"""
# In case user are admin. They can also watch the
# public conversations
can_see_public: bool = False
with Session(engine) as session:
statement = select(User).where(User.id == user_id)
result = session.exec(statement).one_or_none()
if result is not None:
if flowsettings.KH_USER_CAN_SEE_PUBLIC:
can_see_public = (
result.username == flowsettings.KH_USER_CAN_SEE_PUBLIC
)
else:
can_see_public = True
print(f"User-id: {user_id}, can see public conversations: {can_see_public}")
options = []
with Session(engine) as session:
statement = (
select(Conversation)
.where(Conversation.user == user_id)
.order_by(Conversation.date_created.desc()) # type: ignore
)
# Define condition based on admin-role:
# - can_see: can see their conversations & public files
# - can_not_see: only see their conversations
if can_see_public:
statement = (
select(Conversation)
.where(
or_(
Conversation.user == user_id,
Conversation.is_public,
)
)
.order_by(
Conversation.is_public.desc(), Conversation.date_created.desc()
) # type: ignore
)
else:
statement = (
select(Conversation)
.where(Conversation.user == user_id)
.order_by(Conversation.date_created.desc()) # type: ignore
)
results = session.exec(statement).all()
for result in results:
options.append((result.name, result.id))
@@ -129,7 +201,7 @@ class ConversationControl(BasePage):
else:
return None, gr.update(value=None, choices=[])
def select_conv(self, conversation_id):
def select_conv(self, conversation_id, user_id):
"""Select the conversation"""
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
@@ -137,18 +209,46 @@ class ConversationControl(BasePage):
result = session.exec(statement).one()
id_ = result.id
name = result.name
selected = result.data_source.get("selected", {})
is_conv_public = result.is_public
# disable file selection ids state if
# not the owner of the conversation
if user_id == result.user:
selected = result.data_source.get("selected", {})
else:
selected = {}
chats = result.data_source.get("messages", [])
info_panel = ""
retrieval_history: list[str] = result.data_source.get(
"retrieval_messages", []
)
plot_history: list[dict] = result.data_source.get("plot_history", [])
# On initialization
# Ensure len of retrieval and messages are equal
retrieval_history = sync_retrieval_n_message(chats, retrieval_history)
info_panel = (
retrieval_history[-1]
if retrieval_history
else "<h5><b>No evidence found.</b></h5>"
)
plot_data = plot_history[-1] if plot_history else None
state = result.data_source.get("state", STATE)
except Exception as e:
logger.warning(e)
id_ = ""
name = ""
selected = {}
chats = []
retrieval_history = []
plot_history = []
info_panel = ""
plot_data = None
state = STATE
is_conv_public = False
indices = []
for index in self._app.index_manager.indices:
@@ -160,10 +260,29 @@ class ConversationControl(BasePage):
if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), index.default_selector))
return id_, id_, name, chats, info_panel, state, *indices
return (
id_,
id_,
name,
chats,
info_panel,
plot_data,
retrieval_history,
plot_history,
is_conv_public,
state,
*indices,
)
def rename_conv(self, conversation_id, new_name, user_id):
def rename_conv(self, conversation_id, new_name, is_renamed, user_id):
"""Rename the conversation"""
if not is_renamed:
return (
gr.update(),
conversation_id,
gr.update(visible=False),
)
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
@@ -185,7 +304,12 @@ class ConversationControl(BasePage):
session.commit()
history = self.load_chat_history(user_id)
return gr.update(choices=history), conversation_id
gr.Info("Conversation renamed.")
return (
gr.update(choices=history),
conversation_id,
gr.update(visible=False),
)
def _on_app_created(self):
"""Reload the conversation once the app is created"""

View File

@@ -12,7 +12,7 @@ class ReportIssue(BasePage):
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="Report", open=False):
with gr.Accordion(label="Feedback", open=False):
self.correctness = gr.Radio(
choices=[
("The answer is correct", "correct"),

View File

@@ -9,6 +9,7 @@ from theflow.settings import settings
def get_remote_doc(url: str) -> str:
try:
res = requests.get(url)
res.raise_for_status()
return res.text
except Exception as e:
print(f"Failed to fetch document from {url}: {e}")

View File

@@ -7,9 +7,9 @@ from sqlmodel import Session, select
fetch_creds = """
function() {
const username = getStorage('username')
const password = getStorage('password')
return [username, password];
const username = getStorage('username', '')
const password = getStorage('password', '')
return [username, password, null];
}
"""

View File

@@ -15,18 +15,18 @@ class ResourcesTab(BasePage):
self.on_building_ui()
def on_building_ui(self):
if self._app.f_user_management:
with gr.Tab("User Management", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
with gr.Tab("Index Collections") as self.index_management_tab:
self.index_management = IndexManagement(self._app)
with gr.Tab("LLMs") as self.llm_management_tab:
self.llm_management = LLMManagement(self._app)
with gr.Tab("Embedding Models") as self.emb_management_tab:
with gr.Tab("Embeddings") as self.emb_management_tab:
self.emb_management = EmbeddingManagement(self._app)
with gr.Tab("Index Management") as self.index_management_tab:
self.index_management = IndexManagement(self._app)
if self._app.f_user_management:
with gr.Tab("Users", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
def on_subscribe_public_events(self):
if self._app.f_user_management:

View File

@@ -94,6 +94,28 @@ def validate_password(pwd, pwd_cnf):
return ""
def create_user(usn, pwd) -> bool:
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
result = session.exec(statement).all()
if result:
print(f'User "{usn}" already exists')
return False
else:
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
user = User(
username=usn,
username_lower=usn.lower(),
password=hashed_password,
admin=True,
)
session.add(user)
session.commit()
return True
class UserManagement(BasePage):
def __init__(self, app):
self._app = app
@@ -105,23 +127,9 @@ class UserManagement(BasePage):
usn = flowsettings.KH_FEATURE_USER_MANAGEMENT_ADMIN
pwd = flowsettings.KH_FEATURE_USER_MANAGEMENT_PASSWORD
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
result = session.exec(statement).all()
if result:
print(f'User "{usn}" already exists')
else:
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
user = User(
username=usn,
username_lower=usn.lower(),
password=hashed_password,
admin=True,
)
session.add(user)
session.commit()
gr.Info(f'User "{usn}" created successfully')
is_created = create_user(usn, pwd)
if is_created:
gr.Info(f'User "{usn}" created successfully')
def on_building_ui(self):
with gr.Tab(label="User list"):
@@ -224,7 +232,7 @@ class UserManagement(BasePage):
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)

View File

@@ -2,13 +2,15 @@ import hashlib
import gradio as gr
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Settings, User, engine
from sqlmodel import Session, select
signout_js = """
function() {
function(u, c, pw, pwc) {
removeFromStorage('username');
removeFromStorage('password');
return [u, c, pw, pwc];
}
"""
@@ -72,6 +74,10 @@ class SettingsPage(BasePage):
self._components = {}
self._reasoning_mode = {}
# store llms and embeddings components
self._llms = []
self._embeddings = []
# render application page if there are application settings
self._render_app_tab = False
if self._default_settings.application.settings:
@@ -101,14 +107,13 @@ class SettingsPage(BasePage):
def on_building_ui(self):
if self._app.f_user_management:
with gr.Tab("Users"):
with gr.Tab("User settings"):
self.user_tab()
with gr.Tab("General"):
self.app_tab()
with gr.Tab("Document Indices"):
self.index_tab()
with gr.Tab("Reasoning Pipelines"):
self.reasoning_tab()
self.app_tab()
self.index_tab()
self.reasoning_tab()
self.setting_save_btn = gr.Button(
"Save changes", variant="primary", scale=1, elem_classes=["right-button"]
)
@@ -192,7 +197,7 @@ class SettingsPage(BasePage):
)
onSignOutClick = self.signout.click(
lambda: (None, "Current user: ___", "", ""),
inputs=None,
inputs=[],
outputs=[
self._user_id,
self.current_name,
@@ -248,10 +253,14 @@ class SettingsPage(BasePage):
return "", ""
def app_tab(self):
with gr.Tab("General application settings", visible=self._render_app_tab):
with gr.Tab("General", visible=self._render_app_tab):
for n, si in self._default_settings.application.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"application.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
def index_tab(self):
# TODO: double check if we need general
@@ -260,12 +269,18 @@ class SettingsPage(BasePage):
# obj = render_setting_item(si, si.value)
# self._components[f"index.{n}"] = obj
with gr.Tab("Index settings", visible=self._render_index_tab):
id2name = {k: v.name for k, v in self._app.index_manager.info().items()}
with gr.Tab("Retrieval settings", visible=self._render_index_tab):
for pn, sig in self._default_settings.index.options.items():
with gr.Tab(f"Index {pn}"):
name = "{} Collection".format(id2name.get(pn, f"<id {pn}>"))
with gr.Tab(name):
for n, si in sig.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"index.options.{pn}.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
def reasoning_tab(self):
with gr.Tab("Reasoning settings", visible=self._render_reasoning_tab):
@@ -275,6 +290,10 @@ class SettingsPage(BasePage):
continue
obj = render_setting_item(si, si.value)
self._components[f"reasoning.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
gr.Markdown("### Reasoning-specific settings")
self._components["reasoning.use"] = render_setting_item(
@@ -289,10 +308,19 @@ class SettingsPage(BasePage):
visible=idx == 0,
elem_id=pn,
) as self._reasoning_mode[pn]:
gr.Markdown("**Name**: Description")
reasoning = reasonings.get(pn, None)
if reasoning is None:
gr.Markdown("**Name**: Description")
else:
info = reasoning.get_info()
gr.Markdown(f"**{info['name']}**: {info['description']}")
for n, si in sig.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"reasoning.options.{pn}.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
def change_reasoning_mode(self, value):
output = []
@@ -360,3 +388,38 @@ class SettingsPage(BasePage):
outputs=[self._settings_state] + self.components(),
show_progress="hidden",
)
def update_llms():
from ktem.llms.manager import llms
if llms._default:
llm_choices = [(f"{llms._default} (default)", "")]
else:
llm_choices = [("(random)", "")]
llm_choices += [(_, _) for _ in llms.options().keys()]
return gr.update(choices=llm_choices)
def update_embeddings():
from ktem.embeddings.manager import embedding_models_manager
if embedding_models_manager._default:
emb_choices = [(f"{embedding_models_manager._default} (default)", "")]
else:
emb_choices = [("(random)", "")]
emb_choices += [(_, _) for _ in embedding_models_manager.options().keys()]
return gr.update(choices=emb_choices)
for llm in self._llms:
self._app.app.load(
update_llms,
inputs=[],
outputs=[llm],
show_progress="hidden",
)
for emb in self._embeddings:
self._app.app.load(
update_embeddings,
inputs=[],
outputs=[emb],
show_progress="hidden",
)

View File

@@ -0,0 +1,9 @@
from .decompose_question import DecomposeQuestionPipeline
from .fewshot_rewrite_question import FewshotRewriteQuestionPipeline
from .rewrite_question import RewriteQuestionPipeline
__all__ = [
"DecomposeQuestionPipeline",
"FewshotRewriteQuestionPipeline",
"RewriteQuestionPipeline",
]

View File

@@ -0,0 +1,79 @@
import logging
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization.rewrite_question import RewriteQuestionPipeline
from pydantic import BaseModel, Field
from kotaemon.base import Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM
logger = logging.getLogger(__name__)
class SubQuery(BaseModel):
"""Search over a database of insurance rulebooks or financial reports"""
sub_query: str = Field(
...,
description="A very specific query against the database.",
)
class DecomposeQuestionPipeline(RewriteQuestionPipeline):
"""Decompose user complex question into multiple sub-questions
Args:
llm: the language model to rewrite question
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(
default_callback=lambda _: llms.get("openai-gpt4-turbo", llms.get_default())
)
DECOMPOSE_SYSTEM_PROMPT_TEMPLATE = (
"You are an expert at converting user complex questions into sub questions. "
"Perform query decomposition using provided function_call. "
"Given a user question, break it down into the most specific sub"
" questions you can (at most 3) "
"which will help you answer the original question. "
"Each sub question should be about a single concept/fact/idea. "
"If there are acronyms or words you are not familiar with, "
"do not try to rephrase them."
)
prompt_template: str = DECOMPOSE_SYSTEM_PROMPT_TEMPLATE
def create_prompt(self, question):
schema = SubQuery.model_json_schema()
function = {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
llm_kwargs = {
"tools": [{"type": "function", "function": function}],
"tool_choice": "auto",
}
messages = [
SystemMessage(content=self.prompt_template),
HumanMessage(content=question),
]
return messages, llm_kwargs
def run(self, question: str) -> list: # type: ignore
messages, llm_kwargs = self.create_prompt(question)
result = self.llm(messages, **llm_kwargs)
tool_calls = result.additional_kwargs.get("tool_calls", None)
sub_queries = []
if tool_calls:
for tool_call in tool_calls:
sub_queries.append(
Document(
content=SubQuery.parse_raw(
tool_call["function"]["arguments"]
).sub_query
)
)
return sub_queries

View File

@@ -0,0 +1,100 @@
import json
import uuid
from pathlib import Path
from ktem.components import get_docstore, get_vectorstore
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization.rewrite_question import (
DEFAULT_REWRITE_PROMPT,
RewriteQuestionPipeline,
)
from theflow.settings import settings as flowsettings
from kotaemon.base import AIMessage, Document, HumanMessage, Node, SystemMessage
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.llms import ChatLLM
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
class FewshotRewriteQuestionPipeline(RewriteQuestionPipeline):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
embedding: the embedding model to encode the question
vector_store: the vector store to store the encoded question
doc_store: the document store to store the original question
k: the number of examples to retrieve for rewriting
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
embedding: BaseEmbeddings
vector_store: BaseVectorStore
doc_store: BaseDocumentStore
k: int = getattr(flowsettings, "N_PROMPT_OPT_EXAMPLES", 3)
def add_documents(self, examples, batch_size: int = 50):
print("Adding fewshot examples for rewriting")
documents = []
for example in examples:
doc = Document(
text=example["input"], id_=str(uuid.uuid4()), metadata=example
)
documents.append(doc)
for i in range(0, len(documents), batch_size):
embeddings = self.embedding(documents[i : i + batch_size])
ids = [t.doc_id for t in documents[i : i + batch_size]]
self.vector_store.add(
embeddings=embeddings,
ids=ids,
)
self.doc_store.add(documents[i : i + batch_size])
@classmethod
def get_pipeline(
cls,
embedding,
example_path=Path(__file__).parent / "rephrase_question_train.json",
collection_name: str = "fewshot_rewrite_examples",
):
vector_store = get_vectorstore(collection_name)
doc_store = get_docstore(collection_name)
pipeline = cls(
embedding=embedding, vector_store=vector_store, doc_store=doc_store
)
if doc_store.count():
return pipeline
examples = json.load(open(example_path, "r"))
pipeline.add_documents(examples)
return pipeline
def run(self, question: str) -> Document: # type: ignore
emb = self.embedding(question)[0].embedding
_, _, ids = self.vector_store.query(embedding=emb, top_k=self.k)
examples = self.doc_store.get(ids)
messages = [SystemMessage(content="You are a helpful assistant")]
for example in examples:
messages.append(
HumanMessage(
content=self.rewrite_template.format(
question=example.metadata["input"], lang=self.lang
)
)
)
messages.append(AIMessage(content=example.metadata["output"]))
messages.append(
HumanMessage(
content=self.rewrite_template.format(question=question, lang=self.lang)
)
)
result = self.llm(messages)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
from ktem.llms.manager import llms
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
return self.llm(messages)

View File

@@ -0,0 +1,36 @@
import logging
from ktem.llms.manager import llms
from kotaemon.base import AIMessage, BaseComponent, Document, HumanMessage, Node
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
class SuggestConvNamePipeline(BaseComponent):
"""Suggest a good conversation name based on the chat history."""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
SUGGEST_NAME_PROMPT_TEMPLATE = (
"You are an expert at suggesting good and memorable conversation name. "
"Based on the chat history above, "
"suggest a good conversation name (max 10 words). "
"Give answer in {lang}. Just output the conversation "
"name without any extra."
)
prompt_template: str = SUGGEST_NAME_PROMPT_TEMPLATE
lang: str = "English"
def run(self, chat_history: list[tuple[str, str]]) -> Document: # type: ignore
prompt_template = PromptTemplate(self.prompt_template)
prompt = prompt_template.populate(lang=self.lang)
messages = []
for human, ai in chat_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=prompt))
return self.llm(messages)

View File

@@ -19,7 +19,10 @@ from kotaemon.agents import (
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
from ..utils import SUPPORTED_LANGUAGE_MAP
logger = logging.getLogger(__name__)
DEFAULT_AGENT_STEPS = 4
class DocSearchArgs(BaseModel):
@@ -97,7 +100,7 @@ class DocSearchTool(BaseTool):
)
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
print("Score", retrieved_item.metadata.get("relevance_score", None))
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
# trim context by trim_len
if evidence:
@@ -190,7 +193,9 @@ class ReactAgentPipeline(BaseReasoning):
"<b>Action</b>: <em>{tool}[{input}]</em>\n\n<b>Output</b>: {output}"
).format(
tool=step.tool if status == "thinking" else "",
input=step.tool_input.replace("\n", "") if status == "thinking" else "",
input=step.tool_input.replace("\n", "").replace('"', "")
if status == "thinking"
else "",
output=output if status == "thinking" else "Finished",
)
return Document(
@@ -261,9 +266,17 @@ class ReactAgentPipeline(BaseReasoning):
llm_name = settings[f"{prefix}.llm"]
llm = llms.get(llm_name, llms.get_default())
max_context_length_setting = settings.get("reasoning.max_context_length", None)
pipeline = ReactAgentPipeline(retrievers=retrievers)
pipeline.agent.llm = llm
pipeline.agent.max_iterations = settings[f"{prefix}.max_iterations"]
if max_context_length_setting:
pipeline.agent.max_context_length = (
max_context_length_setting // DEFAULT_AGENT_STEPS
)
tools = []
for tool_name in settings[f"reasoning.options.{_id}.tools"]:
tool = TOOL_REGISTRY[tool_name]
@@ -273,7 +286,7 @@ class ReactAgentPipeline(BaseReasoning):
tool.llm = llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
pipeline.agent.output_lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
@@ -298,6 +311,7 @@ class ReactAgentPipeline(BaseReasoning):
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"special_type": "llm",
"info": (
"The language model to use for generating the answer. If None, "
"the application default language model will be used."
@@ -325,5 +339,10 @@ class ReactAgentPipeline(BaseReasoning):
return {
"id": "ReAct",
"name": "ReAct Agent",
"description": "Implementing ReAct paradigm",
"description": (
"Implementing ReAct paradigm: https://arxiv.org/abs/2210.03629. "
"ReAct agent answers the user's request by iteratively formulating "
"plan and executing it. The agent can use multiple tools to gather "
"information and generate the final answer."
),
}

View File

@@ -20,7 +20,10 @@ from kotaemon.agents import (
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
from ..utils import SUPPORTED_LANGUAGE_MAP
logger = logging.getLogger(__name__)
DEFAULT_AGENT_STEPS = 4
DEFAULT_PLANNER_PROMPT = (
@@ -135,7 +138,7 @@ class DocSearchTool(BaseTool):
)
print("Retrieved #{}: {}".format(_id, retrieved_content))
print("Score", retrieved_item.metadata.get("relevance_score", None))
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
# trim context by trim_len
if evidence:
@@ -215,7 +218,7 @@ class RewooAgentPipeline(BaseReasoning):
use_rewrite: bool = False
enable_citation: bool = False
def format_info_panel(self, worker_log):
def format_info_panel_evidence(self, worker_log):
header = ""
content = []
@@ -223,6 +226,10 @@ class RewooAgentPipeline(BaseReasoning):
if line.startswith("#Plan"):
# line starts with #Plan should be marked as a new segment
header = line
elif line.startswith("#Action"):
# small fix for markdown output
line = "\\" + line + "<br>"
content.append(line)
elif line.startswith("#"):
# stop markdown from rendering big headers
line = "\\" + line
@@ -238,6 +245,17 @@ class RewooAgentPipeline(BaseReasoning):
content=Render.collapsible(
header=header,
content=Render.table("\n".join(content)),
open=False,
),
)
def format_info_panel_planner(self, planner_output):
planner_output = planner_output.replace("\n", "<br>")
return Document(
channel="info",
content=Render.collapsible(
header="Planner Output",
content=planner_output,
open=True,
),
)
@@ -285,12 +303,19 @@ class RewooAgentPipeline(BaseReasoning):
# line starts with #Plan should be marked as a new segment
new_segment = [line]
segments.append(new_segment)
elif line.startswith("#Action"):
# small fix for markdown output
line = "\\" + line + "<br>"
segments[-1].append(line)
elif line.startswith("#"):
# stop markdown from rendering big headers
line = "\\" + line
segments[-1].append(line)
else:
segments[-1].append(line)
if segments:
segments[-1].append(line)
else:
segments.append([line])
outputs = []
for segment in segments:
@@ -337,18 +362,23 @@ class RewooAgentPipeline(BaseReasoning):
for item in output_stream:
if item.intermediate_steps:
for step in item.intermediate_steps:
yield Document(
channel="info",
content=self.format_info_panel(step["worker_log"]),
)
if "planner_log" in step:
yield Document(
channel="info",
content=self.format_info_panel_planner(step["planner_log"]),
)
else:
yield Document(
channel="info",
content=self.format_info_panel_evidence(step["worker_log"]),
)
if item.text:
# final answer
yield Document(channel="chat", content=item.text)
answer = output_stream.value
yield Document(channel="info", content=None)
refined_citations = self.prepare_citation(answer)
for _ in refined_citations:
yield _
yield from self.prepare_citation(answer)
return answer
@@ -360,6 +390,8 @@ class RewooAgentPipeline(BaseReasoning):
prefix = f"reasoning.options.{_id}"
pipeline = RewooAgentPipeline(retrievers=retrievers)
max_context_length_setting = settings.get("reasoning.max_context_length", None)
planner_llm_name = settings[f"{prefix}.planner_llm"]
planner_llm = llms.get(planner_llm_name, llms.get_default())
solver_llm_name = settings[f"{prefix}.solver_llm"]
@@ -367,6 +399,10 @@ class RewooAgentPipeline(BaseReasoning):
pipeline.agent.planner_llm = planner_llm
pipeline.agent.solver_llm = solver_llm
if max_context_length_setting:
pipeline.agent.max_context_length = (
max_context_length_setting // DEFAULT_AGENT_STEPS
)
tools = []
for tool_name in settings[f"{prefix}.tools"]:
@@ -377,7 +413,7 @@ class RewooAgentPipeline(BaseReasoning):
tool.llm = solver_llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
pipeline.agent.output_lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
pipeline.agent.prompt_template["Planner"] = PromptTemplate(
@@ -413,6 +449,7 @@ class RewooAgentPipeline(BaseReasoning):
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"special_type": "llm",
"info": (
"The language model to use for planning. "
"This model will generate a plan based on the "
@@ -424,6 +461,7 @@ class RewooAgentPipeline(BaseReasoning):
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"special_type": "llm",
"info": (
"The language model to use for solving. "
"This model will generate the answer based on the "
@@ -457,6 +495,10 @@ class RewooAgentPipeline(BaseReasoning):
"id": "ReWOO",
"name": "ReWOO Agent",
"description": (
"Implementing ReWOO paradigm " "https://arxiv.org/pdf/2305.18323.pdf"
"Implementing ReWOO paradigm: https://arxiv.org/abs/2305.18323. "
"The ReWOO agent makes a step by step plan in the first stage, "
"then solves each step in the second stage. The agent can use "
"external tools to help in the reasoning process. Once all stages "
"are completed, the agent will summarize the answer."
),
}

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,7 @@ class SettingItem(BaseModel):
choices: list = Field(default_factory=list)
metadata: dict = Field(default_factory=dict)
component: str = "text"
special_type: str = ""
class BaseSettingGroup(BaseModel):
@@ -55,6 +56,9 @@ class BaseSettingGroup(BaseModel):
option = self.options[option_id]
return option.get_setting_item(sub_path)
def __bool__(self):
return bool(self.settings) or bool(self.options)
class SettingReasoningGroup(BaseSettingGroup):
def _get_options(self) -> dict:

View File

@@ -0,0 +1,3 @@
from .lang import SUPPORTED_LANGUAGE_MAP
__all__ = ["SUPPORTED_LANGUAGE_MAP"]

View File

@@ -0,0 +1,20 @@
def sync_retrieval_n_message(
messages: list[list[str]],
retrievals: list[str],
) -> list[str]:
"""Ensure len of messages history and retrieval history are equal
Empty string/Truncate will be used in case any difference exist
"""
n_message = len(messages) # include previous history
n_retrieval = min(n_message, len(retrievals))
diff = n_message - n_retrieval
retrievals = retrievals[:n_retrieval] + ["" for _ in range(diff)]
assert len(retrievals) == n_message
return retrievals
if __name__ == "__main__":
print(sync_retrieval_n_message([[""], [""], [""]], []))

View File

@@ -0,0 +1 @@
SUPPORTED_LANGUAGE_MAP = {"en": "English", "ja": "Japanese", "vi": "Vietnamese"}

View File

@@ -1,4 +1,36 @@
import os.path
import markdown
from fast_langdetect import detect
from kotaemon.base import RetrievedDocument
def is_close(val1, val2, tolerance=1e-9):
return abs(val1 - val2) <= tolerance
def replace_mardown_header(text: str) -> str:
textlines = text.splitlines()
newlines = []
for line in textlines:
if line.startswith("#"):
line = "<strong>" + line.replace("#", "") + "</strong>"
if line.startswith("=="):
line = ""
newlines.append(line)
return "\n".join(newlines)
def get_header(doc: RetrievedDocument) -> str:
"""Get the header for the document"""
header = ""
if "page_label" in doc.metadata:
header += f" [Page {doc.metadata['page_label']}]"
header += f" {doc.metadata.get('file_name', '<evidence>')}"
return header.strip()
class Render:
@@ -13,9 +45,152 @@ class Render:
@staticmethod
def table(text: str) -> str:
"""Render table from markdown format into HTML"""
text = replace_mardown_header(text)
return markdown.markdown(text, extensions=["markdown.extensions.tables"])
@staticmethod
def preview(
html_content: str,
doc: RetrievedDocument,
highlight_text: str | None = None,
) -> str:
text = doc.content
pdf_path = doc.metadata.get("file_path", "")
if not os.path.isfile(pdf_path):
print(f"pdf-path: {pdf_path} does not exist")
return html_content
is_pdf = doc.metadata.get("file_type", "") == "application/pdf"
page_idx = int(doc.metadata.get("page_label", 1))
if not is_pdf:
print("Document is not pdf")
return html_content
if page_idx < 0:
print("Fail to extract page number")
return html_content
if not highlight_text:
try:
lang = detect(text.replace("\n", " "))["lang"]
print("lang", lang)
if lang not in ["ja", "cn"]:
highlight_words = [
t[:-1] if t.endswith("-") else t for t in text.split("\n")
]
highlight_text = highlight_words[0]
phrase = "true"
else:
highlight_text = text.replace("\n", "")
phrase = "false"
print("highlight_text", highlight_text, phrase)
except Exception as e:
print(e)
highlight_text = text
else:
phrase = "true"
return f"""
{html_content}
<a href="#" class="pdf-link" data-src="/file={pdf_path}" data-page="{page_idx}" data-search="{highlight_text}" data-phrase="{phrase}">
[Preview]
</a>
""" # noqa
@staticmethod
def highlight(text: str) -> str:
"""Highlight text"""
return f"<mark>{text}</mark>"
@staticmethod
def image(url: str, text: str = "") -> str:
"""Render an image"""
img = f'<img src="{url}"><br>'
if text:
caption = f"<p>{text}</p>"
return f"<figure>{img}{caption}</figure><br>"
return img
@staticmethod
def collapsible_with_header(
doc: RetrievedDocument,
open_collapsible: bool = False,
) -> str:
header = f"<i>{get_header(doc)}</i>"
if doc.metadata.get("type", "") == "image":
doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text)
else:
doc_content = Render.table(doc.text)
return Render.collapsible(
header=Render.preview(header, doc),
content=doc_content,
open=open_collapsible,
)
@staticmethod
def collapsible_with_header_score(
doc: RetrievedDocument,
override_text: str | None = None,
highlight_text: str | None = None,
open_collapsible: bool = False,
) -> str:
"""Format the retrieval score and the document"""
# score from doc_store (Elasticsearch)
if is_close(doc.score, -1.0):
vectorstore_score = ""
text_search_str = " (full-text search)<br>"
else:
vectorstore_score = str(round(doc.score, 2))
text_search_str = "<br>"
llm_reranking_score = (
round(doc.metadata["llm_trulens_score"], 2)
if doc.metadata.get("llm_trulens_score") is not None
else 0.0
)
cohere_reranking_score = (
round(doc.metadata["cohere_reranking_score"], 2)
if doc.metadata.get("cohere_reranking_score") is not None
else 0.0
)
item_type_prefix = doc.metadata.get("type", "")
item_type_prefix = item_type_prefix.capitalize()
if item_type_prefix:
item_type_prefix += " from "
rendered_score = Render.collapsible(
header=f"<b>&emsp;Relevance score</b>: {llm_reranking_score}",
content="<b>&emsp;&emsp;Vectorstore score:</b>"
f" {vectorstore_score}"
f"{text_search_str}"
"<b>&emsp;&emsp;LLM relevant score:</b>"
f" {llm_reranking_score}<br>"
"<b>&emsp;&emsp;Reranking score:</b>"
f" {cohere_reranking_score}<br>",
)
text = doc.text if not override_text else override_text
if doc.metadata.get("type", "") == "image":
rendered_doc_content = Render.image(
url=doc.metadata["image_origin"],
text=text,
)
else:
rendered_doc_content = Render.table(text)
rendered_header = Render.preview(
f"<i>{item_type_prefix}{get_header(doc)}</i>"
f" [score: {llm_reranking_score}]",
doc,
highlight_text=highlight_text,
)
return Render.collapsible(
header=rendered_header,
content=rendered_score + rendered_doc_content,
open=open_collapsible,
)