kotaemon/libs/ktem/ktem/reasoning/simple.py

468 lines
17 KiB
Python

import asyncio
import logging
import warnings
from collections import defaultdict
from functools import partial
from typing import Optional
import tiktoken
from ktem.components import embeddings, get_docstore, get_vectorstore, llms
from ktem.db.models import Index, SourceTargetRelation, engine
from kotaemon.base import (
BaseComponent,
Document,
HumanMessage,
Node,
RetrievedDocument,
SystemMessage,
)
from kotaemon.indices import VectorRetrieval
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
from llama_index.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
from sqlmodel import Session, select
from theflow.settings import settings
logger = logging.getLogger(__name__)
class DocumentRetrievalPipeline(BaseComponent):
"""Retrieve relevant document
Args:
vector_retrieval: the retrieval pipeline that return the relevant documents
given a text query
reranker: the reranking pipeline that re-rank and filter the retrieved
documents
get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page)
"""
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
doc_store=get_docstore(),
vector_store=get_vectorstore(),
embedding=embeddings.get_default(),
)
reranker: BaseReranking = CohereReranking.withx(
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
get_extra_table: bool = False
def run(
self,
text: str,
top_k: int = 5,
mmr: bool = False,
doc_ids: Optional[list[str]] = None,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval
"""
kwargs = {}
if doc_ids:
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == SourceTargetRelation.VECTOR,
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.exec(stmt)
vs_ids = [r.target_id for r in results.all()]
kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)
if mmr:
# TODO: double check that llama-index MMR works correctly
kwargs["mode"] = VectorStoreQueryMode.MMR
kwargs["mmr_threshold"] = 0.5
# rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
if self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
if not self.get_extra_table:
return docs
# retrieve extra nodes relate to table
table_pages = defaultdict(list)
retrieved_id = set([doc.doc_id for doc in docs])
for doc in docs:
if "page_label" not in doc.metadata:
continue
if "file_name" not in doc.metadata:
warnings.warn(
"file_name not in metadata while page_label is in metadata: "
f"{doc.metadata}"
)
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
queries = [
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
for fn, pls in table_pages.items()
]
if queries:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where={"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
return docs
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
This step usually happens after `DocumentRetrievalPipeline`.
Args:
trim_func: a callback function or a BaseComponent, that splits a large
chunk of text into smaller ones. The first one will be retained.
"""
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=7600,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
def run(self, docs: list[RetrievedDocument]) -> Document:
evidence = ""
table_found = 0
evidence_mode = 0
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
page = retrieved_item.metadata.get("page_label", None)
source = filename = retrieved_item.metadata.get("file_name", "-")
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
evidence_mode = 1 # table
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
table_found += 1
evidence += (
f"<br><b>Table from {source}</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
evidence_mode = 2 # chatbot
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
else:
retrieved_content = retrieved_item.text
retrieved_content = retrieved_content.replace("\n", " ")
if retrieved_content not in evidence:
evidence += (
f"<br><b>Content from {source}: </b> "
+ retrieved_content
+ " \n<br>"
)
print("Retrieved #{}: {}".format(_id, retrieved_content))
print(retrieved_item.metadata)
print("Score", retrieved_item.metadata.get("relevance_score", None))
# trim context by trim_len
print("len (original)", len(evidence))
if evidence:
texts = self.trim_func([Document(text=evidence)])
evidence = texts[0].text
print("len (trimmed)", len(evidence))
print(f"PrepareEvidence with input {input}\nOutput: {evidence}\n")
return Document(content=(evidence_mode, evidence))
DEFAULT_QA_TEXT_PROMPT = (
"Use the following pieces of context to answer the question at the end. "
"If you don't know the answer, just say that you don't know, don't try to "
"make up an answer. Keep the answer as concise as possible. Give answer in "
"{lang}. {system}\n\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
)
DEFAULT_QA_TABLE_PROMPT = (
"List all rows (row number) from the table context that related to the question, "
"then provide detail answer with clear explanation and citations. "
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. Give answer in {lang}. {system}\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
)
DEFAULT_QA_CHATBOT_PROMPT = (
"Pick the most suitable chatbot scenarios to answer the question at the end, "
"output the provided answer text. If you don't know the answer, "
"just say that you don't know. Keep the answer as concise as possible. "
"Give answer in {lang}. {system}\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Answer:"
)
class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
Args:
llm: the language model to generate the answer
citation_pipeline: generates citation from the evidence
qa_template: the prompt template for LLM to generate answer (refer to
evidence_mode)
qa_table_template: the prompt template for LLM to generate answer for table
(refer to evidence_mode)
qa_chatbot_template: the prompt template for LLM to generate answer for
pre-made scenarios (refer to evidence_mode)
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
system_prompt: str = ""
lang: str = "English" # support English and Japanese
async def run( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0
) -> Document:
"""Answer the question based on the evidence
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == 0:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == 1:
prompt_template = PromptTemplate(self.qa_table_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
system=self.system_prompt,
)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
output = ""
for text in self.llm(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
citation = self.citation_pipeline(context=evidence, question=question)
answer = Document(text=output, metadata={"citation": citation})
return answer
class FullQAPipeline(BaseComponent):
"""Question answering pipeline. Handle from question to answer"""
class Config:
allow_extra = True
params_publish = True
retrieval_pipeline: DocumentRetrievalPipeline = DocumentRetrievalPipeline.withx()
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
async def run( # type: ignore
self, question: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
docs = self.retrieval_pipeline(text=question)
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline(
question=question, evidence=evidence, evidence_mode=evidence_mode
)
# prepare citation
from collections import defaultdict
spans = defaultdict(list)
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx >= 0:
spans[doc.doc_id].append(
{
"start": start_idx,
"end": start_idx + len(quote),
}
)
break
id2docs = {doc.doc_id: doc for doc in docs}
for id, ss in spans.items():
if not ss:
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
self.report_output(
{
"evidence": (
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
)
}
)
self.report_output(None)
return answer
@classmethod
def get_pipeline(cls, settings, **kwargs):
"""Get the reasoning pipeline
Need a base pipeline implementation. Currently the drawback is that we want to
treat the retrievers as tools. Hence, the reasoning pipelie should just take
the already initiated tools (retrievers), and do not need to set such logic
here.
"""
pipeline = FullQAPipeline(get_extra_table=settings["index.prioritize_table"])
if not settings["index.use_reranking"]:
pipeline.retrieval_pipeline.reranker = None # type: ignore
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
kwargs = {
".retrieval_pipeline.top_k": int(settings["index.num_retrieval"]),
".retrieval_pipeline.mmr": settings["index.mmr"],
".retrieval_pipeline.doc_ids": kwargs.get("files", None),
}
pipeline.set_run(kwargs, temp=True)
return pipeline
@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms
try:
citation_llm = llms.get_lowest_cost_name()
citation_llm_choices = list(llms.options().keys())
main_llm = llms.get_highest_accuracy_name()
main_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
citation_llm = None
citation_llm_choices = []
main_llm = None
main_llm_choices = []
return {
"highlight_citation": {
"name": "Highlight Citation",
"value": True,
"component": "checkbox",
},
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",
},
"citation_llm": {
"name": "LLM for citation",
"value": citation_llm,
"component": "dropdown",
"choices": citation_llm_choices,
},
"main_llm": {
"name": "LLM for main generation",
"value": main_llm,
"component": "dropdown",
"choices": main_llm_choices,
},
}
@classmethod
def get_info(cls) -> dict:
return {
"id": "simple",
"name": "Simple QA",
"description": "Simple QA pipeline",
}