Refactor reasoning pipeline (#31)

* Move the text rendering out for reusability

* Refactor common operations in the reasoning pipeline

* Add run method

* Provide dedicated method for invoke
This commit is contained in:
Duc Nguyen (john) 2024-04-13 23:13:04 +07:00 committed by GitHub
parent af38708b77
commit 0417610d3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 227 additions and 293 deletions

View File

@ -39,7 +39,7 @@ class BaseComponent(Function):
if isinstance(node, BaseComponent):
node.set_output_queue(queue)
def report_output(self, output: Optional[dict]):
def report_output(self, output: Optional[Document]):
if self._queue is not None:
self._queue.put_nowait(output)

View File

@ -270,7 +270,7 @@ class ChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = {
params_ = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
@ -285,6 +285,7 @@ class ChatOpenAI(BaseChatOpenAI):
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
}
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)
return client.chat.completions.create(**params)

View File

@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.base import BaseComponent
from kotaemon.llms import ChatLLM
from .db import LLMTable, engine
@ -14,7 +14,7 @@ class LLMManager:
"""Represent a pool of models"""
def __init__(self):
self._models: dict[str, BaseComponent] = {}
self._models: dict[str, ChatLLM] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
@ -63,7 +63,7 @@ class LLMManager:
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
def __getitem__(self, key: str) -> BaseComponent:
def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name"""
return self._models[key]
@ -71,9 +71,7 @@ class LLMManager:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
"""Get model by name with default value"""
return self._models.get(key, default)
@ -119,18 +117,18 @@ class LLMManager:
return self._default
def get_random(self) -> BaseComponent:
def get_random(self) -> ChatLLM:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> BaseComponent:
def get_default(self) -> ChatLLM:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseComponent: model
ChatLLM: model
"""
return self._models[self.get_default_name()]

View File

