feat: add inline citation style (#523) bump:minor

* feat: add URL quick index, export mindmap, refine UI & animation

* fix: inject multimodal mode from env var

* fix: minor update css

* feat: add citation inline mode

* fix: minor update citation inline pipeline

* feat: add citation quick setting

* fix: minor update

* fix: minor update
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-11-25 12:07:02 +07:00 committed by GitHub
parent 013f6f4103
commit 7e34e4343b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1173 additions and 651 deletions

View File

@ -255,7 +255,7 @@ KH_REASONINGS = [
"ktem.reasoning.react.ReactAgentPipeline", "ktem.reasoning.react.ReactAgentPipeline",
"ktem.reasoning.rewoo.RewooAgentPipeline", "ktem.reasoning.rewoo.RewooAgentPipeline",
] ]
KH_REASONINGS_USE_MULTIMODAL = False KH_REASONINGS_USE_MULTIMODAL = config("USE_MULTIMODAL", default=False, cast=bool)
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format( KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
config("AZURE_OPENAI_ENDPOINT", default=""), config("AZURE_OPENAI_ENDPOINT", default=""),
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4o"), config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4o"),

View File

@ -1,7 +1,5 @@
from .citation import CitationPipeline from .citation import CitationPipeline
from .text_based import CitationQAPipeline
__all__ = [ __all__ = [
"CitationPipeline", "CitationPipeline",
"CitationQAPipeline",
] ]

View File

@ -0,0 +1,390 @@
import threading
from collections import defaultdict
from typing import Generator
import numpy as np
from theflow.settings import settings as flowsettings
from kotaemon.base import (
AIMessage,
BaseComponent,
Document,
HumanMessage,
Node,
SystemMessage,
)
from kotaemon.llms import ChatLLM, PromptTemplate
from .citation import CitationPipeline
from .format_context import (
EVIDENCE_MODE_FIGURE,
EVIDENCE_MODE_TABLE,
EVIDENCE_MODE_TEXT,
)
from .utils import find_text
try:
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization.mindmap import CreateMindmapPipeline
from ktem.utils.render import Render
except ImportError:
raise ImportError("Please install `ktem` to use this component")
MAX_IMAGES = 10
CITATION_TIMEOUT = 5.0
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
DEFAULT_QA_TEXT_PROMPT = (
"Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
"If you don't know the answer, just say that you don't know, don't try to "
"make up an answer. Give answer in "
"{lang}.\n\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
)
DEFAULT_QA_TABLE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question, "
"then provide answer with clear explanation."
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. Give answer in {lang}.\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
) # noqa
DEFAULT_QA_CHATBOT_PROMPT = (
"Pick the most suitable chatbot scenarios to answer the question at the end, "
"output the provided answer text. If you don't know the answer, "
"just say that you don't know. Keep the answer as concise as possible. "
"Give answer in {lang}.\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Answer:"
) # noqa
DEFAULT_QA_FIGURE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Give answer in {lang}.\n\n"
"Context: \n"
"{context}\n"
"Question: {question}\n"
"Answer: "
) # noqa
class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
Args:
llm: the language model to generate the answer
citation_pipeline: generates citation from the evidence
qa_template: the prompt template for LLM to generate answer (refer to
evidence_mode)
qa_table_template: the prompt template for LLM to generate answer for table
(refer to evidence_mode)
qa_chatbot_template: the prompt template for LLM to generate answer for
pre-made scenarios (refer to evidence_mode)
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
vlm_endpoint: str = getattr(flowsettings, "KH_VLM_ENDPOINT", "")
use_multimodal: bool = getattr(flowsettings, "KH_REASONINGS_USE_MULTIMODAL", True)
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
create_mindmap_pipeline: CreateMindmapPipeline = Node(
default_callback=lambda _: CreateMindmapPipeline(llm=llms.get_default())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
enable_citation: bool = False
enable_mindmap: bool = False
enable_citation_viz: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese
n_last_interactions: int = 5
def get_prompt(self, question, evidence, evidence_mode: int):
"""Prepare the prompt and other information for LLM"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
if self.use_multimodal:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
return prompt, evidence
def run(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
return self.invoke(question, evidence, evidence_mode, **kwargs)
def invoke(
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Document:
raise NotImplementedError
async def ainvoke( # type: ignore
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Document:
"""Answer the question based on the evidence
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
raise NotImplementedError
def stream( # type: ignore
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Generator[Document, None, Document]:
history = kwargs.get("history", [])
print(f"Got {len(images)} images")
# check if evidence exists, use QA prompt
if evidence:
prompt, evidence = self.get_prompt(question, evidence, evidence_mode)
else:
prompt = question
# retrieve the citation
citation = None
mindmap = None
def citation_call():
nonlocal citation
citation = self.citation_pipeline(context=evidence, question=question)
def mindmap_call():
nonlocal mindmap
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
citation_thread = None
mindmap_thread = None
# execute function call in thread
if evidence:
if self.enable_citation:
citation_thread = threading.Thread(target=citation_call)
citation_thread.start()
if self.enable_mindmap:
mindmap_thread = threading.Thread(target=mindmap_call)
mindmap_thread.start()
output = ""
logprobs = []
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
if self.use_multimodal and evidence_mode == EVIDENCE_MODE_FIGURE:
# create image message:
messages.append(
HumanMessage(
content=[
{"type": "text", "text": prompt},
]
+ [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images[:MAX_IMAGES]
],
)
)
else:
# append main prompt
messages.append(HumanMessage(content=prompt))
try:
# try streaming first
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
output += out_msg.text
logprobs += out_msg.logprobs
yield Document(channel="chat", content=out_msg.text)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
yield Document(channel="chat", content=output)
if logprobs:
qa_score = np.exp(np.average(logprobs))
else:
qa_score = None
if citation_thread:
citation_thread.join(timeout=CITATION_TIMEOUT)
if mindmap_thread:
mindmap_thread.join(timeout=CITATION_TIMEOUT)
answer = Document(
text=output,
metadata={
"citation_viz": self.enable_citation_viz,
"mindmap": mindmap,
"citation": citation,
"qa_score": qa_score,
},
)
return answer
def match_evidence_with_context(self, answer, docs) -> dict[str, list[dict]]:
"""Match the evidence with the context"""
spans: dict[str, list[dict]] = defaultdict(list)
if not answer.metadata["citation"]:
return spans
evidences = answer.metadata["citation"].evidences
for quote in evidences:
matched_excerpts = []
for doc in docs:
matches = find_text(quote, doc.text)
for start, end in matches:
if "|" not in doc.text[start:end]:
spans[doc.doc_id].append(
{
"start": start,
"end": end,
}
)
matched_excerpts.append(doc.text[start:end])
# print("Matched citation:", quote, matched_excerpts),
return spans
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
"""Prepare the citations to show on the UI"""
with_citation, without_citation = [], []
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
spans = self.match_evidence_with_context(answer, docs)
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())
# render highlight spans
for _id, ss in spans.items():
if not ss:
not_detected.add(_id)
continue
cur_doc = id2docs[_id]
highlight_text = ""
ss = sorted(ss, key=lambda x: x["start"])
text = cur_doc.text[: ss[0]["start"]]
for idx, span in enumerate(ss):
to_highlight = cur_doc.text[span["start"] : span["end"]]
if len(to_highlight) > len(highlight_text):
highlight_text = to_highlight
span_idx = span.get("idx", None)
if span_idx is not None:
to_highlight = f"{span_idx + 1}" + to_highlight
text += Render.highlight(
to_highlight,
elem_id=str(span_idx + 1) if span_idx is not None else None,
)
print(text)
if idx < len(ss) - 1:
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
text += cur_doc.text[ss[-1]["end"] :]
# add to display list
with_citation.append(
Document(
channel="info",
content=Render.collapsible_with_header_score(
cur_doc,
override_text=text,
highlight_text=highlight_text,
open_collapsible=True,
),
)
)
print("Got {} cited docs".format(len(with_citation)))
sorted_not_detected_items_with_scores = [
(id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0))
for id_ in not_detected
]
sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True)
for id_, _ in sorted_not_detected_items_with_scores:
doc = id2docs[id_]
doc_score = doc.metadata.get("llm_trulens_score", 0.0)
is_open = not has_llm_score or (
doc_score > CONTEXT_RELEVANT_WARNING_SCORE and len(with_citation) == 0
)
without_citation.append(
Document(
channel="info",
content=Render.collapsible_with_header_score(
doc, open_collapsible=is_open
),
)
)
return with_citation, without_citation

