Relate the retrievers to the indexer
This commit is contained in:
@@ -36,6 +36,7 @@ class ChatPage(BasePage):
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
@@ -64,6 +65,7 @@ class ChatPage(BasePage):
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
|
@@ -7,8 +7,7 @@ from typing import Optional
|
||||
import gradio as gr
|
||||
from ktem.components import llms, reasonings
|
||||
from ktem.db.models import Conversation, Source, engine
|
||||
from ktem.indexing.base import BaseIndex
|
||||
from ktem.reasoning.simple import DocumentRetrievalPipeline
|
||||
from ktem.indexing.base import BaseIndexing
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as app_settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
@@ -26,9 +25,15 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||
the pipeline objects
|
||||
"""
|
||||
|
||||
# get retrievers
|
||||
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
retrievers = indexing_cls.get_pipeline(settings).get_retrievers(
|
||||
settings, files=files
|
||||
)
|
||||
|
||||
reasoning_mode = settings["reasoning.use"]
|
||||
reasoning_cls = reasonings[reasoning_mode]
|
||||
pipeline = reasoning_cls.get_pipeline(settings, files=files)
|
||||
pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files)
|
||||
|
||||
if settings["reasoning.use"] in ["rewoo", "react"]:
|
||||
from kotaemon.agents import ReactAgent, RewooAgent
|
||||
@@ -49,47 +54,38 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||
from kotaemon.agents import LLMTool
|
||||
|
||||
tools.append(LLMTool(llm=llm))
|
||||
elif tool == "docsearch":
|
||||
from kotaemon.agents import ComponentTool
|
||||
# elif tool == "docsearch":
|
||||
# pass
|
||||
|
||||
filenames = ""
|
||||
if files:
|
||||
with Session(engine) as session:
|
||||
statement = select(Source).where(
|
||||
Source.id.in_(files) # type: ignore
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
filenames = (
|
||||
"The file names are: "
|
||||
+ " ".join([result.name for result in results])
|
||||
+ ". "
|
||||
)
|
||||
# filenames = ""
|
||||
# if files:
|
||||
# with Session(engine) as session:
|
||||
# statement = select(Source).where(
|
||||
# Source.id.in_(files) # type: ignore
|
||||
# )
|
||||
# results = session.exec(statement).all()
|
||||
# filenames = (
|
||||
# "The file names are: "
|
||||
# + " ".join([result.name for result in results])
|
||||
# + ". "
|
||||
# )
|
||||
|
||||
retrieval_pipeline = DocumentRetrievalPipeline()
|
||||
retrieval_pipeline.set_run(
|
||||
{
|
||||
".top_k": int(settings["retrieval_number"]),
|
||||
".mmr": settings["retrieval_mmr"],
|
||||
".doc_ids": files,
|
||||
},
|
||||
temp=True,
|
||||
)
|
||||
tool = ComponentTool(
|
||||
name="docsearch",
|
||||
description=(
|
||||
"A vector store that searches for similar and "
|
||||
"related content "
|
||||
f"in a document. {filenames}"
|
||||
"The result is a huge chunk of text related "
|
||||
"to your search but can also "
|
||||
"contain irrelevant info."
|
||||
),
|
||||
component=retrieval_pipeline,
|
||||
postprocessor=lambda docs: "\n\n".join(
|
||||
[doc.text.replace("\n", " ") for doc in docs]
|
||||
),
|
||||
)
|
||||
tools.append(tool)
|
||||
# tool = ComponentTool(
|
||||
# name="docsearch",
|
||||
# description=(
|
||||
# "A vector store that searches for similar and "
|
||||
# "related content "
|
||||
# f"in a document. {filenames}"
|
||||
# "The result is a huge chunk of text related "
|
||||
# "to your search but can also "
|
||||
# "contain irrelevant info."
|
||||
# ),
|
||||
# component=retrieval_pipeline,
|
||||
# postprocessor=lambda docs: "\n\n".join(
|
||||
# [doc.text.replace("\n", " ") for doc in docs]
|
||||
# ),
|
||||
# )
|
||||
# tools.append(tool)
|
||||
elif tool == "google":
|
||||
from kotaemon.agents import GoogleSearchTool
|
||||
|
||||
@@ -117,7 +113,7 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||
return pipeline
|
||||
|
||||
|
||||
async def chat_fn(chat_history, files, settings):
|
||||
async def chat_fn(conversation_id, chat_history, files, settings):
|
||||
"""Chat function"""
|
||||
chat_input = chat_history[-1][0]
|
||||
chat_history = chat_history[:-1]
|
||||
@@ -128,7 +124,7 @@ async def chat_fn(chat_history, files, settings):
|
||||
pipeline = create_pipeline(settings, files)
|
||||
pipeline.set_output_queue(queue)
|
||||
|
||||
asyncio.create_task(pipeline(chat_input, chat_history))
|
||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||
text, refs = "", ""
|
||||
|
||||
while True:
|
||||
@@ -207,7 +203,7 @@ def index_fn(files, reindex: bool, selected_files, settings):
|
||||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
||||
|
||||
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
||||
|
Reference in New Issue
Block a user