Refactor reasoning pipeline (#31)

* Move the text rendering out for reusability

* Refactor common operations in the reasoning pipeline

* Add run method

* Provide dedicated method for invoke
This commit is contained in:
Duc Nguyen (john) 2024-04-13 23:13:04 +07:00 committed by GitHub
parent af38708b77
commit 0417610d3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 227 additions and 293 deletions

View File

@ -39,7 +39,7 @@ class BaseComponent(Function):
if isinstance(node, BaseComponent): if isinstance(node, BaseComponent):
node.set_output_queue(queue) node.set_output_queue(queue)
def report_output(self, output: Optional[dict]): def report_output(self, output: Optional[Document]):
if self._queue is not None: if self._queue is not None:
self._queue.put_nowait(output) self._queue.put_nowait(output)

View File

@ -270,7 +270,7 @@ class ChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs): def openai_response(self, client, **kwargs):
"""Get the openai response""" """Get the openai response"""
params = { params_ = {
"model": self.model, "model": self.model,
"temperature": self.temperature, "temperature": self.temperature,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
@ -285,6 +285,7 @@ class ChatOpenAI(BaseChatOpenAI):
"top_logprobs": self.top_logprobs, "top_logprobs": self.top_logprobs,
"top_p": self.top_p, "top_p": self.top_p,
} }
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs) params.update(kwargs)
return client.chat.completions.create(**params) return client.chat.completions.create(**params)

View File