View File

@ -0,0 +1,267 @@
import re
import threading
from collections import defaultdict
from typing import Generator
import numpy as np
from kotaemon.base import AIMessage, Document, HumanMessage, SystemMessage
from kotaemon.llms import PromptTemplate
from .citation import CiteEvidence
from .citation_qa import CITATION_TIMEOUT, MAX_IMAGES, AnswerWithContextPipeline
from .format_context import EVIDENCE_MODE_FIGURE
from .utils import find_start_end_phrase
DEFAULT_QA_CITATION_PROMPT = """
Use the following pieces of context to answer the question at the end.
Provide DETAILED ansswer with clear explanation.
Format answer with easy to follow bullets / paragraphs.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use the same language as the question to response.
CONTEXT:
----
{context}
----
Answer using this format:
CITATION LIST
// the index in this array
CITATIONnumber
// output 2 phrase to mark start and end of the relevant span
// each has ~ 6 words
// MUST COPY EXACTLY from the CONTEXT
// NO CHANGE or REPHRASE
// RELEVANT_SPAN_FROM_CONTEXT
START_PHRASE: string
END_PHRASE: string
// When you answer, ensure to add citations from the documents
// in the CONTEXT with a number that corresponds to the answersInText array.
// (in the form [number])
// Try to include the number after each facts / statements you make.
// You can create as many citations as you need.
FINAL ANSWER
string
STRICTLY FOLLOW THIS EXAMPLE:
CITATION LIST
CITATION1
START_PHRASE: Known as fixed-size chunking , the traditional
END_PHRASE: not degrade the final retrieval performance.
CITATION2
START_PHRASE: Fixed-size Chunker This is our baseline chunker
END_PHRASE: this shows good retrieval quality.
FINAL ANSWER
An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient1. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance2.
QUESTION: {question}\n
ANSWER:
""" # noqa
class AnswerWithInlineCitation(AnswerWithContextPipeline):
"""Answer the question based on the evidence with inline citation"""
qa_citation_template: str = DEFAULT_QA_CITATION_PROMPT
def get_prompt(self, question, evidence, evidence_mode: int):
"""Prepare the prompt and other information for LLM"""
prompt_template = PromptTemplate(self.qa_citation_template)
prompt = prompt_template.populate(
context=evidence,
question=question,
safe=False,
)
return prompt, evidence
def answer_to_citations(self, answer):
evidences = []
lines = answer.split("\n")
for line in lines:
for keyword in ["START_PHRASE:", "END_PHRASE:"]:
if line.startswith(keyword):
evidences.append(line[len(keyword) :].strip())
return CiteEvidence(evidences=evidences)
def replace_citation_with_link(self, answer: str):
# Define the regex pattern to match 【number】
pattern = r"\d+】"
matches = re.finditer(pattern, answer)
matched_citations = set()
for match in matches:
citation = match.group()
matched_citations.add(citation)
for citation in matched_citations:
print("Found citation:", citation)
answer = answer.replace(
citation,
(
"<a href='#' class='citation' "
f"id='mark-{citation[1:-1]}'>{citation}</a>"
),
)
print("Replaced answer:", answer)
return answer
def stream( # type: ignore
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Generator[Document, None, Document]:
history = kwargs.get("history", [])
print(f"Got {len(images)} images")
# check if evidence exists, use QA prompt
if evidence:
prompt, evidence = self.get_prompt(question, evidence, evidence_mode)
else:
prompt = question
output = ""
logprobs = []
citation = None
mindmap = None
def mindmap_call():
nonlocal mindmap
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
mindmap_thread = None
# execute function call in thread
if evidence:
if self.enable_mindmap:
mindmap_thread = threading.Thread(target=mindmap_call)
mindmap_thread.start()
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
if self.use_multimodal and evidence_mode == EVIDENCE_MODE_FIGURE:
# create image message:
messages.append(
HumanMessage(
content=[
{"type": "text", "text": prompt},
]
+ [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images[:MAX_IMAGES]
],
)
)
else:
# append main prompt
messages.append(HumanMessage(content=prompt))
START_ANSWER = "FINAL ANSWER"
start_of_answer = True
final_answer = ""
try:
# try streaming first
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
if START_ANSWER in output:
final_answer += (
out_msg.text.lstrip() if start_of_answer else out_msg.text
)
start_of_answer = False
yield Document(channel="chat", content=out_msg.text)
output += out_msg.text
logprobs += out_msg.logprobs
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
yield Document(channel="chat", content=output)
if logprobs:
qa_score = np.exp(np.average(logprobs))
else:
qa_score = None
citation = self.answer_to_citations(output)
if mindmap_thread:
mindmap_thread.join(timeout=CITATION_TIMEOUT)
# convert citation to link
answer = Document(
text=final_answer,
metadata={
"citation_viz": self.enable_citation_viz,
"mindmap": mindmap,
"citation": citation,
"qa_score": qa_score,
},
)
# yield the final answer
final_answer = self.replace_citation_with_link(final_answer)
yield Document(channel="chat", content=None)
yield Document(channel="chat", content=final_answer)
return answer
def match_evidence_with_context(self, answer, docs) -> dict[str, list[dict]]:
"""Match the evidence with the context"""
spans: dict[str, list[dict]] = defaultdict(list)
if not answer.metadata["citation"]:
return spans
evidences = answer.metadata["citation"].evidences
for start_idx in range(0, len(evidences), 2):
start_phrase, end_phrase = evidences[start_idx : start_idx + 2]
best_match = None
best_match_length = 0
best_match_doc_idx = None
for doc in docs:
match, match_length = find_start_end_phrase(
start_phrase, end_phrase, doc.text
)
if best_match is None or (
match is not None and match_length > best_match_length
):
best_match = match
best_match_length = match_length
best_match_doc_idx = doc.doc_id
if best_match is not None and best_match_doc_idx is not None:
spans[best_match_doc_idx].append(
{
"start": best_match[0],
"end": best_match[1],
"idx": start_idx // 2, # implicitly set from the start_idx
}
)
return spans

View File

@ -0,0 +1,114 @@
import html
from functools import partial
import tiktoken
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from kotaemon.indices.splitters import TokenSplitter
EVIDENCE_MODE_TEXT = 0
EVIDENCE_MODE_TABLE = 1
EVIDENCE_MODE_CHATBOT = 2
EVIDENCE_MODE_FIGURE = 3
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
This step usually happens after `DocumentRetrievalPipeline`.
Args:
trim_func: a callback function or a BaseComponent, that splits a large
chunk of text into smaller ones. The first one will be retained.
"""
max_context_length: int = 32000
trim_func: TokenSplitter | None = None
def run(self, docs: list[RetrievedDocument]) -> Document:
evidence = ""
images = []
table_found = 0
evidence_modes = []
evidence_trim_func = (
self.trim_func
if self.trim_func
else TokenSplitter(
chunk_size=self.max_context_length,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
)
for _, retrieved_item in enumerate(docs):
retrieved_content = ""
page = retrieved_item.metadata.get("page_label", None)
source = filename = retrieved_item.metadata.get("file_name", "-")
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
evidence_modes.append(EVIDENCE_MODE_TABLE)
if table_found < 5:
retrieved_content = retrieved_item.metadata.get(
"table_origin", retrieved_item.text
)
if retrieved_content not in evidence:
table_found += 1
evidence += (
f"<br><b>Table from {source}</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
evidence_modes.append(EVIDENCE_MODE_CHATBOT)
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "image":
evidence_modes.append(EVIDENCE_MODE_FIGURE)
retrieved_content = retrieved_item.metadata.get("image_origin", "")
retrieved_caption = html.escape(retrieved_item.get_content())
evidence += (
f"<br><b>Figure from {source}</b>\n"
+ "<img width='85%' src='<src>' "
+ f"alt='{retrieved_caption}'/>"
+ "\n<br>"
)
images.append(retrieved_content)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
else:
retrieved_content = retrieved_item.text
retrieved_content = retrieved_content.replace("\n", " ")
if retrieved_content not in evidence:
evidence += (
f"<br><b>Content from {source}: </b> "
+ retrieved_content
+ " \n<br>"
)
# resolve evidence mode
evidence_mode = EVIDENCE_MODE_TEXT
if EVIDENCE_MODE_FIGURE in evidence_modes:
evidence_mode = EVIDENCE_MODE_FIGURE
elif EVIDENCE_MODE_TABLE in evidence_modes:
evidence_mode = EVIDENCE_MODE_TABLE
# trim context by trim_len
print("len (original)", len(evidence))
if evidence:
texts = evidence_trim_func([Document(text=evidence)])
evidence = texts[0].text
print("len (trimmed)", len(evidence))
return Document(content=(evidence_mode, evidence, images))

View File

@ -1,63 +0,0 @@
import os
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate
from .citation import CitationPipeline
class CitationQAPipeline(BaseComponent):
"""Answering question from a text corpus with citation"""
qa_prompt_template: PromptTemplate = PromptTemplate(
'Answer the following question: "{question}". '
"The context is: \n{context}\nAnswer: "
)
llm: BaseLLM = LCAzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-07-01-preview",
deployment_name="dummy-q2-16k",
temperature=0,
request_timeout=60,
)
citation_pipeline: CitationPipeline = Node(
default_callback=lambda self: CitationPipeline(llm=self.llm)
)
def _format_doc_text(self, text: str) -> str:
"""Format the text of each document"""
return text.replace("\n", " ")
def _format_retrieved_context(self, documents: list[RetrievedDocument]) -> str:
"""Format the texts between all documents"""
matched_texts: list[str] = [
self._format_doc_text(doc.text) for doc in documents
]
return "\n\n".join(matched_texts)
def run(
self,
question: str,
documents: list[RetrievedDocument],
use_citation: bool = False,
**kwargs
) -> Document:
# retrieve relevant documents as context
context = self._format_retrieved_context(documents)
self.log_progress(".context", context=context)
# generate the answer
prompt = self.qa_prompt_template.populate(
context=context,
question=question,
)
self.log_progress(".prompt", prompt=prompt)
answer_text = self.llm(prompt).text
if use_citation:
citation = self.citation_pipeline(context=context, question=question)
else:
citation = None
answer = Document(text=answer_text, metadata={"citation": citation})
return answer

View File

@ -0,0 +1,53 @@
from difflib import SequenceMatcher
def find_text(search_span, context, min_length=5):
sentence_list = search_span.split("\n")
context = context.replace("\n", " ")
matches = []
# don't search for small text
if len(search_span) > min_length:
for sentence in sentence_list:
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
if match.size > max(len(sentence) * 0.35, min_length):
matches.append((match.b, match.b + match.size))
return matches
def find_start_end_phrase(
start_phrase, end_phrase, context, min_length=5, max_excerpt_length=300
):
context = context.replace("\n", " ")
matches = []
matched_length = 0
for sentence in [start_phrase, end_phrase]:
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
if match.size > max(len(sentence) * 0.35, min_length):
matches.append((match.b, match.b + match.size))
matched_length += match.size
# check if second match is before the first match
if len(matches) == 2 and matches[1][0] < matches[0][0]:
# if so, keep only the first match
matches = [matches[0]]
if matches:
start_idx = min(start for start, _ in matches)
end_idx = max(end for _, end in matches)
# check if the excerpt is too long
if end_idx - start_idx > max_excerpt_length:
end_idx = start_idx + max_excerpt_length
final_match = (start_idx, end_idx)
else:
final_match = None
return final_match, matched_length

View File

@ -42,14 +42,6 @@ class VectorIndexing(BaseIndexing):
**kwargs, **kwargs,
) )
def to_qa_pipeline(self, *args, **kwargs):
from .qa import CitationQAPipeline
return TextVectorQA(
retrieving_pipeline=self.to_retrieval_pipeline(**kwargs),
qa_pipeline=CitationQAPipeline(**kwargs),
)
def write_chunk_to_file(self, docs: list[Document]): def write_chunk_to_file(self, docs: list[Document]):
# save the chunks content into markdown format # save the chunks content into markdown format
if self.cache_dir: if self.cache_dir:

View File

@ -72,7 +72,7 @@ class PromptTemplate:
UserWarning, UserWarning,
) )
def populate(self, **kwargs) -> str: def populate(self, safe=True, **kwargs) -> str:
""" """
Strictly populate the template with the given keyword arguments. Strictly populate the template with the given keyword arguments.
@ -86,7 +86,8 @@ class PromptTemplate:
Raises: Raises:
ValueError: If an unknown placeholder is provided. ValueError: If an unknown placeholder is provided.
""" """
self.check_missing_kwargs(**kwargs) if safe:
self.check_missing_kwargs(**kwargs)
return self.partial_populate(**kwargs) return self.partial_populate(**kwargs)

View File

@ -97,6 +97,11 @@ button.selected {
#chat-info-panel { #chat-info-panel {
max-height: var(--main-area-height) !important; max-height: var(--main-area-height) !important;
overflow: auto !important; overflow: auto !important;
transition: all 0.5s;
}
body.dark #chat-info-panel figure>img{
filter: invert(100%);
} }
#conv-settings-panel { #conv-settings-panel {
@ -199,11 +204,26 @@ mark {
right: 15px; right: 15px;
} }
/* #new-conv-button > img { #use-mindmap-checkbox {
position: relative; position: absolute;
top: 0px; width: 110px;
right: -50%; top: 10px;
} */ right: 25px;
}
#quick-url textarea {
resize: none;
background: transparent;
margin-top: 0px;
}
#quick-url textarea::placeholder {
text-align: center;
}
#quick-file {
height: 110px;
}
span.icon { span.icon {
color: #cecece; color: #cecece;
@ -225,11 +245,6 @@ span.icon {
overflow: unset !important; overflow: unset !important;
} }
/*body {*/
/* margin: 0;*/
/* font-family: Arial, sans-serif;*/
/*}*/
pdfjs-viewer-element { pdfjs-viewer-element {
height: 100vh; height: 100vh;
height: 100dvh; height: 100dvh;
@ -280,8 +295,7 @@ pdfjs-viewer-element {
overflow: auto; overflow: auto;
} }
/** Switch /* Switch checkbox styles */
-------------------------------------*/
#is-public-checkbox { #is-public-checkbox {
position: relative; position: relative;
@ -293,10 +307,6 @@ pdfjs-viewer-element {
opacity: 0; opacity: 0;
} }
/**
* 1. Adjust this to size
*/
.switch { .switch {
display: inline-block; display: inline-block;
/* 1 */ /* 1 */
@ -330,3 +340,28 @@ pdfjs-viewer-element {
.switch:has(> input:checked) { .switch:has(> input:checked) {
background: #0c895f; background: #0c895f;
} }
/* Bot animation */
.message.bot {
animation: fadein 1.5s ease-in-out forwards;
}
details.evidence {
animation: fadein 0.5s ease-in-out forwards;
}
@keyframes fadein {
0% {
opacity: 0;
}
100% {
opacity: 100%;
}
}
.message a.citation {
color: #10b981;
text-decoration: none;
}

View File

@ -16,6 +16,11 @@ function run() {
let chat_info_panel = document.getElementById("info-expand"); let chat_info_panel = document.getElementById("info-expand");
chat_info_panel.insertBefore(info_expand_button, chat_info_panel.childNodes[2]); chat_info_panel.insertBefore(info_expand_button, chat_info_panel.childNodes[2]);
// move use mind-map checkbox
let mindmap_checkbox = document.getElementById("use-mindmap-checkbox");
let chat_setting_panel = document.getElementById("chat-settings-expand");
chat_setting_panel.insertBefore(mindmap_checkbox, chat_setting_panel.childNodes[2]);
// create slider toggle // create slider toggle
const is_public_checkbox = document.getElementById("is-public-checkbox"); const is_public_checkbox = document.getElementById("is-public-checkbox");
const label_element = is_public_checkbox.getElementsByTagName("label")[0]; const label_element = is_public_checkbox.getElementsByTagName("label")[0];
@ -49,4 +54,21 @@ function run() {
globalThis.removeFromStorage = (key) => { globalThis.removeFromStorage = (key) => {
localStorage.removeItem(key) localStorage.removeItem(key)
} }
// Function to scroll to given citation with ID
// Sleep function using Promise and setTimeout
function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
globalThis.scrollToCitation = async (event) => {
event.preventDefault(); // Prevent the default link behavior
var citationId = event.target.getAttribute('id');
await sleep(100); // Sleep for 500 milliseconds
var citation = document.querySelector('mark[id="' + citationId + '"]');
if (citation) {
citation.scrollIntoView({ behavior: 'smooth' });
}
}
} }

View File

@ -25,8 +25,8 @@ class BaseConversation(SQLModel):
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
) )
name: str = Field( name: str = Field(
default_factory=lambda: datetime.datetime.now(get_localzone()).strftime( default_factory=lambda: "Untitled - {}".format(
"%Y-%m-%d %H:%M:%S" datetime.datetime.now(get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
) )
) )
user: int = Field(default=0) # For now we only have one user user: int = Field(default=0) # For now we only have one user

View File

@ -126,6 +126,9 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
if doc_ids: if doc_ids:
flatten_doc_ids = [] flatten_doc_ids = []
for doc_id in doc_ids: for doc_id in doc_ids:
if doc_id is None:
raise ValueError("No document is selected")
if doc_id.startswith("["): if doc_id.startswith("["):
flatten_doc_ids.extend(json.loads(doc_id)) flatten_doc_ids.extend(json.loads(doc_id))
else: else:

View File

@ -22,6 +22,13 @@ from theflow.settings import settings as flowsettings
DOWNLOAD_MESSAGE = "Press again to download" DOWNLOAD_MESSAGE = "Press again to download"
MAX_FILENAME_LENGTH = 20 MAX_FILENAME_LENGTH = 20
chat_input_focus_js = """
function() {
let chatInput = document.querySelector("#chat-input textarea");
chatInput.focus();
}
"""
class File(gr.File): class File(gr.File):
"""Subclass from gr.File to maintain the original filename """Subclass from gr.File to maintain the original filename
@ -666,7 +673,7 @@ class FileIndexPage(BasePage):
outputs=self._app.chat_page.quick_file_upload_status, outputs=self._app.chat_page.quick_file_upload_status,
) )
.then( .then(
fn=self.index_fn_with_default_loaders, fn=self.index_fn_file_with_default_loaders,
inputs=[ inputs=[
self._app.chat_page.quick_file_upload, self._app.chat_page.quick_file_upload,
gr.State(value=False), gr.State(value=False),
@ -689,6 +696,38 @@ class FileIndexPage(BasePage):
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickUploadedEvent = quickUploadedEvent.then(**event) quickUploadedEvent = quickUploadedEvent.then(**event)
quickURLUploadedEvent = (
self._app.chat_page.quick_urls.submit(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_url_with_default_loaders,
inputs=[
self._app.chat_page.quick_urls,
gr.State(value=True),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_urls,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickURLUploadedEvent = quickURLUploadedEvent.then(**event)
quickUploadedEvent.success( quickUploadedEvent.success(
fn=lambda x: x, fn=lambda x: x,
inputs=self.quick_upload_state, inputs=self.quick_upload_state,
@ -701,6 +740,30 @@ class FileIndexPage(BasePage):
inputs=[self._app.user_id, self.filter], inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list], outputs=[self.file_list_state, self.file_list],
concurrency_limit=20, concurrency_limit=20,
).then(
fn=lambda: True,
inputs=None,
outputs=None,
js=chat_input_focus_js,
)
quickURLUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
).then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
).then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
).then(
fn=lambda: True,
inputs=None,
outputs=None,
js=chat_input_focus_js,
) )
except Exception as e: except Exception as e:
@ -951,7 +1014,7 @@ class FileIndexPage(BasePage):
return results return results
def index_fn_with_default_loaders( def index_fn_file_with_default_loaders(
self, files, reindex: bool, settings, user_id self, files, reindex: bool, settings, user_id
) -> list["str"]: ) -> list["str"]:
"""Function for quick upload with default loaders """Function for quick upload with default loaders
@ -991,6 +1054,22 @@ class FileIndexPage(BasePage):
return exist_ids + returned_ids return exist_ids + returned_ids
def index_fn_url_with_default_loaders(self, urls, reindex: bool, settings, user_id):
returned_ids = []
settings = deepcopy(settings)
settings[f"index.options.{self._index.id}.reader_mode"] = "default"
settings[f"index.options.{self._index.id}.quick_index_mode"] = True
if urls:
_iter = self.index_fn([], urls, reindex, settings, user_id)
try:
while next(_iter):
pass
except StopIteration as e:
returned_ids = e.value
return returned_ids
def index_files_from_dir( def index_files_from_dir(
self, folder_path, reindex, settings, user_id self, folder_path, reindex, settings, user_id
) -> Generator[tuple[str, str], None, None]: ) -> Generator[tuple[str, str], None, None]:

View File

@ -40,26 +40,52 @@ function() {
links[i].onclick = openModal; links[i].onclick = openModal;
} }
var mindmap_el = document.getElementById('mindmap'); // Get all citation links and attach click event
if (mindmap_el) { var links = document.querySelectorAll("a.citation");
var output = svgPanZoom(mindmap_el); for (var i = 0; i < links.length; i++) {
links[i].onclick = scrollToCitation;
} }
var link = document.getElementById("mindmap-toggle"); var mindmap_el = document.getElementById('mindmap');
if (link) {
link.onclick = function(event) { if (mindmap_el) {
var output = svgPanZoom(mindmap_el);
const svg = mindmap_el.cloneNode(true);
function on_svg_export(event) {
event.preventDefault(); // Prevent the default link behavior event.preventDefault(); // Prevent the default link behavior
var div = document.getElementById("mindmap-wrapper"); // convert to a valid XML source
if (div) { const as_text = new XMLSerializer().serializeToString(svg);
var currentHeight = div.style.height; // store in a Blob
if (currentHeight === '400px') { const blob = new Blob([as_text], { type: "image/svg+xml" });
var contentHeight = div.scrollHeight; // create an URI pointing to that blob
div.style.height = contentHeight + 'px'; const url = URL.createObjectURL(blob);
} else { const win = open(url);
div.style.height = '400px' // so the Garbage Collector can collect the blob
win.onload = (evt) => URL.revokeObjectURL(url);
}
var link = document.getElementById("mindmap-toggle");
if (link) {
link.onclick = function(event) {
event.preventDefault(); // Prevent the default link behavior
var div = document.getElementById("mindmap-wrapper");
if (div) {
var currentHeight = div.style.height;
if (currentHeight === '400px') {
var contentHeight = div.scrollHeight;
div.style.height = contentHeight + 'px';
} else {
div.style.height = '400px'
}
} }
} };
}; }
var link = document.getElementById("mindmap-export");
if (link) {
link.addEventListener('click', on_svg_export);
}
} }
return [links.length] return [links.length]
@ -127,6 +153,14 @@ class ChatPage(BasePage):
file_count="multiple", file_count="multiple",
container=True, container=True,
show_label=False, show_label=False,
elem_id="quick-file",
)
self.quick_urls = gr.Textbox(
placeholder="Or paste URLs here",
lines=1,
container=False,
show_label=False,
elem_id="quick-url",
) )
self.quick_file_upload_status = gr.Markdown() self.quick_file_upload_status = gr.Markdown()
@ -136,12 +170,17 @@ class ChatPage(BasePage):
self.chat_panel = ChatPanel(self._app) self.chat_panel = ChatPanel(self._app)
with gr.Row(): with gr.Row():
with gr.Accordion(label="Chat settings", open=False): with gr.Accordion(
label="Chat settings",
elem_id="chat-settings-expand",
open=False,
):
# a quick switch for reasoning type option # a quick switch for reasoning type option
with gr.Row(): with gr.Row():
gr.HTML("Reasoning method") gr.HTML("Reasoning method")
gr.HTML("Model") gr.HTML("Model")
gr.HTML("Generate mindmap") gr.HTML("Language")
gr.HTML("Citation")
with gr.Row(): with gr.Row():
reasoning_type_values = [ reasoning_type_values = [
@ -165,17 +204,36 @@ class ChatPage(BasePage):
container=False, container=False,
show_label=False, show_label=False,
) )
binary_default_choices = [ self.language = gr.Dropdown(
(DEFAULT_SETTING, DEFAULT_SETTING), choices=[
("Enable", True), (DEFAULT_SETTING, DEFAULT_SETTING),
("Disable", False), ]
] + self._app.default_settings.reasoning.settings[
self.use_mindmap = gr.Dropdown( "lang"
].choices,
value=DEFAULT_SETTING, value=DEFAULT_SETTING,
choices=binary_default_choices,
container=False, container=False,
show_label=False, show_label=False,
) )
self.citation = gr.Dropdown(
choices=[
(DEFAULT_SETTING, DEFAULT_SETTING),
]
+ self._app.default_settings.reasoning.options["simple"]
.settings["highlight_citation"]
.choices,
value=DEFAULT_SETTING,
container=False,
show_label=False,
interactive=True,
)
self.use_mindmap = gr.State(value=DEFAULT_SETTING)
self.use_mindmap_check = gr.Checkbox(
label="Mindmap (default)",
container=False,
elem_id="use-mindmap-checkbox",
)
with gr.Column( with gr.Column(
scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel" scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel"
@ -235,6 +293,8 @@ class ChatPage(BasePage):
self._reasoning_type, self._reasoning_type,
self.model_type, self.model_type,
self.use_mindmap, self.use_mindmap,
self.citation,
self.language,
self.state_chat, self.state_chat,
self._app.user_id, self._app.user_id,
] ]
@ -506,6 +566,12 @@ class ChatPage(BasePage):
inputs=[self.reasoning_type], inputs=[self.reasoning_type],
outputs=[self._reasoning_type], outputs=[self._reasoning_type],
) )
self.use_mindmap_check.change(
lambda x: (x, gr.update(label="Mindmap " + ("(on)" if x else "(off)"))),
inputs=[self.use_mindmap_check],
outputs=[self.use_mindmap, self.use_mindmap_check],
show_progress="hidden",
)
self.chat_control.conversation_id.change( self.chat_control.conversation_id.change(
lambda: gr.update(visible=False), lambda: gr.update(visible=False),
outputs=self.plot_panel, outputs=self.plot_panel,
@ -722,6 +788,8 @@ class ChatPage(BasePage):
session_reasoning_type: str, session_reasoning_type: str,
session_llm: str, session_llm: str,
session_use_mindmap: bool | str, session_use_mindmap: bool | str,
session_use_citation: str,
session_language: str,
state: dict, state: dict,
user_id: int, user_id: int,
*selecteds, *selecteds,
@ -743,6 +811,10 @@ class ChatPage(BasePage):
session_reasoning_type, session_reasoning_type,
"use mindmap", "use mindmap",
session_use_mindmap, session_use_mindmap,
"use citation",
session_use_citation,
"language",
session_language,
) )
print("Session LLM", session_llm) print("Session LLM", session_llm)
reasoning_mode = ( reasoning_mode = (
@ -766,6 +838,14 @@ class ChatPage(BasePage):
if session_use_mindmap not in (DEFAULT_SETTING, None): if session_use_mindmap not in (DEFAULT_SETTING, None):
settings["reasoning.options.simple.create_mindmap"] = session_use_mindmap settings["reasoning.options.simple.create_mindmap"] = session_use_mindmap
if session_use_citation not in (DEFAULT_SETTING, None):
settings[
"reasoning.options.simple.highlight_citation"
] = session_use_citation
if session_language not in (DEFAULT_SETTING, None):
settings["reasoning.lang"] = session_language
# get retrievers # get retrievers
retrievers = [] retrievers = []
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
@ -798,6 +878,8 @@ class ChatPage(BasePage):
reasoning_type, reasoning_type,
llm_type, llm_type,
use_mind_map, use_mind_map,
use_citation,
language,
state, state,
user_id, user_id,
*selecteds, *selecteds,
@ -814,7 +896,15 @@ class ChatPage(BasePage):
# construct the pipeline # construct the pipeline
pipeline, reasoning_state = self.create_pipeline( pipeline, reasoning_state = self.create_pipeline(
settings, reasoning_type, llm_type, use_mind_map, state, user_id, *selecteds settings,
reasoning_type,
llm_type,
use_mind_map,
use_citation,
language,
state,
user_id,
*selecteds,
) )
print("Reasoning state", reasoning_state) print("Reasoning state", reasoning_state)
pipeline.set_output_queue(queue) pipeline.set_output_queue(queue)

View File

@ -28,6 +28,7 @@ class ChatPanel(BasePage):
placeholder="Chat input", placeholder="Chat input",
container=False, container=False,
show_label=False, show_label=False,
elem_id="chat-input",
) )
def submit_msg(self, chat_input, chat_history): def submit_msg(self, chat_input, chat_history):

View File

@ -1,17 +1,10 @@
import html
import logging import logging
import threading import threading
from collections import defaultdict
from difflib import SequenceMatcher
from functools import partial
from typing import Generator from typing import Generator
import numpy as np
import tiktoken
from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization import ( from ktem.reasoning.prompt_optimization import (
CreateMindmapPipeline,
DecomposeQuestionPipeline, DecomposeQuestionPipeline,
RewriteQuestionPipeline, RewriteQuestionPipeline,
) )
@ -19,7 +12,6 @@ from ktem.utils.plantuml import PlantUML
from ktem.utils.render import Render from ktem.utils.render import Render
from ktem.utils.visualize_cited import CreateCitationVizPipeline from ktem.utils.visualize_cited import CreateCitationVizPipeline
from plotly.io import to_json from plotly.io import to_json
from theflow.settings import settings as flowsettings
from kotaemon.base import ( from kotaemon.base import (
AIMessage, AIMessage,
@ -30,399 +22,20 @@ from kotaemon.base import (
RetrievedDocument, RetrievedDocument,
SystemMessage, SystemMessage,
) )
from kotaemon.indices.qa.citation import CitationPipeline from kotaemon.indices.qa.citation_qa import (
from kotaemon.indices.splitters import TokenSplitter CONTEXT_RELEVANT_WARNING_SCORE,
from kotaemon.llms import ChatLLM, PromptTemplate DEFAULT_QA_TEXT_PROMPT,
AnswerWithContextPipeline,
)
from kotaemon.indices.qa.citation_qa_inline import AnswerWithInlineCitation
from kotaemon.indices.qa.format_context import PrepareEvidencePipeline
from kotaemon.llms import ChatLLM
from ..utils import SUPPORTED_LANGUAGE_MAP from ..utils import SUPPORTED_LANGUAGE_MAP
from .base import BaseReasoning from .base import BaseReasoning
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EVIDENCE_MODE_TEXT = 0
EVIDENCE_MODE_TABLE = 1
EVIDENCE_MODE_CHATBOT = 2
EVIDENCE_MODE_FIGURE = 3
MAX_IMAGES = 10
CITATION_TIMEOUT = 5.0
def find_text(search_span, context):
sentence_list = search_span.split("\n")
context = context.replace("\n", " ")
matches = []
# don't search for small text
if len(search_span) > 5:
for sentence in sentence_list:
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
if match.size > max(len(sentence) * 0.35, 5):
matches.append((match.b, match.b + match.size))
return matches
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
This step usually happens after `DocumentRetrievalPipeline`.
Args:
trim_func: a callback function or a BaseComponent, that splits a large
chunk of text into smaller ones. The first one will be retained.
"""
max_context_length: int = 32000
trim_func: TokenSplitter | None = None
def run(self, docs: list[RetrievedDocument]) -> Document:
evidence = ""
images = []
table_found = 0
evidence_modes = []
evidence_trim_func = (
self.trim_func
if self.trim_func
else TokenSplitter(
chunk_size=self.max_context_length,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
)
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
page = retrieved_item.metadata.get("page_label", None)
source = filename = retrieved_item.metadata.get("file_name", "-")
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
evidence_modes.append(EVIDENCE_MODE_TABLE)
if table_found < 5:
retrieved_content = retrieved_item.metadata.get(
"table_origin", retrieved_item.text
)
if retrieved_content not in evidence:
table_found += 1
evidence += (
f"<br><b>Table from {source}</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
evidence_modes.append(EVIDENCE_MODE_CHATBOT)
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "image":
evidence_modes.append(EVIDENCE_MODE_FIGURE)
retrieved_content = retrieved_item.metadata.get("image_origin", "")
retrieved_caption = html.escape(retrieved_item.get_content())
evidence += (
f"<br><b>Figure from {source}</b>\n"
+ "<img width='85%' src='<src>' "
+ f"alt='{retrieved_caption}'/>"
+ "\n<br>"
)
images.append(retrieved_content)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
else:
retrieved_content = retrieved_item.text
retrieved_content = retrieved_content.replace("\n", " ")
if retrieved_content not in evidence:
evidence += (
f"<br><b>Content from {source}: </b> "
+ retrieved_content
+ " \n<br>"
)
# resolve evidence mode
evidence_mode = EVIDENCE_MODE_TEXT
if EVIDENCE_MODE_FIGURE in evidence_modes:
evidence_mode = EVIDENCE_MODE_FIGURE
elif EVIDENCE_MODE_TABLE in evidence_modes:
evidence_mode = EVIDENCE_MODE_TABLE
# trim context by trim_len
print("len (original)", len(evidence))
if evidence:
texts = evidence_trim_func([Document(text=evidence)])
evidence = texts[0].text
print("len (trimmed)", len(evidence))
return Document(content=(evidence_mode, evidence, images))
DEFAULT_QA_TEXT_PROMPT = (
"Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
"If you don't know the answer, just say that you don't know, don't try to "
"make up an answer. Give answer in "
"{lang}.\n\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
)
DEFAULT_QA_TABLE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question, "
"then provide answer with clear explanation."
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. Give answer in {lang}.\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
) # noqa
DEFAULT_QA_CHATBOT_PROMPT = (
"Pick the most suitable chatbot scenarios to answer the question at the end, "
"output the provided answer text. If you don't know the answer, "
"just say that you don't know. Keep the answer as concise as possible. "
"Give answer in {lang}.\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Answer:"
) # noqa
DEFAULT_QA_FIGURE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Give answer in {lang}.\n\n"
"Context: \n"
"{context}\n"
"Question: {question}\n"
"Answer: "
) # noqa
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
Args:
llm: the language model to generate the answer
citation_pipeline: generates citation from the evidence
qa_template: the prompt template for LLM to generate answer (refer to
evidence_mode)
qa_table_template: the prompt template for LLM to generate answer for table
(refer to evidence_mode)
qa_chatbot_template: the prompt template for LLM to generate answer for
pre-made scenarios (refer to evidence_mode)
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
vlm_endpoint: str = getattr(flowsettings, "KH_VLM_ENDPOINT", "")
use_multimodal: bool = getattr(flowsettings, "KH_REASONINGS_USE_MULTIMODAL", True)
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
create_mindmap_pipeline: CreateMindmapPipeline = Node(
default_callback=lambda _: CreateMindmapPipeline(llm=llms.get_default())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
enable_citation: bool = False
enable_mindmap: bool = False
enable_citation_viz: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese
n_last_interactions: int = 5
def get_prompt(self, question, evidence, evidence_mode: int):
"""Prepare the prompt and other information for LLM"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
if self.use_multimodal:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
return prompt, evidence
def run(
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
return self.invoke(question, evidence, evidence_mode, **kwargs)
def invoke(
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Document:
raise NotImplementedError
async def ainvoke( # type: ignore
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Document:
"""Answer the question based on the evidence
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
raise NotImplementedError
def stream( # type: ignore
self,
question: str,
evidence: str,
evidence_mode: int = 0,
images: list[str] = [],
**kwargs,
) -> Generator[Document, None, Document]:
history = kwargs.get("history", [])
print(f"Got {len(images)} images")
# check if evidence exists, use QA prompt
if evidence:
prompt, evidence = self.get_prompt(question, evidence, evidence_mode)
else:
prompt = question
# retrieve the citation
citation = None
mindmap = None
def citation_call():
nonlocal citation
citation = self.citation_pipeline(context=evidence, question=question)
def mindmap_call():
nonlocal mindmap
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
citation_thread = None
mindmap_thread = None
# execute function call in thread
if evidence:
if self.enable_citation:
citation_thread = threading.Thread(target=citation_call)
citation_thread.start()
if self.enable_mindmap:
mindmap_thread = threading.Thread(target=mindmap_call)
mindmap_thread.start()
output = ""
logprobs = []
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
for human, ai in history[-self.n_last_interactions :]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
if self.use_multimodal and evidence_mode == EVIDENCE_MODE_FIGURE:
# create image message:
messages.append(
HumanMessage(
content=[
{"type": "text", "text": prompt},
]
+ [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images[:MAX_IMAGES]
],
)
)
else:
# append main prompt
messages.append(HumanMessage(content=prompt))
try:
# try streaming first
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
output += out_msg.text
logprobs += out_msg.logprobs
yield Document(channel="chat", content=out_msg.text)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
yield Document(channel="chat", content=output)
if logprobs:
qa_score = np.exp(np.average(logprobs))
else:
qa_score = None
if citation_thread:
citation_thread.join(timeout=CITATION_TIMEOUT)
if mindmap_thread:
mindmap_thread.join(timeout=CITATION_TIMEOUT)
answer = Document(
text=output,
metadata={
"citation_viz": self.enable_citation_viz,
"mindmap": mindmap,
"citation": citation,
"qa_score": qa_score,
},
)
return answer
class AddQueryContextPipeline(BaseComponent): class AddQueryContextPipeline(BaseComponent):
@ -481,7 +94,7 @@ class FullQAPipeline(BaseReasoning):
retrievers: list[BaseComponent] retrievers: list[BaseComponent]
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() answering_pipeline: AnswerWithContextPipeline
rewrite_pipeline: RewriteQuestionPipeline | None = None rewrite_pipeline: RewriteQuestionPipeline | None = None
create_citation_viz_pipeline: CreateCitationVizPipeline = Node( create_citation_viz_pipeline: CreateCitationVizPipeline = Node(
default_callback=lambda _: CreateCitationVizPipeline( default_callback=lambda _: CreateCitationVizPipeline(
@ -548,104 +161,35 @@ class FullQAPipeline(BaseReasoning):
return docs, info return docs, info
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
"""Prepare the citations to show on the UI"""
with_citation, without_citation = [], []
spans = defaultdict(list)
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
if answer.metadata["citation"]:
evidences = answer.metadata["citation"].evidences
for quote in evidences:
matched_excerpts = []
for doc in docs:
matches = find_text(quote, doc.text)
for start, end in matches:
if "|" not in doc.text[start:end]:
spans[doc.doc_id].append(
{
"start": start,
"end": end,
}
)
matched_excerpts.append(doc.text[start:end])
# print("Matched citation:", quote, matched_excerpts),
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())
# render highlight spans
for _id, ss in spans.items():
if not ss:
not_detected.add(_id)
continue
cur_doc = id2docs[_id]
highlight_text = ""
ss = sorted(ss, key=lambda x: x["start"])
text = cur_doc.text[: ss[0]["start"]]
for idx, span in enumerate(ss):
to_highlight = cur_doc.text[span["start"] : span["end"]]
if len(to_highlight) > len(highlight_text):
highlight_text = to_highlight
text += Render.highlight(to_highlight)
if idx < len(ss) - 1:
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
text += cur_doc.text[ss[-1]["end"] :]
# add to display list
with_citation.append(
Document(
channel="info",
content=Render.collapsible_with_header_score(
cur_doc,
override_text=text,
highlight_text=highlight_text,
open_collapsible=True,
),
)
)
print("Got {} cited docs".format(len(with_citation)))
sorted_not_detected_items_with_scores = [
(id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0))
for id_ in not_detected
]
sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True)
for id_, _ in sorted_not_detected_items_with_scores:
doc = id2docs[id_]
doc_score = doc.metadata.get("llm_trulens_score", 0.0)
is_open = not has_llm_score or (
doc_score > CONTEXT_RELEVANT_WARNING_SCORE and len(with_citation) == 0
)
without_citation.append(
Document(
channel="info",
content=Render.collapsible_with_header_score(
doc, open_collapsible=is_open
),
)
)
return with_citation, without_citation
def prepare_mindmap(self, answer) -> Document | None: def prepare_mindmap(self, answer) -> Document | None:
mindmap = answer.metadata["mindmap"] mindmap = answer.metadata["mindmap"]
if mindmap: if mindmap:
mindmap_text = mindmap.text mindmap_text = mindmap.text
uml_renderer = PlantUML() uml_renderer = PlantUML()
mindmap_svg = uml_renderer.process(mindmap_text)
try:
mindmap_svg = uml_renderer.process(mindmap_text)
except Exception as e:
print("Failed to process mindmap:", e)
mindmap_svg = "<svg></svg>"
# post-process the mindmap SVG
mindmap_svg = (
mindmap_svg.replace("sans-serif", "Quicksand, sans-serif")
.replace("#181818", "#cecece")
.replace("background:#FFFFF", "background:none")
.replace("stroke-width:1", "stroke-width:2")
)
mindmap_content = Document( mindmap_content = Document(
channel="info", channel="info",
content=Render.collapsible( content=Render.collapsible(
header=""" header="""
<i>Mindmap</i> <i>Mindmap</i>
<a href="#" id='mindmap-toggle'"> <a href="#" id='mindmap-toggle'>
[Expand] [Expand]</a>
</a>""", <a href="#" id='mindmap-export'>
[Export]</a>""",
content=mindmap_svg, content=mindmap_svg,
open=True, open=True,
), ),
@ -674,7 +218,9 @@ class FullQAPipeline(BaseReasoning):
def show_citations_and_addons(self, answer, docs, question): def show_citations_and_addons(self, answer, docs, question):
# show the evidence # show the evidence
with_citation, without_citation = self.prepare_citations(answer, docs) with_citation, without_citation = self.answering_pipeline.prepare_citations(
answer, docs
)
mindmap_output = self.prepare_mindmap(answer) mindmap_output = self.prepare_mindmap(answer)
citation_plot_output = self.prepare_citation_viz(answer, question, docs) citation_plot_output = self.prepare_citation_viz(answer, question, docs)
@ -773,6 +319,13 @@ class FullQAPipeline(BaseReasoning):
return answer return answer
@classmethod
def prepare_pipeline_instance(cls, settings, retrievers):
return cls(
retrievers=retrievers,
rewrite_pipeline=RewriteQuestionPipeline(),
)
@classmethod @classmethod
def get_pipeline(cls, settings, states, retrievers): def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline """Get the reasoning pipeline
@ -783,10 +336,7 @@ class FullQAPipeline(BaseReasoning):
""" """
max_context_length_setting = settings.get("reasoning.max_context_length", 32000) max_context_length_setting = settings.get("reasoning.max_context_length", 32000)
pipeline = cls( pipeline = cls.prepare_pipeline_instance(settings, retrievers)
retrievers=retrievers,
rewrite_pipeline=RewriteQuestionPipeline(),
)
prefix = f"reasoning.options.{cls.get_info()['id']}" prefix = f"reasoning.options.{cls.get_info()['id']}"
llm_name = settings.get(f"{prefix}.llm", None) llm_name = settings.get(f"{prefix}.llm", None)
@ -797,13 +347,22 @@ class FullQAPipeline(BaseReasoning):
evidence_pipeline.max_context_length = max_context_length_setting evidence_pipeline.max_context_length = max_context_length_setting
# answering pipeline configuration # answering pipeline configuration
answer_pipeline = pipeline.answering_pipeline use_inline_citation = settings[f"{prefix}.highlight_citation"] == "inline"
if use_inline_citation:
answer_pipeline = pipeline.answering_pipeline = AnswerWithInlineCitation()
else:
answer_pipeline = pipeline.answering_pipeline = AnswerWithContextPipeline()
answer_pipeline.llm = llm answer_pipeline.llm = llm
answer_pipeline.citation_pipeline.llm = llm answer_pipeline.citation_pipeline.llm = llm
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"] answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"] answer_pipeline.enable_citation = (
settings[f"{prefix}.highlight_citation"] != "off"
)
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"] answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"] answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"]
answer_pipeline.use_multimodal = settings[f"{prefix}.use_multimodal"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"] answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"] answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get( answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
@ -848,9 +407,10 @@ class FullQAPipeline(BaseReasoning):
), ),
}, },
"highlight_citation": { "highlight_citation": {
"name": "Highlight Citation", "name": "Citation style",
"value": True, "value": "highlight",
"component": "checkbox", "component": "radio",
"choices": ["highlight", "inline", "off"],
}, },
"create_mindmap": { "create_mindmap": {
"name": "Create Mindmap", "name": "Create Mindmap",
@ -862,6 +422,11 @@ class FullQAPipeline(BaseReasoning):
"value": False, "value": False,
"component": "checkbox", "component": "checkbox",
}, },
"use_multimodal": {
"name": "Use Multimodal Input",
"value": False,
"component": "checkbox",
},
"system_prompt": { "system_prompt": {
"name": "System Prompt", "name": "System Prompt",
"value": "This is a question answering system", "value": "This is a question answering system",
@ -979,7 +544,9 @@ class FullDecomposeQAPipeline(FullQAPipeline):
) )
# show the evidence # show the evidence
with_citation, without_citation = self.prepare_citations(answer, docs) with_citation, without_citation = self.answering_pipeline.prepare_citations(
answer, docs
)
if not with_citation and not without_citation: if not with_citation and not without_citation:
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>") yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
else: else:
@ -999,13 +566,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
return user_settings return user_settings
@classmethod @classmethod
def get_pipeline(cls, settings, states, retrievers): def prepare_pipeline_instance(cls, settings, retrievers):
"""Get the reasoning pipeline
Args:
settings: the settings for the pipeline
retrievers: the retrievers to use
"""
prefix = f"reasoning.options.{cls.get_info()['id']}" prefix = f"reasoning.options.{cls.get_info()['id']}"
pipeline = cls( pipeline = cls(
retrievers=retrievers, retrievers=retrievers,
@ -1013,31 +574,6 @@ class FullDecomposeQAPipeline(FullQAPipeline):
prompt_template=settings.get(f"{prefix}.decompose_prompt") prompt_template=settings.get(f"{prefix}.decompose_prompt")
), ),
) )
llm_name = settings.get(f"{prefix}.llm", None)
llm = llms.get(llm_name, llms.get_default())
# answering pipeline configuration
answer_pipeline = pipeline.answering_pipeline
answer_pipeline.llm = llm
answer_pipeline.citation_pipeline.llm = llm
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
pipeline.add_query_context.llm = llm
pipeline.add_query_context.n_last_interactions = settings[
f"{prefix}.n_last_interactions"
]
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
if pipeline.rewrite_pipeline:
pipeline.rewrite_pipeline.llm = llm
return pipeline return pipeline
@classmethod @classmethod

View File

@ -40,7 +40,10 @@ class Render:
def collapsible(header, content, open: bool = False) -> str: def collapsible(header, content, open: bool = False) -> str:
"""Render an HTML friendly collapsible section""" """Render an HTML friendly collapsible section"""
o = " open" if open else "" o = " open" if open else ""
return f"<details{o}><summary>{header}</summary>{content}</details><br>" return (
f"<details class='evidence' {o}><summary>"
f"{header}</summary>{content}</details><br>"
)
@staticmethod @staticmethod
def table(text: str) -> str: def table(text: str) -> str:
@ -103,9 +106,10 @@ class Render:
""" # noqa """ # noqa
@staticmethod @staticmethod
def highlight(text: str) -> str: def highlight(text: str, elem_id: str | None = None) -> str:
"""Highlight text""" """Highlight text"""
return f"<mark>{text}</mark>" id_text = f" id='mark-{elem_id}'" if elem_id else ""
return f"<mark{id_text}>{text}</mark>"
@staticmethod @staticmethod
def image(url: str, text: str = "") -> str: def image(url: str, text: str = "") -> str: