feat: add inline citation style (#523) bump:minor
* feat: add URL quick index, export mindmap, refine UI & animation * fix: inject multimodal mode from env var * fix: minor update css * feat: add citation inline mode * fix: minor update citation inline pipeline * feat: add citation quick setting * fix: minor update * fix: minor update
This commit is contained in:
committed by
GitHub
parent
013f6f4103
commit
7e34e4343b
@@ -1,7 +1,5 @@
|
||||
from .citation import CitationPipeline
|
||||
from .text_based import CitationQAPipeline
|
||||
|
||||
__all__ = [
|
||||
"CitationPipeline",
|
||||
"CitationQAPipeline",
|
||||
]
|
||||
|
390
libs/kotaemon/kotaemon/indices/qa/citation_qa.py
Normal file
390
libs/kotaemon/kotaemon/indices/qa/citation_qa.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import (
|
||||
AIMessage,
|
||||
BaseComponent,
|
||||
Document,
|
||||
HumanMessage,
|
||||
Node,
|
||||
SystemMessage,
|
||||
)
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
from .format_context import (
|
||||
EVIDENCE_MODE_FIGURE,
|
||||
EVIDENCE_MODE_TABLE,
|
||||
EVIDENCE_MODE_TEXT,
|
||||
)
|
||||
from .utils import find_text
|
||||
|
||||
try:
|
||||
from ktem.llms.manager import llms
|
||||
from ktem.reasoning.prompt_optimization.mindmap import CreateMindmapPipeline
|
||||
from ktem.utils.render import Render
|
||||
except ImportError:
|
||||
raise ImportError("Please install `ktem` to use this component")
|
||||
|
||||
MAX_IMAGES = 10
|
||||
CITATION_TIMEOUT = 5.0
|
||||
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
|
||||
|
||||
DEFAULT_QA_TEXT_PROMPT = (
|
||||
"Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
|
||||
"If you don't know the answer, just say that you don't know, don't try to "
|
||||
"make up an answer. Give answer in "
|
||||
"{lang}.\n\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Helpful Answer:"
|
||||
)
|
||||
|
||||
DEFAULT_QA_TABLE_PROMPT = (
|
||||
"Use the given context: texts, tables, and figures below to answer the question, "
|
||||
"then provide answer with clear explanation."
|
||||
"If you don't know the answer, just say that you don't know, "
|
||||
"don't try to make up an answer. Give answer in {lang}.\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Helpful Answer:"
|
||||
) # noqa
|
||||
|
||||
DEFAULT_QA_CHATBOT_PROMPT = (
|
||||
"Pick the most suitable chatbot scenarios to answer the question at the end, "
|
||||
"output the provided answer text. If you don't know the answer, "
|
||||
"just say that you don't know. Keep the answer as concise as possible. "
|
||||
"Give answer in {lang}.\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer:"
|
||||
) # noqa
|
||||
|
||||
DEFAULT_QA_FIGURE_PROMPT = (
|
||||
"Use the given context: texts, tables, and figures below to answer the question. "
|
||||
"If you don't know the answer, just say that you don't know. "
|
||||
"Give answer in {lang}.\n\n"
|
||||
"Context: \n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer: "
|
||||
) # noqa
|
||||
|
||||
|
||||
class AnswerWithContextPipeline(BaseComponent):
|
||||
"""Answer the question based on the evidence
|
||||
|
||||
Args:
|
||||
llm: the language model to generate the answer
|
||||
citation_pipeline: generates citation from the evidence
|
||||
qa_template: the prompt template for LLM to generate answer (refer to
|
||||
evidence_mode)
|
||||
qa_table_template: the prompt template for LLM to generate answer for table
|
||||
(refer to evidence_mode)
|
||||
qa_chatbot_template: the prompt template for LLM to generate answer for
|
||||
pre-made scenarios (refer to evidence_mode)
|
||||
lang: the language of the answer. Currently support English and Japanese
|
||||
"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||
vlm_endpoint: str = getattr(flowsettings, "KH_VLM_ENDPOINT", "")
|
||||
use_multimodal: bool = getattr(flowsettings, "KH_REASONINGS_USE_MULTIMODAL", True)
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
|
||||
)
|
||||
create_mindmap_pipeline: CreateMindmapPipeline = Node(
|
||||
default_callback=lambda _: CreateMindmapPipeline(llm=llms.get_default())
|
||||
)
|
||||
|
||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
||||
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
||||
|
||||
enable_citation: bool = False
|
||||
enable_mindmap: bool = False
|
||||
enable_citation_viz: bool = False
|
||||
|
||||
system_prompt: str = ""
|
||||
lang: str = "English" # support English and Japanese
|
||||
n_last_interactions: int = 5
|
||||
|
||||
def get_prompt(self, question, evidence, evidence_mode: int):
|
||||
"""Prepare the prompt and other information for LLM"""
|
||||
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||
prompt_template = PromptTemplate(self.qa_template)
|
||||
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||
prompt_template = PromptTemplate(self.qa_table_template)
|
||||
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
if self.use_multimodal:
|
||||
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||
else:
|
||||
prompt_template = PromptTemplate(self.qa_template)
|
||||
else:
|
||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
|
||||
return prompt, evidence
|
||||
|
||||
def run(
|
||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||
) -> Document:
|
||||
return self.invoke(question, evidence, evidence_mode, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
question: str,
|
||||
evidence: str,
|
||||
evidence_mode: int = 0,
|
||||
images: list[str] = [],
|
||||
**kwargs,
|
||||
) -> Document:
|
||||
raise NotImplementedError
|
||||
|
||||
async def ainvoke( # type: ignore
|
||||
self,
|
||||
question: str,
|
||||
evidence: str,
|
||||
evidence_mode: int = 0,
|
||||
images: list[str] = [],
|
||||
**kwargs,
|
||||
) -> Document:
|
||||
"""Answer the question based on the evidence
|
||||
|
||||
In addition to the question and the evidence, this method also take into
|
||||
account evidence_mode. The evidence_mode tells which kind of evidence is.
|
||||
The kind of evidence affects:
|
||||
1. How the evidence is represented.
|
||||
2. The prompt to generate the answer.
|
||||
|
||||
By default, the evidence_mode is 0, which means the evidence is plain text with
|
||||
no particular semantic representation. The evidence_mode can be:
|
||||
1. "table": There will be HTML markup telling that there is a table
|
||||
within the evidence.
|
||||
2. "chatbot": There will be HTML markup telling that there is a chatbot.
|
||||
This chatbot is a scenario, extracted from an Excel file, where each
|
||||
row corresponds to an interaction.
|
||||
|
||||
Args:
|
||||
question: the original question posed by user
|
||||
evidence: the text that contain relevant information to answer the question
|
||||
(determined by retrieval pipeline)
|
||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def stream( # type: ignore
|
||||
self,
|
||||
question: str,
|
||||
evidence: str,
|
||||
evidence_mode: int = 0,
|
||||
images: list[str] = [],
|
||||
**kwargs,
|
||||
) -> Generator[Document, None, Document]:
|
||||
history = kwargs.get("history", [])
|
||||
print(f"Got {len(images)} images")
|
||||
# check if evidence exists, use QA prompt
|
||||
if evidence:
|
||||
prompt, evidence = self.get_prompt(question, evidence, evidence_mode)
|
||||
else:
|
||||
prompt = question
|
||||
|
||||
# retrieve the citation
|
||||
citation = None
|
||||
mindmap = None
|
||||
|
||||
def citation_call():
|
||||
nonlocal citation
|
||||
citation = self.citation_pipeline(context=evidence, question=question)
|
||||
|
||||
def mindmap_call():
|
||||
nonlocal mindmap
|
||||
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
|
||||
|
||||
citation_thread = None
|
||||
mindmap_thread = None
|
||||
|
||||
# execute function call in thread
|
||||
if evidence:
|
||||
if self.enable_citation:
|
||||
citation_thread = threading.Thread(target=citation_call)
|
||||
citation_thread.start()
|
||||
|
||||
if self.enable_mindmap:
|
||||
mindmap_thread = threading.Thread(target=mindmap_call)
|
||||
mindmap_thread.start()
|
||||
|
||||
output = ""
|
||||
logprobs = []
|
||||
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
|
||||
for human, ai in history[-self.n_last_interactions :]:
|
||||
messages.append(HumanMessage(content=human))
|
||||
messages.append(AIMessage(content=ai))
|
||||
|
||||
if self.use_multimodal and evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
# create image message:
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image},
|
||||
}
|
||||
for image in images[:MAX_IMAGES]
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# append main prompt
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for out_msg in self.llm.stream(messages):
|
||||
output += out_msg.text
|
||||
logprobs += out_msg.logprobs
|
||||
yield Document(channel="chat", content=out_msg.text)
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal processing")
|
||||
output = self.llm(messages).text
|
||||
yield Document(channel="chat", content=output)
|
||||
|
||||
if logprobs:
|
||||
qa_score = np.exp(np.average(logprobs))
|
||||
else:
|
||||
qa_score = None
|
||||
|
||||
if citation_thread:
|
||||
citation_thread.join(timeout=CITATION_TIMEOUT)
|
||||
if mindmap_thread:
|
||||
mindmap_thread.join(timeout=CITATION_TIMEOUT)
|
||||
|
||||
answer = Document(
|
||||
text=output,
|
||||
metadata={
|
||||
"citation_viz": self.enable_citation_viz,
|
||||
"mindmap": mindmap,
|
||||
"citation": citation,
|
||||
"qa_score": qa_score,
|
||||
},
|
||||
)
|
||||
|
||||
return answer
|
||||
|
||||
def match_evidence_with_context(self, answer, docs) -> dict[str, list[dict]]:
|
||||
"""Match the evidence with the context"""
|
||||
spans: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
if not answer.metadata["citation"]:
|
||||
return spans
|
||||
|
||||
evidences = answer.metadata["citation"].evidences
|
||||
for quote in evidences:
|
||||
matched_excerpts = []
|
||||
for doc in docs:
|
||||
matches = find_text(quote, doc.text)
|
||||
|
||||
for start, end in matches:
|
||||
if "|" not in doc.text[start:end]:
|
||||
spans[doc.doc_id].append(
|
||||
{
|
||||
"start": start,
|
||||
"end": end,
|
||||
}
|
||||
)
|
||||
matched_excerpts.append(doc.text[start:end])
|
||||
|
||||
# print("Matched citation:", quote, matched_excerpts),
|
||||
return spans
|
||||
|
||||
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
|
||||
"""Prepare the citations to show on the UI"""
|
||||
with_citation, without_citation = [], []
|
||||
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
|
||||
|
||||
spans = self.match_evidence_with_context(answer, docs)
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||
|
||||
# render highlight spans
|
||||
for _id, ss in spans.items():
|
||||
if not ss:
|
||||
not_detected.add(_id)
|
||||
continue
|
||||
cur_doc = id2docs[_id]
|
||||
highlight_text = ""
|
||||
|
||||
ss = sorted(ss, key=lambda x: x["start"])
|
||||
text = cur_doc.text[: ss[0]["start"]]
|
||||
for idx, span in enumerate(ss):
|
||||
to_highlight = cur_doc.text[span["start"] : span["end"]]
|
||||
if len(to_highlight) > len(highlight_text):
|
||||
highlight_text = to_highlight
|
||||
|
||||
span_idx = span.get("idx", None)
|
||||
if span_idx is not None:
|
||||
to_highlight = f"【{span_idx + 1}】" + to_highlight
|
||||
|
||||
text += Render.highlight(
|
||||
to_highlight,
|
||||
elem_id=str(span_idx + 1) if span_idx is not None else None,
|
||||
)
|
||||
print(text)
|
||||
if idx < len(ss) - 1:
|
||||
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
|
||||
text += cur_doc.text[ss[-1]["end"] :]
|
||||
# add to display list
|
||||
with_citation.append(
|
||||
Document(
|
||||
channel="info",
|
||||
content=Render.collapsible_with_header_score(
|
||||
cur_doc,
|
||||
override_text=text,
|
||||
highlight_text=highlight_text,
|
||||
open_collapsible=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
print("Got {} cited docs".format(len(with_citation)))
|
||||
|
||||
sorted_not_detected_items_with_scores = [
|
||||
(id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0))
|
||||
for id_ in not_detected
|
||||
]
|
||||
sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for id_, _ in sorted_not_detected_items_with_scores:
|
||||
doc = id2docs[id_]
|
||||
doc_score = doc.metadata.get("llm_trulens_score", 0.0)
|
||||
is_open = not has_llm_score or (
|
||||
doc_score > CONTEXT_RELEVANT_WARNING_SCORE and len(with_citation) == 0
|
||||
)
|
||||
without_citation.append(
|
||||
Document(
|
||||
channel="info",
|
||||
content=Render.collapsible_with_header_score(
|
||||
doc, open_collapsible=is_open
|
||||
),
|
||||
)
|
||||
)
|
||||
return with_citation, without_citation
|
267
libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py
Normal file
267
libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import re
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kotaemon.base import AIMessage, Document, HumanMessage, SystemMessage
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from .citation import CiteEvidence
|
||||
from .citation_qa import CITATION_TIMEOUT, MAX_IMAGES, AnswerWithContextPipeline
|
||||
from .format_context import EVIDENCE_MODE_FIGURE
|
||||
from .utils import find_start_end_phrase
|
||||
|
||||
DEFAULT_QA_CITATION_PROMPT = """
|
||||
Use the following pieces of context to answer the question at the end.
|
||||
Provide DETAILED ansswer with clear explanation.
|
||||
Format answer with easy to follow bullets / paragraphs.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
Use the same language as the question to response.
|
||||
|
||||
CONTEXT:
|
||||
----
|
||||
{context}
|
||||
----
|
||||
|
||||
Answer using this format:
|
||||
CITATION LIST
|
||||
|
||||
// the index in this array
|
||||
CITATION【number】
|
||||
|
||||
// output 2 phrase to mark start and end of the relevant span
|
||||
// each has ~ 6 words
|
||||
// MUST COPY EXACTLY from the CONTEXT
|
||||
// NO CHANGE or REPHRASE
|
||||
// RELEVANT_SPAN_FROM_CONTEXT
|
||||
START_PHRASE: string
|
||||
END_PHRASE: string
|
||||
|
||||
// When you answer, ensure to add citations from the documents
|
||||
// in the CONTEXT with a number that corresponds to the answersInText array.
|
||||
// (in the form [number])
|
||||
// Try to include the number after each facts / statements you make.
|
||||
// You can create as many citations as you need.
|
||||
FINAL ANSWER
|
||||
string
|
||||
|
||||
STRICTLY FOLLOW THIS EXAMPLE:
|
||||
CITATION LIST
|
||||
|
||||
CITATION【1】
|
||||
|
||||
START_PHRASE: Known as fixed-size chunking , the traditional
|
||||
END_PHRASE: not degrade the final retrieval performance.
|
||||
|
||||
CITATION【2】
|
||||
|
||||
START_PHRASE: Fixed-size Chunker This is our baseline chunker
|
||||
END_PHRASE: this shows good retrieval quality.
|
||||
|
||||
FINAL ANSWER
|
||||
An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient【1】. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance【2】.
|
||||
|
||||
QUESTION: {question}\n
|
||||
ANSWER:
|
||||
""" # noqa
|
||||
|
||||
|
||||
class AnswerWithInlineCitation(AnswerWithContextPipeline):
|
||||
"""Answer the question based on the evidence with inline citation"""
|
||||
|
||||
qa_citation_template: str = DEFAULT_QA_CITATION_PROMPT
|
||||
|
||||
def get_prompt(self, question, evidence, evidence_mode: int):
|
||||
"""Prepare the prompt and other information for LLM"""
|
||||
prompt_template = PromptTemplate(self.qa_citation_template)
|
||||
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
safe=False,
|
||||
)
|
||||
|
||||
return prompt, evidence
|
||||
|
||||
def answer_to_citations(self, answer):
|
||||
evidences = []
|
||||
lines = answer.split("\n")
|
||||
for line in lines:
|
||||
for keyword in ["START_PHRASE:", "END_PHRASE:"]:
|
||||
if line.startswith(keyword):
|
||||
evidences.append(line[len(keyword) :].strip())
|
||||
|
||||
return CiteEvidence(evidences=evidences)
|
||||
|
||||
def replace_citation_with_link(self, answer: str):
|
||||
# Define the regex pattern to match 【number】
|
||||
pattern = r"【\d+】"
|
||||
matches = re.finditer(pattern, answer)
|
||||
|
||||
matched_citations = set()
|
||||
for match in matches:
|
||||
citation = match.group()
|
||||
matched_citations.add(citation)
|
||||
|
||||
for citation in matched_citations:
|
||||
print("Found citation:", citation)
|
||||
answer = answer.replace(
|
||||
citation,
|
||||
(
|
||||
"<a href='#' class='citation' "
|
||||
f"id='mark-{citation[1:-1]}'>{citation}</a>"
|
||||
),
|
||||
)
|
||||
|
||||
print("Replaced answer:", answer)
|
||||
return answer
|
||||
|
||||
def stream( # type: ignore
|
||||
self,
|
||||
question: str,
|
||||
evidence: str,
|
||||
evidence_mode: int = 0,
|
||||
images: list[str] = [],
|
||||
**kwargs,
|
||||
) -> Generator[Document, None, Document]:
|
||||
history = kwargs.get("history", [])
|
||||
print(f"Got {len(images)} images")
|
||||
# check if evidence exists, use QA prompt
|
||||
if evidence:
|
||||
prompt, evidence = self.get_prompt(question, evidence, evidence_mode)
|
||||
else:
|
||||
prompt = question
|
||||
|
||||
output = ""
|
||||
logprobs = []
|
||||
|
||||
citation = None
|
||||
mindmap = None
|
||||
|
||||
def mindmap_call():
|
||||
nonlocal mindmap
|
||||
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
|
||||
|
||||
mindmap_thread = None
|
||||
|
||||
# execute function call in thread
|
||||
if evidence:
|
||||
if self.enable_mindmap:
|
||||
mindmap_thread = threading.Thread(target=mindmap_call)
|
||||
mindmap_thread.start()
|
||||
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
|
||||
for human, ai in history[-self.n_last_interactions :]:
|
||||
messages.append(HumanMessage(content=human))
|
||||
messages.append(AIMessage(content=ai))
|
||||
|
||||
if self.use_multimodal and evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
# create image message:
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image},
|
||||
}
|
||||
for image in images[:MAX_IMAGES]
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# append main prompt
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
START_ANSWER = "FINAL ANSWER"
|
||||
start_of_answer = True
|
||||
final_answer = ""
|
||||
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for out_msg in self.llm.stream(messages):
|
||||
if START_ANSWER in output:
|
||||
final_answer += (
|
||||
out_msg.text.lstrip() if start_of_answer else out_msg.text
|
||||
)
|
||||
start_of_answer = False
|
||||
yield Document(channel="chat", content=out_msg.text)
|
||||
|
||||
output += out_msg.text
|
||||
logprobs += out_msg.logprobs
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal processing")
|
||||
output = self.llm(messages).text
|
||||
yield Document(channel="chat", content=output)
|
||||
|
||||
if logprobs:
|
||||
qa_score = np.exp(np.average(logprobs))
|
||||
else:
|
||||
qa_score = None
|
||||
|
||||
citation = self.answer_to_citations(output)
|
||||
|
||||
if mindmap_thread:
|
||||
mindmap_thread.join(timeout=CITATION_TIMEOUT)
|
||||
|
||||
# convert citation to link
|
||||
answer = Document(
|
||||
text=final_answer,
|
||||
metadata={
|
||||
"citation_viz": self.enable_citation_viz,
|
||||
"mindmap": mindmap,
|
||||
"citation": citation,
|
||||
"qa_score": qa_score,
|
||||
},
|
||||
)
|
||||
|
||||
# yield the final answer
|
||||
final_answer = self.replace_citation_with_link(final_answer)
|
||||
yield Document(channel="chat", content=None)
|
||||
yield Document(channel="chat", content=final_answer)
|
||||
|
||||
return answer
|
||||
|
||||
def match_evidence_with_context(self, answer, docs) -> dict[str, list[dict]]:
|
||||
"""Match the evidence with the context"""
|
||||
spans: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
if not answer.metadata["citation"]:
|
||||
return spans
|
||||
|
||||
evidences = answer.metadata["citation"].evidences
|
||||
|
||||
for start_idx in range(0, len(evidences), 2):
|
||||
start_phrase, end_phrase = evidences[start_idx : start_idx + 2]
|
||||
best_match = None
|
||||
best_match_length = 0
|
||||
best_match_doc_idx = None
|
||||
|
||||
for doc in docs:
|
||||
match, match_length = find_start_end_phrase(
|
||||
start_phrase, end_phrase, doc.text
|
||||
)
|
||||
if best_match is None or (
|
||||
match is not None and match_length > best_match_length
|
||||
):
|
||||
best_match = match
|
||||
best_match_length = match_length
|
||||
best_match_doc_idx = doc.doc_id
|
||||
|
||||
if best_match is not None and best_match_doc_idx is not None:
|
||||
spans[best_match_doc_idx].append(
|
||||
{
|
||||
"start": best_match[0],
|
||||
"end": best_match[1],
|
||||
"idx": start_idx // 2, # implicitly set from the start_idx
|
||||
}
|
||||
)
|
||||
return spans
|
114
libs/kotaemon/kotaemon/indices/qa/format_context.py
Normal file
114
libs/kotaemon/kotaemon/indices/qa/format_context.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import html
|
||||
from functools import partial
|
||||
|
||||
import tiktoken
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
|
||||
EVIDENCE_MODE_TEXT = 0
|
||||
EVIDENCE_MODE_TABLE = 1
|
||||
EVIDENCE_MODE_CHATBOT = 2
|
||||
EVIDENCE_MODE_FIGURE = 3
|
||||
|
||||
|
||||
class PrepareEvidencePipeline(BaseComponent):
|
||||
"""Prepare the evidence text from the list of retrieved documents
|
||||
|
||||
This step usually happens after `DocumentRetrievalPipeline`.
|
||||
|
||||
Args:
|
||||
trim_func: a callback function or a BaseComponent, that splits a large
|
||||
chunk of text into smaller ones. The first one will be retained.
|
||||
"""
|
||||
|
||||
max_context_length: int = 32000
|
||||
trim_func: TokenSplitter | None = None
|
||||
|
||||
def run(self, docs: list[RetrievedDocument]) -> Document:
|
||||
evidence = ""
|
||||
images = []
|
||||
table_found = 0
|
||||
evidence_modes = []
|
||||
|
||||
evidence_trim_func = (
|
||||
self.trim_func
|
||||
if self.trim_func
|
||||
else TokenSplitter(
|
||||
chunk_size=self.max_context_length,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for _, retrieved_item in enumerate(docs):
|
||||
retrieved_content = ""
|
||||
page = retrieved_item.metadata.get("page_label", None)
|
||||
source = filename = retrieved_item.metadata.get("file_name", "-")
|
||||
if page:
|
||||
source += f" (Page {page})"
|
||||
if retrieved_item.metadata.get("type", "") == "table":
|
||||
evidence_modes.append(EVIDENCE_MODE_TABLE)
|
||||
if table_found < 5:
|
||||
retrieved_content = retrieved_item.metadata.get(
|
||||
"table_origin", retrieved_item.text
|
||||
)
|
||||
if retrieved_content not in evidence:
|
||||
table_found += 1
|
||||
evidence += (
|
||||
f"<br><b>Table from {source}</b>\n"
|
||||
+ retrieved_content
|
||||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||
evidence_modes.append(EVIDENCE_MODE_CHATBOT)
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
evidence += (
|
||||
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||
+ retrieved_content
|
||||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "image":
|
||||
evidence_modes.append(EVIDENCE_MODE_FIGURE)
|
||||
retrieved_content = retrieved_item.metadata.get("image_origin", "")
|
||||
retrieved_caption = html.escape(retrieved_item.get_content())
|
||||
evidence += (
|
||||
f"<br><b>Figure from {source}</b>\n"
|
||||
+ "<img width='85%' src='<src>' "
|
||||
+ f"alt='{retrieved_caption}'/>"
|
||||
+ "\n<br>"
|
||||
)
|
||||
images.append(retrieved_content)
|
||||
else:
|
||||
if "window" in retrieved_item.metadata:
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
else:
|
||||
retrieved_content = retrieved_item.text
|
||||
retrieved_content = retrieved_content.replace("\n", " ")
|
||||
if retrieved_content not in evidence:
|
||||
evidence += (
|
||||
f"<br><b>Content from {source}: </b> "
|
||||
+ retrieved_content
|
||||
+ " \n<br>"
|
||||
)
|
||||
|
||||
# resolve evidence mode
|
||||
evidence_mode = EVIDENCE_MODE_TEXT
|
||||
if EVIDENCE_MODE_FIGURE in evidence_modes:
|
||||
evidence_mode = EVIDENCE_MODE_FIGURE
|
||||
elif EVIDENCE_MODE_TABLE in evidence_modes:
|
||||
evidence_mode = EVIDENCE_MODE_TABLE
|
||||
|
||||
# trim context by trim_len
|
||||
print("len (original)", len(evidence))
|
||||
if evidence:
|
||||
texts = evidence_trim_func([Document(text=evidence)])
|
||||
evidence = texts[0].text
|
||||
print("len (trimmed)", len(evidence))
|
||||
|
||||
return Document(content=(evidence_mode, evidence, images))
|
@@ -1,63 +0,0 @@
|
||||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
||||
from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
|
||||
|
||||
class CitationQAPipeline(BaseComponent):
|
||||
"""Answering question from a text corpus with citation"""
|
||||
|
||||
qa_prompt_template: PromptTemplate = PromptTemplate(
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
llm: BaseLLM = LCAzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-07-01-preview",
|
||||
deployment_name="dummy-q2-16k",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda self: CitationPipeline(llm=self.llm)
|
||||
)
|
||||
|
||||
def _format_doc_text(self, text: str) -> str:
|
||||
"""Format the text of each document"""
|
||||
return text.replace("\n", " ")
|
||||
|
||||
def _format_retrieved_context(self, documents: list[RetrievedDocument]) -> str:
|
||||
"""Format the texts between all documents"""
|
||||
matched_texts: list[str] = [
|
||||
self._format_doc_text(doc.text) for doc in documents
|
||||
]
|
||||
return "\n\n".join(matched_texts)
|
||||
|
||||
def run(
|
||||
self,
|
||||
question: str,
|
||||
documents: list[RetrievedDocument],
|
||||
use_citation: bool = False,
|
||||
**kwargs
|
||||
) -> Document:
|
||||
# retrieve relevant documents as context
|
||||
context = self._format_retrieved_context(documents)
|
||||
self.log_progress(".context", context=context)
|
||||
|
||||
# generate the answer
|
||||
prompt = self.qa_prompt_template.populate(
|
||||
context=context,
|
||||
question=question,
|
||||
)
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
citation = self.citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
answer = Document(text=answer_text, metadata={"citation": citation})
|
||||
return answer
|
53
libs/kotaemon/kotaemon/indices/qa/utils.py
Normal file
53
libs/kotaemon/kotaemon/indices/qa/utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
def find_text(search_span, context, min_length=5):
|
||||
sentence_list = search_span.split("\n")
|
||||
context = context.replace("\n", " ")
|
||||
|
||||
matches = []
|
||||
# don't search for small text
|
||||
if len(search_span) > min_length:
|
||||
for sentence in sentence_list:
|
||||
match = SequenceMatcher(
|
||||
None, sentence, context, autojunk=False
|
||||
).find_longest_match()
|
||||
if match.size > max(len(sentence) * 0.35, min_length):
|
||||
matches.append((match.b, match.b + match.size))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def find_start_end_phrase(
|
||||
start_phrase, end_phrase, context, min_length=5, max_excerpt_length=300
|
||||
):
|
||||
context = context.replace("\n", " ")
|
||||
|
||||
matches = []
|
||||
matched_length = 0
|
||||
for sentence in [start_phrase, end_phrase]:
|
||||
match = SequenceMatcher(
|
||||
None, sentence, context, autojunk=False
|
||||
).find_longest_match()
|
||||
if match.size > max(len(sentence) * 0.35, min_length):
|
||||
matches.append((match.b, match.b + match.size))
|
||||
matched_length += match.size
|
||||
|
||||
# check if second match is before the first match
|
||||
if len(matches) == 2 and matches[1][0] < matches[0][0]:
|
||||
# if so, keep only the first match
|
||||
matches = [matches[0]]
|
||||
|
||||
if matches:
|
||||
start_idx = min(start for start, _ in matches)
|
||||
end_idx = max(end for _, end in matches)
|
||||
|
||||
# check if the excerpt is too long
|
||||
if end_idx - start_idx > max_excerpt_length:
|
||||
end_idx = start_idx + max_excerpt_length
|
||||
|
||||
final_match = (start_idx, end_idx)
|
||||
else:
|
||||
final_match = None
|
||||
|
||||
return final_match, matched_length
|
@@ -42,14 +42,6 @@ class VectorIndexing(BaseIndexing):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_qa_pipeline(self, *args, **kwargs):
|
||||
from .qa import CitationQAPipeline
|
||||
|
||||
return TextVectorQA(
|
||||
retrieving_pipeline=self.to_retrieval_pipeline(**kwargs),
|
||||
qa_pipeline=CitationQAPipeline(**kwargs),
|
||||
)
|
||||
|
||||
def write_chunk_to_file(self, docs: list[Document]):
|
||||
# save the chunks content into markdown format
|
||||
if self.cache_dir:
|
||||
|
@@ -72,7 +72,7 @@ class PromptTemplate:
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def populate(self, **kwargs) -> str:
|
||||
def populate(self, safe=True, **kwargs) -> str:
|
||||
"""
|
||||
Strictly populate the template with the given keyword arguments.
|
||||
|
||||
@@ -86,7 +86,8 @@ class PromptTemplate:
|
||||
Raises:
|
||||
ValueError: If an unknown placeholder is provided.
|
||||
"""
|
||||
self.check_missing_kwargs(**kwargs)
|
||||
if safe:
|
||||
self.check_missing_kwargs(**kwargs)
|
||||
|
||||
return self.partial_populate(**kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user