@ -8,6 +8,7 @@ from typing import Generator
import tiktoken
from ktem.llms.manager import llms
from ktem.utils.render import Render
from kotaemon.base import (
BaseComponent,
@ -20,7 +21,7 @@ from kotaemon.base import (
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
from kotaemon.loaders.utils.gpt4v import generate_gpt4v, stream_gpt4v
from .base import BaseReasoning
@ -205,7 +206,68 @@ class AnswerWithContextPipeline(BaseComponent):
system_prompt: str = ""
lang: str = "English" # support English and Japanese
async def run( # type: ignore
def get_prompt(self, question, evidence, evidence_mode: int):
"""Prepare the prompt and other information for LLM"""
images = []
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
return prompt, images
def run(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
return self.invoke(question, evidence, evidence_mode, **kwargs)
def invoke(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
prompt, images = self.get_prompt(question, evidence, evidence_mode)
output = ""
if evidence_mode == EVIDENCE_MODE_FIGURE:
output = generate_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768)
else:
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
output = self.llm(messages).text
# retrieve the citation
citation = None
if evidence and self.enable_citation:
citation = self.citation_pipeline.invoke(
context=evidence, question=question
)
answer = Document(text=output, metadata={"citation": citation})
return answer
async def ainvoke( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
"""Answer the question based on the evidence
@ -230,30 +292,7 @@ class AnswerWithContextPipeline(BaseComponent):
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
prompt, images = self.get_prompt(question, evidence, evidence_mode)
citation_task = None
if evidence and self.enable_citation:
@ -266,7 +305,7 @@ class AnswerWithContextPipeline(BaseComponent):
if evidence_mode == EVIDENCE_MODE_FIGURE:
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
output += text
self.report_output({"output": text})
self.report_output(Document(channel="chat", content=text))
await asyncio.sleep(0)
else:
messages = []
@ -279,12 +318,12 @@ class AnswerWithContextPipeline(BaseComponent):
print("Trying LLM streaming")
for text in self.llm.stream(messages):
output += text.text
self.report_output({"output": text.text})
self.report_output(Document(content=text.text, channel="chat"))
await asyncio.sleep(0)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
self.report_output({"output": output})
self.report_output(Document(content=output, channel="chat"))
# retrieve the citation
print("Waiting for citation task")
@ -300,52 +339,7 @@ class AnswerWithContextPipeline(BaseComponent):
def stream( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Generator[Document, None, Document]:
"""Answer the question based on the evidence
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
prompt, images = self.get_prompt(question, evidence, evidence_mode)
output = ""
if evidence_mode == EVIDENCE_MODE_FIGURE:
@ -425,37 +419,112 @@ class FullQAPipeline(BaseReasoning):
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
async def ainvoke( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
import markdown
docs = []
doc_ids = []
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]:
"""Retrieve the documents based on the message"""
docs, doc_ids = [], []
for retriever in self.retrievers:
for doc in retriever(text=message):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
info = []
for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
info.append(
Document(
channel="info",
content=Render.collapsible(
header=doc.metadata["file_name"],
content=Render.table(doc.text),
open=True,
),
)
)
self.report_output(
{
"evidence": (
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
)
}
return docs, info
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
"""Prepare the citations to show on the UI"""
with_citation, without_citation = [], []
spans = defaultdict(list)
if answer.metadata["citation"] is not None:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += Render.highlight(id2docs[id].text[span["start"] : span["end"]])
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
with_citation.append(
Document(
channel="info",
content=Render.collapsible(
header=id2docs[id].metadata["file_name"],
content=Render.table(text),
open=True,
),
)
)
without_citation = [
Document(
channel="info",
content=Render.collapsible(
header=id2docs[id].metadata["file_name"],
content=Render.table(id2docs[id].text),
open=False,
),
)
for id in list(not_detected)
]
return with_citation, without_citation
async def ainvoke( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
docs, infos = self.retrieve(message)
for _ in infos:
self.report_output(_)
await asyncio.sleep(0.1)
evidence_mode, evidence = self.evidence_pipeline(docs).content
@ -468,90 +537,23 @@ class FullQAPipeline(BaseReasoning):
**kwargs,
)
# prepare citation
spans = defaultdict(list)
if answer.metadata["citation"] is not None:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
self.report_output({"evidence": None})
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
self.report_output(
{
"evidence": (
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
)
}
)
lack_evidence = False
if lack_evidence:
self.report_output({"evidence": "No evidence found.\n"})
if not_detected:
self.report_output(
{"evidence": "Retrieved segments without matching evidence:\n"}
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
# show the evidence
with_citation, without_citation = self.prepare_citations(answer, docs)
if not with_citation and not without_citation:
self.report_output(Document(channel="info", content="No evidence found.\n"))
else:
self.report_output(Document(channel="info", content=None))
for _ in with_citation:
self.report_output(_)
if without_citation:
self.report_output(
{
"evidence": (
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
)
}
Document(
channel="info",
content="Retrieved segments without matching evidence:\n",
)
)
for _ in without_citation:
self.report_output(_)
self.report_output(None)
return answer
@ -559,32 +561,12 @@ class FullQAPipeline(BaseReasoning):
def stream( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Generator[Document, None, Document]:
import markdown
docs = []
doc_ids = []
if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text
for retriever in self.retrievers:
for doc in retriever(text=message):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
),
channel="info",
)
docs, infos = self.retrieve(message)
for _ in infos:
yield _
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = yield from self.answering_pipeline.stream(
@ -596,89 +578,21 @@ class FullQAPipeline(BaseReasoning):
**kwargs,
)
# prepare citation
spans = defaultdict(list)
if answer.metadata["citation"] is not None:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
yield Document(channel="info", content=None)
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
)
lack_evidence = False
if lack_evidence:
# show the evidence
with_citation, without_citation = self.prepare_citations(answer, docs)
if not with_citation and not without_citation:
yield Document(channel="info", content="No evidence found.\n")
if not_detected:
yield Document(
channel="info",
content="Retrieved segments without matching evidence:\n",
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
else:
yield Document(channel="info", content=None)
for _ in with_citation:
yield _
if without_citation:
yield Document(
content=(
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
content="Retrieved segments without matching evidence:\n",
)
for _ in without_citation:
yield _
return answer

View File

View File

@ -0,0 +1,21 @@
import markdown
class Render:
"""Default text rendering into HTML for the UI"""
@staticmethod
def collapsible(header, content, open: bool = False) -> str:
"""Render an HTML friendly collapsible section"""
o = " open" if open else ""
return f"<details{o}><summary>{header}</summary>{content}</details><br>"
@staticmethod
def table(text: str) -> str:
"""Render table from markdown format into HTML"""
return markdown.markdown(text, extensions=["markdown.extensions.tables"])
@staticmethod
def highlight(text: str) -> str:
"""Highlight text"""
return f"<mark>{text}</mark>"