@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize from theflow.utils.modules import deserialize
from kotaemon.base import BaseComponent from kotaemon.llms import ChatLLM
from .db import LLMTable, engine from .db import LLMTable, engine
@ -14,7 +14,7 @@ class LLMManager:
"""Represent a pool of models""" """Represent a pool of models"""
def __init__(self): def __init__(self):
self._models: dict[str, BaseComponent] = {} self._models: dict[str, ChatLLM] = {}
self._info: dict[str, dict] = {} self._info: dict[str, dict] = {}
self._default: str = "" self._default: str = ""
self._vendors: list[Type] = [] self._vendors: list[Type] = []
@ -63,7 +63,7 @@ class LLMManager:
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM] self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
def __getitem__(self, key: str) -> BaseComponent: def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name""" """Get model by name"""
return self._models[key] return self._models[key]
@ -71,9 +71,7 @@ class LLMManager:
"""Check if model exists""" """Check if model exists"""
return key in self._models return key in self._models
def get( def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
"""Get model by name with default value""" """Get model by name with default value"""
return self._models.get(key, default) return self._models.get(key, default)
@ -119,18 +117,18 @@ class LLMManager:
return self._default return self._default
def get_random(self) -> BaseComponent: def get_random(self) -> ChatLLM:
"""Get random model""" """Get random model"""
return self._models[self.get_random_name()] return self._models[self.get_random_name()]
def get_default(self) -> BaseComponent: def get_default(self) -> ChatLLM:
"""Get default model """Get default model
In case there is no default model, choose random model from pool. In In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them. case there are multiple default models, choose random from them.
Returns: Returns:
BaseComponent: model ChatLLM: model
""" """
return self._models[self.get_default_name()] return self._models[self.get_default_name()]

View File

@ -8,6 +8,7 @@ from typing import Generator
import tiktoken import tiktoken
from ktem.llms.manager import llms from ktem.llms.manager import llms
from ktem.utils.render import Render
from kotaemon.base import ( from kotaemon.base import (
BaseComponent, BaseComponent,
@ -20,7 +21,7 @@ from kotaemon.base import (
from kotaemon.indices.qa.citation import CitationPipeline from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate from kotaemon.llms import ChatLLM, PromptTemplate
from kotaemon.loaders.utils.gpt4v import stream_gpt4v from kotaemon.loaders.utils.gpt4v import generate_gpt4v, stream_gpt4v
from .base import BaseReasoning from .base import BaseReasoning
@ -205,7 +206,68 @@ class AnswerWithContextPipeline(BaseComponent):
system_prompt: str = "" system_prompt: str = ""
lang: str = "English" # support English and Japanese lang: str = "English" # support English and Japanese
async def run( # type: ignore def get_prompt(self, question, evidence, evidence_mode: int):
"""Prepare the prompt and other information for LLM"""
images = []
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
return prompt, images
def run(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
return self.invoke(question, evidence, evidence_mode, **kwargs)
def invoke(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
prompt, images = self.get_prompt(question, evidence, evidence_mode)
output = ""
if evidence_mode == EVIDENCE_MODE_FIGURE:
output = generate_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768)
else:
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
output = self.llm(messages).text
# retrieve the citation
citation = None
if evidence and self.enable_citation:
citation = self.citation_pipeline.invoke(
context=evidence, question=question
)
answer = Document(text=output, metadata={"citation": citation})
return answer
async def ainvoke( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document: ) -> Document:
"""Answer the question based on the evidence """Answer the question based on the evidence
@ -230,30 +292,7 @@ class AnswerWithContextPipeline(BaseComponent):
(determined by retrieval pipeline) (determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
""" """
if evidence_mode == EVIDENCE_MODE_TEXT: prompt, images = self.get_prompt(question, evidence, evidence_mode)
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
citation_task = None citation_task = None
if evidence and self.enable_citation: if evidence and self.enable_citation:
@ -266,7 +305,7 @@ class AnswerWithContextPipeline(BaseComponent):
if evidence_mode == EVIDENCE_MODE_FIGURE: if evidence_mode == EVIDENCE_MODE_FIGURE:
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
output += text output += text
self.report_output({"output": text}) self.report_output(Document(channel="chat", content=text))
await asyncio.sleep(0) await asyncio.sleep(0)
else: else:
messages = [] messages = []
@ -279,12 +318,12 @@ class AnswerWithContextPipeline(BaseComponent):
print("Trying LLM streaming") print("Trying LLM streaming")
for text in self.llm.stream(messages): for text in self.llm.stream(messages):
output += text.text output += text.text
self.report_output({"output": text.text}) self.report_output(Document(content=text.text, channel="chat"))
await asyncio.sleep(0) await asyncio.sleep(0)
except NotImplementedError: except NotImplementedError:
print("Streaming is not supported, falling back to normal processing") print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text output = self.llm(messages).text
self.report_output({"output": output}) self.report_output(Document(content=output, channel="chat"))
# retrieve the citation # retrieve the citation
print("Waiting for citation task") print("Waiting for citation task")
@ -300,52 +339,7 @@ class AnswerWithContextPipeline(BaseComponent):
def stream( # type: ignore def stream( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Generator[Document, None, Document]: ) -> Generator[Document, None, Document]:
"""Answer the question based on the evidence prompt, images = self.get_prompt(question, evidence, evidence_mode)
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
output = "" output = ""
if evidence_mode == EVIDENCE_MODE_FIGURE: if evidence_mode == EVIDENCE_MODE_FIGURE:
@ -425,37 +419,112 @@ class FullQAPipeline(BaseReasoning):
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False use_rewrite: bool = False
async def ainvoke( # type: ignore def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]:
self, message: str, conv_id: str, history: list, **kwargs # type: ignore """Retrieve the documents based on the message"""
) -> Document: # type: ignore docs, doc_ids = [], []
import markdown
docs = []
doc_ids = []
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
for retriever in self.retrievers: for retriever in self.retrievers:
for doc in retriever(text=message): for doc in retriever(text=message):
if doc.doc_id not in doc_ids: if doc.doc_id not in doc_ids:
docs.append(doc) docs.append(doc)
doc_ids.append(doc.doc_id) doc_ids.append(doc.doc_id)
info = []
for doc in docs: for doc in docs:
# TODO: a better approach to show the information info.append(
text = markdown.markdown( Document(
doc.text, extensions=["markdown.extensions.tables"] channel="info",
content=Render.collapsible(
header=doc.metadata["file_name"],
content=Render.table(doc.text),
open=True,
),
)
) )
self.report_output(
{ return docs, info
"evidence": (
"<details open>" def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
f"<summary>{doc.metadata['file_name']}</summary>" """Prepare the citations to show on the UI"""
f"{text}" with_citation, without_citation = [], []
"</details><br>" spans = defaultdict(list)
)
} if answer.metadata["citation"] is not None:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += Render.highlight(id2docs[id].text[span["start"] : span["end"]])
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
with_citation.append(
Document(
channel="info",
content=Render.collapsible(
header=id2docs[id].metadata["file_name"],
content=Render.table(text),
open=True,
),
)
) )
without_citation = [
Document(
channel="info",
content=Render.collapsible(
header=id2docs[id].metadata["file_name"],
content=Render.table(id2docs[id].text),
open=False,
),
)
for id in list(not_detected)
]
return with_citation, without_citation
async def ainvoke( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
docs, infos = self.retrieve(message)
for _ in infos:
self.report_output(_)
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
evidence_mode, evidence = self.evidence_pipeline(docs).content evidence_mode, evidence = self.evidence_pipeline(docs).content
@ -468,90 +537,23 @@ class FullQAPipeline(BaseReasoning):
**kwargs, **kwargs,
) )
# prepare citation # show the evidence
spans = defaultdict(list) with_citation, without_citation = self.prepare_citations(answer, docs)
if answer.metadata["citation"] is not None: if not with_citation and not without_citation:
for fact_with_evidence in answer.metadata["citation"].answer: self.report_output(Document(channel="info", content="No evidence found.\n"))
for quote in fact_with_evidence.substring_quote: else:
for doc in docs: self.report_output(Document(channel="info", content=None))
start_idx = doc.text.find(quote) for _ in with_citation:
if start_idx == -1: self.report_output(_)
continue if without_citation:
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
self.report_output({"evidence": None})
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
self.report_output(
{
"evidence": (
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
)
}
)
lack_evidence = False
if lack_evidence:
self.report_output({"evidence": "No evidence found.\n"})
if not_detected:
self.report_output(
{"evidence": "Retrieved segments without matching evidence:\n"}
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
self.report_output( self.report_output(
{ Document(
"evidence": ( channel="info",
"<details>" content="Retrieved segments without matching evidence:\n",
f"<summary>{id2docs[id].metadata['file_name']}</summary>" )
f"{text_out}"
"</details><br>"
)
}
) )
for _ in without_citation:
self.report_output(_)
self.report_output(None) self.report_output(None)
return answer return answer
@ -559,32 +561,12 @@ class FullQAPipeline(BaseReasoning):
def stream( # type: ignore def stream( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Generator[Document, None, Document]: ) -> Generator[Document, None, Document]:
import markdown
docs = []
doc_ids = []
if self.use_rewrite: if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text message = self.rewrite_pipeline(question=message).text
for retriever in self.retrievers: docs, infos = self.retrieve(message)
for doc in retriever(text=message): for _ in infos:
if doc.doc_id not in doc_ids: yield _
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
),
channel="info",
)
evidence_mode, evidence = self.evidence_pipeline(docs).content evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = yield from self.answering_pipeline.stream( answer = yield from self.answering_pipeline.stream(
@ -596,89 +578,21 @@ class FullQAPipeline(BaseReasoning):
**kwargs, **kwargs,
) )
# prepare citation # show the evidence
spans = defaultdict(list) with_citation, without_citation = self.prepare_citations(answer, docs)
if answer.metadata["citation"] is not None: if not with_citation and not without_citation:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
yield Document(channel="info", content=None)
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
)
lack_evidence = False
if lack_evidence:
yield Document(channel="info", content="No evidence found.\n") yield Document(channel="info", content="No evidence found.\n")
else:
if not_detected: yield Document(channel="info", content=None)
yield Document( for _ in with_citation:
channel="info", yield _
content="Retrieved segments without matching evidence:\n", if without_citation:
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
yield Document( yield Document(
content=(
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info", channel="info",
content="Retrieved segments without matching evidence:\n",
) )
for _ in without_citation:
yield _
return answer return answer

View File

View File

@ -0,0 +1,21 @@
import markdown
class Render:
"""Default text rendering into HTML for the UI"""
@staticmethod
def collapsible(header, content, open: bool = False) -> str:
"""Render an HTML friendly collapsible section"""
o = " open" if open else ""
return f"<details{o}><summary>{header}</summary>{content}</details><br>"
@staticmethod
def table(text: str) -> str:
"""Render table from markdown format into HTML"""
return markdown.markdown(text, extensions=["markdown.extensions.tables"])
@staticmethod
def highlight(text: str) -> str:
"""Highlight text"""
return f"<mark>{text}</mark>"