feat: merge develop (#123)

* Support hybrid vector retrieval

* Enable figures and table reading in Azure DI

* Retrieve with multi-modal

* Fix mixing up table

* Add txt loader

* Add Anthropic Chat

* Raising error when retrieving help file

* Allow same filename for different people if private is True

* Allow declaring extra LLM vendors

* Show chunks on the File page

* Allow elasticsearch to get more docs

* Fix Cohere response (#86)

* Fix Cohere response

* Remove Adobe pdfservice from dependency

kotaemon doesn't rely more pdfservice for its core functionality,
and pdfservice uses very out-dated dependency that causes conflict.

---------

Co-authored-by: trducng <trungduc1992@gmail.com>

* Add confidence score (#87)

* Save question answering data as a log file

* Save the original information besides the rewritten info

* Export Cohere relevance score as confidence score

* Fix style check

* Upgrade the confidence score appearance (#90)

* Highlight the relevance score

* Round relevance score. Get key from config instead of env

* Cohere return all scores

* Display relevance score for image

* Remove columns and rows in Excel loader which contains all NaN (#91)

* remove columns and rows which contains all NaN

* back to multiple joiner options

* Fix style

---------

Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local>
Co-authored-by: trducng <trungduc1992@gmail.com>

* Track retriever state

* Bump llama-index version 0.10

* feat/save-azuredi-mhtml-to-markdown (#93)

* feat/save-azuredi-mhtml-to-markdown

* fix: replace os.path to pathlib change theflow.settings

* refactor: base on pre-commit

* chore: move the func of saving content markdown above removed_spans

---------

Co-authored-by: jacky0218 <jacky0218@github.com>

* fix: losing first chunk (#94)

* fix: losing first chunk.

* fix: update the method of preventing losing chunks

---------

Co-authored-by: jacky0218 <jacky0218@github.com>

* fix: adding the base64 image in markdown (#95)

* feat: more chunk info on UI

* fix: error when reindexing files

* refactor: allow more information exception trace when using gpt4v

* feat: add excel reader that treats each worksheet as a document

* Persist loader information when indexing file

* feat: allow hiding unneeded setting panels

* feat: allow specific timezone when creating conversation

* feat: add more confidence score (#96)

* Allow a list of rerankers

* Export llm reranking score instead of filter with boolean

* Get logprobs from LLMs

* Rename cohere reranking score

* Call 2 rerankers at once

* Run QA pipeline for each chunk to get qa_score

* Display more relevance scores

* Define another LLMScoring instead of editing the original one

* Export logprobs instead of probs

* Call LLMScoring

* Get qa_score only in the final answer

* feat: replace text length with token in file list

* ui: show index name instead of id in the settings

* feat(ai): restrict the vision temperature

* fix(ui): remove the misleading message about non-retrieved evidences

* feat(ui): show the reasoning name and description in the reasoning setting page

* feat(ui): show version on the main windows

* feat(ui): show default llm name in the setting page

* fix(conf): append the result of doc in llm_scoring (#97)

* fix: constraint maximum number of images

* feat(ui): allow filter file by name in file list page

* Fix exceeding token length error for OpenAI embeddings by chunking then averaging (#99)

* Average embeddings in case the text exceeds max size

* Add docstring

* fix: Allow empty string when calling embedding

* fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98)

* Round when displaying not by default

* Add LLMTrulens reranking model

* Use llmtrulensscoring in pipeline

* fix: update UI display for trulen score

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* feat: add question decomposition & few-shot rewrite pipeline (#89)

* Create few-shot query-rewriting. Run and display the result in info_panel

* Fix style check

* Put the functions to separate modules

* Add zero-shot question decomposition

* Fix fewshot rewriting

* Add default few-shot examples

* Fix decompose question

* Fix importing rewriting pipelines

* fix: update decompose logic in fullQA pipeline

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* fix: add encoding utf-8 when save temporal markdown in vectorIndex (#101)

* fix: improve retrieval pipeline and relevant score display (#102)

* fix: improve retrieval pipeline by extending first round top_k with multiplier

* fix: minor fix

* feat: improve UI default settings and add quick switch option for pipeline

* fix: improve agent logics (#103)

* fix: improve agent progres display

* fix: update retrieval logic

* fix: UI display

* fix: less verbose debug log

* feat: add warning message for low confidence

* fix: LLM scoring enabled by default

* fix: minor update logics

* fix: hotfix image citation

* feat: update docx loader for handle merged table cells + handle zip file upload (#104)

* feat: update docx loader for handle merged table cells

* feat: handle zip file

* refactor: pre-commit

* fix: escape text in download UI

* feat: optimize vector store query db (#105)

* feat: optimize vector store query db

* feat: add file_id to chroma metadatas

* feat: remove unnecessary logs and update migrate script

* feat: iterate through file index

* fix: remove unused code

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* fix: add openai embedidng exponential back-off

* fix: update import download_loader

* refactor: codespell

* fix: update some default settings

* fix: update installation instruction

* fix: default chunk length in simple QA

* feat: add share converstation feature and enable retrieval history (#108)

* feat: add share converstation feature and enable retrieval history

* fix: update share conversation UI

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>

* fix: allow exponential backoff for failed OCR call (#109)

* fix: update default prompt when no retrieval is used

* fix: create embedding for long image chunks

* fix: add exception handling for additional table retriever

* fix: clean conversation & file selection UI

* fix: elastic search with empty doc_ids

* feat: add thumbnail PDF reader for quick multimodal QA

* feat: add thumbnail handling logic in indexing

* fix: UI text update

* fix: PDF thumb loader page number logic

* feat: add quick indexing pipeline and update UI

* feat: add conv name suggestion

* fix: minor UI change

* feat: citation in thread

* fix: add conv name suggestion in regen

* chore: add assets for usage doc

* chore: update usage doc

* feat: pdf viewer (#110)

* feat: update pdfviewer

* feat: update missing files

* fix: update rendering logic of infor panel

* fix: improve thumbnail retrieval logic

* fix: update PDF evidence rendering logic

* fix: remove pdfjs built dist

* fix: reduce thumbnail evidence count

* chore: update gitignore

* fix: add js event on chat msg select

* fix: update css for viewer

* fix: add env var for PDFJS prebuilt

* fix: move language setting to reasoning utils

---------

Co-authored-by: phv2312 <kat87yb@gmail.com>
Co-authored-by: trducng <trungduc1992@gmail.com>

* feat: graph rag (#116)

* fix: reload server when add/delete index

* fix: rework indexing pipeline to be able to disable vectorstore and splitter if needed

* feat: add graphRAG index with plot view

* fix: update requirement for graphRAG and lighten unnecessary packages

* feat: add knowledge network index (#118)

* feat: add Knowledge Network index

* fix: update reader mode setting for knet

* fix: update init knet

* fix: update collection name to index pipeline

* fix: missing req

---------

Co-authored-by: jeff52415 <jeff.yang@cinnamon.is>

* fix: update info panel return for graphrag

* fix: retriever setting graphrag

* feat: local llm settings (#122)

* feat: expose context length as reasoning setting to better fit local models

* fix: update context length setting for agents

* fix: rework threadpool llm call

* fix: fix improve indexing logic

* fix: fix improve UI

* feat: add lancedb

* fix: improve lancedb logic

* feat: add lancedb vectorstore

* fix: lighten requirement

* fix: improve lanceDB vs

* fix: improve UI

* fix: openai retry

* fix: update reqs

* fix: update launch command

* feat: update Dockerfile

* feat: add plot history

* fix: update default config

* fix: remove verbose print

* fix: update default setting

* fix: update gradio plot return

* fix: default gradio tmp

* fix: improve lancedb docstore

* fix: fix question decompose pipeline

* feat: add multimodal reader in UI

* fix: udpate docs

* fix: update default settings & docker build

* fix: update app startup

* chore: update documentation

* chore: update README

* chore: update README

---------

Co-authored-by: trducng <trungduc1992@gmail.com>

* chore: update README

* chore: update README

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
Co-authored-by: cin-ace <ace@cinnamon.is>
Co-authored-by: Linh Nguyen <70562198+linhnguyen-cinnamon@users.noreply.github.com>
Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local>
Co-authored-by: cin-jacky <101088014+jacky0218@users.noreply.github.com>
Co-authored-by: jacky0218 <jacky0218@github.com>
Co-authored-by: kan_cin <kan@cinnamon.is>
Co-authored-by: phv2312 <kat87yb@gmail.com>
Co-authored-by: jeff52415 <jeff.yang@cinnamon.is>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2024-08-26 08:50:37 +07:00
committed by GitHub
parent 86d60e1649
commit 2570e11501
121 changed files with 14748 additions and 1063 deletions

View File

@@ -39,16 +39,11 @@ class ReactAgent(BaseAgent):
)
max_iterations: int = 5
strict_decode: bool = False
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=800,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
max_context_length: int = Param(
default=3000,
help="Max context length for each tool output.",
)
trim_func: TokenSplitter | None = None
def _compose_plugin_description(self) -> str:
"""
@@ -149,14 +144,28 @@ class ReactAgent(BaseAgent):
function_map[plugin.name] = plugin
return function_map
def _trim(self, text: str) -> str:
def _trim(self, text: str | Document) -> str:
"""
Trim the text to the maximum token length.
"""
evidence_trim_func = (
self.trim_func
if self.trim_func
else TokenSplitter(
chunk_size=self.max_context_length,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
)
if isinstance(text, str):
texts = self.trim_func([Document(text=text)])
texts = evidence_trim_func([Document(text=text)])
elif isinstance(text, Document):
texts = self.trim_func([text])
texts = evidence_trim_func([text])
else:
raise ValueError("Invalid text type to trim")
trim_text = texts[0].text

View File

@@ -39,16 +39,11 @@ class RewooAgent(BaseAgent):
examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent."
)
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=3000,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
max_context_length: int = Param(
default=3000,
help="Max context length for each tool output.",
)
trim_func: TokenSplitter | None = None
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
def planner(self):
@@ -248,8 +243,22 @@ class RewooAgent(BaseAgent):
return p
def _trim_evidence(self, evidence: str):
evidence_trim_func = (
self.trim_func
if self.trim_func
else TokenSplitter(
chunk_size=self.max_context_length,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
)
if evidence:
texts = self.trim_func([Document(text=evidence)])
texts = evidence_trim_func([Document(text=evidence)])
evidence = texts[0].text
logging.info(f"len (trimmed): {len(evidence)}")
return evidence
@@ -317,6 +326,14 @@ class RewooAgent(BaseAgent):
)
print("Planner output:", planner_text_output)
# output planner to info panel
yield AgentOutput(
text="",
agent_type=self.agent_type,
status="thinking",
intermediate_steps=[{"planner_log": planner_text_output}],
)
# Work
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
planner_evidences, evidence_level
@@ -326,7 +343,9 @@ class RewooAgent(BaseAgent):
worker_log += f"{plan}: {plans[plan]}\n"
current_progress = f"{plan}: {plans[plan]}\n"
for e in plan_to_es[plan]:
worker_log += f"#Action: {planner_evidences.get(e, None)}\n"
worker_log += f"{e}: {worker_evidences[e]}\n"
current_progress += f"#Action: {planner_evidences.get(e, None)}\n"
current_progress += f"{e}: {worker_evidences[e]}\n"
yield AgentOutput(

View File

@@ -1,7 +1,7 @@
from typing import AnyStr, Optional, Type
from urllib.error import HTTPError
from langchain.utilities import SerpAPIWrapper
from langchain_community.utilities import SerpAPIWrapper
from pydantic import BaseModel, Field
from .base import BaseTool

View File

@@ -22,12 +22,16 @@ class LLMTool(BaseTool):
)
llm: BaseLLM
args_schema: Optional[Type[BaseModel]] = LLMArgs
dummy_mode: bool = True
def _run_tool(self, query: AnyStr) -> str:
output = None
try:
response = self.llm(query)
if not self.dummy_mode:
response = self.llm(query)
else:
response = None
except ValueError:
raise ToolException("LLM Tool call failed")
output = response.text
output = response.text if response else "<->"
return output

View File

@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
from langchain.schema.messages import AIMessage as LCAIMessage
from langchain.schema.messages import HumanMessage as LCHumanMessage
from langchain.schema.messages import SystemMessage as LCSystemMessage
from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument
from llama_index.core.bridge.pydantic import Field
from llama_index.core.schema import Document as BaseDocument
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
@@ -38,7 +38,7 @@ class Document(BaseDocument):
content: Any = None
source: Optional[str] = None
channel: Optional[Literal["chat", "info", "index", "debug"]] = None
channel: Optional[Literal["chat", "info", "index", "debug", "plot"]] = None
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:
@@ -140,6 +140,7 @@ class LLMInterface(AIMessage):
total_cost: float = 0
logits: list[list[float]] = Field(default_factory=list)
messages: list[AIMessage] = Field(default_factory=list)
logprobs: list[float] = []
class ExtractorOutput(Document):

View File

@@ -133,9 +133,7 @@ def construct_chat_ui(
label="Output file", show_label=True, height=100
)
export_btn = gr.Button("Export")
export_btn.click(
func_export_to_excel, inputs=None, outputs=exported_file
)
export_btn.click(func_export_to_excel, inputs=[], outputs=exported_file)
with gr.Row():
with gr.Column():

View File

@@ -91,7 +91,7 @@ def construct_pipeline_ui(
save_btn.click(func_save, inputs=params, outputs=history_dataframe)
load_params_btn = gr.Button("Reload params")
load_params_btn.click(
func_load_params, inputs=None, outputs=history_dataframe
func_load_params, inputs=[], outputs=history_dataframe
)
history_dataframe.render()
history_dataframe.select(
@@ -103,7 +103,7 @@ def construct_pipeline_ui(
export_btn = gr.Button(
"Export (Result will be in Exported file next to Output)"
)
export_btn.click(func_export, inputs=None, outputs=exported_file)
export_btn.click(func_export, inputs=[], outputs=exported_file)
with gr.Row():
with gr.Column():
if params:

View File

@@ -1,5 +1,15 @@
from itertools import islice
from typing import Optional
import numpy as np
import openai
import tiktoken
from tenacity import (
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from theflow.utils.modules import import_dotted_string
from kotaemon.base import Param
@@ -7,6 +17,24 @@ from kotaemon.base import Param
from .base import BaseEmbeddings, Document, DocumentWithEmbedding
def split_text_by_chunk_size(text: str, chunk_size: int) -> list[list[int]]:
"""Split the text into chunks of a given size
Args:
text: text to split
chunk_size: size of each chunk
Returns:
list of chunks (as tokens)
"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens = iter(encoding.encode(text))
result = []
while chunk := list(islice(tokens, chunk_size)):
result.append(chunk)
return result
class BaseOpenAIEmbeddings(BaseEmbeddings):
"""Base interface for OpenAI embedding model, using the openai library.
@@ -32,6 +60,9 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
"Only supported in `text-embedding-3` and later models."
),
)
context_length: Optional[int] = Param(
None, help="The maximum context length of the embedding model"
)
@Param.auto(depends_on=["max_retries"])
def max_retries_(self):
@@ -56,16 +87,42 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
def invoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
input_ = self.prepare_input(text)
input_doc = self.prepare_input(text)
client = self.prepare_client(async_version=False)
resp = self.openai_response(
client, input=[_.text if _.text else " " for _ in input_], **kwargs
).dict()
output_ = sorted(resp["data"], key=lambda x: x["index"])
return [
DocumentWithEmbedding(embedding=o["embedding"], content=i)
for i, o in zip(input_, output_)
]
input_: list[str | list[int]] = []
splitted_indices = {}
for idx, text in enumerate(input_doc):
if self.context_length:
chunks = split_text_by_chunk_size(text.text or " ", self.context_length)
splitted_indices[idx] = (len(input_), len(input_) + len(chunks))
input_.extend(chunks)
else:
splitted_indices[idx] = (len(input_), len(input_) + 1)
input_.append(text.text)
resp = self.openai_response(client, input=input_, **kwargs).dict()
output_ = list(sorted(resp["data"], key=lambda x: x["index"]))
output = []
for idx, doc in enumerate(input_doc):
embs = output_[splitted_indices[idx][0] : splitted_indices[idx][1]]
if len(embs) == 1:
output.append(
DocumentWithEmbedding(embedding=embs[0]["embedding"], content=doc)
)
continue
chunk_lens = [
len(_)
for _ in input_[splitted_indices[idx][0] : splitted_indices[idx][1]]
]
vs: list[list[float]] = [_["embedding"] for _ in embs]
emb = np.average(vs, axis=0, weights=chunk_lens)
emb = emb / np.linalg.norm(emb)
output.append(DocumentWithEmbedding(embedding=emb.tolist(), content=doc))
return output
async def ainvoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
@@ -118,6 +175,13 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
return OpenAI(**params)
@retry(
retry=retry_if_not_exception_type(
(openai.NotFoundError, openai.BadRequestError)
),
wait=wait_random_exponential(min=1, max=40),
stop=stop_after_attempt(6),
)
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params: dict = {
@@ -174,6 +238,13 @@ class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
return AzureOpenAI(**params)
@retry(
retry=retry_if_not_exception_type(
(openai.NotFoundError, openai.BadRequestError)
),
wait=wait_random_exponential(min=1, max=40),
stop=stop_after_attempt(6),
)
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params: dict = {

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import abstractmethod
from typing import Any, Type
from llama_index.node_parser.interface import NodeParser
from llama_index.core.node_parser.interface import NodeParser
from kotaemon.base import BaseComponent, Document, RetrievedDocument
@@ -32,7 +32,7 @@ class LlamaIndexDocTransformerMixin:
Example:
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
from llama_index.core.text_splitter import TokenTextSplitter
return TokenTextSplitter
To use this mixin, please:

View File

@@ -15,7 +15,7 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
super().__init__(llm=llm, nodes=nodes, **params)
def _get_li_class(self):
from llama_index.extractors import TitleExtractor
from llama_index.core.extractors import TitleExtractor
return TitleExtractor
@@ -30,6 +30,6 @@ class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
super().__init__(llm=llm, summaries=summaries, **params)
def _get_li_class(self):
from llama_index.extractors import SummaryExtractor
from llama_index.core.extractors import SummaryExtractor
return SummaryExtractor

View File

@@ -1,27 +1,42 @@
from pathlib import Path
from typing import Type
from llama_index.readers import PDFReader
from llama_index.readers.base import BaseReader
from decouple import config
from llama_index.core.readers.base import BaseReader
from theflow.settings import settings as flowsettings
from kotaemon.base import BaseComponent, Document, Param
from kotaemon.indices.extractors import BaseDocParser
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from kotaemon.loaders import (
AdobeReader,
AzureAIDocumentIntelligenceLoader,
DirectoryReader,
HtmlReader,
MathpixPDFReader,
MhtmlReader,
OCRReader,
PandasExcelReader,
PDFThumbnailReader,
UnstructuredReader,
)
unstructured = UnstructuredReader()
adobe_reader = AdobeReader()
azure_reader = AzureAIDocumentIntelligenceLoader(
endpoint=str(config("AZURE_DI_ENDPOINT", default="")),
credential=str(config("AZURE_DI_CREDENTIAL", default="")),
cache_dir=getattr(flowsettings, "KH_MARKDOWN_OUTPUT_DIR", None),
)
adobe_reader.vlm_endpoint = azure_reader.vlm_endpoint = getattr(
flowsettings, "KH_VLM_ENDPOINT", ""
)
KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
".xlsx": PandasExcelReader(),
".docx": unstructured,
".pptx": unstructured,
".xls": unstructured,
".doc": unstructured,
".html": HtmlReader(),
@@ -31,7 +46,7 @@ KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
".jpg": unstructured,
".tiff": unstructured,
".tif": unstructured,
".pdf": PDFReader(),
".pdf": PDFThumbnailReader(),
}

View File

@@ -103,7 +103,9 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
if not llm_output.messages:
if not llm_output.messages or not llm_output.additional_kwargs.get(
"tool_calls"
):
return None
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"

View File

@@ -1,5 +1,13 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .llm import LLMReranking
from .llm_scoring import LLMScoring
from .llm_trulens import LLMTrulensScoring
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]
__all__ = [
"CohereReranking",
"LLMReranking",
"LLMScoring",
"BaseReranking",
"LLMTrulensScoring",
]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import os
from decouple import config
from kotaemon.base import Document
@@ -9,8 +9,7 @@ from .base import BaseReranking
class CohereReranking(BaseReranking):
model_name: str = "rerank-multilingual-v2.0"
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
top_k: int = 1
cohere_api_key: str = config("COHERE_API_KEY", "")
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
@@ -22,6 +21,10 @@ class CohereReranking(BaseReranking):
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)
if not self.cohere_api_key:
print("Cohere API key not found. Skipping reranking.")
return documents
cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []
@@ -29,12 +32,13 @@ class CohereReranking(BaseReranking):
return compressed_docs
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
for r in results:
print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
doc.metadata["cohere_reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from langchain.output_parsers.boolean import BooleanOutputParser
from kotaemon.base import Document
from .llm import LLMReranking
class LLMScoring(LLMReranking):
def run(
self,
documents: list[Document],
query: str,
) -> list[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs: list[Document] = []
output_parser = BooleanOutputParser()
if self.concurrent:
with ThreadPoolExecutor() as executor:
futures = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
futures.append(executor.submit(lambda: self.llm(_prompt)))
results = [future.result() for future in futures]
else:
results = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
results.append(self.llm(_prompt))
for result, doc in zip(results, documents):
score = np.exp(np.average(result.logprobs))
include_doc = output_parser.parse(result.text)
if include_doc:
doc.metadata["llm_reranking_score"] = score
else:
doc.metadata["llm_reranking_score"] = 1 - score
filtered_docs.append(doc)
# prevent returning empty result
if len(filtered_docs) == 0:
filtered_docs = documents[: self.top_k]
return filtered_docs

View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import tiktoken
from kotaemon.base import Document, HumanMessage, SystemMessage
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import BaseLLM, PromptTemplate
from .llm import LLMReranking
SYSTEM_PROMPT_TEMPLATE = PromptTemplate(
"""You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION.
Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant.
A few additional scoring guidelines:
- Long CONTEXTS should score equally well as short CONTEXTS.
- RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION.
- RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION.
- CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE.
- CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE.
- CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE.
- CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10.
- Never elaborate.""" # noqa: E501
)
USER_PROMPT_TEMPLATE = PromptTemplate(
"""QUESTION: {question}
CONTEXT: {context}
RELEVANCE: """
) # noqa
PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)")
"""Regex that matches integers."""
MAX_CONTEXT_LEN = 7500
def validate_rating(rating) -> int:
"""Validate a rating is between 0 and 10."""
if not 0 <= rating <= 10:
raise ValueError("Rating must be between 0 and 10")
return rating
def re_0_10_rating(s: str) -> int:
"""Extract a 0-10 rating from a string.
If the string does not match an integer or matches an integer outside the
0-10 range, raises an error instead. If multiple numbers are found within
the expected 0-10 range, the smallest is returned.
Args:
s: String to extract rating from.
Returns:
int: Extracted rating.
Raises:
ParseError: If no integers between 0 and 10 are found in the string.
"""
matches = PATTERN_INTEGER.findall(s)
if not matches:
raise AssertionError
vals = set()
for match in matches:
try:
vals.add(validate_rating(int(match)))
except ValueError:
pass
if not vals:
raise AssertionError
# Min to handle cases like "The rating is 8 out of 10."
return min(vals)
class LLMTrulensScoring(LLMReranking):
llm: BaseLLM
system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE
user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE
concurrent: bool = True
normalize: float = 10
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=MAX_CONTEXT_LEN,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
def run(
self,
documents: list[Document],
query: str,
) -> list[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
documents = sorted(documents, key=lambda doc: doc.get_content())
if self.concurrent:
with ThreadPoolExecutor() as executor:
futures = []
for doc in documents:
chunked_doc_content = self.trim_func(
[
Document(content=doc.get_content())
# skip metadata which cause troubles
]
)[0].text
messages = []
messages.append(
SystemMessage(self.system_prompt_template.populate())
)
messages.append(
HumanMessage(
self.user_prompt_template.populate(
question=query, context=chunked_doc_content
)
)
)
def llm_call():
return self.llm(messages).text
futures.append(executor.submit(llm_call))
results = [future.result() for future in futures]
else:
results = []
for doc in documents:
messages = []
messages.append(SystemMessage(self.system_prompt_template.populate()))
messages.append(
SystemMessage(
self.user_prompt_template.populate(
question=query, context=doc.get_content()
)
)
)
results.append(self.llm(messages).text)
# use Boolean parser to extract relevancy output from LLM
results = [
(r_idx, float(re_0_10_rating(result)) / self.normalize)
for r_idx, result in enumerate(results)
]
results.sort(key=lambda x: x[1], reverse=True)
for r_idx, score in results:
doc = documents[r_idx]
doc.metadata["llm_trulens_score"] = score
filtered_docs.append(doc)
print(
"LLM rerank scores",
[doc.metadata["llm_trulens_score"] for doc in filtered_docs],
)
return filtered_docs

View File

@@ -23,7 +23,7 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
)
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
from llama_index.core.text_splitter import TokenTextSplitter
return TokenTextSplitter
@@ -44,6 +44,6 @@ class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
)
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser
from llama_index.core.node_parser import SentenceWindowNodeParser
return SentenceWindowNodeParser

View File

@@ -1,14 +1,18 @@
from __future__ import annotations
import threading
import uuid
from pathlib import Path
from typing import Optional, Sequence, cast
from theflow.settings import settings as flowsettings
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseIndexing, BaseRetrieval
from .rankings import BaseReranking
from .rankings import BaseReranking, LLMReranking
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
@@ -23,9 +27,11 @@ class VectorIndexing(BaseIndexing):
- List of texts
"""
cache_dir: Optional[str] = getattr(flowsettings, "KH_CHUNKS_OUTPUT_DIR", None)
vector_store: BaseVectorStore
doc_store: Optional[BaseDocumentStore] = None
embedding: BaseEmbeddings
count_: int = 0
def to_retrieval_pipeline(self, *args, **kwargs):
"""Convert the indexing pipeline to a retrieval pipeline"""
@@ -44,6 +50,52 @@ class VectorIndexing(BaseIndexing):
qa_pipeline=CitationQAPipeline(**kwargs),
)
def write_chunk_to_file(self, docs: list[Document]):
# save the chunks content into markdown format
if self.cache_dir:
file_name = Path(docs[0].metadata["file_name"])
for i in range(len(docs)):
markdown_content = ""
if "page_label" in docs[i].metadata:
page_label = str(docs[i].metadata["page_label"])
markdown_content += f"Page label: {page_label}"
if "file_name" in docs[i].metadata:
filename = docs[i].metadata["file_name"]
markdown_content += f"\nFile name: {filename}"
if "section" in docs[i].metadata:
section = docs[i].metadata["section"]
markdown_content += f"\nSection: {section}"
if "type" in docs[i].metadata:
if docs[i].metadata["type"] == "image":
image_origin = docs[i].metadata["image_origin"]
image_origin = f'<p><img src="{image_origin}"></p>'
markdown_content += f"\nImage origin: {image_origin}"
if docs[i].text:
markdown_content += f"\ntext:\n{docs[i].text}"
with open(
Path(self.cache_dir) / f"{file_name.stem}_{self.count_+i}.md",
"w",
encoding="utf-8",
) as f:
f.write(markdown_content)
def add_to_docstore(self, docs: list[Document]):
if self.doc_store:
print("Adding documents to doc store")
self.doc_store.add(docs)
def add_to_vectorstore(self, docs: list[Document]):
# in case we want to skip embedding
if self.vector_store:
print(f"Getting embeddings for {len(docs)} nodes")
embeddings = self.embedding(docs)
print("Adding embeddings to vector store")
self.vector_store.add(
embeddings=embeddings,
ids=[t.doc_id for t in docs],
)
def run(self, text: str | list[str] | Document | list[Document]):
input_: list[Document] = []
if not isinstance(text, list):
@@ -59,16 +111,10 @@ class VectorIndexing(BaseIndexing):
f"Invalid input type {type(item)}, should be str or Document"
)
print(f"Getting embeddings for {len(input_)} nodes")
embeddings = self.embedding(input_)
print("Adding embeddings to vector store")
self.vector_store.add(
embeddings=embeddings,
ids=[t.doc_id for t in input_],
)
if self.doc_store:
print("Adding documents to doc store")
self.doc_store.add(input_)
self.add_to_vectorstore(input_)
self.add_to_docstore(input_)
self.write_chunk_to_file(input_)
self.count_ += len(input_)
class VectorRetrieval(BaseRetrieval):
@@ -78,7 +124,16 @@ class VectorRetrieval(BaseRetrieval):
doc_store: Optional[BaseDocumentStore] = None
embedding: BaseEmbeddings
rerankers: Sequence[BaseReranking] = []
top_k: int = 1
top_k: int = 5
first_round_top_k_mult: int = 10
retrieval_mode: str = "hybrid" # vector, text, hybrid
def _filter_docs(
self, documents: list[RetrievedDocument], top_k: int | None = None
):
if top_k:
documents = documents[:top_k]
return documents
def run(
self, text: str | Document, top_k: Optional[int] = None, **kwargs
@@ -95,24 +150,155 @@ class VectorRetrieval(BaseRetrieval):
if top_k is None:
top_k = self.top_k
do_extend = kwargs.pop("do_extend", False)
thumbnail_count = kwargs.pop("thumbnail_count", 3)
if do_extend:
top_k_first_round = top_k * self.first_round_top_k_mult
else:
top_k_first_round = top_k
if self.doc_store is None:
raise ValueError(
"doc_store is not provided. Please provide a doc_store to "
"retrieve the documents"
)
emb: list[float] = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
result: list[RetrievedDocument] = []
# TODO: should declare scope directly in the run params
scope = kwargs.pop("scope", None)
emb: list[float]
if self.retrieval_mode == "vector":
emb = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(
embedding=emb, top_k=top_k_first_round, **kwargs
)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
elif self.retrieval_mode == "text":
query = text.text if isinstance(text, Document) else text
docs = self.doc_store.query(query, top_k=top_k_first_round, doc_ids=scope)
result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs]
elif self.retrieval_mode == "hybrid":
# similarity search section
emb = self.embedding(text)[0].embedding
vs_docs: list[RetrievedDocument] = []
vs_ids: list[str] = []
vs_scores: list[float] = []
def query_vectorstore():
nonlocal vs_docs
nonlocal vs_scores
nonlocal vs_ids
assert self.doc_store is not None
_, vs_scores, vs_ids = self.vector_store.query(
embedding=emb, top_k=top_k_first_round, **kwargs
)
if vs_ids:
vs_docs = self.doc_store.get(vs_ids)
# full-text search section
ds_docs: list[RetrievedDocument] = []
def query_docstore():
nonlocal ds_docs
assert self.doc_store is not None
query = text.text if isinstance(text, Document) else text
ds_docs = self.doc_store.query(
query, top_k=top_k_first_round, doc_ids=scope
)
vs_query_thread = threading.Thread(target=query_vectorstore)
ds_query_thread = threading.Thread(target=query_docstore)
vs_query_thread.start()
ds_query_thread.start()
vs_query_thread.join()
ds_query_thread.join()
result = [
RetrievedDocument(**doc.to_dict(), score=-1.0)
for doc in ds_docs
if doc not in vs_ids
]
result += [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(vs_docs, vs_scores)
]
print(f"Got {len(vs_docs)} from vectorstore")
print(f"Got {len(ds_docs)} from docstore")
# use additional reranker to re-order the document list
if self.rerankers:
if self.rerankers and text:
for reranker in self.rerankers:
# if reranker is LLMReranking, limit the document with top_k items only
if isinstance(reranker, LLMReranking):
result = self._filter_docs(result, top_k=top_k)
result = reranker(documents=result, query=text)
result = self._filter_docs(result, top_k=top_k)
print(f"Got raw {len(result)} retrieved documents")
# add page thumbnails to the result if exists
thumbnail_doc_ids: set[str] = set()
# we should copy the text from retrieved text chunk
# to the thumbnail to get relevant LLM score correctly
text_thumbnail_docs: dict[str, RetrievedDocument] = {}
non_thumbnail_docs = []
raw_thumbnail_docs = []
for doc in result:
if doc.metadata.get("type") == "thumbnail":
# change type to image to display on UI
doc.metadata["type"] = "image"
raw_thumbnail_docs.append(doc)
continue
if (
"thumbnail_doc_id" in doc.metadata
and len(thumbnail_doc_ids) < thumbnail_count
):
thumbnail_id = doc.metadata["thumbnail_doc_id"]
thumbnail_doc_ids.add(thumbnail_id)
text_thumbnail_docs[thumbnail_id] = doc
else:
non_thumbnail_docs.append(doc)
linked_thumbnail_docs = self.doc_store.get(list(thumbnail_doc_ids))
print(
"thumbnail docs",
len(linked_thumbnail_docs),
"non-thumbnail docs",
len(non_thumbnail_docs),
"raw-thumbnail docs",
len(raw_thumbnail_docs),
)
additional_docs = []
for thumbnail_doc in linked_thumbnail_docs:
text_doc = text_thumbnail_docs[thumbnail_doc.doc_id]
doc_dict = thumbnail_doc.to_dict()
doc_dict["_id"] = text_doc.doc_id
doc_dict["content"] = text_doc.content
doc_dict["metadata"]["type"] = "image"
for key in text_doc.metadata:
if key not in doc_dict["metadata"]:
doc_dict["metadata"][key] = text_doc.metadata[key]
additional_docs.append(RetrievedDocument(**doc_dict, score=text_doc.score))
result = additional_docs + non_thumbnail_docs
if not result:
# return output from raw retrieved thumbnails
result = self._filter_docs(raw_thumbnail_docs, top_k=thumbnail_count)
return result

View File

@@ -7,6 +7,7 @@ from .chats import (
ChatLLM,
ChatOpenAI,
EndpointChatLLM,
LCAnthropicChat,
LCAzureChatOpenAI,
LCChatOpenAI,
LlamaCppChat,
@@ -27,6 +28,7 @@ __all__ = [
"SystemMessage",
"AzureChatOpenAI",
"ChatOpenAI",
"LCAnthropicChat",
"LCAzureChatOpenAI",
"LCChatOpenAI",
"LlamaCppChat",

View File

@@ -1,6 +1,11 @@
from .base import ChatLLM
from .endpoint_based import EndpointChatLLM
from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
from .langchain_based import (
LCAnthropicChat,
LCAzureChatOpenAI,
LCChatMixin,
LCChatOpenAI,
)
from .llamacpp import LlamaCppChat
from .openai import AzureChatOpenAI, ChatOpenAI
@@ -10,6 +15,7 @@ __all__ = [
"ChatLLM",
"EndpointChatLLM",
"ChatOpenAI",
"LCAnthropicChat",
"LCChatOpenAI",
"LCAzureChatOpenAI",
"LCChatMixin",

View File

@@ -221,3 +221,27 @@ class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
from langchain.chat_models import AzureChatOpenAI
return AzureChatOpenAI
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
def __init__(
self,
api_key: str | None = None,
model_name: str | None = None,
temperature: float = 0.7,
**params,
):
super().__init__(
api_key=api_key,
model_name=model_name,
temperature=temperature,
**params,
)
def _get_lc_class(self):
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
raise ImportError("Please install langchain-anthropic")
return ChatAnthropic

View File

@@ -159,6 +159,15 @@ class BaseChatOpenAI(ChatLLM):
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
"tool_calls"
]
if resp["choices"][0].get("logprobs") is None:
logprobs = []
else:
all_logprobs = resp["choices"][0]["logprobs"].get("content")
logprobs = (
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
)
output = LLMInterface(
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"] or "",
@@ -170,6 +179,7 @@ class BaseChatOpenAI(ChatLLM):
AIMessage(content=(_["message"]["content"]) or "")
for _ in resp["choices"]
],
logprobs=logprobs,
)
return output
@@ -216,11 +226,24 @@ class BaseChatOpenAI(ChatLLM):
client, messages=input_messages, stream=True, **kwargs
)
for chunk in resp:
if not chunk.choices:
for c in resp:
chunk = c.dict()
if not chunk["choices"]:
continue
if chunk.choices[0].delta.content is not None:
yield LLMInterface(content=chunk.choices[0].delta.content)
if chunk["choices"][0]["delta"]["content"] is not None:
if chunk["choices"][0].get("logprobs") is None:
logprobs = []
else:
logprobs = [
logprob["logprob"]
for logprob in chunk["choices"][0]["logprobs"].get(
"content", []
)
]
yield LLMInterface(
content=chunk["choices"][0]["delta"]["content"], logprobs=logprobs
)
async def astream(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs

View File

@@ -3,10 +3,12 @@ from .azureai_document_intelligence_loader import AzureAIDocumentIntelligenceLoa
from .base import AutoReader, BaseReader
from .composite_loader import DirectoryReader
from .docx_loader import DocxReader
from .excel_loader import PandasExcelReader
from .excel_loader import ExcelReader, PandasExcelReader
from .html_loader import HtmlReader, MhtmlReader
from .mathpix_loader import MathpixPDFReader
from .ocr_loader import ImageReader, OCRReader
from .pdf_loader import PDFThumbnailReader
from .txt_loader import TxtReader
from .unstructured_loader import UnstructuredReader
__all__ = [
@@ -14,6 +16,7 @@ __all__ = [
"AzureAIDocumentIntelligenceLoader",
"BaseReader",
"PandasExcelReader",
"ExcelReader",
"MathpixPDFReader",
"ImageReader",
"OCRReader",
@@ -23,4 +26,6 @@ __all__ = [
"HtmlReader",
"MhtmlReader",
"AdobeReader",
"TxtReader",
"PDFThumbnailReader",
]

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional
from decouple import config
from llama_index.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader
from kotaemon.base import Document
@@ -154,7 +154,7 @@ class AdobeReader(BaseReader):
for page_number, table_content, table_caption in tables:
documents.append(
Document(
text=table_caption,
text=table_content,
metadata={
"table_origin": table_content,
"type": "table",

View File

@@ -1,10 +1,56 @@
import base64
import os
from io import BytesIO
from pathlib import Path
from typing import Optional
from PIL import Image
from kotaemon.base import Document, Param
from .base import BaseReader
from .utils.adobe import generate_single_figure_caption
def crop_image(file_path: Path, bbox: list[float], page_number: int = 0) -> Image.Image:
"""Crop the image based on the bounding box
Args:
file_path (Path): path to the image file
bbox (list[float]): bounding box of the image (in percentage [x0, y0, x1, y1])
page_number (int, optional): page number of the image. Defaults to 0.
Returns:
Image.Image: cropped image
"""
left, upper, right, lower = bbox
img: Image.Image
suffix = file_path.suffix.lower()
if suffix == ".pdf":
try:
import fitz
except ImportError:
raise ImportError("Please install PyMuPDF: 'pip install PyMuPDF'")
doc = fitz.open(file_path)
page = doc.load_page(page_number)
pm = page.get_pixmap(dpi=150)
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
elif suffix in [".tif", ".tiff"]:
img = Image.open(file_path)
img.seek(page_number)
else:
img = Image.open(file_path)
return img.crop(
(
int(left * img.width),
int(upper * img.height),
int(right * img.width),
int(lower * img.height),
)
)
class AzureAIDocumentIntelligenceLoader(BaseReader):
@@ -14,7 +60,7 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
heif, docx, xlsx, pptx and html.
"""
_dependencies = ["azure-ai-documentintelligence"]
_dependencies = ["azure-ai-documentintelligence", "PyMuPDF", "Pillow"]
endpoint: str = Param(
os.environ.get("AZUREAI_DOCUMENT_INTELLIGENT_ENDPOINT", None),
@@ -34,6 +80,29 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
"#model-analysis-features)"
),
)
output_content_format: str = Param(
"markdown",
help="Output content format. Can be 'markdown' or 'text'.Default is markdown",
)
vlm_endpoint: str = Param(
help=(
"Default VLM endpoint for figure captioning. If not provided, will not "
"caption the figures"
)
)
figure_friendly_filetypes: list[str] = Param(
[".pdf", ".jpeg", ".jpg", ".png", ".bmp", ".tiff", ".heif", ".tif"],
help=(
"File types that we can reliably open and extract figures. "
"For files like .docx or .html, the visual layout may be different "
"when viewed from different tools, hence we cannot use Azure DI "
"location to extract figures."
),
)
cache_dir: str = Param(
None,
help="Directory to cache the downloaded files. Default is None",
)
@Param.auto(depends_on=["endpoint", "credential"])
def client_(self):
@@ -55,14 +124,114 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]:
"""Extract the input file, allowing multi-modal extraction"""
metadata = extra_info or {}
file_name = Path(file_path)
with open(file_path, "rb") as fi:
poller = self.client_.begin_analyze_document(
self.model,
analyze_request=fi,
content_type="application/octet-stream",
output_content_format="markdown",
output_content_format=self.output_content_format,
)
result = poller.result()
return [Document(content=result.content, metadata=metadata)]
# the total text content of the document in `output_content_format` format
text_content = result.content
removed_spans: list[dict] = []
# extract the figures
figures = []
for figure_desc in result.get("figures", []):
if not self.vlm_endpoint:
continue
if file_path.suffix.lower() not in self.figure_friendly_filetypes:
continue
# read & crop the image
page_number = figure_desc["boundingRegions"][0]["pageNumber"]
page_width = result.pages[page_number - 1]["width"]
page_height = result.pages[page_number - 1]["height"]
polygon = figure_desc["boundingRegions"][0]["polygon"]
xs = [polygon[i] for i in range(0, len(polygon), 2)]
ys = [polygon[i] for i in range(1, len(polygon), 2)]
bbox = [
min(xs) / page_width,
min(ys) / page_height,
max(xs) / page_width,
max(ys) / page_height,
]
img = crop_image(file_path, bbox, page_number - 1)
# convert the image into base64
img_bytes = BytesIO()
img.save(img_bytes, format="PNG")
img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
img_base64 = f"data:image/png;base64,{img_base64}"
# caption the image
caption = generate_single_figure_caption(
figure=img_base64, vlm_endpoint=self.vlm_endpoint
)
# store the image into document
figure_metadata = {
"image_origin": img_base64,
"type": "image",
"page_label": page_number,
}
figure_metadata.update(metadata)
figures.append(
Document(
text=caption,
metadata=figure_metadata,
)
)
removed_spans += figure_desc["spans"]
# extract the tables
tables = []
for table_desc in result.get("tables", []):
if not table_desc["spans"]:
continue
# convert the tables into markdown format
boundingRegions = table_desc["boundingRegions"]
if boundingRegions:
page_number = boundingRegions[0]["pageNumber"]
else:
page_number = 1
# store the tables into document
offset = table_desc["spans"][0]["offset"]
length = table_desc["spans"][0]["length"]
table_metadata = {
"type": "table",
"page_label": page_number,
"table_origin": text_content[offset : offset + length],
}
table_metadata.update(metadata)
tables.append(
Document(
text=text_content[offset : offset + length],
metadata=table_metadata,
)
)
removed_spans += table_desc["spans"]
# save the text content into markdown format
if self.cache_dir is not None:
with open(
Path(self.cache_dir) / f"{file_name.stem}.md", "w", encoding="utf-8"
) as f:
f.write(text_content)
removed_spans = sorted(removed_spans, key=lambda x: x["offset"], reverse=True)
for span in removed_spans:
text_content = (
text_content[: span["offset"]]
+ text_content[span["offset"] + span["length"] :]
)
return [Document(content=text_content, metadata=metadata)] + figures + tables

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, List, Type, Union
from kotaemon.base import BaseComponent, Document
if TYPE_CHECKING:
from llama_index.readers.base import BaseReader as LIBaseReader
from llama_index.core.readers.base import BaseReader as LIBaseReader
class BaseReader(BaseComponent):
@@ -20,7 +20,7 @@ class AutoReader(BaseReader):
"""Init reader using string identifier or class name from llama-hub"""
if isinstance(reader_type, str):
from llama_index import download_loader
from llama_index.core import download_loader
self._reader = download_loader(reader_type)()
else:

View File

@@ -1,6 +1,6 @@
from typing import Callable, List, Optional, Type
from llama_index.readers.base import BaseReader as LIBaseReader
from llama_index.core.readers.base import BaseReader as LIBaseReader
from .base import BaseReader, LIReaderMixin
@@ -48,6 +48,6 @@ class DirectoryReader(LIReaderMixin, BaseReader):
file_metadata: Optional[Callable[[str], dict]] = None
def _get_wrapped_class(self) -> Type["LIBaseReader"]:
from llama_index import SimpleDirectoryReader
from llama_index.core import SimpleDirectoryReader
return SimpleDirectoryReader

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import List, Optional
import pandas as pd
from llama_index.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader
from kotaemon.base import Document
@@ -27,6 +27,21 @@ class DocxReader(BaseReader):
"Please install it using `pip install python-docx`"
)
def _load_single_table(self, table) -> List[List[str]]:
"""Extract content from tables. Return a list of columns: list[str]
Some merged cells will share duplicated content.
"""
n_row = len(table.rows)
n_col = len(table.columns)
arrays = [["" for _ in range(n_row)] for _ in range(n_col)]
for i, row in enumerate(table.rows):
for j, cell in enumerate(row.cells):
arrays[j][i] = cell.text
return arrays
def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> List[Document]:
@@ -50,13 +65,9 @@ class DocxReader(BaseReader):
tables = []
for t in doc.tables:
arrays = [
[
unicodedata.normalize("NFKC", t.cell(i, j).text)
for i in range(len(t.rows))
]
for j in range(len(t.columns))
]
# return list of columns: list of string
arrays = self._load_single_table(t)
tables.append(pd.DataFrame({a[0]: a[1:] for a in arrays}))
extra_info = extra_info or {}

View File

@@ -6,7 +6,7 @@ Pandas parser for .xlsx files.
from pathlib import Path
from typing import Any, List, Optional, Union
from llama_index.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader
from kotaemon.base import Document
@@ -82,6 +82,9 @@ class PandasExcelReader(BaseReader):
sheet = []
if include_sheetname:
sheet.append([key])
dfs[key] = dfs[key].dropna(axis=0, how="all")
dfs[key] = dfs[key].dropna(axis=0, how="all")
dfs[key].fillna("", inplace=True)
sheet.extend(dfs[key].values.astype(str).tolist())
df_sheets.append(sheet)
@@ -99,3 +102,91 @@ class PandasExcelReader(BaseReader):
]
return output
class ExcelReader(BaseReader):
r"""Spreadsheet exporter respecting multiple worksheets
Parses CSVs using the separator detection from Pandas `read_csv` function.
If special parameters are required, use the `pandas_config` dict.
Args:
pandas_config (dict): Options for the `pandas.read_excel` function call.
Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html
for more information. Set to empty dict by default,
this means defaults will be used.
"""
def __init__(
self,
*args: Any,
pandas_config: Optional[dict] = None,
row_joiner: str = "\n",
col_joiner: str = " ",
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._pandas_config = pandas_config or {}
self._row_joiner = row_joiner if row_joiner else "\n"
self._col_joiner = col_joiner if col_joiner else " "
def load_data(
self,
file: Path,
include_sheetname: bool = True,
sheet_name: Optional[Union[str, int, list]] = None,
extra_info: Optional[dict] = None,
**kwargs,
) -> List[Document]:
"""Parse file and extract values from a specific column.
Args:
file (Path): The path to the Excel file to read.
include_sheetname (bool): Whether to include the sheet name in the output.
sheet_name (Union[str, int, None]): The specific sheet to read from,
default is None which reads all sheets.
Returns:
List[Document]: A list of`Document objects containing the
values from the specified column in the Excel file.
"""
try:
import pandas as pd
except ImportError:
raise ImportError(
"install pandas using `pip3 install pandas` to use this loader"
)
if sheet_name is not None:
sheet_name = (
[sheet_name] if not isinstance(sheet_name, list) else sheet_name
)
# clean up input
file = Path(file)
extra_info = extra_info or {}
dfs = pd.read_excel(file, sheet_name=sheet_name, **self._pandas_config)
sheet_names = dfs.keys()
output = []
for idx, key in enumerate(sheet_names):
dfs[key] = dfs[key].dropna(axis=0, how="all")
dfs[key] = dfs[key].dropna(axis=0, how="all")
dfs[key] = dfs[key].astype("object")
dfs[key].fillna("", inplace=True)
rows = dfs[key].values.astype(str).tolist()
content = self._row_joiner.join(
self._col_joiner.join(row).strip() for row in rows
).strip()
if include_sheetname:
content = f"(Sheet {key} of file {file.name})\n{content}"
metadata = {"page_label": idx + 1, "sheet_name": key, **extra_info}
output.append(Document(text=content, metadata=metadata))
return output

View File

@@ -2,7 +2,8 @@ import email
from pathlib import Path
from typing import Optional
from llama_index.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader
from theflow.settings import settings as flowsettings
from kotaemon.base import Document
@@ -78,6 +79,9 @@ class MhtmlReader(BaseReader):
def __init__(
self,
cache_dir: Optional[str] = getattr(
flowsettings, "KH_MARKDOWN_OUTPUT_DIR", None
),
open_encoding: Optional[str] = None,
bs_kwargs: Optional[dict] = None,
get_text_separator: str = "",
@@ -86,6 +90,7 @@ class MhtmlReader(BaseReader):
to pass to the BeautifulSoup object.
Args:
cache_dir: Path for markdwon format.
file_path: Path to file to load.
open_encoding: The encoding to use when opening the file.
bs_kwargs: Any kwargs to pass to the BeautifulSoup object.
@@ -100,6 +105,7 @@ class MhtmlReader(BaseReader):
"`pip install beautifulsoup4`"
)
self.cache_dir = cache_dir
self.open_encoding = open_encoding
if bs_kwargs is None:
bs_kwargs = {"features": "lxml"}
@@ -116,6 +122,7 @@ class MhtmlReader(BaseReader):
extra_info = extra_info or {}
metadata: dict = extra_info
page = []
file_name = Path(file_path)
with open(file_path, "r", encoding=self.open_encoding) as f:
message = email.message_from_string(f.read())
parts = message.get_payload()
@@ -144,5 +151,11 @@ class MhtmlReader(BaseReader):
text = "\n\n".join(lines)
if text:
page.append(text)
# save the page into markdown format
print(self.cache_dir)
if self.cache_dir is not None:
print(Path(self.cache_dir) / f"{file_name.stem}.md")
with open(Path(self.cache_dir) / f"{file_name.stem}.md", "w") as f:
f.write(page[0])
return [Document(text="\n\n".join(page), metadata=metadata)]

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
import requests
from langchain.utils import get_from_dict_or_env
from llama_index.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader
from kotaemon.base import Document

View File

@@ -5,8 +5,8 @@ from typing import List, Optional
from uuid import uuid4
import requests
from llama_index.readers.base import BaseReader
from tenacity import after_log, retry, stop_after_attempt, wait_fixed, wait_random
from llama_index.core.readers.base import BaseReader
from tenacity import after_log, retry, stop_after_attempt, wait_exponential
from kotaemon.base import Document
@@ -19,13 +19,16 @@ DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(5) + wait_random(0, 2),
after=after_log(logger, logging.DEBUG),
stop=stop_after_attempt(6),
wait=wait_exponential(multiplier=20, exp_base=2, min=1, max=1000),
after=after_log(logger, logging.WARNING),
)
def tenacious_api_post(url, **kwargs):
resp = requests.post(url=url, **kwargs)
resp.raise_for_status()
def tenacious_api_post(url, file_path, table_only, **kwargs):
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": table_only}
resp = requests.post(url=url, files=files, data=data, **kwargs)
resp.raise_for_status()
return resp
@@ -71,18 +74,16 @@ class OCRReader(BaseReader):
"""
file_path = Path(file_path).resolve()
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = tenacious_api_post(
url=self.ocr_endpoint, file_path=file_path, table_only=not self.use_ocr
)
ocr_results = resp.json()["result"]
debug_path = kwargs.pop("debug_path", None)
artifact_path = kwargs.pop("artifact_path", None)
@@ -168,18 +169,16 @@ class ImageReader(BaseReader):
"""
file_path = Path(file_path).resolve()
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": False}
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = tenacious_api_post(
url=self.ocr_endpoint, file_path=file_path, table_only=False
)
ocr_results = resp.json()["result"]
extra_info = extra_info or {}
result = []

View File

@@ -0,0 +1,114 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional
from fsspec import AbstractFileSystem
from llama_index.readers.file import PDFReader
from PIL import Image
from kotaemon.base import Document
def get_page_thumbnails(
file_path: Path, pages: list[int], dpi: int = 80
) -> List[Image.Image]:
"""Get image thumbnails of the pages in the PDF file.
Args:
file_path (Path): path to the image file
page_number (list[int]): list of page numbers to extract
Returns:
list[Image.Image]: list of page thumbnails
"""
img: Image.Image
suffix = file_path.suffix.lower()
assert suffix == ".pdf", "This function only supports PDF files."
try:
import fitz
except ImportError:
raise ImportError("Please install PyMuPDF: 'pip install PyMuPDF'")
doc = fitz.open(file_path)
output_imgs = []
for page_number in pages:
page = doc.load_page(page_number)
pm = page.get_pixmap(dpi=dpi)
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
output_imgs.append(convert_image_to_base64(img))
return output_imgs
def convert_image_to_base64(img: Image.Image) -> str:
# convert the image into base64
img_bytes = BytesIO()
img.save(img_bytes, format="PNG")
img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
img_base64 = f"data:image/png;base64,{img_base64}"
return img_base64
class PDFThumbnailReader(PDFReader):
"""PDF parser with thumbnail for each page."""
def __init__(self) -> None:
"""
Initialize PDFReader.
"""
super().__init__(return_full_document=False)
def load_data(
self,
file: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse file."""
documents = super().load_data(file, extra_info, fs)
page_numbers_str = []
filtered_docs = []
is_int_page_number: dict[str, bool] = {}
for doc in documents:
if "page_label" in doc.metadata:
page_num_str = doc.metadata["page_label"]
page_numbers_str.append(page_num_str)
try:
_ = int(page_num_str)
is_int_page_number[page_num_str] = True
filtered_docs.append(doc)
except ValueError:
is_int_page_number[page_num_str] = False
continue
documents = filtered_docs
page_numbers = list(range(len(page_numbers_str)))
print("Page numbers:", len(page_numbers))
page_thumbnails = get_page_thumbnails(file, page_numbers)
documents.extend(
[
Document(
text="Page thumbnail",
metadata={
"image_origin": page_thumbnail,
"type": "thumbnail",
"page_label": page_number,
**(extra_info if extra_info is not None else {}),
},
)
for (page_thumbnail, page_number) in zip(
page_thumbnails, page_numbers_str
)
if is_int_page_number[page_number]
]
)
return documents

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Optional
from kotaemon.base import Document
from .base import BaseReader
class TxtReader(BaseReader):
def run(
self, file_path: str | Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]:
return self.load_data(Path(file_path), extra_info=extra_info, **kwargs)
def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]:
with open(file_path, "r") as f:
text = f.read()
metadata = extra_info or {}
return [Document(text=text, metadata=metadata)]

View File

@@ -12,7 +12,7 @@ pip install xlrd
from pathlib import Path
from typing import Any, Dict, List, Optional
from llama_index.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader
from kotaemon.base import Document

View File

@@ -1,12 +1,19 @@
import json
import logging
from typing import Any, List
import requests
from decouple import config
logger = logging.getLogger(__name__)
def generate_gpt4v(
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
endpoint: str,
images: str | List[str],
prompt: str,
max_tokens: int = 512,
max_images: int = 10,
) -> str:
# OpenAI API Key
api_key = config("AZURE_OPENAI_API_KEY", default="")
@@ -27,24 +34,36 @@ def generate_gpt4v(
"type": "image_url",
"image_url": {"url": image},
}
for image in images
for image in images[:max_images]
],
}
],
"max_tokens": max_tokens,
"temperature": 0,
}
if len(images) > max_images:
print(f"Truncated to {max_images} images (original {len(images)} images")
response = requests.post(endpoint, headers=headers, json=payload)
try:
response = requests.post(endpoint, headers=headers, json=payload)
output = response.json()
output = output["choices"][0]["message"]["content"]
except Exception:
output = ""
response.raise_for_status()
except Exception as e:
logger.exception(f"Error generating gpt4v: {response.text}; error {e}")
return ""
output = response.json()
output = output["choices"][0]["message"]["content"]
return output
def stream_gpt4v(
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
endpoint: str,
images: str | List[str],
prompt: str,
max_tokens: int = 512,
max_images: int = 10,
) -> Any:
# OpenAI API Key
api_key = config("AZURE_OPENAI_API_KEY", default="")
@@ -65,17 +84,22 @@ def stream_gpt4v(
"type": "image_url",
"image_url": {"url": image},
}
for image in images
for image in images[:max_images]
],
}
],
"max_tokens": max_tokens,
"stream": True,
"logprobs": True,
"temperature": 0,
}
if len(images) > max_images:
print(f"Truncated to {max_images} images (original {len(images)} images")
try:
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
assert response.status_code == 200, str(response.content)
output = ""
logprobs = []
for line in response.iter_lines():
if line:
if line.startswith(b"\xef\xbb\xbf"):
@@ -89,8 +113,23 @@ def stream_gpt4v(
except Exception:
break
if len(line["choices"]):
if line["choices"][0].get("logprobs") is None:
_logprobs = []
else:
_logprobs = [
logprob["logprob"]
for logprob in line["choices"][0]["logprobs"].get(
"content", []
)
]
output += line["choices"][0]["delta"].get("content", "")
yield line["choices"][0]["delta"].get("content", "")
except Exception:
logprobs += _logprobs
yield line["choices"][0]["delta"].get("content", ""), _logprobs
except Exception as e:
logger.error(f"Error streaming gpt4v {e}")
logprobs = []
output = ""
return output
return output, logprobs

View File

@@ -2,12 +2,14 @@ from .docstores import (
BaseDocumentStore,
ElasticsearchDocumentStore,
InMemoryDocumentStore,
LanceDBDocumentStore,
SimpleFileDocumentStore,
)
from .vectorstores import (
BaseVectorStore,
ChromaVectorStore,
InMemoryVectorStore,
LanceDBVectorStore,
SimpleFileVectorStore,
)
@@ -17,9 +19,11 @@ __all__ = [
"InMemoryDocumentStore",
"ElasticsearchDocumentStore",
"SimpleFileDocumentStore",
"LanceDBDocumentStore",
# Vector stores
"BaseVectorStore",
"ChromaVectorStore",
"InMemoryVectorStore",
"SimpleFileVectorStore",
"LanceDBVectorStore",
]

View File

@@ -1,6 +1,7 @@
from .base import BaseDocumentStore
from .elasticsearch import ElasticsearchDocumentStore
from .in_memory import InMemoryDocumentStore
from .lancedb import LanceDBDocumentStore
from .simple_file import SimpleFileDocumentStore
__all__ = [
@@ -8,4 +9,5 @@ __all__ = [
"InMemoryDocumentStore",
"ElasticsearchDocumentStore",
"SimpleFileDocumentStore",
"LanceDBDocumentStore",
]

View File

@@ -41,6 +41,13 @@ class BaseDocumentStore(ABC):
"""Count number of documents"""
...
@abstractmethod
def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
"""Search document store using search query"""
...
@abstractmethod
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""

View File

@@ -92,7 +92,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
"_id": doc_id,
}
requests.append(request)
self.es_bulk(self.client, requests)
success, failed = self.es_bulk(self.client, requests)
print("Added/Updated documents to index", success)
print("Failed documents to index", failed)
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
@@ -131,16 +134,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
Returns:
List[Document]: List of result documents
"""
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
if doc_ids:
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
query_dict: dict = {"match": {"content": query}}
if doc_ids is not None:
query_dict = {"bool": {"must": [query_dict, {"terms": {"_id": doc_ids}}]}}
query_dict = {"query": query_dict, "size": top_k}
return self.query_raw(query_dict)
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
if not isinstance(ids, list):
ids = [ids]
query_dict = {"query": {"terms": {"_id": ids}}}
query_dict = {"query": {"terms": {"_id": ids}}, "size": 10000}
return self.query_raw(query_dict)
def count(self) -> int:

View File

@@ -81,6 +81,12 @@ class InMemoryDocumentStore(BaseDocumentStore):
# Also, for portability, use SQLAlchemy for document store.
self._store = {key: Document.from_dict(value) for key, value in store.items()}
def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
"""Perform full-text search on document store"""
return []
def __persist_flow__(self):
return {}

View File

@@ -0,0 +1,153 @@
import json
from typing import List, Optional, Union
from kotaemon.base import Document
from .base import BaseDocumentStore
MAX_DOCS_TO_GET = 10**4
class LanceDBDocumentStore(BaseDocumentStore):
"""LancdDB document store which support full-text search query"""
def __init__(self, path: str = "lancedb", collection_name: str = "docstore"):
try:
import lancedb
except ImportError:
raise ImportError(
"Please install lancedb: 'pip install lancedb tanvity-py'"
)
self.db_uri = path
self.collection_name = collection_name
self.db_connection = lancedb.connect(self.db_uri) # type: ignore
def add(
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
refresh_indices: bool = True,
**kwargs,
):
"""Load documents into lancedb storage."""
doc_ids = ids if ids else [doc.doc_id for doc in docs]
data: list[dict[str, str]] | None = [
{
"id": doc_id,
"text": doc.text,
"attributes": json.dumps(doc.metadata),
}
for doc_id, doc in zip(doc_ids, docs)
]
if self.collection_name not in self.db_connection.table_names():
if data:
document_collection = self.db_connection.create_table(
self.collection_name, data=data, mode="overwrite"
)
else:
# add data to existing table
document_collection = self.db_connection.open_table(self.collection_name)
if data:
document_collection.add(data)
if refresh_indices:
document_collection.create_fts_index(
"text",
tokenizer_name="en_stem",
replace=True,
)
def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
if doc_ids:
id_filter = ", ".join([f"'{_id}'" for _id in doc_ids])
query_filter = f"id in ({id_filter})"
else:
query_filter = None
try:
document_collection = self.db_connection.open_table(self.collection_name)
if query_filter:
docs = (
document_collection.search(query, query_type="fts")
.where(query_filter, prefilter=True)
.limit(top_k)
.to_list()
)
else:
docs = (
document_collection.search(query, query_type="fts")
.limit(top_k)
.to_list()
)
except (ValueError, FileNotFoundError):
docs = []
return [
Document(
id_=doc["id"],
text=doc["text"] if doc["text"] else "<empty>",
metadata=json.loads(doc["attributes"]),
)
for doc in docs
]
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
if not isinstance(ids, list):
ids = [ids]
id_filter = ", ".join([f"'{_id}'" for _id in ids])
try:
document_collection = self.db_connection.open_table(self.collection_name)
query_filter = f"id in ({id_filter})"
docs = (
document_collection.search()
.where(query_filter)
.limit(MAX_DOCS_TO_GET)
.to_list()
)
except (ValueError, FileNotFoundError):
docs = []
return [
Document(
id_=doc["id"],
text=doc["text"] if doc["text"] else "<empty>",
metadata=json.loads(doc["attributes"]),
)
for doc in docs
]
def delete(self, ids: Union[List[str], str], refresh_indices: bool = True):
"""Delete document by id"""
if not isinstance(ids, list):
ids = [ids]
document_collection = self.db_connection.open_table(self.collection_name)
id_filter = ", ".join([f"'{_id}'" for _id in ids])
query_filter = f"id in ({id_filter})"
document_collection.delete(query_filter)
if refresh_indices:
document_collection.create_fts_index(
"text",
tokenizer_name="en_stem",
replace=True,
)
def drop(self):
"""Drop the document store"""
self.db_connection.drop_table(self.collection_name)
def count(self) -> int:
raise NotImplementedError
def get_all(self) -> List[Document]:
raise NotImplementedError
def __persist_flow__(self):
return {
"db_uri": self.db_uri,
"collection_name": self.collection_name,
}

View File

@@ -1,6 +1,7 @@
from .base import BaseVectorStore
from .chroma import ChromaVectorStore
from .in_memory import InMemoryVectorStore
from .lancedb import LanceDBVectorStore
from .simple_file import SimpleFileVectorStore
__all__ = [
@@ -8,4 +9,5 @@ __all__ = [
"ChromaVectorStore",
"InMemoryVectorStore",
"SimpleFileVectorStore",
"LanceDBVectorStore",
]

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional
from llama_index.schema import NodeRelationship, RelatedNodeInfo
from llama_index.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.types import VectorStore as LIVectorStore
from llama_index.vector_stores.types import VectorStoreQuery
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.core.vector_stores.types import VectorStore as LIVectorStore
from llama_index.core.vector_stores.types import VectorStoreQuery
from kotaemon.base import DocumentWithEmbedding

View File

@@ -2,8 +2,8 @@
from typing import Any, Optional, Type
import fsspec
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.vector_stores.simple import SimpleVectorStoreData
from llama_index.core.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.core.vector_stores.simple import SimpleVectorStoreData
from .base import LlamaIndexVectorStore

View File

@@ -0,0 +1,87 @@
from typing import Any, List, Type, cast
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.vector_stores.lancedb import LanceDBVectorStore as LILanceDBVectorStore
from llama_index.vector_stores.lancedb import base as base_lancedb
from .base import LlamaIndexVectorStore
# custom monkey patch for LanceDB
original_to_lance_filter = base_lancedb._to_lance_filter
def custom_to_lance_filter(
standard_filters: MetadataFilters, metadata_keys: list
) -> Any:
for filter in standard_filters.filters:
if isinstance(filter.value, list):
# quote string values if filter are list of strings
if filter.value and isinstance(filter.value[0], str):
filter.value = [f"'{v}'" for v in filter.value]
return original_to_lance_filter(standard_filters, metadata_keys)
# skip table existence check
LILanceDBVectorStore._table_exists = lambda _: False
base_lancedb._to_lance_filter = custom_to_lance_filter
class LanceDBVectorStore(LlamaIndexVectorStore):
_li_class: Type[LILanceDBVectorStore] = LILanceDBVectorStore
def __init__(
self,
path: str = "./lancedb",
collection_name: str = "default",
**kwargs: Any,
):
self._path = path
self._collection_name = collection_name
try:
import lancedb
except ImportError:
raise ImportError(
"Please install lancedb: 'pip install lancedb tanvity-py'"
)
db_connection = lancedb.connect(path) # type: ignore
try:
table = db_connection.open_table(collection_name)
except FileNotFoundError:
table = None
self._kwargs = kwargs
# pass through for nice IDE support
super().__init__(
uri=path,
table_name=collection_name,
table=table,
**kwargs,
)
self._client = cast(LILanceDBVectorStore, self._client)
self._client._metadata_keys = ["file_id"]
def delete(self, ids: List[str], **kwargs):
"""Delete vector embeddings from vector stores
Args:
ids: List of ids of the embeddings to be deleted
kwargs: meant for vectorstore-specific parameters
"""
self._client.delete_nodes(ids)
def drop(self):
"""Delete entire collection from vector stores"""
self._client.client.drop_table(self.collection_name)
def count(self) -> int:
raise NotImplementedError
def __persist_flow__(self):
return {
"path": self._path,
"collection_name": self._collection_name,
}

View File

@@ -3,8 +3,8 @@ from pathlib import Path
from typing import Any, Optional, Type
import fsspec
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.vector_stores.simple import SimpleVectorStoreData
from llama_index.core.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.core.vector_stores.simple import SimpleVectorStoreData
from kotaemon.base import DocumentWithEmbedding

View File

@@ -26,9 +26,11 @@ dependencies = [
"langchain-openai>=0.1.4,<0.2.0",
"openai>=1.23.6,<2",
"theflow>=0.8.6,<0.9.0",
"llama-index==0.9.48",
"llama-index>=0.10.40,<0.11.0",
"llama-index-vector-stores-chroma>=0.1.9",
"llama-index-vector-stores-lancedb",
"llama-hub>=0.0.79,<0.1.0",
"gradio>=4.26.0,<5",
"gradio>=4.31.0,<4.40",
"openpyxl>=3.1.2,<3.2",
"cookiecutter>=2.6.0,<2.7",
"click>=8.1.7,<9",
@@ -36,13 +38,9 @@ dependencies = [
"trogon>=0.5.0,<0.6",
"tenacity>=8.2.3,<8.3",
"python-dotenv>=1.0.1,<1.1",
"chromadb>=0.4.21,<0.5",
"unstructured==0.13.4",
"pypdf>=4.2.0,<4.3",
"PyMuPDF>=1.23",
"html2text==2024.2.26",
"fastembed==0.2.6",
"llama-cpp-python>=0.2.72,<0.3",
"azure-ai-documentintelligence",
"cohere>=5.3.2,<5.4",
]
readme = "README.md"
@@ -63,11 +61,12 @@ adv = [
"duckduckgo-search>=6.1.0,<6.2",
"googlesearch-python>=1.2.4,<1.3",
"python-docx>=1.1.0,<1.2",
"unstructured[pdf]==0.13.4",
"sentence_transformers==2.7.0",
"elasticsearch>=8.13.0,<8.14",
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
"beautifulsoup4>=4.12.3,<4.13",
"plotly",
"tabulate",
"fast_langdetect",
"azure-ai-documentintelligence",
]
dev = [
"ipython",

View File

@@ -2,7 +2,7 @@ from pathlib import Path
from unittest.mock import patch
from langchain.schema import Document as LangchainDocument
from llama_index.node_parser import SimpleNodeParser
from llama_index.core.node_parser import SimpleNodeParser
from kotaemon.base import Document
from kotaemon.loaders import (

View File

@@ -1,4 +1,4 @@
from llama_index.schema import NodeRelationship
from llama_index.core.schema import NodeRelationship
from kotaemon.base import Document
from kotaemon.indices.splitters import TokenSplitter

View File

@@ -1,2 +1,3 @@
14-1_抜粋-1.pdf
_example_.db
ktem/assets/prebuilt/

View File

@@ -4,6 +4,7 @@ from typing import Optional
import gradio as gr
import pluggy
from ktem import extension_protocol
from ktem.assets import PDFJS_PREBUILT_DIR
from ktem.components import reasonings
from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared
from ktem.index import IndexManager
@@ -36,6 +37,7 @@ class BaseApp:
def __init__(self):
self.dev_mode = getattr(settings, "KH_MODE", "") == "dev"
self.app_name = getattr(settings, "KH_APP_NAME", "Kotaemon")
self.app_version = getattr(settings, "KH_APP_VERSION", "")
self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False)
self._theme = gr.Theme.from_hub("lone17/kotaemon")
@@ -44,6 +46,13 @@ class BaseApp:
self._css = fi.read()
with (dir_assets / "js" / "main.js").open() as fi:
self._js = fi.read()
self._js = self._js.replace("KH_APP_VERSION", self.app_version)
with (dir_assets / "js" / "pdf_viewer.js").open() as fi:
self._pdf_view_js = fi.read()
self._pdf_view_js = self._pdf_view_js.replace(
"PDFJS_PREBUILT_DIR", str(PDFJS_PREBUILT_DIR)
)
self._favicon = str(dir_assets / "img" / "favicon.svg")
self.default_settings = SettingGroup(
@@ -156,11 +165,17 @@ class BaseApp:
"""Called when the app is created"""
def make(self):
external_js = """
<script type="module" src="https://cdn.skypack.dev/pdfjs-viewer-element"></script>
"""
with gr.Blocks(
theme=self._theme,
css=self._css,
title=self.app_name,
analytics_enabled=False,
js=self._js,
head=external_js,
) as demo:
self.app = demo
self.settings_state.render()
@@ -173,6 +188,8 @@ class BaseApp:
self.register_events()
self.on_app_created()
demo.load(None, None, None, js=self._pdf_view_js)
return demo
def declare_public_events(self):
@@ -200,7 +217,6 @@ class BaseApp:
def on_app_created(self):
"""Execute on app created callbacks"""
self.app.load(lambda: None, None, None, js=f"() => {{{self._js}}}")
self._on_app_created()
for value in self.__dict__.values():
if isinstance(value, BasePage):

View File

@@ -0,0 +1,6 @@
from pathlib import Path
from decouple import config
PDFJS_VERSION_DIST: str = config("PDFJS_VERSION_DIST", "pdfjs-4.0.379-dist")
PDFJS_PREBUILT_DIR: Path = Path(__file__).parent / "prebuilt" / PDFJS_VERSION_DIST

View File

@@ -147,6 +147,16 @@ mark {
max-height: 42px;
}
/* Hide sort buttons at gr.DataFrame */
.sort-button {
display: none !important;
}
/* Show sort button only in File list*/
#file_list_view .sort-button {
display: block !important;
}
.scrollable {
overflow-y: auto;
}
@@ -158,3 +168,58 @@ mark {
.unset-overflow {
overflow: unset !important;
}
/*body {*/
/* margin: 0;*/
/* font-family: Arial, sans-serif;*/
/*}*/
pdfjs-viewer-element {
height: 100vh;
height: 100dvh;
}
/* Modal styles */
.modal {
display: none;
position: relative;
z-index: 1;
left: 0;
top: 0;
width: 100%;
height: 100%;
overflow: auto;
background-color: rgb(0, 0, 0);
background-color: rgba(0, 0, 0, 0.4);
}
.modal-header {
padding: 0px 10px
}
.modal-content {
background-color: #fefefe;
height: 110%;
display: flex;
flex-direction: column;
}
.close {
color: #aaa;
align-self: flex-end;
font-size: 28px;
font-weight: bold;
}
.close:hover,
.close:focus {
color: black;
text-decoration: none;
cursor: pointer;
}
.modal-body {
flex: 1;
overflow: auto;
}

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="h-5 w-5 shrink-0"><path fill="#f93a37" fill-rule="evenodd" d="M10.556 4a1 1 0 0 0-.97.751l-.292 1.14h5.421l-.293-1.14A1 1 0 0 0 13.453 4zm6.224 1.892-.421-1.639A3 3 0 0 0 13.453 2h-2.897A3 3 0 0 0 7.65 4.253l-.421 1.639H4a1 1 0 1 0 0 2h.1l1.215 11.425A3 3 0 0 0 8.3 22h7.4a3 3 0 0 0 2.984-2.683l1.214-11.425H20a1 1 0 1 0 0-2zm1.108 2H6.112l1.192 11.214A1 1 0 0 0 8.3 20h7.4a1 1 0 0 0 .995-.894zM10 10a1 1 0 0 1 1 1v5a1 1 0 1 1-2 0v-5a1 1 0 0 1 1-1m4 0a1 1 0 0 1 1 1v5a1 1 0 1 1-2 0v-5a1 1 0 0 1 1-1" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 610 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="#10b981" class="icon-xl-heavy"><path d="M15.673 3.913a3.121 3.121 0 1 1 4.414 4.414l-5.937 5.937a5 5 0 0 1-2.828 1.415l-2.18.31a1 1 0 0 1-1.132-1.13l.311-2.18A5 5 0 0 1 9.736 9.85zm3 1.414a1.12 1.12 0 0 0-1.586 0l-5.937 5.937a3 3 0 0 0-.849 1.697l-.123.86.86-.122a3 3 0 0 0 1.698-.849l5.937-5.937a1.12 1.12 0 0 0 0-1.586M11 4a1 1 0 0 1-1 1c-.998 0-1.702.008-2.253.06-.54.052-.862.141-1.109.267a3 3 0 0 0-1.311 1.311c-.134.263-.226.611-.276 1.216C5.001 8.471 5 9.264 5 10.4v3.2c0 1.137 0 1.929.051 2.546.05.605.142.953.276 1.216a3 3 0 0 0 1.311 1.311c.263.134.611.226 1.216.276.617.05 1.41.051 2.546.051h3.2c1.137 0 1.929 0 2.546-.051.605-.05.953-.142 1.216-.276a3 3 0 0 0 1.311-1.311c.126-.247.215-.569.266-1.108.053-.552.06-1.256.06-2.255a1 1 0 1 1 2 .002c0 .978-.006 1.78-.069 2.442-.064.673-.192 1.27-.475 1.827a5 5 0 0 1-2.185 2.185c-.592.302-1.232.428-1.961.487C15.6 21 14.727 21 13.643 21h-3.286c-1.084 0-1.958 0-2.666-.058-.728-.06-1.369-.185-1.96-.487a5 5 0 0 1-2.186-2.185c-.302-.592-.428-1.233-.487-1.961C3 15.6 3 14.727 3 13.643v-3.286c0-1.084 0-1.958.058-2.666.06-.729.185-1.369.487-1.961A5 5 0 0 1 5.73 3.545c.556-.284 1.154-.411 1.827-.475C8.22 3.007 9.021 3 10 3a1 1 0 0 1 1 1"/></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="h-5 w-5 shrink-0"><path fill="#cecece" fill-rule="evenodd" d="M13.293 4.293a4.536 4.536 0 1 1 6.414 6.414l-1 1-7.094 7.094A5 5 0 0 1 8.9 20.197l-4.736.79a1 1 0 0 1-1.15-1.151l.789-4.736a5 5 0 0 1 1.396-2.713zM13 7.414l-6.386 6.387a3 3 0 0 0-.838 1.628l-.56 3.355 3.355-.56a3 3 0 0 0 1.628-.837L16.586 11zm5 2.172L14.414 6l.293-.293a2.536 2.536 0 0 1 3.586 3.586z" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 474 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="icon-xl-heavy"><path fill="#cecece" fill-rule="evenodd" d="M8.857 3h6.286c1.084 0 1.958 0 2.666.058.729.06 1.369.185 1.961.487a5 5 0 0 1 2.185 2.185c.302.592.428 1.233.487 1.961.058.708.058 1.582.058 2.666v3.286c0 1.084 0 1.958-.058 2.666-.06.729-.185 1.369-.487 1.961a5 5 0 0 1-2.185 2.185c-.592.302-1.232.428-1.961.487C17.1 21 16.227 21 15.143 21H8.857c-1.084 0-1.958 0-2.666-.058-.728-.06-1.369-.185-1.96-.487a5 5 0 0 1-2.186-2.185c-.302-.592-.428-1.232-.487-1.961C1.5 15.6 1.5 14.727 1.5 13.643v-3.286c0-1.084 0-1.958.058-2.666.06-.728.185-1.369.487-1.96A5 5 0 0 1 4.23 3.544c.592-.302 1.233-.428 1.961-.487C6.9 3 7.773 3 8.857 3M6.354 5.051c-.605.05-.953.142-1.216.276a3 3 0 0 0-1.311 1.311c-.134.263-.226.611-.276 1.216-.05.617-.051 1.41-.051 2.546v3.2c0 1.137 0 1.929.051 2.546.05.605.142.953.276 1.216a3 3 0 0 0 1.311 1.311c.263.134.611.226 1.216.276.617.05 1.41.051 2.546.051h.6V5h-.6c-1.137 0-1.929 0-2.546.051M11.5 5v14h3.6c1.137 0 1.929 0 2.546-.051.605-.05.953-.142 1.216-.276a3 3 0 0 0 1.311-1.311c.134-.263.226-.611.276-1.216.05-.617.051-1.41.051-2.546v-3.2c0-1.137 0-1.929-.051-2.546-.05-.605-.142-.953-.276-1.216a3 3 0 0 0-1.311-1.311c-.263-.134-.611-.226-1.216-.276C17.029 5.001 16.236 5 15.1 5zM5 8.5a1 1 0 0 1 1-1h1a1 1 0 1 1 0 2H6a1 1 0 0 1-1-1M5 12a1 1 0 0 1 1-1h1a1 1 0 1 1 0 2H6a1 1 0 0 1-1-1" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -1,30 +1,37 @@
let main_parent = document.getElementById("chat-tab").parentNode;
function run() {
let main_parent = document.getElementById("chat-tab").parentNode;
main_parent.childNodes[0].classList.add("header-bar");
main_parent.style = "padding: 0; margin: 0";
main_parent.parentNode.style = "gap: 0";
main_parent.parentNode.parentNode.style = "padding: 0";
main_parent.childNodes[0].classList.add("header-bar");
main_parent.style = "padding: 0; margin: 0";
main_parent.parentNode.style = "gap: 0";
main_parent.parentNode.parentNode.style = "padding: 0";
const version_node = document.createElement("p");
version_node.innerHTML = "version: KH_APP_VERSION";
version_node.style = "position: fixed; top: 10px; right: 10px;";
main_parent.appendChild(version_node);
// clpse
globalThis.clpseFn = (id) => {
var obj = document.getElementById('clpse-btn-' + id);
obj.classList.toggle("clpse-active");
var content = obj.nextElementSibling;
if (content.style.display === "none") {
content.style.display = "block";
} else {
content.style.display = "none";
// clpse
globalThis.clpseFn = (id) => {
var obj = document.getElementById('clpse-btn-' + id);
obj.classList.toggle("clpse-active");
var content = obj.nextElementSibling;
if (content.style.display === "none") {
content.style.display = "block";
} else {
content.style.display = "none";
}
}
// store info in local storage
globalThis.setStorage = (key, value) => {
localStorage.setItem(key, value)
}
globalThis.getStorage = (key, value) => {
item = localStorage.getItem(key);
return item ? item : value;
}
globalThis.removeFromStorage = (key) => {
localStorage.removeItem(key)
}
}
// store info in local storage
globalThis.setStorage = (key, value) => {
localStorage.setItem(key, JSON.stringify(value))
}
globalThis.getStorage = (key, value) => {
return JSON.parse(localStorage.getItem(key))
}
globalThis.removeFromStorage = (key) => {
localStorage.removeItem(key)
}

View File

@@ -0,0 +1,99 @@
function onBlockLoad () {
var infor_panel_scroll_pos = 0;
globalThis.createModal = () => {
// Create modal for the 1st time if it does not exist
var modal = document.getElementById("pdf-modal");
var old_position = null;
var old_width = null;
var old_left = null;
var expanded = false;
modal.id = "pdf-modal";
modal.className = "modal";
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<span class="close" id="modal-close">&times;</span>
<span class="close" id="modal-expand">&#x26F6;</span>
</div>
<div class="modal-body">
<pdfjs-viewer-element id="pdf-viewer" viewer-path="/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true">
</pdfjs-viewer-element>
</div>
</div>
`;
modal.querySelector("#modal-close").onclick = function() {
modal.style.display = "none";
var info_panel = document.getElementById("html-info-panel");
if (info_panel) {
info_panel.style.display = "block";
}
var scrollableDiv = document.getElementById("chat-info-panel");
scrollableDiv.scrollTop = infor_panel_scroll_pos;
};
modal.querySelector("#modal-expand").onclick = function () {
expanded = !expanded;
if (expanded) {
old_position = modal.style.position;
old_left = modal.style.left;
old_width = modal.style.width;
modal.style.position = "fixed";
modal.style.width = "70%";
modal.style.left = "15%";
} else {
modal.style.position = old_position;
modal.style.width = old_width;
modal.style.left = old_left;
}
};
}
// Function to open modal and display PDF
globalThis.openModal = (event) => {
event.preventDefault();
var target = event.currentTarget;
var src = target.getAttribute("data-src");
var page = target.getAttribute("data-page");
var search = target.getAttribute("data-search");
var phrase = target.getAttribute("data-phrase");
var pdfViewer = document.getElementById("pdf-viewer");
current_src = pdfViewer.getAttribute("src");
if (current_src != src) {
pdfViewer.setAttribute("src", src);
}
pdfViewer.setAttribute("phrase", phrase);
pdfViewer.setAttribute("search", search);
pdfViewer.setAttribute("page", page);
var scrollableDiv = document.getElementById("chat-info-panel");
infor_panel_scroll_pos = scrollableDiv.scrollTop;
var modal = document.getElementById("pdf-modal")
modal.style.display = "block";
var info_panel = document.getElementById("html-info-panel");
if (info_panel) {
info_panel.style.display = "none";
}
scrollableDiv.scrollTop = 0;
}
globalThis.assignPdfOnclickEvent = () => {
// Get all links and attach click event
var links = document.getElementsByClassName("pdf-link");
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
}
var created_modal = document.getElementById("pdf-viewer");
if (!created_modal) {
createModal();
console.log("Created modal")
}
}

View File

@@ -8,3 +8,6 @@ An open-source tool for you to chat with your documents.
[User Guide](https://cinnamon.github.io/kotaemon/) |
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
[Feedback](https://github.com/Cinnamon/kotaemon/issues)
[Dark Mode](?__theme=dark)
[Night Mode](?__theme=light)

View File

@@ -136,6 +136,6 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
files will be considered during chat.
2. Chat Panel
- This is where you can chat with the chatbot.
3. Information panel
3. Information Panel
- Supporting information such as the retrieved evidence and reference will be
displayed here.

View File

@@ -1,9 +1,11 @@
import datetime
import uuid
from typing import Optional
from zoneinfo import ZoneInfo
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
from theflow.settings import settings as flowsettings
class BaseConversation(SQLModel):
@@ -24,10 +26,14 @@ class BaseConversation(SQLModel):
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
)
name: str = Field(
default_factory=lambda: datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
default_factory=lambda: datetime.datetime.now(
ZoneInfo(getattr(flowsettings, "TIME_ZONE", "UTC"))
).strftime("%Y-%m-%d %H:%M:%S")
)
user: int = Field(default=0) # For now we only have one user
is_public: bool = Field(default=False)
# contains messages + current files
data_source: dict = Field(default={}, sa_column=Column(JSON))

View File

@@ -36,7 +36,7 @@ class EmbeddingManager:
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
self._models, self._info, self._default = {}, {}, ""
with Session(engine) as sess:
stmt = select(EmbeddingTable)
items = sess.execute(stmt)

View File

@@ -115,7 +115,7 @@ class EmbeddingManagement(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_embeddings,
inputs=None,
inputs=[],
outputs=[self.emb_list],
)
self._app.app.load(
@@ -144,7 +144,7 @@ class EmbeddingManagement(BasePage):
self.create_emb,
inputs=[self.name, self.emb_choices, self.spec, self.default],
outputs=None,
).success(self.list_embeddings, inputs=None, outputs=[self.emb_list]).success(
).success(self.list_embeddings, inputs=[], outputs=[self.emb_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
@@ -179,7 +179,7 @@ class EmbeddingManagement(BasePage):
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -190,7 +190,7 @@ class EmbeddingManagement(BasePage):
show_progress="hidden",
).then(
self.list_embeddings,
inputs=None,
inputs=[],
outputs=[self.emb_list],
)
self.btn_delete_no.click(
@@ -199,7 +199,7 @@ class EmbeddingManagement(BasePage):
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -213,7 +213,7 @@ class EmbeddingManagement(BasePage):
show_progress="hidden",
).then(
self.list_embeddings,
inputs=None,
inputs=[],
outputs=[self.emb_list],
)
self.btn_close.click(

View File

@@ -54,6 +54,7 @@ class BaseFileIndexIndexing(BaseComponent):
DS = Param(help="The DocStore")
FSPath = Param(help="The file storage path")
user_id = Param(help="The user id")
private = Param(False, help="Whether this is private index")
def run(
self, file_paths: str | Path | list[str | Path], *args, **kwargs
@@ -73,7 +74,9 @@ class BaseFileIndexIndexing(BaseComponent):
def stream(
self, file_paths: str | Path | list[str | Path], *args, **kwargs
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
"""Stream the indexing pipeline
Args:
@@ -87,6 +90,7 @@ class BaseFileIndexIndexing(BaseComponent):
None if the indexing failed for that file path)
- the error messages (each error message corresponds to an input file path,
or None if the indexing was successful for that file path)
- the indexed documents in form of list[Documents]
"""
raise NotImplementedError
@@ -149,3 +153,7 @@ class BaseFileIndexIndexing(BaseComponent):
msg: the message to log
"""
print(msg)
def rebuild_index(self):
"""Rebuild the index"""
raise NotImplementedError

View File

@@ -0,0 +1,3 @@
from .graph_index import GraphRAGIndex
__all__ = ["GraphRAGIndex"]

View File

@@ -0,0 +1,36 @@
from typing import Any
from ktem.index.file import FileIndex
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .pipelines import GraphRAGIndexingPipeline, GraphRAGRetrieverPipeline
class GraphRAGIndex(FileIndex):
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = GraphRAGIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [GraphRAGRetrieverPipeline]
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline"""
obj = super().get_indexing_pipeline(settings, user_id)
# disable vectorstore for this kind of Index
obj.VS = None
return obj
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
retrievers = [
GraphRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
)
]
return retrievers

View File

@@ -0,0 +1,359 @@
import os
import subprocess
from pathlib import Path
from shutil import rmtree
from typing import Generator
from uuid import uuid4
import pandas as pd
import tiktoken
from ktem.db.models import engine
from sqlalchemy.orm import Session
from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
from .visualize import create_knowledge_graph, visualize_graph
try:
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.vector_stores.lancedb import LanceDBVectorStore
except ImportError:
print(
(
"GraphRAG dependencies not installed. "
"GraphRAG retriever pipeline will not work properly."
)
)
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "graphrag"
filestorage_path.mkdir(parents=True, exist_ok=True)
def prepare_graph_index_path(graph_id: str):
root_path = Path(filestorage_path) / graph_id
input_path = root_path / "input"
return root_path, input_path
class GraphRAGIndexingPipeline(IndexDocumentPipeline):
"""GraphRAG specific indexing pipeline"""
def route(self, file_path: Path) -> IndexPipeline:
"""Simply disable the splitter (chunking) for this pipeline"""
pipeline = super().route(file_path)
pipeline.splitter = None
return pipeline
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
# create new graph_id and assign them to doc_id in self.Index
# record in the index
graph_id = str(uuid4())
with Session(engine) as session:
nodes = []
for file_id in file_ids:
if not file_id:
continue
nodes.append(
self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
)
session.add_all(nodes)
session.commit()
return graph_id
def write_docs_to_files(self, graph_id: str, docs: list[Document]):
root_path, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
for doc in docs:
if doc.metadata.get("type", "text") == "text":
with open(input_path / f"{doc.doc_id}.txt", "w") as f:
f.write(doc.text)
return root_path
def call_graphrag_index(self, input_path: str):
# Construct the command
command = [
"python",
"-m",
"graphrag.index",
"--root",
input_path,
"--reporter",
"rich",
"--init",
]
# Run the command
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
)
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
command = command[:-1]
# Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout:
for line in process.stdout:
yield Document(channel="debug", text=line)
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
file_ids, errors, all_docs = yield from super().stream(
file_paths, reindex=reindex, **kwargs
)
# assign graph_id to file_ids
graph_id = self.store_file_id_with_graph_id(file_ids)
# call GraphRAG index with docs and graph_id
graph_index_path = self.write_docs_to_files(graph_id, all_docs)
yield from self.call_graphrag_index(graph_index_path)
return file_ids, errors, all_docs
class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
"""GraphRAG specific retriever pipeline"""
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}
def _build_graph_search(self):
assert (
len(self.file_ids) <= 1
), "GraphRAG retriever only supports one file_id at a time"
file_id = self.file_ids[0]
# retrieve the graph_id from the index
with Session(engine) as session:
graph_id = (
session.query(self.Index.target_id)
.filter(self.Index.source_id == file_id)
.filter(self.Index.relation_type == "graph")
.first()
)
graph_id = graph_id[0] if graph_id else None
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
root_path, _ = prepare_graph_index_path(graph_id)
output_path = root_path / "output"
child_paths = sorted(
list(output_path.iterdir()), key=lambda x: x.stem, reverse=True
)
# get the latest child path
assert child_paths, "GraphRAG index output not found"
latest_child_path = Path(child_paths[0]) / "artifacts"
INPUT_DIR = latest_child_path
LANCEDB_URI = str(INPUT_DIR / "lancedb")
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(
f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet"
)
entities = read_indexer_entities(
entity_df, entity_embedding_df, COMMUNITY_LEVEL
)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="entity_description_embeddings",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
if Path(LANCEDB_URI).is_dir():
rmtree(LANCEDB_URI)
_ = store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
print(f"Entity count: {len(entity_df)}")
# Read relationships
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
# Read community reports
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
# Read text units
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
embedding_model = os.getenv("GRAPHRAG_EMBEDDING_MODEL")
text_embedder = OpenAIEmbedding(
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=20,
)
token_encoder = tiktoken.get_encoding("cl100k_base")
context_builder = LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
entities=entities,
relationships=relationships,
covariates=None,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID,
# if the vectorstore uses entity title as ids,
# set this to EntityVectorStoreKey.TITLE
text_embedder=text_embedder,
token_encoder=token_encoder,
)
return context_builder
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
return RetrievedDocument(
text=context_text,
metadata={
"file_name": header,
"type": "table",
"llm_trulens_score": 1.0,
},
score=1.0,
)
def format_context_records(self, context_records) -> list[RetrievedDocument]:
entities = context_records.get("entities", [])
relationships = context_records.get("relationships", [])
reports = context_records.get("reports", [])
sources = context_records.get("sources", [])
docs = []
context: str = ""
header = "<b>Entities</b>\n"
context = entities[["entity", "description"]].to_markdown(index=False)
docs.append(self._to_document(header, context))
header = "\n<b>Relationships</b>\n"
context = relationships[["source", "target", "description"]].to_markdown(
index=False
)
docs.append(self._to_document(header, context))
header = "\n<b>Reports</b>\n"
context = ""
for idx, row in reports.iterrows():
title, content = row["title"], row["content"]
context += f"\n\n<h5>Report <b>{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
header = "\n<b>Sources</b>\n"
context = ""
for idx, row in sources.iterrows():
title, content = row["id"], row["text"]
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
return docs
def plot_graph(self, context_records):
relationships = context_records.get("relationships", [])
G = create_knowledge_graph(relationships)
plot = visualize_graph(G)
return plot
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
return documents
def run(
self,
text: str,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []
context_builder = self._build_graph_search()
local_context_params = {
"text_unit_prop": 0.5,
"community_prop": 0.1,
"conversation_history_max_turns": 5,
"conversation_history_user_turns_only": True,
"top_k_mapped_entities": 10,
"top_k_relationships": 10,
"include_entity_rank": False,
"include_relationship_weight": False,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
# set this to EntityVectorStoreKey.TITLE i
# f the vectorstore uses entity title as ids
"max_tokens": 12_000,
# change this based on the token limit you have on your model
# (if you are using a model with 8k limit, a good setting could be 5000)
}
context_text, context_records = context_builder.build_context(
query=text,
conversation_history=None,
**local_context_params,
)
documents = self.format_context_records(context_records)
plot = self.plot_graph(context_records)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]

View File

@@ -0,0 +1,102 @@
import networkx as nx
import plotly.graph_objects as go
from plotly.io import to_json
def create_knowledge_graph(df):
"""
create nx Graph from DataFrame relations data
"""
G = nx.Graph()
for _, row in df.iterrows():
source = row["source"]
target = row["target"]
attributes = {k: v for k, v in row.items() if k not in ["source", "target"]}
G.add_edge(source, target, **attributes)
return G
def visualize_graph(G):
pos = nx.spring_layout(G, dim=2)
edge_x = []
edge_y = []
edge_texts = nx.get_edge_attributes(G, "description")
to_display_edge_texts = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(None)
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(None)
to_display_edge_texts.append(edge_texts[edge])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
text=to_display_edge_texts,
line=dict(width=0.5, color="#888"),
hoverinfo="text",
mode="lines",
)
node_x = []
node_y = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_adjacencies = []
node_text = []
node_size = []
for node_id, adjacencies in enumerate(G.adjacency()):
degree = len(adjacencies[1])
node_adjacencies.append(degree)
node_text.append(adjacencies[0])
node_size.append(15 if degree < 5 else (30 if degree < 10 else 60))
node_trace = go.Scatter(
x=node_x,
y=node_y,
textfont=dict(
family="Courier New, monospace",
size=10, # Set the font size here
),
textposition="top center",
mode="markers+text",
hoverinfo="text",
text=node_text,
marker=dict(
showscale=True,
# colorscale options
size=node_size,
colorscale="YlGnBu",
reversescale=True,
color=node_adjacencies,
colorbar=dict(
thickness=5,
xanchor="left",
titleside="right",
),
line_width=2,
),
)
fig = go.Figure(
data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
),
)
fig.update_layout(autosize=True)
return to_json(fig)

View File

@@ -4,8 +4,9 @@ from typing import Any, Optional, Type
from ktem.components import filestorage_path, get_docstore, get_vectorstore
from ktem.db.engine import engine
from ktem.index.base import BaseIndex
from sqlalchemy import Column, DateTime, Integer, String
from sqlalchemy import JSON, Column, DateTime, Integer, String, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.sql import func
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
@@ -52,27 +53,60 @@ class FileIndex(BaseIndex):
- File storage path
"""
Base = declarative_base()
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String, unique=True),
"path": Column(String),
"size": Column(Integer, default=0),
"text_length": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
},
)
if self.config.get("private", False):
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"__table_args__": (
UniqueConstraint("name", "user", name="_name_user_uc"),
),
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String),
"path": Column(String),
"size": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
"note": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={},
),
},
)
else:
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String, unique=True),
"path": Column(String),
"size": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
"note": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={},
),
},
)
Index = type(
"IndexTable",
(Base,),
@@ -85,6 +119,7 @@ class FileIndex(BaseIndex):
"user": Column(Integer, default=1),
},
)
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
self._fs_path = filestorage_path / f"index_{self.id}"
@@ -358,8 +393,6 @@ class FileIndex(BaseIndex):
for key, value in settings.items():
if key.startswith(prefix):
stripped_settings[key[len(prefix) :]] = value
else:
stripped_settings[key] = value
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.Source = self._resources["Source"]
@@ -368,6 +401,7 @@ class FileIndex(BaseIndex):
obj.DS = self._docstore
obj.FSPath = self._fs_path
obj.user_id = user_id
obj.private = self.config.get("private", False)
return obj
@@ -380,8 +414,6 @@ class FileIndex(BaseIndex):
for key, value in settings.items():
if key.startswith(prefix):
stripped_settings[key[len(prefix) :]] = value
else:
stripped_settings[key] = value
# transform selected id
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)

View File

@@ -0,0 +1,3 @@
from .knet_index import KnowledgeNetworkFileIndex
__all__ = ["KnowledgeNetworkFileIndex"]

View File

@@ -0,0 +1,47 @@
from typing import Any
from ktem.index.file import FileIndex
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .pipelines import KnetIndexingPipeline, KnetRetrievalPipeline
class KnowledgeNetworkFileIndex(FileIndex):
@classmethod
def get_admin_settings(cls):
admin_settings = super().get_admin_settings()
# remove embedding from admin settings
# as we don't need it
admin_settings.pop("embedding")
return admin_settings
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = KnetIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [KnetRetrievalPipeline]
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline"""
obj = super().get_indexing_pipeline(settings, user_id)
# disable vectorstore for this kind of Index
# also set the collection_name for API call
obj.VS = None
obj.collection_name = f"kh_index_{self.id}"
return obj
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
retrievers = super().get_retriever_pipelines(settings, user_id, selected)
for obj in retrievers:
# disable vectorstore for this kind of Index
# also set the collection_name for API call
obj.VS = None
obj.collection_name = f"kh_index_{self.id}"
return retrievers

View File

@@ -0,0 +1,169 @@
import base64
import json
import os
from pathlib import Path
from typing import Optional, Sequence
import requests
import yaml
from kotaemon.base import RetrievedDocument
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
class KnetIndexingPipeline(IndexDocumentPipeline):
"""Knowledge Network specific indexing pipeline"""
# collection name for external indexing call
collection_name: str = "default"
@classmethod
def get_user_settings(cls):
return {
"reader_mode": {
"name": "Index parser",
"value": "knowledge_network",
"choices": [
("Default (KN)", "knowledge_network"),
],
"component": "dropdown",
},
}
def route(self, file_path: Path) -> IndexPipeline:
"""Simply disable the splitter (chunking) for this pipeline"""
pipeline = super().route(file_path)
pipeline.splitter = None
# assign IndexPipeline collection name to parse to loader
pipeline.collection_name = self.collection_name
return pipeline
class KnetRetrievalPipeline(BaseFileIndexRetriever):
DEFAULT_KNET_ENDPOINT: str = "http://127.0.0.1:8081/retrieve"
collection_name: str = "default"
rerankers: Sequence[BaseReranking] = [LLMReranking.withx()]
def encode_image_base64(self, image_path: str | Path) -> bytes | str:
"""Convert image to base64"""
img_base64 = "data:image/png;base64,{}"
with open(image_path, "rb") as image_file:
return img_base64.format(
base64.b64encode(image_file.read()).decode("utf-8")
)
def run(
self,
text: str,
doc_ids: Optional[list[str]] = None,
*args,
**kwargs,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
doc_ids: list of document ids to constraint the retrieval
"""
print("searching in doc_ids", doc_ids)
if not doc_ids:
return []
docs: list[RetrievedDocument] = []
params = {
"query": text,
"collection": self.collection_name,
"meta_filters": {"doc_name": doc_ids},
}
params["meta_filters"] = json.dumps(params["meta_filters"])
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
metadata_translation = {
"TABLE": "table",
"FIGURE": "image",
}
if response.status_code == 200:
# Load YAML content from the response content
chunks = yaml.safe_load(response.content)
for chunk in chunks:
metadata = chunk["node"]["metadata"]
metadata["type"] = metadata_translation.get(
metadata.pop("content_type", ""), ""
)
metadata["file_name"] = metadata.pop("company_name", "")
# load image from returned path
image_path = metadata.get("image_path", "")
if image_path and os.path.isfile(image_path):
base64_im = self.encode_image_base64(image_path)
# explicitly set document type
metadata["type"] = "image"
metadata["image_origin"] = base64_im
docs.append(
RetrievedDocument(text=chunk["node"]["text"], metadata=metadata)
)
else:
raise IOError(f"{response.status_code}: {response.text}")
for reranker in self.rerankers:
docs = reranker(documents=docs, query=text)
return docs
@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms
try:
reranking_llm = llms.get_default_name()
reranking_llm_choices = list(llms.options().keys())
except Exception:
reranking_llm = None
reranking_llm_choices = []
return {
"reranking_llm": {
"name": "LLM for scoring",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
"special_type": "llm",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "hybrid",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
}
@classmethod
def get_pipeline(cls, user_settings, index_settings, selected):
"""Get retriever objects associated with the index
Args:
settings: the settings of the app
kwargs: other arguments
"""
from ktem.llms.manager import llms
retriever = cls(
rerankers=[LLMTrulensScoring()],
)
# hacky way to input doc_ids to retriever.run() call (through theflow)
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=False)
for reranker in retriever.rerankers:
if isinstance(reranker, LLMReranking):
reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
return retriever

View File

@@ -2,25 +2,29 @@ from __future__ import annotations
import logging
import shutil
import threading
import time
import warnings
from collections import defaultdict
from copy import deepcopy
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Generator, Optional
from typing import Generator, Optional, Sequence
import tiktoken
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms
from llama_index.readers.base import BaseReader
from llama_index.readers.file.base import default_file_metadata_func
from llama_index.vector_stores import (
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.file.base import default_file_metadata_func
from llama_index.core.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from theflow.settings import settings
@@ -29,8 +33,18 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from kotaemon.indices.rankings import BaseReranking, LLMReranking
from kotaemon.indices.ingests.files import (
KH_DEFAULT_FILE_EXTRACTORS,
adobe_reader,
azure_reader,
unstructured,
)
from kotaemon.indices.rankings import (
BaseReranking,
CohereReranking,
LLMReranking,
LLMTrulensScoring,
)
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
@@ -60,6 +74,9 @@ def dev_settings():
return file_extractors, chunk_size, chunk_overlap
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""Retrieve relevant document
@@ -75,10 +92,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""
embedding: BaseEmbeddings
reranker: BaseReranking = LLMReranking.withx()
rerankers: Sequence[BaseReranking] = []
# use LLM to create relevant scores for displaying on UI
llm_scorer: LLMReranking | None = LLMReranking.withx()
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
retrieval_mode: str = "hybrid"
@Node.auto(depends_on=["embedding", "VS", "DS"])
def vector_retrieval(self) -> VectorRetrieval:
@@ -86,6 +106,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
embedding=self.embedding,
vector_store=self.VS,
doc_store=self.DS,
retrieval_mode=self.retrieval_mode, # type: ignore
rerankers=self.rerankers,
)
def run(
@@ -101,27 +123,30 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
text: the text to retrieve similar documents
doc_ids: list of document ids to constraint the retrieval
"""
print("searching in doc_ids", doc_ids)
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")
return []
retrieval_kwargs = {}
retrieval_kwargs: dict = {}
with Session(engine) as session:
stmt = select(self.Index).where(
self.Index.relation_type == "vector",
self.Index.relation_type == "document",
self.Index.source_id.in_(doc_ids),
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]
chunk_ids = [r[0].target_id for r in results.all()]
# do first round top_k extension
retrieval_kwargs["do_extend"] = True
retrieval_kwargs["scope"] = chunk_ids
retrieval_kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
key="file_id",
value=doc_ids,
operator=FilterOperator.IN,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)
@@ -132,9 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
retrieval_kwargs["mmr_threshold"] = 0.5
# rerank
s_time = time.time()
print(f"retrieval_kwargs: {retrieval_kwargs.keys()}")
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
if docs and self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
print("retrieval step took", time.time() - s_time)
if not self.get_extra_table:
return docs
@@ -157,17 +183,30 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
for fn, pls in table_pages.items()
]
if queries:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where=queries[0] if len(queries) == 1 else {"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
try:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where=queries[0] if len(queries) == 1 else {"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
except Exception:
print("Error retrieving additional tables")
return docs
def generate_relevant_scores(
self, query: str, documents: list[RetrievedDocument]
) -> list[RetrievedDocument]:
docs = (
documents
if not self.llm_scorer
else self.llm_scorer(documents=documents, query=query)
)
return docs
@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms
@@ -182,43 +221,44 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
return {
"reranking_llm": {
"name": "LLM for reranking",
"name": "LLM for relevant scoring",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
},
"separate_embedding": {
"name": "Use separate embedding",
"value": False,
"choices": [("Yes", True), ("No", False)],
"component": "dropdown",
"special_type": "llm",
},
"num_retrieval": {
"name": "Number of document chunks to retrieve",
"value": 3,
"value": 10,
"component": "number",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "vector",
"value": "hybrid",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
"prioritize_table": {
"name": "Prioritize table",
"value": True,
"value": False,
"choices": [True, False],
"component": "checkbox",
},
"mmr": {
"name": "Use MMR",
"value": True,
"value": False,
"choices": [True, False],
"component": "checkbox",
},
"use_reranking": {
"name": "Use reranking",
"value": False,
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"use_llm_reranking": {
"name": "Use LLM relevant scoring",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
@@ -232,6 +272,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
settings: the settings of the app
kwargs: other arguments
"""
use_llm_reranking = user_settings.get("use_llm_reranking", False)
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
top_k=user_settings["num_retrieval"],
@@ -241,16 +283,26 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"embedding", embedding_models_manager.get_default_name()
)
],
retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
rerankers=[CohereReranking()],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
else:
retriever.reranker.llm = llms.get(
retriever.rerankers = [] # type: ignore
for reranker in retriever.rerankers:
if isinstance(reranker, LLMReranking):
reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
if retriever.llm_scorer:
retriever.llm_scorer.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=True)
retriever.set_run(kwargs, temp=False)
return retriever
@@ -258,8 +310,8 @@ class IndexPipeline(BaseComponent):
"""Index a single file"""
loader: BaseReader
splitter: BaseSplitter
chunk_batch_size: int = 50
splitter: BaseSplitter | None
chunk_batch_size: int = 200
Source = Param(help="The SQLAlchemy Source table")
Index = Param(help="The SQLAlchemy Index table")
@@ -267,6 +319,9 @@ class IndexPipeline(BaseComponent):
DS = Param(help="The DocStore")
FSPath = Param(help="The file storage path")
user_id = Param(help="The user id")
collection_name: str = "default"
private: bool = False
run_embedding_in_thread: bool = False
embedding: BaseEmbeddings
@Node.auto(depends_on=["Source", "Index", "embedding"])
@@ -276,31 +331,81 @@ class IndexPipeline(BaseComponent):
)
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
s_time = time.time()
text_docs = []
non_text_docs = []
thumbnail_docs = []
for doc in docs:
doc_type = doc.metadata.get("type", "text")
if doc_type == "text":
text_docs.append(doc)
elif doc_type == "thumbnail":
thumbnail_docs.append(doc)
else:
non_text_docs.append(doc)
print(f"Got {len(thumbnail_docs)} page thumbnails")
page_label_to_thumbnail = {
doc.metadata["page_label"]: doc.doc_id for doc in thumbnail_docs
}
if self.splitter:
all_chunks = self.splitter(text_docs)
else:
all_chunks = text_docs
# add the thumbnails doc_id to the chunks
for chunk in all_chunks:
page_label = chunk.metadata.get("page_label", None)
if page_label and page_label in page_label_to_thumbnail:
chunk.metadata["thumbnail_doc_id"] = page_label_to_thumbnail[page_label]
to_index_chunks = all_chunks + non_text_docs + thumbnail_docs
# add to doc store
chunks = []
n_chunks = 0
for cidx, chunk in enumerate(self.splitter(docs)):
chunks.append(chunk)
if cidx % self.chunk_batch_size == 0:
self.handle_chunks(chunks, file_id)
n_chunks += len(chunks)
chunks = []
yield Document(
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
)
if chunks:
self.handle_chunks(chunks, file_id)
chunk_size = self.chunk_batch_size * 4
for start_idx in range(0, len(to_index_chunks), chunk_size):
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
self.handle_chunks_docstore(chunks, file_id)
n_chunks += len(chunks)
yield Document(
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
f" => [{file_name}] Processed {n_chunks} chunks",
channel="debug",
)
def insert_chunks_to_vectorstore():
chunks = []
n_chunks = 0
chunk_size = self.chunk_batch_size
for start_idx in range(0, len(to_index_chunks), chunk_size):
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
self.handle_chunks_vectorstore(chunks, file_id)
n_chunks += len(chunks)
if self.VS:
yield Document(
f" => [{file_name}] Created embedding for {n_chunks} chunks",
channel="debug",
)
# run vector indexing in thread if specified
if self.run_embedding_in_thread:
print("Running embedding in thread")
threading.Thread(
target=lambda: list(insert_chunks_to_vectorstore())
).start()
else:
yield from insert_chunks_to_vectorstore()
print("indexing step took", time.time() - s_time)
return n_chunks
def handle_chunks(self, chunks, file_id):
def handle_chunks_docstore(self, chunks, file_id):
"""Run chunks"""
# run embedding, add to both vector store and doc store
self.vector_indexing(chunks)
self.vector_indexing.add_to_docstore(chunks)
# record in the index
with Session(engine) as session:
@@ -313,16 +418,30 @@ class IndexPipeline(BaseComponent):
relation_type="document",
)
)
nodes.append(
self.Index(
source_id=file_id,
target_id=chunk.doc_id,
relation_type="vector",
)
)
session.add_all(nodes)
session.commit()
def handle_chunks_vectorstore(self, chunks, file_id):
"""Run chunks"""
# run embedding, add to both vector store and doc store
self.vector_indexing.add_to_vectorstore(chunks)
self.vector_indexing.write_chunk_to_file(chunks)
if self.VS:
# record in the index
with Session(engine) as session:
nodes = []
for chunk in chunks:
nodes.append(
self.Index(
source_id=file_id,
target_id=chunk.doc_id,
relation_type="vector",
)
)
session.add_all(nodes)
session.commit()
def get_id_if_exists(self, file_path: Path) -> Optional[str]:
"""Check if the file is already indexed
@@ -332,8 +451,16 @@ class IndexPipeline(BaseComponent):
Returns:
the file id if the file is indexed, otherwise None
"""
if self.private:
cond: tuple = (
self.Source.name == file_path.name,
self.Source.user == self.user_id,
)
else:
cond = (self.Source.name == file_path.name,)
with Session(engine) as session:
stmt = select(self.Source).where(self.Source.name == file_path.name)
stmt = select(self.Source).where(*cond)
item = session.execute(stmt).first()
if item:
return item[0].id
@@ -369,20 +496,36 @@ class IndexPipeline(BaseComponent):
def finish(self, file_id: str, file_path: Path) -> str:
"""Finish the indexing"""
with Session(engine) as session:
stmt = select(self.Index.target_id).where(self.Index.source_id == file_id)
doc_ids = [_[0] for _ in session.execute(stmt)]
if doc_ids:
stmt = select(self.Source).where(self.Source.id == file_id)
result = session.execute(stmt).first()
if not result:
return file_id
item = result[0]
# populate the number of tokens
doc_ids_stmt = select(self.Index.target_id).where(
self.Index.source_id == file_id,
self.Index.relation_type == "document",
)
doc_ids = [_[0] for _ in session.execute(doc_ids_stmt)]
token_func = self.get_token_func()
if doc_ids and token_func:
docs = self.DS.get(doc_ids)
stmt = select(self.Source).where(self.Source.id == file_id)
result = session.execute(stmt).first()
if result:
item = result[0]
item.text_length = sum([len(doc.text) for doc in docs])
session.add(item)
session.commit()
item.note["tokens"] = sum([len(token_func(doc.text)) for doc in docs])
# populate the note
item.note["loader"] = self.get_from_path("loader").__class__.__name__
session.add(item)
session.commit()
return file_id
def get_token_func(self):
"""Get the token function for calculating the number of tokens"""
return _default_token_func
def delete_file(self, file_id: str):
"""Delete a file from the db, including its chunks in docstore and vectorstore
@@ -398,44 +541,24 @@ class IndexPipeline(BaseComponent):
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
elif each[0].relation_type == "document":
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self.VS.delete(vs_ids)
self.DS.delete(ds_ids)
def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
"""Index the file and return the file id"""
# check for duplication
file_path = Path(file_path).resolve()
file_id = self.get_id_if_exists(file_path)
if file_id is not None:
if not reindex:
raise ValueError(
f"File {file_path.name} already indexed. Please rerun with "
"reindex=True to force reindexing."
)
else:
# remove the existing records
self.delete_file(file_id)
file_id = self.store_file(file_path)
else:
# add record to db
file_id = self.store_file(file_path)
if vs_ids and self.VS:
self.VS.delete(vs_ids)
if ds_ids:
self.DS.delete(ds_ids)
# extract the file
extra_info = default_file_metadata_func(str(file_path))
docs = self.loader.load_data(file_path, extra_info=extra_info)
for _ in self.handle_docs(docs, file_id, file_path.name):
continue
self.finish(file_id, file_path)
return file_id
def run(
self, file_path: str | Path, reindex: bool, **kwargs
) -> tuple[str, list[Document]]:
raise NotImplementedError
def stream(
self, file_path: str | Path, reindex: bool, **kwargs
) -> Generator[Document, None, str]:
) -> Generator[Document, None, tuple[str, list[Document]]]:
# check for duplication
file_path = Path(file_path).resolve()
file_id = self.get_id_if_exists(file_path)
@@ -456,6 +579,9 @@ class IndexPipeline(BaseComponent):
# extract the file
extra_info = default_file_metadata_func(str(file_path))
extra_info["file_id"] = file_id
extra_info["collection_name"] = self.collection_name
yield Document(f" => Converting {file_path.name} to text", channel="debug")
docs = self.loader.load_data(file_path, extra_info=extra_info)
yield Document(f" => Converted {file_path.name} to text", channel="debug")
@@ -464,7 +590,7 @@ class IndexPipeline(BaseComponent):
self.finish(file_id, file_path)
yield Document(f" => Finished indexing {file_path.name}", channel="debug")
return file_id
return file_id, docs
class IndexDocumentPipeline(BaseFileIndexIndexing):
@@ -479,16 +605,54 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
decide which pipeline should be used.
"""
reader_mode: str = Param("default", help="The reader mode")
embedding: BaseEmbeddings
run_embedding_in_thread: bool = False
@Param.auto(depends_on="reader_mode")
def readers(self):
readers = deepcopy(KH_DEFAULT_FILE_EXTRACTORS)
print("reader_mode", self.reader_mode)
if self.reader_mode == "adobe":
readers[".pdf"] = adobe_reader
elif self.reader_mode == "azure-di":
readers[".pdf"] = azure_reader
dev_readers, _, _ = dev_settings()
readers.update(dev_readers)
return readers
@classmethod
def get_user_settings(cls):
return {
"reader_mode": {
"name": "File loader",
"value": "default",
"choices": [
("Default (open-source)", "default"),
("Adobe API (figure+table extraction)", "adobe"),
(
"Azure AI Document Intelligence (figure+table extraction)",
"azure-di",
),
],
"component": "dropdown",
},
}
@classmethod
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
use_quick_index_mode = user_settings.get("quick_index_mode", False)
print("use_quick_index_mode", use_quick_index_mode)
obj = cls(
embedding=embedding_models_manager[
index_settings.get(
"embedding", embedding_models_manager.get_default_name()
)
]
],
run_embedding_in_thread=use_quick_index_mode,
reader_mode=user_settings.get("reader_mode", "default"),
)
return obj
@@ -497,16 +661,17 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
Can subclass this method for a more elaborate pipeline routing strategy.
"""
readers, chunk_size, chunk_overlap = dev_settings()
_, chunk_size, chunk_overlap = dev_settings()
ext = file_path.suffix
reader = readers.get(ext, KH_DEFAULT_FILE_EXTRACTORS.get(ext, None))
ext = file_path.suffix.lower()
reader = self.readers.get(ext, unstructured)
if reader is None:
raise NotImplementedError(
f"No supported pipeline to index {file_path.name}. Please specify "
"the suitable pipeline for this file type in the settings."
)
print("Using reader", reader)
pipeline: IndexPipeline = IndexPipeline(
loader=reader,
splitter=TokenSplitter(
@@ -515,50 +680,37 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
separator="\n\n",
backup_separators=["\n", ".", "\u200B"],
),
run_embedding_in_thread=self.run_embedding_in_thread,
Source=self.Source,
Index=self.Index,
VS=self.VS,
DS=self.DS,
FSPath=self.FSPath,
user_id=self.user_id,
private=self.private,
embedding=self.embedding,
)
return pipeline
def run(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
self, file_paths: str | Path | list[str | Path], *args, **kwargs
) -> tuple[list[str | None], list[str | None]]:
"""Return a list of indexed file ids, and a list of errors"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
file_ids: list[str | None] = []
errors: list[str | None] = []
for file_path in file_paths:
file_path = Path(file_path)
try:
pipeline = self.route(file_path)
file_id = pipeline.run(file_path, reindex=reindex, **kwargs)
file_ids.append(file_id)
errors.append(None)
except Exception as e:
logger.error(e)
file_ids.append(None)
errors.append(str(e))
return file_ids, errors
raise NotImplementedError
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
"""Return a list of indexed file ids, and a list of errors"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
file_ids: list[str | None] = []
errors: list[str | None] = []
all_docs = []
n_files = len(file_paths)
for idx, file_path in enumerate(file_paths):
file_path = Path(file_path)
@@ -569,9 +721,10 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
try:
pipeline = self.route(file_path)
file_id = yield from pipeline.stream(
file_id, docs = yield from pipeline.stream(
file_path, reindex=reindex, **kwargs
)
all_docs.extend(docs)
file_ids.append(file_id)
errors.append(None)
yield Document(
@@ -579,7 +732,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
channel="index",
)
except Exception as e:
logger.error(e)
logger.exception(e)
file_ids.append(None)
errors.append(str(e))
yield Document(
@@ -591,4 +744,4 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
channel="index",
)
return file_ids, errors
return file_ids, errors, all_docs

View File

@@ -1,5 +1,9 @@
import html
import os
import shutil
import tempfile
import zipfile
from copy import deepcopy
from pathlib import Path
from typing import Generator
@@ -9,8 +13,12 @@ from gradio.data_classes import FileData
from gradio.utils import NamedString
from ktem.app import BasePage
from ktem.db.engine import engine
from ktem.utils.render import Render
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
DOWNLOAD_MESSAGE = "Press again to download"
class File(gr.File):
@@ -143,28 +151,57 @@ class FileIndexPage(BasePage):
)
gr.Markdown("## File List")
self.filter = gr.Textbox(
value="",
label="Filter by name:",
info=(
"(1) Case-insensitive. "
"(2) Search with empty string to show all files."
),
)
self.file_list_state = gr.State(value=None)
self.file_list = gr.DataFrame(
headers=["id", "name", "size", "text_length", "date_created"],
headers=[
"id",
"name",
"size",
"tokens",
"loader",
"date_created",
],
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
interactive=False,
wrap=False,
elem_id="file_list_view",
)
with gr.Row():
self.deselect_button = gr.Button(
"Close",
visible=False,
)
self.delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
)
with gr.Row():
self.is_zipped_state = gr.State(value=False)
self.download_all_button = gr.DownloadButton(
"Download all files",
visible=True,
)
self.download_single_button = gr.DownloadButton(
"Download file",
visible=False,
)
with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
with gr.Column(scale=2):
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button(
"Deselect",
visible=False,
elem_classes=["right-button"],
)
self.delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
elem_classes=["right-button"],
)
self.chunks = gr.HTML(visible=False)
def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app"""
@@ -189,12 +226,58 @@ class FileIndexPage(BasePage):
)
def file_selected(self, file_id):
chunks = []
if file_id is not None:
# get the chunks
Index = self._index._resources["Index"]
with Session(engine) as session:
matches = session.execute(
select(Index).where(
Index.source_id == file_id,
Index.relation_type == "document",
)
)
doc_ids = [doc.target_id for (doc,) in matches]
docs = self._index._docstore.get(doc_ids)
docs = sorted(
docs, key=lambda x: x.metadata.get("page_label", float("inf"))
)
for idx, doc in enumerate(docs):
title = html.escape(
f"{doc.text[:50]}..." if len(doc.text) > 50 else doc.text
)
doc_type = doc.metadata.get("type", "text")
content = ""
if doc_type == "text":
content = html.escape(doc.text)
elif doc_type == "table":
content = Render.table(doc.text)
elif doc_type == "image":
content = Render.image(
url=doc.metadata.get("image_origin", ""), text=doc.text
)
header_prefix = f"[{idx+1}/{len(docs)}]"
if doc.metadata.get("page_label"):
header_prefix += f" [Page {doc.metadata['page_label']}]"
chunks.append(
Render.collapsible(
header=f"{header_prefix} {title}",
content=content,
)
)
return (
gr.update(value="".join(chunks), visible=file_id is not None),
gr.update(visible=file_id is not None),
gr.update(visible=file_id is not None),
gr.update(visible=file_id is not None),
)
def delete_event(self, file_id):
file_name = ""
with Session(engine) as session:
source = session.execute(
select(self._index._resources["Source"]).where(
@@ -202,6 +285,7 @@ class FileIndexPage(BasePage):
)
).first()
if source:
file_name = source[0].name
session.delete(source[0])
vs_ids, ds_ids = [], []
@@ -213,15 +297,16 @@ class FileIndexPage(BasePage):
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
elif each[0].relation_type == "document":
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self._index._vs.delete(vs_ids)
if vs_ids:
self._index._vs.delete(vs_ids)
self._index._docstore.delete(ds_ids)
gr.Info(f"File {file_id} has been deleted")
gr.Info(f"File {file_name} has been deleted")
return None, self.selected_panel_false
@@ -231,6 +316,57 @@ class FileIndexPage(BasePage):
gr.update(visible=False),
)
def download_single_file(self, is_zipped_state, file_id):
with Session(engine) as session:
source = session.execute(
select(self._index._resources["Source"]).where(
self._index._resources["Source"].id == file_id
)
).first()
if source:
target_file_name = Path(source[0].name)
zip_files = []
for file_name in os.listdir(flowsettings.KH_CHUNKS_OUTPUT_DIR):
if target_file_name.stem in file_name:
zip_files.append(
os.path.join(flowsettings.KH_CHUNKS_OUTPUT_DIR, file_name)
)
for file_name in os.listdir(flowsettings.KH_MARKDOWN_OUTPUT_DIR):
if target_file_name.stem in file_name:
zip_files.append(
os.path.join(flowsettings.KH_MARKDOWN_OUTPUT_DIR, file_name)
)
zip_file_path = os.path.join(
flowsettings.KH_ZIP_OUTPUT_DIR, target_file_name.stem
)
with zipfile.ZipFile(f"{zip_file_path}.zip", "w") as zipMe:
for file in zip_files:
zipMe.write(file, arcname=os.path.basename(file))
if is_zipped_state:
new_button = gr.DownloadButton(label="Download", value=None)
else:
new_button = gr.DownloadButton(
label=DOWNLOAD_MESSAGE, value=f"{zip_file_path}.zip"
)
return not is_zipped_state, new_button
def download_all_files(self):
zip_files = []
for file_name in os.listdir(flowsettings.KH_CHUNKS_OUTPUT_DIR):
zip_files.append(os.path.join(flowsettings.KH_CHUNKS_OUTPUT_DIR, file_name))
for file_name in os.listdir(flowsettings.KH_MARKDOWN_OUTPUT_DIR):
zip_files.append(
os.path.join(flowsettings.KH_MARKDOWN_OUTPUT_DIR, file_name)
)
zip_file_path = os.path.join(flowsettings.KH_ZIP_OUTPUT_DIR, "all")
with zipfile.ZipFile(f"{zip_file_path}.zip", "w") as zipMe:
for file in zip_files:
arcname = Path(file)
zipMe.write(file, arcname=arcname.name)
return gr.DownloadButton(label=DOWNLOAD_MESSAGE, value=f"{zip_file_path}.zip")
def on_register_events(self):
"""Register all events to the app"""
onDeleted = (
@@ -241,35 +377,61 @@ class FileIndexPage(BasePage):
)
.then(
fn=lambda: (None, self.selected_panel_false),
inputs=None,
inputs=[],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
.then(
fn=self.list_file,
inputs=[self._app.user_id],
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
)
.then(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.chunks,
self.deselect_button,
self.delete_button,
self.download_single_button,
],
show_progress="hidden",
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onDeleted = onDeleted.then(**event)
self.deselect_button.click(
fn=lambda: (None, self.selected_panel_false),
inputs=None,
inputs=[],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
self.selected_panel.change(
).then(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.chunks,
self.deselect_button,
self.delete_button,
self.download_single_button,
],
show_progress="hidden",
)
self.download_all_button.click(
fn=self.download_all_files,
inputs=[],
outputs=self.download_all_button,
show_progress="hidden",
)
self.download_single_button.click(
fn=self.download_single_file,
inputs=[self.is_zipped_state, self.selected_file_id],
outputs=[self.is_zipped_state, self.download_single_button],
show_progress="hidden",
)
onUploaded = self.upload_button.click(
fn=lambda: gr.update(visible=True),
outputs=[self.upload_progress_panel],
@@ -285,9 +447,63 @@ class FileIndexPage(BasePage):
concurrency_limit=20,
)
try:
# quick file upload event registration of first Index only
if self._index.id == 1:
self.quick_upload_state = gr.State(value=[])
print("Setting up quick upload event")
quickUploadedEvent = (
self._app.chat_page.quick_file_upload.upload(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_with_default_loaders,
inputs=[
self._app.chat_page.quick_file_upload,
gr.State(value=False),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_file_upload,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickUploadedEvent = quickUploadedEvent.then(**event)
quickUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
).then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
).then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
except Exception as e:
print(e)
uploadedEvent = onUploaded.then(
fn=self.list_file,
inputs=[self._app.user_id],
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
@@ -309,16 +525,64 @@ class FileIndexPage(BasePage):
inputs=[self.file_list],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
).then(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.chunks,
self.deselect_button,
self.delete_button,
self.download_single_button,
],
show_progress="hidden",
)
self.filter.submit(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
show_progress="hidden",
)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_file,
inputs=[self._app.user_id],
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
)
def _may_extract_zip(self, files, zip_dir: str):
"""Handle zip files"""
zip_files = [file for file in files if file.endswith(".zip")]
remaining_files = [file for file in files if not file.endswith("zip")]
# Clean-up <zip_dir> before unzip to remove old files
shutil.rmtree(zip_dir, ignore_errors=True)
for zip_file in zip_files:
# Prepare new zip output dir, separated for each files
basename = os.path.splitext(os.path.basename(zip_file))[0]
zip_out_dir = os.path.join(zip_dir, basename)
os.makedirs(zip_out_dir, exist_ok=True)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(zip_out_dir)
n_zip_file = 0
for root, dirs, files in os.walk(zip_dir):
for file in files:
ext = os.path.splitext(file)[1]
# only allow supported file-types ( not zip )
if ext not in [".zip"] and ext in self._supported_file_types:
remaining_files += [os.path.join(root, file)]
n_zip_file += 1
if n_zip_file > 0:
print(f"Update zip files: {n_zip_file}")
return remaining_files
def index_fn(
self, files, reindex: bool, settings, user_id
) -> Generator[tuple[str, str], None, None]:
@@ -335,6 +599,8 @@ class FileIndexPage(BasePage):
yield "", ""
return
files = self._may_extract_zip(files, flowsettings.KH_ZIP_INPUT_DIR)
errors = self.validate(files)
if errors:
gr.Warning(", ".join(errors))
@@ -366,19 +632,61 @@ class FileIndexPage(BasePage):
debugs.append(response.text)
yield "\n".join(outputs), "\n".join(debugs)
except StopIteration as e:
result, errors = e.value
results, index_errors, docs = e.value
except Exception as e:
debugs.append(f"Error: {e}")
yield "\n".join(outputs), "\n".join(debugs)
return
n_successes = len([_ for _ in result if _])
n_successes = len([_ for _ in results if _])
if n_successes:
gr.Info(f"Successfully index {n_successes} files")
n_errors = len([_ for _ in errors if _])
if n_errors:
gr.Warning(f"Have errors for {n_errors} files")
return results
def index_fn_with_default_loaders(
self, files, reindex: bool, settings, user_id
) -> list["str"]:
"""Function for quick upload with default loaders
Args:
files: the list of files to be uploaded
reindex: whether to reindex the files
selected_files: the list of files already selected
settings: the settings of the app
"""
print("Overriding with default loaders")
exist_ids = []
to_process_files = []
for str_file_path in files:
file_path = Path(str(str_file_path))
exist_id = (
self._index.get_indexing_pipeline(settings, user_id)
.route(file_path)
.get_id_if_exists(file_path)
)
if exist_id:
exist_ids.append(exist_id)
else:
to_process_files.append(str_file_path)
returned_ids = []
settings = deepcopy(settings)
settings[f"index.options.{self._index.id}.reader_mode"] = "default"
settings[f"index.options.{self._index.id}.quick_index_mode"] = True
if to_process_files:
_iter = self.index_fn(to_process_files, reindex, settings, user_id)
try:
while next(_iter):
pass
except StopIteration as e:
returned_ids = e.value
return exist_ids + returned_ids
def index_files_from_dir(
self, folder_path, reindex, settings, user_id
) -> Generator[tuple[str, str], None, None]:
@@ -452,7 +760,19 @@ class FileIndexPage(BasePage):
yield from self.index_fn(files, reindex, settings, user_id)
def list_file(self, user_id):
def format_size_human_readable(self, num: float | str, suffix="B"):
try:
num = float(num)
except ValueError:
return num
for unit in ("", "K", "M", "G", "T", "P", "E", "Z"):
if abs(num) < 1024.0:
return f"{num:3.0f}{unit}{suffix}"
num /= 1024.0
return f"{num:.0f}Yi{suffix}"
def list_file(self, user_id, name_pattern=""):
if user_id is None:
# not signed in
return [], pd.DataFrame.from_records(
@@ -461,7 +781,8 @@ class FileIndexPage(BasePage):
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"tokens": "-",
"loader": "-",
"date_created": "-",
}
]
@@ -472,12 +793,17 @@ class FileIndexPage(BasePage):
statement = select(Source)
if self._index.config.get("private", False):
statement = statement.where(Source.user == user_id)
if name_pattern:
statement = statement.where(Source.name.ilike(f"%{name_pattern}%"))
results = [
{
"id": each[0].id,
"name": each[0].name,
"size": each[0].size,
"text_length": each[0].text_length,
"size": self.format_size_human_readable(each[0].size),
"tokens": self.format_size_human_readable(
each[0].note.get("tokens", "-"), suffix=""
),
"loader": each[0].note.get("loader", "-"),
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
}
for each in session.execute(statement).all()
@@ -492,12 +818,14 @@ class FileIndexPage(BasePage):
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"tokens": "-",
"loader": "-",
"date_created": "-",
}
]
)
print(f"{len(results)=}, {len(file_list)=}")
return results, file_list
def interact_file_list(self, list_files, ev: gr.SelectData):
@@ -561,9 +889,8 @@ class FileSelector(BasePage):
self.mode = gr.Radio(
value=default_mode,
choices=[
("Disabled", "disabled"),
("Search All", "all"),
("Select", "select"),
("Search In File(s)", "select"),
],
container=False,
)

View File

@@ -123,8 +123,11 @@ class IndexManager:
)
try:
# clean up
index.on_delete()
try:
# clean up
index.on_delete()
except Exception as e:
print(f"Error while deleting index {index.name}: {e}")
# remove from database
with Session(engine) as sess:

View File

@@ -7,6 +7,21 @@ from ktem.utils.file import YAMLNoDateSafeLoader
from .manager import IndexManager
# UGLY way to restart gradio server by updating atime
def update_current_module_atime():
import os
import time
# Define the file path
file_path = __file__
print("Updating atime for", file_path)
# Get the current time
current_time = time.time()
# Set the modified time (and access time) to the current time
os.utime(file_path, (current_time, current_time))
def format_description(cls):
user_settings = cls.get_admin_settings()
params_lines = ["| Name | Default | Description |", "| --- | --- | --- |"]
@@ -29,7 +44,7 @@ class IndexManagement(BasePage):
def on_building_ui(self):
with gr.Tab(label="View"):
self.index_list = gr.DataFrame(
headers=["ID", "Name", "Index Type"],
headers=["id", "name", "index type"],
interactive=False,
)
@@ -95,7 +110,7 @@ class IndexManagement(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_indices,
inputs=None,
inputs=[],
outputs=[self.index_list],
)
self._app.app.load(
@@ -117,7 +132,7 @@ class IndexManagement(BasePage):
self.create_index,
inputs=[self.name, self.index_type, self.spec],
outputs=None,
).success(self.list_indices, inputs=None, outputs=[self.index_list]).success(
).success(self.list_indices, inputs=[], outputs=[self.index_list]).success(
lambda: ("", None, "", self.spec_desc_default),
outputs=[
self.name,
@@ -125,6 +140,8 @@ class IndexManagement(BasePage):
self.spec,
self.spec_desc,
],
).success(
update_current_module_atime
)
self.index_list.select(
self.select_index,
@@ -152,7 +169,7 @@ class IndexManagement(BasePage):
gr.update(visible=False),
gr.update(visible=True),
),
inputs=None,
inputs=[],
outputs=[
self.btn_edit_save,
self.btn_delete,
@@ -166,10 +183,8 @@ class IndexManagement(BasePage):
inputs=[self.selected_index_id],
outputs=[self.selected_index_id],
show_progress="hidden",
).then(
self.list_indices,
inputs=None,
outputs=[self.index_list],
).then(self.list_indices, inputs=[], outputs=[self.index_list],).success(
update_current_module_atime
)
self.btn_delete_no.click(
lambda: (
@@ -178,7 +193,7 @@ class IndexManagement(BasePage):
gr.update(visible=True),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[
self.btn_edit_save,
self.btn_delete,
@@ -197,7 +212,7 @@ class IndexManagement(BasePage):
show_progress="hidden",
).then(
self.list_indices,
inputs=None,
inputs=[],
outputs=[self.index_list],
)
self.btn_close.click(
@@ -245,16 +260,16 @@ class IndexManagement(BasePage):
items = []
for item in self.manager.indices:
record = {}
record["ID"] = item.id
record["Name"] = item.name
record["Index Type"] = item.__class__.__name__
record["id"] = item.id
record["name"] = item.name
record["index type"] = item.__class__.__name__
items.append(record)
if items:
indices_list = pd.DataFrame.from_records(items)
else:
indices_list = pd.DataFrame.from_records(
[{"ID": "-", "Name": "-", "Index Type": "-"}]
[{"id": "-", "name": "-", "index type": "-"}]
)
return indices_list
@@ -268,7 +283,7 @@ class IndexManagement(BasePage):
if not ev.selected:
return -1
return int(index_list["ID"][ev.index[0]])
return int(index_list["id"][ev.index[0]])
def on_selected_index_change(self, selected_index_id: int):
"""Show the relevant index as user selects it on the UI

View File

@@ -3,7 +3,7 @@ from typing import Optional, Type, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from theflow.utils.modules import deserialize, import_dotted_string
from kotaemon.llms import ChatLLM
@@ -38,7 +38,7 @@ class LLMManager:
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
self._models, self._info, self._default = {}, {}, ""
with Session(engine) as session:
stmt = select(LLMTable)
items = session.execute(stmt)
@@ -54,14 +54,12 @@ class LLMManager:
self._default = item.name
def load_vendors(self):
from kotaemon.llms import (
AzureChatOpenAI,
ChatOpenAI,
EndpointChatLLM,
LlamaCppChat,
)
from kotaemon.llms import AzureChatOpenAI, ChatOpenAI, LlamaCppChat
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat]
for extra_vendor in getattr(flowsettings, "KH_LLM_EXTRA_VENDORS", []):
self._vendors.append(import_dotted_string(extra_vendor, safe=False))
def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name"""

View File

@@ -112,7 +112,7 @@ class LLMManagement(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_llms,
inputs=None,
inputs=[],
outputs=[self.llm_list],
)
self._app.app.load(
@@ -140,8 +140,8 @@ class LLMManagement(BasePage):
self.btn_new.click(
self.create_llm,
inputs=[self.name, self.llm_choices, self.spec, self.default],
outputs=None,
).success(self.list_llms, inputs=None, outputs=[self.llm_list]).success(
outputs=[],
).success(self.list_llms, inputs=[], outputs=[self.llm_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
@@ -176,7 +176,7 @@ class LLMManagement(BasePage):
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -187,7 +187,7 @@ class LLMManagement(BasePage):
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
inputs=[],
outputs=[self.llm_list],
)
self.btn_delete_no.click(
@@ -196,7 +196,7 @@ class LLMManagement(BasePage):
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
@@ -210,7 +210,7 @@ class LLMManagement(BasePage):
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
inputs=[],
outputs=[self.llm_list],
)
self.btn_close.click(

View File

@@ -44,7 +44,7 @@ class App(BaseApp):
if len(self.index_manager.indices) == 1:
for index in self.index_manager.indices:
with gr.Tab(
f"{index.name} Index",
f"{index.name}",
elem_id="indices-tab",
elem_classes=[
"fill-main-area-height",
@@ -58,7 +58,7 @@ class App(BaseApp):
setattr(self, f"_index_{index.id}", page)
elif len(self.index_manager.indices) > 1:
with gr.Tab(
"Indices",
"Files",
elem_id="indices-tab",
elem_classes=["fill-main-area-height", "scrollable", "indices-tab"],
id="indices-tab",
@@ -66,7 +66,7 @@ class App(BaseApp):
) as self._tabs["indices-tab"]:
for index in self.index_manager.indices:
with gr.Tab(
f"{index.name}",
f"{index.name} Collection",
elem_id=f"{index.id}-tab",
) as self._tabs[f"{index.id}-tab"]:
page = index.get_index_page_ui()

View File

@@ -1,15 +1,25 @@
import asyncio
import csv
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional
import gradio as gr
from filelock import FileLock
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline,
)
from plotly.io import from_json
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
@@ -17,23 +27,49 @@ from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
DEFAULT_SETTING = "(default)"
INFO_PANEL_SCALES = {True: 8, False: 4}
pdfview_js = """
function() {
// Get all links and attach click event
var links = document.getElementsByClassName("pdf-link");
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
}
"""
class ChatPage(BasePage):
def __init__(self, app):
self._app = app
self._indices_input = []
self.on_building_ui()
self._reasoning_type = gr.State(value=None)
self._llm_type = gr.State(value=None)
self._conversation_renamed = gr.State(value=False)
self.info_panel_expanded = gr.State(value=True)
def on_building_ui(self):
with gr.Row():
self.chat_state = gr.State(STATE)
with gr.Column(scale=1, elem_id="conv-settings-panel"):
self.state_chat = gr.State(STATE)
self.state_retrieval_history = gr.State([])
self.state_chat_history = gr.State([])
self.state_plot_history = gr.State([])
self.state_settings = gr.State({})
self.state_info_panel = gr.State("")
self.state_plot_panel = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion = ChatSuggestion(self._app)
for index in self._app.index_manager.indices:
for index_id, index in enumerate(self._app.index_manager.indices):
index.selector = None
index_ui = index.get_selector_component_ui()
if not index_ui:
@@ -41,7 +77,9 @@ class ChatPage(BasePage):
continue
index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=f"{index.name} Index", open=True):
with gr.Accordion(
label=f"{index.name} Collection", open=index_id < 1
):
index_ui.render()
gr_index = index_ui.as_gradio_component()
if gr_index:
@@ -60,14 +98,66 @@ class ChatPage(BasePage):
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
if len(self._app.index_manager.indices) > 0:
with gr.Accordion(label="Quick Upload") as _:
self.quick_file_upload = File(
file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()),
file_count="multiple",
container=True,
show_label=False,
)
self.quick_file_upload_status = gr.Markdown()
self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6, elem_id="chat-area"):
self.chat_panel = ChatPanel(self._app)
with gr.Column(scale=3, elem_id="chat-info-panel"):
with gr.Row():
with gr.Accordion(label="Chat settings", open=False):
# a quick switch for reasoning type option
with gr.Row():
gr.HTML("Reasoning method")
gr.HTML("Model")
with gr.Row():
reasoning_type_values = [
(DEFAULT_SETTING, DEFAULT_SETTING)
] + self._app.default_settings.reasoning.settings[
"use"
].choices
self.reasoning_types = gr.Dropdown(
choices=reasoning_type_values,
value=DEFAULT_SETTING,
container=False,
show_label=False,
)
self.model_types = gr.Dropdown(
choices=self._app.default_settings.reasoning.options[
"simple"
]
.settings["llm"]
.choices,
value="",
container=False,
show_label=False,
)
with gr.Column(
scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel"
) as self.info_column:
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.HTML()
self.modal = gr.HTML("<div id='pdf-modal'></div>")
self.plot_panel = gr.Plot(visible=False)
self.info_panel = gr.HTML(elem_id="html-info-panel")
def _json_to_plot(self, json_dict: dict | None):
if json_dict:
plot = from_json(json_dict)
plot = gr.update(visible=True, value=plot)
else:
plot = gr.update(visible=False)
return plot
def on_register_events(self):
gr.on(
@@ -98,27 +188,75 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
self._reasoning_type,
self._llm_type,
self.state_chat,
self._app.user_id,
]
+ self._indices_input,
outputs=[
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.plot_panel,
self.state_plot_panel,
self.state_chat,
],
concurrency_limit=20,
show_progress="minimal",
).success(
fn=self.backup_original_info,
inputs=[
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
],
outputs=[
self.state_chat_history,
self.state_settings,
self.state_info_panel,
],
).then(
fn=self.update_data_source,
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.chat_state,
self.state_chat,
]
+ self._indices_input,
outputs=None,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
outputs=[
self.chat_control.conversation_rn,
self._conversation_renamed,
],
).success(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self._conversation_renamed,
self._app.user_id,
],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
self.chat_panel.regen_btn.click(
@@ -127,33 +265,90 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
self._reasoning_type,
self._llm_type,
self.state_chat,
self._app.user_id,
]
+ self._indices_input,
outputs=[
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.plot_panel,
self.state_plot_panel,
self.state_chat,
],
concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.chat_state,
self.state_chat,
]
+ self._indices_input,
outputs=None,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
).success(
fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot,
outputs=[
self.chat_control.conversation_rn,
self._conversation_renamed,
],
).success(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self._conversation_renamed,
self._app.user_id,
],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
show_progress="hidden",
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
self.chat_control.btn_info_expand.click(
fn=lambda is_expanded: (
gr.update(scale=INFO_PANEL_SCALES[is_expanded]),
not is_expanded,
),
inputs=self.info_panel_expanded,
outputs=[self.info_column, self.info_panel_expanded],
)
self.chat_panel.chatbot.like(
fn=self.is_liked,
inputs=[self.chat_control.conversation_id],
outputs=None,
).success(
self.save_log,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
self.state_settings,
self.state_info_panel,
gr.State(getattr(flowsettings, "KH_APP_DATA_DIR", "logs")),
],
outputs=None,
)
self.chat_control.btn_new.click(
@@ -163,17 +358,25 @@ class ChatPage(BasePage):
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
inputs=[self.chat_control.conversation, self._app.user_id],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
show_progress="hidden",
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
)
self.chat_control.btn_del.click(
@@ -188,17 +391,25 @@ class ChatPage(BasePage):
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
inputs=[self.chat_control.conversation, self._app.user_id],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
show_progress="hidden",
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
@@ -207,33 +418,80 @@ class ChatPage(BasePage):
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.conversation_rn_btn.click(
self.chat_control.btn_conversation_rn.click(
lambda: gr.update(visible=True),
outputs=[
self.chat_control.conversation_rn,
],
)
self.chat_control.conversation_rn.submit(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
gr.State(value=True),
self._app.user_id,
],
outputs=[self.chat_control.conversation, self.chat_control.conversation],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
show_progress="hidden",
)
self.chat_control.conversation.select(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
inputs=[self.chat_control.conversation, self._app.user_id],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
show_progress="hidden",
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
# evidence display on message selection
self.chat_panel.chatbot.select(
self.message_selected,
inputs=[
self.state_retrieval_history,
self.state_plot_history,
],
outputs=[
self.info_panel,
self.state_plot_panel,
],
).then(
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
)
self.chat_control.cb_is_public.change(
self.on_set_public_conversation,
inputs=[self.chat_control.cb_is_public, self.chat_control.conversation],
outputs=None,
show_progress="hidden",
)
self.report_issue.report_btn.click(
@@ -247,11 +505,26 @@ class ChatPage(BasePage):
self._app.settings_state,
self._app.user_id,
self.info_panel,
self.chat_state,
self.state_chat,
]
+ self._indices_input,
outputs=None,
)
self.reasoning_types.change(
self.reasoning_changed,
inputs=[self.reasoning_types],
outputs=[self._reasoning_type],
)
self.model_types.change(
lambda x: x,
inputs=[self.model_types],
outputs=[self._llm_type],
)
self.chat_control.conversation_id.change(
lambda: gr.update(visible=False),
outputs=self.plot_panel,
)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion.example.select(
self.chat_suggestion.select_example,
@@ -291,6 +564,28 @@ class ChatPage(BasePage):
else:
return gr.update(visible=True), gr.update(visible=False)
def on_set_public_conversation(self, is_public, convo_id):
if not convo_id:
gr.Warning("No conversation selected")
return
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
result = session.exec(statement).one()
name = result.name
if result.is_public != is_public:
# Only trigger updating when user
# select different value from the current
result.is_public = is_public
session.add(result)
session.commit()
gr.Info(
f"Conversation: {name} is {'public' if is_public else 'private'}."
)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
@@ -306,25 +601,53 @@ class ChatPage(BasePage):
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": lambda: self.chat_control.select_conv(""),
"fn": lambda: self.chat_control.select_conv("", None),
"outputs": [
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
]
+ self._indices_input,
"show_progress": "hidden",
},
)
def update_data_source(self, convo_id, messages, state, *selecteds):
def persist_data_source(
self,
convo_id,
user_id,
retrieval_msg,
plot_data,
retrival_history,
plot_history,
messages,
state,
*selecteds,
):
"""Update the data source"""
if not convo_id:
gr.Warning("No conversation selected")
return
# if not regen, then append the new message
if not state["app"].get("regen", False):
retrival_history = retrival_history + [retrieval_msg]
plot_history = plot_history + [plot_data]
else:
if retrival_history:
print("Updating retrieval history (regen=True)")
retrival_history[-1] = retrieval_msg
plot_history[-1] = plot_data
# reset regen state
state["app"]["regen"] = False
selecteds_ = {}
for index in self._app.index_manager.indices:
if index.selector is None:
@@ -339,15 +662,29 @@ class ChatPage(BasePage):
result = session.exec(statement).one()
data_source = result.data_source
old_selecteds = data_source.get("selected", {})
is_owner = result.user == user_id
# Write down to db
result.data_source = {
"selected": selecteds_,
"selected": selecteds_ if is_owner else old_selecteds,
"messages": messages,
"retrieval_messages": retrival_history,
"plot_history": plot_history,
"state": state,
"likes": deepcopy(data_source.get("likes", [])),
}
session.add(result)
session.commit()
return retrival_history, plot_history
def reasoning_changed(self, reasoning_type):
if reasoning_type != DEFAULT_SETTING:
# override app settings state (temporary)
gr.Info("Reasoning type changed to `{}`".format(reasoning_type))
return reasoning_type
def is_liked(self, convo_id, liked: gr.LikeData):
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
@@ -362,7 +699,19 @@ class ChatPage(BasePage):
session.add(result)
session.commit()
def create_pipeline(self, settings: dict, state: dict, user_id: int, *selecteds):
def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData):
index = msg.index[0]
return retrieval_history[index], plot_history[index]
def create_pipeline(
self,
settings: dict,
session_reasoning_type: str,
session_llm: str,
state: dict,
user_id: int,
*selecteds,
):
"""Create the pipeline from settings
Args:
@@ -374,10 +723,23 @@ class ChatPage(BasePage):
Returns:
- the pipeline objects
"""
reasoning_mode = settings["reasoning.use"]
# override reasoning_mode by temporary chat page state
print("Session reasoning type", session_reasoning_type)
print("Session LLM", session_llm)
reasoning_mode = (
settings["reasoning.use"]
if session_reasoning_type in (DEFAULT_SETTING, None)
else session_reasoning_type
)
reasoning_cls = reasonings[reasoning_mode]
print("Reasoning class", reasoning_cls)
reasoning_id = reasoning_cls.get_info()["id"]
settings = deepcopy(settings)
llm_setting_key = f"reasoning.options.{reasoning_id}.llm"
if llm_setting_key in settings and session_llm not in (DEFAULT_SETTING, None):
settings[llm_setting_key] = session_llm
# get retrievers
retrievers = []
for index in self._app.index_manager.indices:
@@ -403,7 +765,15 @@ class ChatPage(BasePage):
return pipeline, reasoning_state
def chat_fn(
self, conversation_id, chat_history, settings, state, user_id, *selecteds
self,
conversation_id,
chat_history,
settings,
reasoning_type,
llm_type,
state,
user_id,
*selecteds,
):
"""Chat function"""
chat_input = chat_history[-1][0]
@@ -413,18 +783,23 @@ class ChatPage(BasePage):
# construct the pipeline
pipeline, reasoning_state = self.create_pipeline(
settings, state, user_id, *selecteds
settings, reasoning_type, llm_type, state, user_id, *selecteds
)
print("Reasoning state", reasoning_state)
pipeline.set_output_queue(queue)
text, refs = "", ""
text, refs, plot, plot_gr = "", "", None, gr.update(visible=False)
msg_placeholder = getattr(
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
)
print(msg_placeholder)
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
len_ref = -1 # for logging purpose
yield (
chat_history + [(chat_input, text or msg_placeholder)],
refs,
plot_gr,
plot,
state,
)
for response in pipeline.stream(chat_input, conversation_id, chat_history):
@@ -446,22 +821,42 @@ class ChatPage(BasePage):
else:
refs += response.content
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
if response.channel == "plot":
plot = response.content
plot_gr = self._json_to_plot(plot)
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
yield (
chat_history + [(chat_input, text or msg_placeholder)],
refs,
plot_gr,
plot,
state,
)
if not text:
empty_msg = getattr(
flowsettings, "KH_CHAT_EMPTY_MSG_PLACEHOLDER", "(Sorry, I don't know)"
)
print(f"Generate nothing: {empty_msg}")
yield chat_history + [(chat_input, text or empty_msg)], refs, state
yield (
chat_history + [(chat_input, text or empty_msg)],
refs,
plot_gr,
plot,
state,
)
def regen_fn(
self, conversation_id, chat_history, settings, state, user_id, *selecteds
self,
conversation_id,
chat_history,
settings,
reasoning_type,
llm_type,
state,
user_id,
*selecteds,
):
"""Regen function"""
if not chat_history:
@@ -470,11 +865,119 @@ class ChatPage(BasePage):
return
state["app"]["regen"] = True
for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, user_id, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
yield from self.chat_fn(
conversation_id,
chat_history,
settings,
reasoning_type,
llm_type,
state,
user_id,
*selecteds,
)
state["app"]["regen"] = False
def check_and_suggest_name_conv(self, chat_history):
suggest_pipeline = SuggestConvNamePipeline()
new_name = gr.update()
renamed = False
# check if this is a newly created conversation
if len(chat_history) == 1:
suggested_name = suggest_pipeline(chat_history).text[:40]
new_name = gr.update(value=suggested_name)
renamed = True
return new_name, renamed
def backup_original_info(
self, chat_history, settings, info_pannel, original_chat_history
):
original_chat_history.append(chat_history[-1])
return original_chat_history, settings, info_pannel
def save_log(
self,
conversation_id,
chat_history,
settings,
info_panel,
original_chat_history,
original_settings,
original_info_panel,
log_dir,
):
if not Path(log_dir).exists():
Path(log_dir).mkdir(parents=True)
lock = FileLock(Path(log_dir) / ".lock")
# get current date
today = datetime.now()
formatted_date = today.strftime("%d%m%Y_%H")
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
if not likes:
return
feedback = likes[-1][-1]
message_index = likes[-1][0]
current_message = chat_history[message_index[0]]
original_message = original_chat_history[message_index[0]]
is_original = all(
[
current_item == original_item
for current_item, original_item in zip(
current_message, original_message
)
]
)
dataframe = [
[
conversation_id,
message_index,
current_message[0],
current_message[1],
chat_history,
settings,
info_panel,
feedback,
is_original,
original_message[1],
original_chat_history,
original_settings,
original_info_panel,
]
]
with lock:
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
is_log_file_exist = log_file.is_file()
with open(log_file, "a") as f:
writer = csv.writer(f)
# write headers
if not is_log_file_exist:
writer.writerow(
[
"Conversation ID",
"Message ID",
"Question",
"Answer",
"Chat History",
"Settings",
"Evidences",
"Feedback",
"Original/ Rewritten",
"Original Answer",
"Original Chat History",
"Original Settings",
"Original Evidences",
]
)
writer.writerows(dataframe)

View File

@@ -1,13 +1,20 @@
import logging
import os
import gradio as gr
from ktem.app import BasePage
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from ktem.db.models import Conversation, User, engine
from sqlmodel import Session, or_, select
import flowsettings
from ...utils.conversation import sync_retrieval_n_message
from .common import STATE
logger = logging.getLogger(__name__)
ASSETS_DIR = "assets/icons"
if not os.path.isdir(ASSETS_DIR):
ASSETS_DIR = "libs/ktem/ktem/assets/icons"
def is_conv_name_valid(name):
@@ -35,14 +42,47 @@ class ConversationControl(BasePage):
label="Chat sessions",
choices=[],
container=False,
filterable=False,
filterable=True,
interactive=True,
elem_classes=["unset-overflow"],
)
with gr.Row() as self._new_delete:
self.btn_new = gr.Button(value="New", min_width=10, variant="primary")
self.btn_del = gr.Button(value="Delete", min_width=10, variant="stop")
self.btn_new = gr.Button(
value="",
icon=f"{ASSETS_DIR}/new.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.btn_del = gr.Button(
value="",
icon=f"{ASSETS_DIR}/delete.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.btn_conversation_rn = gr.Button(
value="",
icon=f"{ASSETS_DIR}/rename.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.btn_info_expand = gr.Button(
value="",
icon=f"{ASSETS_DIR}/sidebar.svg",
min_width=2,
scale=1,
size="sm",
elem_classes=["no-background", "body-text-color"],
)
self.cb_is_public = gr.Checkbox(
value=False, label="Shared", min_width=10, scale=4
)
with gr.Row(visible=False) as self._delete_confirm:
self.btn_del_conf = gr.Button(
@@ -54,28 +94,60 @@ class ConversationControl(BasePage):
with gr.Row():
self.conversation_rn = gr.Text(
label="(Enter) to save",
placeholder="Conversation name",
container=False,
container=True,
scale=5,
min_width=10,
interactive=True,
)
self.conversation_rn_btn = gr.Button(
value="Rename",
scale=1,
min_width=10,
elem_classes=["no-background", "body-text-color", "bold-text"],
visible=False,
)
def load_chat_history(self, user_id):
"""Reload chat history"""
# In case user are admin. They can also watch the
# public conversations
can_see_public: bool = False
with Session(engine) as session:
statement = select(User).where(User.id == user_id)
result = session.exec(statement).one_or_none()
if result is not None:
if flowsettings.KH_USER_CAN_SEE_PUBLIC:
can_see_public = (
result.username == flowsettings.KH_USER_CAN_SEE_PUBLIC
)
else:
can_see_public = True
print(f"User-id: {user_id}, can see public conversations: {can_see_public}")
options = []
with Session(engine) as session:
statement = (
select(Conversation)
.where(Conversation.user == user_id)
.order_by(Conversation.date_created.desc()) # type: ignore
)
# Define condition based on admin-role:
# - can_see: can see their conversations & public files
# - can_not_see: only see their conversations
if can_see_public:
statement = (
select(Conversation)
.where(
or_(
Conversation.user == user_id,
Conversation.is_public,
)
)
.order_by(
Conversation.is_public.desc(), Conversation.date_created.desc()
) # type: ignore
)
else:
statement = (
select(Conversation)
.where(Conversation.user == user_id)
.order_by(Conversation.date_created.desc()) # type: ignore
)
results = session.exec(statement).all()
for result in results:
options.append((result.name, result.id))
@@ -129,7 +201,7 @@ class ConversationControl(BasePage):
else:
return None, gr.update(value=None, choices=[])
def select_conv(self, conversation_id):
def select_conv(self, conversation_id, user_id):
"""Select the conversation"""
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
@@ -137,18 +209,46 @@ class ConversationControl(BasePage):
result = session.exec(statement).one()
id_ = result.id
name = result.name
selected = result.data_source.get("selected", {})
is_conv_public = result.is_public
# disable file selection ids state if
# not the owner of the conversation
if user_id == result.user:
selected = result.data_source.get("selected", {})
else:
selected = {}
chats = result.data_source.get("messages", [])
info_panel = ""
retrieval_history: list[str] = result.data_source.get(
"retrieval_messages", []
)
plot_history: list[dict] = result.data_source.get("plot_history", [])
# On initialization
# Ensure len of retrieval and messages are equal
retrieval_history = sync_retrieval_n_message(chats, retrieval_history)
info_panel = (
retrieval_history[-1]
if retrieval_history
else "<h5><b>No evidence found.</b></h5>"
)
plot_data = plot_history[-1] if plot_history else None
state = result.data_source.get("state", STATE)
except Exception as e:
logger.warning(e)
id_ = ""
name = ""
selected = {}
chats = []
retrieval_history = []
plot_history = []
info_panel = ""
plot_data = None
state = STATE
is_conv_public = False
indices = []
for index in self._app.index_manager.indices:
@@ -160,10 +260,29 @@ class ConversationControl(BasePage):
if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), index.default_selector))
return id_, id_, name, chats, info_panel, state, *indices
return (
id_,
id_,
name,
chats,
info_panel,
plot_data,
retrieval_history,
plot_history,
is_conv_public,
state,
*indices,
)
def rename_conv(self, conversation_id, new_name, user_id):
def rename_conv(self, conversation_id, new_name, is_renamed, user_id):
"""Rename the conversation"""
if not is_renamed:
return (
gr.update(),
conversation_id,
gr.update(visible=False),
)
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
@@ -185,7 +304,12 @@ class ConversationControl(BasePage):
session.commit()
history = self.load_chat_history(user_id)
return gr.update(choices=history), conversation_id
gr.Info("Conversation renamed.")
return (
gr.update(choices=history),
conversation_id,
gr.update(visible=False),
)
def _on_app_created(self):
"""Reload the conversation once the app is created"""

View File

@@ -12,7 +12,7 @@ class ReportIssue(BasePage):
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="Report", open=False):
with gr.Accordion(label="Feedback", open=False):
self.correctness = gr.Radio(
choices=[
("The answer is correct", "correct"),

View File

@@ -9,6 +9,7 @@ from theflow.settings import settings
def get_remote_doc(url: str) -> str:
try:
res = requests.get(url)
res.raise_for_status()
return res.text
except Exception as e:
print(f"Failed to fetch document from {url}: {e}")

View File

@@ -7,9 +7,9 @@ from sqlmodel import Session, select
fetch_creds = """
function() {
const username = getStorage('username')
const password = getStorage('password')
return [username, password];
const username = getStorage('username', '')
const password = getStorage('password', '')
return [username, password, null];
}
"""

View File

@@ -15,18 +15,18 @@ class ResourcesTab(BasePage):
self.on_building_ui()
def on_building_ui(self):
if self._app.f_user_management:
with gr.Tab("User Management", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
with gr.Tab("Index Collections") as self.index_management_tab:
self.index_management = IndexManagement(self._app)
with gr.Tab("LLMs") as self.llm_management_tab:
self.llm_management = LLMManagement(self._app)
with gr.Tab("Embedding Models") as self.emb_management_tab:
with gr.Tab("Embeddings") as self.emb_management_tab:
self.emb_management = EmbeddingManagement(self._app)
with gr.Tab("Index Management") as self.index_management_tab:
self.index_management = IndexManagement(self._app)
if self._app.f_user_management:
with gr.Tab("Users", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
def on_subscribe_public_events(self):
if self._app.f_user_management:

View File

@@ -94,6 +94,28 @@ def validate_password(pwd, pwd_cnf):
return ""
def create_user(usn, pwd) -> bool:
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
result = session.exec(statement).all()
if result:
print(f'User "{usn}" already exists')
return False
else:
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
user = User(
username=usn,
username_lower=usn.lower(),
password=hashed_password,
admin=True,
)
session.add(user)
session.commit()
return True
class UserManagement(BasePage):
def __init__(self, app):
self._app = app
@@ -105,23 +127,9 @@ class UserManagement(BasePage):
usn = flowsettings.KH_FEATURE_USER_MANAGEMENT_ADMIN
pwd = flowsettings.KH_FEATURE_USER_MANAGEMENT_PASSWORD
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
result = session.exec(statement).all()
if result:
print(f'User "{usn}" already exists')
else:
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
user = User(
username=usn,
username_lower=usn.lower(),
password=hashed_password,
admin=True,
)
session.add(user)
session.commit()
gr.Info(f'User "{usn}" created successfully')
is_created = create_user(usn, pwd)
if is_created:
gr.Info(f'User "{usn}" created successfully')
def on_building_ui(self):
with gr.Tab(label="User list"):
@@ -224,7 +232,7 @@ class UserManagement(BasePage):
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)

View File

@@ -2,13 +2,15 @@ import hashlib
import gradio as gr
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Settings, User, engine
from sqlmodel import Session, select
signout_js = """
function() {
function(u, c, pw, pwc) {
removeFromStorage('username');
removeFromStorage('password');
return [u, c, pw, pwc];
}
"""
@@ -72,6 +74,10 @@ class SettingsPage(BasePage):
self._components = {}
self._reasoning_mode = {}
# store llms and embeddings components
self._llms = []
self._embeddings = []
# render application page if there are application settings
self._render_app_tab = False
if self._default_settings.application.settings:
@@ -101,14 +107,13 @@ class SettingsPage(BasePage):
def on_building_ui(self):
if self._app.f_user_management:
with gr.Tab("Users"):
with gr.Tab("User settings"):
self.user_tab()
with gr.Tab("General"):
self.app_tab()
with gr.Tab("Document Indices"):
self.index_tab()
with gr.Tab("Reasoning Pipelines"):
self.reasoning_tab()
self.app_tab()
self.index_tab()
self.reasoning_tab()
self.setting_save_btn = gr.Button(
"Save changes", variant="primary", scale=1, elem_classes=["right-button"]
)
@@ -192,7 +197,7 @@ class SettingsPage(BasePage):
)
onSignOutClick = self.signout.click(
lambda: (None, "Current user: ___", "", ""),
inputs=None,
inputs=[],
outputs=[
self._user_id,
self.current_name,
@@ -248,10 +253,14 @@ class SettingsPage(BasePage):
return "", ""
def app_tab(self):
with gr.Tab("General application settings", visible=self._render_app_tab):
with gr.Tab("General", visible=self._render_app_tab):
for n, si in self._default_settings.application.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"application.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
def index_tab(self):
# TODO: double check if we need general
@@ -260,12 +269,18 @@ class SettingsPage(BasePage):
# obj = render_setting_item(si, si.value)
# self._components[f"index.{n}"] = obj
with gr.Tab("Index settings", visible=self._render_index_tab):
id2name = {k: v.name for k, v in self._app.index_manager.info().items()}
with gr.Tab("Retrieval settings", visible=self._render_index_tab):
for pn, sig in self._default_settings.index.options.items():
with gr.Tab(f"Index {pn}"):
name = "{} Collection".format(id2name.get(pn, f"<id {pn}>"))
with gr.Tab(name):
for n, si in sig.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"index.options.{pn}.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
def reasoning_tab(self):
with gr.Tab("Reasoning settings", visible=self._render_reasoning_tab):
@@ -275,6 +290,10 @@ class SettingsPage(BasePage):
continue
obj = render_setting_item(si, si.value)
self._components[f"reasoning.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
gr.Markdown("### Reasoning-specific settings")
self._components["reasoning.use"] = render_setting_item(
@@ -289,10 +308,19 @@ class SettingsPage(BasePage):
visible=idx == 0,
elem_id=pn,
) as self._reasoning_mode[pn]:
gr.Markdown("**Name**: Description")
reasoning = reasonings.get(pn, None)
if reasoning is None:
gr.Markdown("**Name**: Description")
else:
info = reasoning.get_info()
gr.Markdown(f"**{info['name']}**: {info['description']}")
for n, si in sig.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"reasoning.options.{pn}.{n}"] = obj
if si.special_type == "llm":
self._llms.append(obj)
if si.special_type == "embedding":
self._embeddings.append(obj)
def change_reasoning_mode(self, value):
output = []
@@ -360,3 +388,38 @@ class SettingsPage(BasePage):
outputs=[self._settings_state] + self.components(),
show_progress="hidden",
)
def update_llms():
from ktem.llms.manager import llms
if llms._default:
llm_choices = [(f"{llms._default} (default)", "")]
else:
llm_choices = [("(random)", "")]
llm_choices += [(_, _) for _ in llms.options().keys()]
return gr.update(choices=llm_choices)
def update_embeddings():
from ktem.embeddings.manager import embedding_models_manager
if embedding_models_manager._default:
emb_choices = [(f"{embedding_models_manager._default} (default)", "")]
else:
emb_choices = [("(random)", "")]
emb_choices += [(_, _) for _ in embedding_models_manager.options().keys()]
return gr.update(choices=emb_choices)
for llm in self._llms:
self._app.app.load(
update_llms,
inputs=[],
outputs=[llm],
show_progress="hidden",
)
for emb in self._embeddings:
self._app.app.load(
update_embeddings,
inputs=[],
outputs=[emb],
show_progress="hidden",
)

View File

@@ -0,0 +1,9 @@
from .decompose_question import DecomposeQuestionPipeline
from .fewshot_rewrite_question import FewshotRewriteQuestionPipeline
from .rewrite_question import RewriteQuestionPipeline
__all__ = [
"DecomposeQuestionPipeline",
"FewshotRewriteQuestionPipeline",
"RewriteQuestionPipeline",
]

View File

@@ -0,0 +1,79 @@
import logging
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization.rewrite_question import RewriteQuestionPipeline
from pydantic import BaseModel, Field
from kotaemon.base import Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM
logger = logging.getLogger(__name__)
class SubQuery(BaseModel):
"""Search over a database of insurance rulebooks or financial reports"""
sub_query: str = Field(
...,
description="A very specific query against the database.",
)
class DecomposeQuestionPipeline(RewriteQuestionPipeline):
"""Decompose user complex question into multiple sub-questions
Args:
llm: the language model to rewrite question
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(
default_callback=lambda _: llms.get("openai-gpt4-turbo", llms.get_default())
)
DECOMPOSE_SYSTEM_PROMPT_TEMPLATE = (
"You are an expert at converting user complex questions into sub questions. "
"Perform query decomposition using provided function_call. "
"Given a user question, break it down into the most specific sub"
" questions you can (at most 3) "
"which will help you answer the original question. "
"Each sub question should be about a single concept/fact/idea. "
"If there are acronyms or words you are not familiar with, "
"do not try to rephrase them."
)
prompt_template: str = DECOMPOSE_SYSTEM_PROMPT_TEMPLATE
def create_prompt(self, question):
schema = SubQuery.model_json_schema()
function = {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
llm_kwargs = {
"tools": [{"type": "function", "function": function}],
"tool_choice": "auto",
}
messages = [
SystemMessage(content=self.prompt_template),
HumanMessage(content=question),
]
return messages, llm_kwargs
def run(self, question: str) -> list: # type: ignore
messages, llm_kwargs = self.create_prompt(question)
result = self.llm(messages, **llm_kwargs)
tool_calls = result.additional_kwargs.get("tool_calls", None)
sub_queries = []
if tool_calls:
for tool_call in tool_calls:
sub_queries.append(
Document(
content=SubQuery.parse_raw(
tool_call["function"]["arguments"]
).sub_query
)
)
return sub_queries

View File

@@ -0,0 +1,100 @@
import json
import uuid
from pathlib import Path
from ktem.components import get_docstore, get_vectorstore
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization.rewrite_question import (
DEFAULT_REWRITE_PROMPT,
RewriteQuestionPipeline,
)
from theflow.settings import settings as flowsettings
from kotaemon.base import AIMessage, Document, HumanMessage, Node, SystemMessage
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.llms import ChatLLM
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
class FewshotRewriteQuestionPipeline(RewriteQuestionPipeline):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
embedding: the embedding model to encode the question
vector_store: the vector store to store the encoded question
doc_store: the document store to store the original question
k: the number of examples to retrieve for rewriting
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
embedding: BaseEmbeddings
vector_store: BaseVectorStore
doc_store: BaseDocumentStore
k: int = getattr(flowsettings, "N_PROMPT_OPT_EXAMPLES", 3)
def add_documents(self, examples, batch_size: int = 50):
print("Adding fewshot examples for rewriting")
documents = []
for example in examples:
doc = Document(
text=example["input"], id_=str(uuid.uuid4()), metadata=example
)
documents.append(doc)
for i in range(0, len(documents), batch_size):
embeddings = self.embedding(documents[i : i + batch_size])
ids = [t.doc_id for t in documents[i : i + batch_size]]
self.vector_store.add(
embeddings=embeddings,
ids=ids,
)
self.doc_store.add(documents[i : i + batch_size])
@classmethod
def get_pipeline(
cls,
embedding,
example_path=Path(__file__).parent / "rephrase_question_train.json",
collection_name: str = "fewshot_rewrite_examples",
):
vector_store = get_vectorstore(collection_name)
doc_store = get_docstore(collection_name)
pipeline = cls(
embedding=embedding, vector_store=vector_store, doc_store=doc_store
)
if doc_store.count():
return pipeline
examples = json.load(open(example_path, "r"))
pipeline.add_documents(examples)
return pipeline
def run(self, question: str) -> Document: # type: ignore
emb = self.embedding(question)[0].embedding
_, _, ids = self.vector_store.query(embedding=emb, top_k=self.k)
examples = self.doc_store.get(ids)
messages = [SystemMessage(content="You are a helpful assistant")]
for example in examples:
messages.append(
HumanMessage(
content=self.rewrite_template.format(
question=example.metadata["input"], lang=self.lang
)
)
)
messages.append(AIMessage(content=example.metadata["output"]))
messages.append(
HumanMessage(
content=self.rewrite_template.format(question=question, lang=self.lang)
)
)
result = self.llm(messages)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
from ktem.llms.manager import llms
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
return self.llm(messages)

View File

@@ -0,0 +1,36 @@
import logging
from ktem.llms.manager import llms
from kotaemon.base import AIMessage, BaseComponent, Document, HumanMessage, Node
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
class SuggestConvNamePipeline(BaseComponent):
"""Suggest a good conversation name based on the chat history."""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
SUGGEST_NAME_PROMPT_TEMPLATE = (
"You are an expert at suggesting good and memorable conversation name. "
"Based on the chat history above, "
"suggest a good conversation name (max 10 words). "
"Give answer in {lang}. Just output the conversation "
"name without any extra."
)
prompt_template: str = SUGGEST_NAME_PROMPT_TEMPLATE
lang: str = "English"
def run(self, chat_history: list[tuple[str, str]]) -> Document: # type: ignore
prompt_template = PromptTemplate(self.prompt_template)
prompt = prompt_template.populate(lang=self.lang)
messages = []
for human, ai in chat_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=prompt))
return self.llm(messages)

View File

@@ -19,7 +19,10 @@ from kotaemon.agents import (
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
from ..utils import SUPPORTED_LANGUAGE_MAP
logger = logging.getLogger(__name__)
DEFAULT_AGENT_STEPS = 4
class DocSearchArgs(BaseModel):
@@ -97,7 +100,7 @@ class DocSearchTool(BaseTool):
)
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
print("Score", retrieved_item.metadata.get("relevance_score", None))
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
# trim context by trim_len
if evidence:
@@ -190,7 +193,9 @@ class ReactAgentPipeline(BaseReasoning):
"<b>Action</b>: <em>{tool}[{input}]</em>\n\n<b>Output</b>: {output}"
).format(
tool=step.tool if status == "thinking" else "",
input=step.tool_input.replace("\n", "") if status == "thinking" else "",
input=step.tool_input.replace("\n", "").replace('"', "")
if status == "thinking"
else "",
output=output if status == "thinking" else "Finished",
)
return Document(
@@ -261,9 +266,17 @@ class ReactAgentPipeline(BaseReasoning):
llm_name = settings[f"{prefix}.llm"]
llm = llms.get(llm_name, llms.get_default())
max_context_length_setting = settings.get("reasoning.max_context_length", None)
pipeline = ReactAgentPipeline(retrievers=retrievers)
pipeline.agent.llm = llm
pipeline.agent.max_iterations = settings[f"{prefix}.max_iterations"]
if max_context_length_setting:
pipeline.agent.max_context_length = (
max_context_length_setting // DEFAULT_AGENT_STEPS
)
tools = []
for tool_name in settings[f"reasoning.options.{_id}.tools"]:
tool = TOOL_REGISTRY[tool_name]
@@ -273,7 +286,7 @@ class ReactAgentPipeline(BaseReasoning):
tool.llm = llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
pipeline.agent.output_lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
@@ -298,6 +311,7 @@ class ReactAgentPipeline(BaseReasoning):
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"special_type": "llm",
"info": (
"The language model to use for generating the answer. If None, "
"the application default language model will be used."
@@ -325,5 +339,10 @@ class ReactAgentPipeline(BaseReasoning):
return {
"id": "ReAct",
"name": "ReAct Agent",
"description": "Implementing ReAct paradigm",
"description": (
"Implementing ReAct paradigm: https://arxiv.org/abs/2210.03629. "
"ReAct agent answers the user's request by iteratively formulating "
"plan and executing it. The agent can use multiple tools to gather "
"information and generate the final answer."
),
}

View File

@@ -20,7 +20,10 @@ from kotaemon.agents import (
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
from ..utils import SUPPORTED_LANGUAGE_MAP
logger = logging.getLogger(__name__)
DEFAULT_AGENT_STEPS = 4
DEFAULT_PLANNER_PROMPT = (
@@ -135,7 +138,7 @@ class DocSearchTool(BaseTool):
)
print("Retrieved #{}: {}".format(_id, retrieved_content))
print("Score", retrieved_item.metadata.get("relevance_score", None))
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
# trim context by trim_len
if evidence:
@@ -215,7 +218,7 @@ class RewooAgentPipeline(BaseReasoning):
use_rewrite: bool = False
enable_citation: bool = False
def format_info_panel(self, worker_log):
def format_info_panel_evidence(self, worker_log):
header = ""
content = []
@@ -223,6 +226,10 @@ class RewooAgentPipeline(BaseReasoning):
if line.startswith("#Plan"):
# line starts with #Plan should be marked as a new segment
header = line
elif line.startswith("#Action"):
# small fix for markdown output
line = "\\" + line + "<br>"
content.append(line)
elif line.startswith("#"):
# stop markdown from rendering big headers
line = "\\" + line
@@ -238,6 +245,17 @@ class RewooAgentPipeline(BaseReasoning):
content=Render.collapsible(
header=header,
content=Render.table("\n".join(content)),
open=False,
),
)
def format_info_panel_planner(self, planner_output):
planner_output = planner_output.replace("\n", "<br>")
return Document(
channel="info",
content=Render.collapsible(
header="Planner Output",
content=planner_output,
open=True,
),
)
@@ -285,12 +303,19 @@ class RewooAgentPipeline(BaseReasoning):
# line starts with #Plan should be marked as a new segment
new_segment = [line]
segments.append(new_segment)
elif line.startswith("#Action"):
# small fix for markdown output
line = "\\" + line + "<br>"
segments[-1].append(line)
elif line.startswith("#"):
# stop markdown from rendering big headers
line = "\\" + line
segments[-1].append(line)
else:
segments[-1].append(line)
if segments:
segments[-1].append(line)
else:
segments.append([line])
outputs = []
for segment in segments:
@@ -337,18 +362,23 @@ class RewooAgentPipeline(BaseReasoning):
for item in output_stream:
if item.intermediate_steps:
for step in item.intermediate_steps:
yield Document(
channel="info",
content=self.format_info_panel(step["worker_log"]),
)
if "planner_log" in step:
yield Document(
channel="info",
content=self.format_info_panel_planner(step["planner_log"]),
)
else:
yield Document(
channel="info",
content=self.format_info_panel_evidence(step["worker_log"]),
)
if item.text:
# final answer
yield Document(channel="chat", content=item.text)
answer = output_stream.value
yield Document(channel="info", content=None)
refined_citations = self.prepare_citation(answer)
for _ in refined_citations:
yield _
yield from self.prepare_citation(answer)
return answer
@@ -360,6 +390,8 @@ class RewooAgentPipeline(BaseReasoning):
prefix = f"reasoning.options.{_id}"
pipeline = RewooAgentPipeline(retrievers=retrievers)
max_context_length_setting = settings.get("reasoning.max_context_length", None)
planner_llm_name = settings[f"{prefix}.planner_llm"]
planner_llm = llms.get(planner_llm_name, llms.get_default())
solver_llm_name = settings[f"{prefix}.solver_llm"]
@@ -367,6 +399,10 @@ class RewooAgentPipeline(BaseReasoning):
pipeline.agent.planner_llm = planner_llm
pipeline.agent.solver_llm = solver_llm
if max_context_length_setting:
pipeline.agent.max_context_length = (
max_context_length_setting // DEFAULT_AGENT_STEPS
)
tools = []
for tool_name in settings[f"{prefix}.tools"]:
@@ -377,7 +413,7 @@ class RewooAgentPipeline(BaseReasoning):
tool.llm = solver_llm
tools.append(tool)
pipeline.agent.plugins = tools
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
pipeline.agent.output_lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
pipeline.agent.prompt_template["Planner"] = PromptTemplate(
@@ -413,6 +449,7 @@ class RewooAgentPipeline(BaseReasoning):
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"special_type": "llm",
"info": (
"The language model to use for planning. "
"This model will generate a plan based on the "
@@ -424,6 +461,7 @@ class RewooAgentPipeline(BaseReasoning):
"value": llm,
"component": "dropdown",
"choices": llm_choices,
"special_type": "llm",
"info": (
"The language model to use for solving. "
"This model will generate the answer based on the "
@@ -457,6 +495,10 @@ class RewooAgentPipeline(BaseReasoning):
"id": "ReWOO",
"name": "ReWOO Agent",
"description": (
"Implementing ReWOO paradigm " "https://arxiv.org/pdf/2305.18323.pdf"
"Implementing ReWOO paradigm: https://arxiv.org/abs/2305.18323. "
"The ReWOO agent makes a step by step plan in the first stage, "
"then solves each step in the second stage. The agent can use "
"external tools to help in the reasoning process. Once all stages "
"are completed, the agent will summarize the answer."
),
}

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,7 @@ class SettingItem(BaseModel):
choices: list = Field(default_factory=list)
metadata: dict = Field(default_factory=dict)
component: str = "text"
special_type: str = ""
class BaseSettingGroup(BaseModel):
@@ -55,6 +56,9 @@ class BaseSettingGroup(BaseModel):
option = self.options[option_id]
return option.get_setting_item(sub_path)
def __bool__(self):
return bool(self.settings) or bool(self.options)
class SettingReasoningGroup(BaseSettingGroup):
def _get_options(self) -> dict:

View File

@@ -0,0 +1,3 @@
from .lang import SUPPORTED_LANGUAGE_MAP
__all__ = ["SUPPORTED_LANGUAGE_MAP"]

Some files were not shown because too many files have changed in this diff Show More