Refactor reasoning pipeline (#31)
* Move the text rendering out for reusability * Refactor common operations in the reasoning pipeline * Add run method * Provide dedicated method for invoke
This commit is contained in:
parent
af38708b77
commit
0417610d3e
|
@ -39,7 +39,7 @@ class BaseComponent(Function):
|
||||||
if isinstance(node, BaseComponent):
|
if isinstance(node, BaseComponent):
|
||||||
node.set_output_queue(queue)
|
node.set_output_queue(queue)
|
||||||
|
|
||||||
def report_output(self, output: Optional[dict]):
|
def report_output(self, output: Optional[Document]):
|
||||||
if self._queue is not None:
|
if self._queue is not None:
|
||||||
self._queue.put_nowait(output)
|
self._queue.put_nowait(output)
|
||||||
|
|
||||||
|
|
|
@ -270,7 +270,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||||
|
|
||||||
def openai_response(self, client, **kwargs):
|
def openai_response(self, client, **kwargs):
|
||||||
"""Get the openai response"""
|
"""Get the openai response"""
|
||||||
params = {
|
params_ = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
|
@ -285,6 +285,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||||
"top_logprobs": self.top_logprobs,
|
"top_logprobs": self.top_logprobs,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
}
|
}
|
||||||
|
params = {k: v for k, v in params_.items() if v is not None}
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
return client.chat.completions.create(**params)
|
return client.chat.completions.create(**params)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
from theflow.utils.modules import deserialize
|
from theflow.utils.modules import deserialize
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.llms import ChatLLM
|
||||||
|
|
||||||
from .db import LLMTable, engine
|
from .db import LLMTable, engine
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ class LLMManager:
|
||||||
"""Represent a pool of models"""
|
"""Represent a pool of models"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._models: dict[str, BaseComponent] = {}
|
self._models: dict[str, ChatLLM] = {}
|
||||||
self._info: dict[str, dict] = {}
|
self._info: dict[str, dict] = {}
|
||||||
self._default: str = ""
|
self._default: str = ""
|
||||||
self._vendors: list[Type] = []
|
self._vendors: list[Type] = []
|
||||||
|
@ -63,7 +63,7 @@ class LLMManager:
|
||||||
|
|
||||||
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
|
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> BaseComponent:
|
def __getitem__(self, key: str) -> ChatLLM:
|
||||||
"""Get model by name"""
|
"""Get model by name"""
|
||||||
return self._models[key]
|
return self._models[key]
|
||||||
|
|
||||||
|
@ -71,9 +71,7 @@ class LLMManager:
|
||||||
"""Check if model exists"""
|
"""Check if model exists"""
|
||||||
return key in self._models
|
return key in self._models
|
||||||
|
|
||||||
def get(
|
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
|
||||||
self, key: str, default: Optional[BaseComponent] = None
|
|
||||||
) -> Optional[BaseComponent]:
|
|
||||||
"""Get model by name with default value"""
|
"""Get model by name with default value"""
|
||||||
return self._models.get(key, default)
|
return self._models.get(key, default)
|
||||||
|
|
||||||
|
@ -119,18 +117,18 @@ class LLMManager:
|
||||||
|
|
||||||
return self._default
|
return self._default
|
||||||
|
|
||||||
def get_random(self) -> BaseComponent:
|
def get_random(self) -> ChatLLM:
|
||||||
"""Get random model"""
|
"""Get random model"""
|
||||||
return self._models[self.get_random_name()]
|
return self._models[self.get_random_name()]
|
||||||
|
|
||||||
def get_default(self) -> BaseComponent:
|
def get_default(self) -> ChatLLM:
|
||||||
"""Get default model
|
"""Get default model
|
||||||
|
|
||||||
In case there is no default model, choose random model from pool. In
|
In case there is no default model, choose random model from pool. In
|
||||||
case there are multiple default models, choose random from them.
|
case there are multiple default models, choose random from them.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseComponent: model
|
ChatLLM: model
|
||||||
"""
|
"""
|
||||||
return self._models[self.get_default_name()]
|
return self._models[self.get_default_name()]
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Generator
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
|
from ktem.utils.render import Render
|
||||||
|
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
BaseComponent,
|
BaseComponent,
|
||||||
|
@ -20,7 +21,7 @@ from kotaemon.base import (
|
||||||
from kotaemon.indices.qa.citation import CitationPipeline
|
from kotaemon.indices.qa.citation import CitationPipeline
|
||||||
from kotaemon.indices.splitters import TokenSplitter
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
|
from kotaemon.loaders.utils.gpt4v import generate_gpt4v, stream_gpt4v
|
||||||
|
|
||||||
from .base import BaseReasoning
|
from .base import BaseReasoning
|
||||||
|
|
||||||
|
@ -205,7 +206,68 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
lang: str = "English" # support English and Japanese
|
lang: str = "English" # support English and Japanese
|
||||||
|
|
||||||
async def run( # type: ignore
|
def get_prompt(self, question, evidence, evidence_mode: int):
|
||||||
|
"""Prepare the prompt and other information for LLM"""
|
||||||
|
images = []
|
||||||
|
|
||||||
|
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||||
|
prompt_template = PromptTemplate(self.qa_template)
|
||||||
|
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||||
|
prompt_template = PromptTemplate(self.qa_table_template)
|
||||||
|
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||||
|
else:
|
||||||
|
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||||
|
|
||||||
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
# isolate image from evidence
|
||||||
|
evidence, images = self.extract_evidence_images(evidence)
|
||||||
|
prompt = prompt_template.populate(
|
||||||
|
context=evidence,
|
||||||
|
question=question,
|
||||||
|
lang=self.lang,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_template.populate(
|
||||||
|
context=evidence,
|
||||||
|
question=question,
|
||||||
|
lang=self.lang,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt, images
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
|
) -> Document:
|
||||||
|
return self.invoke(question, evidence, evidence_mode, **kwargs)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
|
) -> Document:
|
||||||
|
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
output = generate_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768)
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
if self.system_prompt:
|
||||||
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
messages.append(HumanMessage(content=prompt))
|
||||||
|
output = self.llm(messages).text
|
||||||
|
|
||||||
|
# retrieve the citation
|
||||||
|
citation = None
|
||||||
|
if evidence and self.enable_citation:
|
||||||
|
citation = self.citation_pipeline.invoke(
|
||||||
|
context=evidence, question=question
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = Document(text=output, metadata={"citation": citation})
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
async def ainvoke( # type: ignore
|
||||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
) -> Document:
|
) -> Document:
|
||||||
"""Answer the question based on the evidence
|
"""Answer the question based on the evidence
|
||||||
|
@ -230,30 +292,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
(determined by retrieval pipeline)
|
(determined by retrieval pipeline)
|
||||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||||
"""
|
"""
|
||||||
if evidence_mode == EVIDENCE_MODE_TEXT:
|
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
||||||
prompt_template = PromptTemplate(self.qa_template)
|
|
||||||
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
|
||||||
prompt_template = PromptTemplate(self.qa_table_template)
|
|
||||||
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
|
||||||
prompt_template = PromptTemplate(self.qa_figure_template)
|
|
||||||
else:
|
|
||||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
|
||||||
|
|
||||||
images = []
|
|
||||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
|
||||||
# isolate image from evidence
|
|
||||||
evidence, images = self.extract_evidence_images(evidence)
|
|
||||||
prompt = prompt_template.populate(
|
|
||||||
context=evidence,
|
|
||||||
question=question,
|
|
||||||
lang=self.lang,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_template.populate(
|
|
||||||
context=evidence,
|
|
||||||
question=question,
|
|
||||||
lang=self.lang,
|
|
||||||
)
|
|
||||||
|
|
||||||
citation_task = None
|
citation_task = None
|
||||||
if evidence and self.enable_citation:
|
if evidence and self.enable_citation:
|
||||||
|
@ -266,7 +305,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
||||||
output += text
|
output += text
|
||||||
self.report_output({"output": text})
|
self.report_output(Document(channel="chat", content=text))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
else:
|
else:
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -279,12 +318,12 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
print("Trying LLM streaming")
|
print("Trying LLM streaming")
|
||||||
for text in self.llm.stream(messages):
|
for text in self.llm.stream(messages):
|
||||||
output += text.text
|
output += text.text
|
||||||
self.report_output({"output": text.text})
|
self.report_output(Document(content=text.text, channel="chat"))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
print("Streaming is not supported, falling back to normal processing")
|
print("Streaming is not supported, falling back to normal processing")
|
||||||
output = self.llm(messages).text
|
output = self.llm(messages).text
|
||||||
self.report_output({"output": output})
|
self.report_output(Document(content=output, channel="chat"))
|
||||||
|
|
||||||
# retrieve the citation
|
# retrieve the citation
|
||||||
print("Waiting for citation task")
|
print("Waiting for citation task")
|
||||||
|
@ -300,52 +339,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
def stream( # type: ignore
|
def stream( # type: ignore
|
||||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
) -> Generator[Document, None, Document]:
|
) -> Generator[Document, None, Document]:
|
||||||
"""Answer the question based on the evidence
|
prompt, images = self.get_prompt(question, evidence, evidence_mode)
|
||||||
|
|
||||||
In addition to the question and the evidence, this method also take into
|
|
||||||
account evidence_mode. The evidence_mode tells which kind of evidence is.
|
|
||||||
The kind of evidence affects:
|
|
||||||
1. How the evidence is represented.
|
|
||||||
2. The prompt to generate the answer.
|
|
||||||
|
|
||||||
By default, the evidence_mode is 0, which means the evidence is plain text with
|
|
||||||
no particular semantic representation. The evidence_mode can be:
|
|
||||||
1. "table": There will be HTML markup telling that there is a table
|
|
||||||
within the evidence.
|
|
||||||
2. "chatbot": There will be HTML markup telling that there is a chatbot.
|
|
||||||
This chatbot is a scenario, extracted from an Excel file, where each
|
|
||||||
row corresponds to an interaction.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question: the original question posed by user
|
|
||||||
evidence: the text that contain relevant information to answer the question
|
|
||||||
(determined by retrieval pipeline)
|
|
||||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
|
||||||
"""
|
|
||||||
if evidence_mode == EVIDENCE_MODE_TEXT:
|
|
||||||
prompt_template = PromptTemplate(self.qa_template)
|
|
||||||
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
|
||||||
prompt_template = PromptTemplate(self.qa_table_template)
|
|
||||||
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
|
||||||
prompt_template = PromptTemplate(self.qa_figure_template)
|
|
||||||
else:
|
|
||||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
|
||||||
|
|
||||||
images = []
|
|
||||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
|
||||||
# isolate image from evidence
|
|
||||||
evidence, images = self.extract_evidence_images(evidence)
|
|
||||||
prompt = prompt_template.populate(
|
|
||||||
context=evidence,
|
|
||||||
question=question,
|
|
||||||
lang=self.lang,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_template.populate(
|
|
||||||
context=evidence,
|
|
||||||
question=question,
|
|
||||||
lang=self.lang,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
@ -425,37 +419,112 @@ class FullQAPipeline(BaseReasoning):
|
||||||
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||||
use_rewrite: bool = False
|
use_rewrite: bool = False
|
||||||
|
|
||||||
async def ainvoke( # type: ignore
|
def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]:
|
||||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
"""Retrieve the documents based on the message"""
|
||||||
) -> Document: # type: ignore
|
docs, doc_ids = [], []
|
||||||
import markdown
|
|
||||||
|
|
||||||
docs = []
|
|
||||||
doc_ids = []
|
|
||||||
if self.use_rewrite:
|
|
||||||
rewrite = await self.rewrite_pipeline(question=message)
|
|
||||||
message = rewrite.text
|
|
||||||
|
|
||||||
for retriever in self.retrievers:
|
for retriever in self.retrievers:
|
||||||
for doc in retriever(text=message):
|
for doc in retriever(text=message):
|
||||||
if doc.doc_id not in doc_ids:
|
if doc.doc_id not in doc_ids:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
doc_ids.append(doc.doc_id)
|
doc_ids.append(doc.doc_id)
|
||||||
|
|
||||||
|
info = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# TODO: a better approach to show the information
|
info.append(
|
||||||
text = markdown.markdown(
|
Document(
|
||||||
doc.text, extensions=["markdown.extensions.tables"]
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header=doc.metadata["file_name"],
|
||||||
|
content=Render.table(doc.text),
|
||||||
|
open=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.report_output(
|
|
||||||
{
|
return docs, info
|
||||||
"evidence": (
|
|
||||||
"<details open>"
|
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
|
||||||
f"<summary>{doc.metadata['file_name']}</summary>"
|
"""Prepare the citations to show on the UI"""
|
||||||
f"{text}"
|
with_citation, without_citation = [], []
|
||||||
"</details><br>"
|
spans = defaultdict(list)
|
||||||
)
|
|
||||||
}
|
if answer.metadata["citation"] is not None:
|
||||||
|
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||||
|
for quote in fact_with_evidence.substring_quote:
|
||||||
|
for doc in docs:
|
||||||
|
start_idx = doc.text.find(quote)
|
||||||
|
if start_idx == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
end_idx = start_idx + len(quote)
|
||||||
|
|
||||||
|
current_idx = start_idx
|
||||||
|
if "|" not in doc.text[start_idx:end_idx]:
|
||||||
|
spans[doc.doc_id].append(
|
||||||
|
{"start": start_idx, "end": end_idx}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
while doc.text[current_idx:end_idx].find("|") != -1:
|
||||||
|
match_idx = doc.text[current_idx:end_idx].find("|")
|
||||||
|
spans[doc.doc_id].append(
|
||||||
|
{
|
||||||
|
"start": current_idx,
|
||||||
|
"end": current_idx + match_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_idx += match_idx + 2
|
||||||
|
if current_idx > end_idx:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
id2docs = {doc.doc_id: doc for doc in docs}
|
||||||
|
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||||
|
for id, ss in spans.items():
|
||||||
|
if not ss:
|
||||||
|
not_detected.add(id)
|
||||||
|
continue
|
||||||
|
ss = sorted(ss, key=lambda x: x["start"])
|
||||||
|
text = id2docs[id].text[: ss[0]["start"]]
|
||||||
|
for idx, span in enumerate(ss):
|
||||||
|
text += Render.highlight(id2docs[id].text[span["start"] : span["end"]])
|
||||||
|
if idx < len(ss) - 1:
|
||||||
|
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||||
|
text += id2docs[id].text[ss[-1]["end"] :]
|
||||||
|
with_citation.append(
|
||||||
|
Document(
|
||||||
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header=id2docs[id].metadata["file_name"],
|
||||||
|
content=Render.table(text),
|
||||||
|
open=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
without_citation = [
|
||||||
|
Document(
|
||||||
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header=id2docs[id].metadata["file_name"],
|
||||||
|
content=Render.table(id2docs[id].text),
|
||||||
|
open=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for id in list(not_detected)
|
||||||
|
]
|
||||||
|
|
||||||
|
return with_citation, without_citation
|
||||||
|
|
||||||
|
async def ainvoke( # type: ignore
|
||||||
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
|
) -> Document: # type: ignore
|
||||||
|
if self.use_rewrite:
|
||||||
|
rewrite = await self.rewrite_pipeline(question=message)
|
||||||
|
message = rewrite.text
|
||||||
|
|
||||||
|
docs, infos = self.retrieve(message)
|
||||||
|
for _ in infos:
|
||||||
|
self.report_output(_)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||||
|
@ -468,90 +537,23 @@ class FullQAPipeline(BaseReasoning):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare citation
|
# show the evidence
|
||||||
spans = defaultdict(list)
|
with_citation, without_citation = self.prepare_citations(answer, docs)
|
||||||
if answer.metadata["citation"] is not None:
|
if not with_citation and not without_citation:
|
||||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
self.report_output(Document(channel="info", content="No evidence found.\n"))
|
||||||
for quote in fact_with_evidence.substring_quote:
|
else:
|
||||||
for doc in docs:
|
self.report_output(Document(channel="info", content=None))
|
||||||
start_idx = doc.text.find(quote)
|
for _ in with_citation:
|
||||||
if start_idx == -1:
|
self.report_output(_)
|
||||||
continue
|
if without_citation:
|
||||||
|
|
||||||
end_idx = start_idx + len(quote)
|
|
||||||
|
|
||||||
current_idx = start_idx
|
|
||||||
if "|" not in doc.text[start_idx:end_idx]:
|
|
||||||
spans[doc.doc_id].append(
|
|
||||||
{"start": start_idx, "end": end_idx}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
while doc.text[current_idx:end_idx].find("|") != -1:
|
|
||||||
match_idx = doc.text[current_idx:end_idx].find("|")
|
|
||||||
spans[doc.doc_id].append(
|
|
||||||
{
|
|
||||||
"start": current_idx,
|
|
||||||
"end": current_idx + match_idx,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
current_idx += match_idx + 2
|
|
||||||
if current_idx > end_idx:
|
|
||||||
break
|
|
||||||
break
|
|
||||||
|
|
||||||
id2docs = {doc.doc_id: doc for doc in docs}
|
|
||||||
lack_evidence = True
|
|
||||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
|
||||||
self.report_output({"evidence": None})
|
|
||||||
for id, ss in spans.items():
|
|
||||||
if not ss:
|
|
||||||
not_detected.add(id)
|
|
||||||
continue
|
|
||||||
ss = sorted(ss, key=lambda x: x["start"])
|
|
||||||
text = id2docs[id].text[: ss[0]["start"]]
|
|
||||||
for idx, span in enumerate(ss):
|
|
||||||
text += (
|
|
||||||
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
|
|
||||||
)
|
|
||||||
if idx < len(ss) - 1:
|
|
||||||
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
|
||||||
text += id2docs[id].text[ss[-1]["end"] :]
|
|
||||||
text_out = markdown.markdown(
|
|
||||||
text, extensions=["markdown.extensions.tables"]
|
|
||||||
)
|
|
||||||
self.report_output(
|
|
||||||
{
|
|
||||||
"evidence": (
|
|
||||||
"<details open>"
|
|
||||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
|
||||||
f"{text_out}"
|
|
||||||
"</details><br>"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
lack_evidence = False
|
|
||||||
|
|
||||||
if lack_evidence:
|
|
||||||
self.report_output({"evidence": "No evidence found.\n"})
|
|
||||||
|
|
||||||
if not_detected:
|
|
||||||
self.report_output(
|
|
||||||
{"evidence": "Retrieved segments without matching evidence:\n"}
|
|
||||||
)
|
|
||||||
for id in list(not_detected):
|
|
||||||
text_out = markdown.markdown(
|
|
||||||
id2docs[id].text, extensions=["markdown.extensions.tables"]
|
|
||||||
)
|
|
||||||
self.report_output(
|
self.report_output(
|
||||||
{
|
Document(
|
||||||
"evidence": (
|
channel="info",
|
||||||
"<details>"
|
content="Retrieved segments without matching evidence:\n",
|
||||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
)
|
||||||
f"{text_out}"
|
|
||||||
"</details><br>"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
for _ in without_citation:
|
||||||
|
self.report_output(_)
|
||||||
|
|
||||||
self.report_output(None)
|
self.report_output(None)
|
||||||
return answer
|
return answer
|
||||||
|
@ -559,32 +561,12 @@ class FullQAPipeline(BaseReasoning):
|
||||||
def stream( # type: ignore
|
def stream( # type: ignore
|
||||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
) -> Generator[Document, None, Document]:
|
) -> Generator[Document, None, Document]:
|
||||||
import markdown
|
|
||||||
|
|
||||||
docs = []
|
|
||||||
doc_ids = []
|
|
||||||
if self.use_rewrite:
|
if self.use_rewrite:
|
||||||
message = self.rewrite_pipeline(question=message).text
|
message = self.rewrite_pipeline(question=message).text
|
||||||
|
|
||||||
for retriever in self.retrievers:
|
docs, infos = self.retrieve(message)
|
||||||
for doc in retriever(text=message):
|
for _ in infos:
|
||||||
if doc.doc_id not in doc_ids:
|
yield _
|
||||||
docs.append(doc)
|
|
||||||
doc_ids.append(doc.doc_id)
|
|
||||||
for doc in docs:
|
|
||||||
# TODO: a better approach to show the information
|
|
||||||
text = markdown.markdown(
|
|
||||||
doc.text, extensions=["markdown.extensions.tables"]
|
|
||||||
)
|
|
||||||
yield Document(
|
|
||||||
content=(
|
|
||||||
"<details open>"
|
|
||||||
f"<summary>{doc.metadata['file_name']}</summary>"
|
|
||||||
f"{text}"
|
|
||||||
"</details><br>"
|
|
||||||
),
|
|
||||||
channel="info",
|
|
||||||
)
|
|
||||||
|
|
||||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||||
answer = yield from self.answering_pipeline.stream(
|
answer = yield from self.answering_pipeline.stream(
|
||||||
|
@ -596,89 +578,21 @@ class FullQAPipeline(BaseReasoning):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare citation
|
# show the evidence
|
||||||
spans = defaultdict(list)
|
with_citation, without_citation = self.prepare_citations(answer, docs)
|
||||||
if answer.metadata["citation"] is not None:
|
if not with_citation and not without_citation:
|
||||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
|
||||||
for quote in fact_with_evidence.substring_quote:
|
|
||||||
for doc in docs:
|
|
||||||
start_idx = doc.text.find(quote)
|
|
||||||
if start_idx == -1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
end_idx = start_idx + len(quote)
|
|
||||||
|
|
||||||
current_idx = start_idx
|
|
||||||
if "|" not in doc.text[start_idx:end_idx]:
|
|
||||||
spans[doc.doc_id].append(
|
|
||||||
{"start": start_idx, "end": end_idx}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
while doc.text[current_idx:end_idx].find("|") != -1:
|
|
||||||
match_idx = doc.text[current_idx:end_idx].find("|")
|
|
||||||
spans[doc.doc_id].append(
|
|
||||||
{
|
|
||||||
"start": current_idx,
|
|
||||||
"end": current_idx + match_idx,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
current_idx += match_idx + 2
|
|
||||||
if current_idx > end_idx:
|
|
||||||
break
|
|
||||||
break
|
|
||||||
|
|
||||||
id2docs = {doc.doc_id: doc for doc in docs}
|
|
||||||
lack_evidence = True
|
|
||||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
|
||||||
yield Document(channel="info", content=None)
|
|
||||||
for id, ss in spans.items():
|
|
||||||
if not ss:
|
|
||||||
not_detected.add(id)
|
|
||||||
continue
|
|
||||||
ss = sorted(ss, key=lambda x: x["start"])
|
|
||||||
text = id2docs[id].text[: ss[0]["start"]]
|
|
||||||
for idx, span in enumerate(ss):
|
|
||||||
text += (
|
|
||||||
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
|
|
||||||
)
|
|
||||||
if idx < len(ss) - 1:
|
|
||||||
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
|
||||||
text += id2docs[id].text[ss[-1]["end"] :]
|
|
||||||
text_out = markdown.markdown(
|
|
||||||
text, extensions=["markdown.extensions.tables"]
|
|
||||||
)
|
|
||||||
yield Document(
|
|
||||||
content=(
|
|
||||||
"<details open>"
|
|
||||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
|
||||||
f"{text_out}"
|
|
||||||
"</details><br>"
|
|
||||||
),
|
|
||||||
channel="info",
|
|
||||||
)
|
|
||||||
lack_evidence = False
|
|
||||||
|
|
||||||
if lack_evidence:
|
|
||||||
yield Document(channel="info", content="No evidence found.\n")
|
yield Document(channel="info", content="No evidence found.\n")
|
||||||
|
else:
|
||||||
if not_detected:
|
yield Document(channel="info", content=None)
|
||||||
yield Document(
|
for _ in with_citation:
|
||||||
channel="info",
|
yield _
|
||||||
content="Retrieved segments without matching evidence:\n",
|
if without_citation:
|
||||||
)
|
|
||||||
for id in list(not_detected):
|
|
||||||
text_out = markdown.markdown(
|
|
||||||
id2docs[id].text, extensions=["markdown.extensions.tables"]
|
|
||||||
)
|
|
||||||
yield Document(
|
yield Document(
|
||||||
content=(
|
|
||||||
"<details>"
|
|
||||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
|
||||||
f"{text_out}"
|
|
||||||
"</details><br>"
|
|
||||||
),
|
|
||||||
channel="info",
|
channel="info",
|
||||||
|
content="Retrieved segments without matching evidence:\n",
|
||||||
)
|
)
|
||||||
|
for _ in without_citation:
|
||||||
|
yield _
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
0
libs/ktem/ktem/utils/__init__.py
Normal file
0
libs/ktem/ktem/utils/__init__.py
Normal file
21
libs/ktem/ktem/utils/render.py
Normal file
21
libs/ktem/ktem/utils/render.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import markdown
|
||||||
|
|
||||||
|
|
||||||
|
class Render:
|
||||||
|
"""Default text rendering into HTML for the UI"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collapsible(header, content, open: bool = False) -> str:
|
||||||
|
"""Render an HTML friendly collapsible section"""
|
||||||
|
o = " open" if open else ""
|
||||||
|
return f"<details{o}><summary>{header}</summary>{content}</details><br>"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def table(text: str) -> str:
|
||||||
|
"""Render table from markdown format into HTML"""
|
||||||
|
return markdown.markdown(text, extensions=["markdown.extensions.tables"])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def highlight(text: str) -> str:
|
||||||
|
"""Highlight text"""
|
||||||
|
return f"<mark>{text}</mark>"
|
Loading…
Reference in New Issue
Block a user