Relate the retrievers to the indexer

This commit is contained in:
trducng
2024-01-27 16:39:40 +07:00
parent 9b586466ff
commit c6637ca56e
5 changed files with 220 additions and 192 deletions

View File

@@ -36,6 +36,7 @@ class ChatPage(BasePage):
).then(
fn=chat_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.data_source.files,
self._app.settings_state,
@@ -64,6 +65,7 @@ class ChatPage(BasePage):
).then(
fn=chat_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.data_source.files,
self._app.settings_state,

View File

@@ -7,8 +7,7 @@ from typing import Optional
import gradio as gr
from ktem.components import llms, reasonings
from ktem.db.models import Conversation, Source, engine
from ktem.indexing.base import BaseIndex
from ktem.reasoning.simple import DocumentRetrievalPipeline
from ktem.indexing.base import BaseIndexing
from sqlmodel import Session, select
from theflow.settings import settings as app_settings
from theflow.utils.modules import import_dotted_string
@@ -26,9 +25,15 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
the pipeline objects
"""
# get retrievers
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
retrievers = indexing_cls.get_pipeline(settings).get_retrievers(
settings, files=files
)
reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
pipeline = reasoning_cls.get_pipeline(settings, files=files)
pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files)
if settings["reasoning.use"] in ["rewoo", "react"]:
from kotaemon.agents import ReactAgent, RewooAgent
@@ -49,47 +54,38 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
from kotaemon.agents import LLMTool
tools.append(LLMTool(llm=llm))
elif tool == "docsearch":
from kotaemon.agents import ComponentTool
# elif tool == "docsearch":
# pass
filenames = ""
if files:
with Session(engine) as session:
statement = select(Source).where(
Source.id.in_(files) # type: ignore
)
results = session.exec(statement).all()
filenames = (
"The file names are: "
+ " ".join([result.name for result in results])
+ ". "
)
# filenames = ""
# if files:
# with Session(engine) as session:
# statement = select(Source).where(
# Source.id.in_(files) # type: ignore
# )
# results = session.exec(statement).all()
# filenames = (
# "The file names are: "
# + " ".join([result.name for result in results])
# + ". "
# )
retrieval_pipeline = DocumentRetrievalPipeline()
retrieval_pipeline.set_run(
{
".top_k": int(settings["retrieval_number"]),
".mmr": settings["retrieval_mmr"],
".doc_ids": files,
},
temp=True,
)
tool = ComponentTool(
name="docsearch",
description=(
"A vector store that searches for similar and "
"related content "
f"in a document. {filenames}"
"The result is a huge chunk of text related "
"to your search but can also "
"contain irrelevant info."
),
component=retrieval_pipeline,
postprocessor=lambda docs: "\n\n".join(
[doc.text.replace("\n", " ") for doc in docs]
),
)
tools.append(tool)
# tool = ComponentTool(
# name="docsearch",
# description=(
# "A vector store that searches for similar and "
# "related content "
# f"in a document. {filenames}"
# "The result is a huge chunk of text related "
# "to your search but can also "
# "contain irrelevant info."
# ),
# component=retrieval_pipeline,
# postprocessor=lambda docs: "\n\n".join(
# [doc.text.replace("\n", " ") for doc in docs]
# ),
# )
# tools.append(tool)
elif tool == "google":
from kotaemon.agents import GoogleSearchTool
@@ -117,7 +113,7 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
return pipeline
async def chat_fn(chat_history, files, settings):
async def chat_fn(conversation_id, chat_history, files, settings):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
@@ -128,7 +124,7 @@ async def chat_fn(chat_history, files, settings):
pipeline = create_pipeline(settings, files)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, chat_history))
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
while True:
@@ -207,7 +203,7 @@ def index_fn(files, reindex: bool, selected_files, settings):
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
indexing_pipeline = indexing_cls.get_pipeline(settings)
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)