feat: add mindmap visualization (#405) bump:minor
This commit is contained in:
parent
4764b0e82a
commit
e6fa1af404
|
@ -55,6 +55,8 @@ class BaseApp:
|
||||||
"PDFJS_PREBUILT_DIR",
|
"PDFJS_PREBUILT_DIR",
|
||||||
pdf_js_dist_dir,
|
pdf_js_dist_dir,
|
||||||
)
|
)
|
||||||
|
with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi:
|
||||||
|
self._svg_js = fi.read()
|
||||||
|
|
||||||
self._favicon = str(dir_assets / "img" / "favicon.svg")
|
self._favicon = str(dir_assets / "img" / "favicon.svg")
|
||||||
|
|
||||||
|
@ -172,6 +174,9 @@ class BaseApp:
|
||||||
"<script type='module' "
|
"<script type='module' "
|
||||||
"src='https://cdn.skypack.dev/pdfjs-viewer-element'>"
|
"src='https://cdn.skypack.dev/pdfjs-viewer-element'>"
|
||||||
"</script>"
|
"</script>"
|
||||||
|
"<script>"
|
||||||
|
f"{self._svg_js}"
|
||||||
|
"</script>"
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
|
|
3
libs/ktem/ktem/assets/js/svg-pan-zoom.min.js
vendored
Normal file
3
libs/ktem/ktem/assets/js/svg-pan-zoom.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
|
@ -39,6 +39,29 @@ function() {
|
||||||
for (var i = 0; i < links.length; i++) {
|
for (var i = 0; i < links.length; i++) {
|
||||||
links[i].onclick = openModal;
|
links[i].onclick = openModal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mindmap_el = document.getElementById('mindmap');
|
||||||
|
if (mindmap_el) {
|
||||||
|
var output = svgPanZoom(mindmap_el);
|
||||||
|
}
|
||||||
|
|
||||||
|
var link = document.getElementById("mindmap-toggle");
|
||||||
|
if (link) {
|
||||||
|
link.onclick = function(event) {
|
||||||
|
event.preventDefault(); // Prevent the default link behavior
|
||||||
|
var div = document.getElementById("mindmap-wrapper");
|
||||||
|
if (div) {
|
||||||
|
var currentHeight = div.style.height;
|
||||||
|
if (currentHeight === '400px') {
|
||||||
|
var contentHeight = div.scrollHeight;
|
||||||
|
div.style.height = contentHeight + 'px';
|
||||||
|
} else {
|
||||||
|
div.style.height = '400px'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
return [links.length]
|
return [links.length]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from .decompose_question import DecomposeQuestionPipeline
|
from .decompose_question import DecomposeQuestionPipeline
|
||||||
from .fewshot_rewrite_question import FewshotRewriteQuestionPipeline
|
from .fewshot_rewrite_question import FewshotRewriteQuestionPipeline
|
||||||
|
from .mindmap import CreateMindmapPipeline
|
||||||
from .rewrite_question import RewriteQuestionPipeline
|
from .rewrite_question import RewriteQuestionPipeline
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DecomposeQuestionPipeline",
|
"DecomposeQuestionPipeline",
|
||||||
"FewshotRewriteQuestionPipeline",
|
"FewshotRewriteQuestionPipeline",
|
||||||
"RewriteQuestionPipeline",
|
"RewriteQuestionPipeline",
|
||||||
|
"CreateMindmapPipeline",
|
||||||
]
|
]
|
||||||
|
|
52
libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py
Normal file
52
libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
|
||||||
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateMindmapPipeline(BaseComponent):
|
||||||
|
"""Create a mindmap from the question and context"""
|
||||||
|
|
||||||
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
From now on you will behave as "MapGPT" and, for every text the user will submit, you are going to create a PlantUML mind map file for the inputted text to best describe main ideas. Format it as a code and remember that the mind map should be in the same language as the inputted context. You don't have to provide a general example for the mind map format before the user inputs the text.
|
||||||
|
""" # noqa: E501
|
||||||
|
MINDMAP_PROMPT_TEMPLATE = """
|
||||||
|
Question:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Context:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Generate a sample PlantUML mindmap for based on the provided question and context above. Only includes context relevant to the question to produce the mindmap.
|
||||||
|
|
||||||
|
Use the template like this:
|
||||||
|
|
||||||
|
@startmindmap
|
||||||
|
* Title
|
||||||
|
** Item A
|
||||||
|
*** Item B
|
||||||
|
**** Item C
|
||||||
|
*** Item D
|
||||||
|
@endmindmap
|
||||||
|
""" # noqa: E501
|
||||||
|
prompt_template: str = MINDMAP_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
def run(self, question: str, context: str) -> Document: # type: ignore
|
||||||
|
prompt_template = PromptTemplate(self.prompt_template)
|
||||||
|
prompt = prompt_template.populate(
|
||||||
|
question=question,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=self.SYSTEM_PROMPT),
|
||||||
|
HumanMessage(content=prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
return self.llm(messages)
|
|
@ -10,9 +10,11 @@ import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
from ktem.reasoning.prompt_optimization import (
|
from ktem.reasoning.prompt_optimization import (
|
||||||
|
CreateMindmapPipeline,
|
||||||
DecomposeQuestionPipeline,
|
DecomposeQuestionPipeline,
|
||||||
RewriteQuestionPipeline,
|
RewriteQuestionPipeline,
|
||||||
)
|
)
|
||||||
|
from ktem.utils.plantuml import PlantUML
|
||||||
from ktem.utils.render import Render
|
from ktem.utils.render import Render
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
@ -227,6 +229,9 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
citation_pipeline: CitationPipeline = Node(
|
citation_pipeline: CitationPipeline = Node(
|
||||||
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
|
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
|
||||||
)
|
)
|
||||||
|
create_mindmap_pipeline: CreateMindmapPipeline = Node(
|
||||||
|
default_callback=lambda _: CreateMindmapPipeline(llm=llms.get_default())
|
||||||
|
)
|
||||||
|
|
||||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||||
|
@ -234,6 +239,8 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
||||||
|
|
||||||
enable_citation: bool = False
|
enable_citation: bool = False
|
||||||
|
enable_mindmap: bool = False
|
||||||
|
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
lang: str = "English" # support English and Japanese
|
lang: str = "English" # support English and Japanese
|
||||||
n_last_interactions: int = 5
|
n_last_interactions: int = 5
|
||||||
|
@ -325,17 +332,28 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
# retrieve the citation
|
# retrieve the citation
|
||||||
citation = None
|
citation = None
|
||||||
|
mindmap = None
|
||||||
|
|
||||||
def citation_call():
|
def citation_call():
|
||||||
nonlocal citation
|
nonlocal citation
|
||||||
citation = self.citation_pipeline(context=evidence, question=question)
|
citation = self.citation_pipeline(context=evidence, question=question)
|
||||||
|
|
||||||
if evidence and self.enable_citation:
|
def mindmap_call():
|
||||||
# execute function call in thread
|
nonlocal mindmap
|
||||||
citation_thread = threading.Thread(target=citation_call)
|
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
|
||||||
citation_thread.start()
|
|
||||||
else:
|
citation_thread = None
|
||||||
citation_thread = None
|
mindmap_thread = None
|
||||||
|
|
||||||
|
# execute function call in thread
|
||||||
|
if evidence:
|
||||||
|
if self.enable_citation:
|
||||||
|
citation_thread = threading.Thread(target=citation_call)
|
||||||
|
citation_thread.start()
|
||||||
|
|
||||||
|
if self.enable_mindmap:
|
||||||
|
mindmap_thread = threading.Thread(target=mindmap_call)
|
||||||
|
mindmap_thread.start()
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -386,10 +404,12 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
if citation_thread:
|
if citation_thread:
|
||||||
citation_thread.join(timeout=CITATION_TIMEOUT)
|
citation_thread.join(timeout=CITATION_TIMEOUT)
|
||||||
|
if mindmap_thread:
|
||||||
|
mindmap_thread.join(timeout=CITATION_TIMEOUT)
|
||||||
|
|
||||||
answer = Document(
|
answer = Document(
|
||||||
text=output,
|
text=output,
|
||||||
metadata={"citation": citation, "qa_score": qa_score},
|
metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score},
|
||||||
)
|
)
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
@ -597,9 +617,35 @@ class FullQAPipeline(BaseReasoning):
|
||||||
)
|
)
|
||||||
return with_citation, without_citation
|
return with_citation, without_citation
|
||||||
|
|
||||||
def show_citations(self, answer, docs):
|
def prepare_mindmap(self, answer) -> Document | None:
|
||||||
|
mindmap = answer.metadata["mindmap"]
|
||||||
|
if mindmap:
|
||||||
|
mindmap_text = mindmap.text
|
||||||
|
uml_renderer = PlantUML()
|
||||||
|
mindmap_svg = uml_renderer.process(mindmap_text)
|
||||||
|
|
||||||
|
mindmap_content = Document(
|
||||||
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header="""
|
||||||
|
<i>Mindmap</i>
|
||||||
|
<a href="#" id='mindmap-toggle'">
|
||||||
|
[Expand]
|
||||||
|
</a>""",
|
||||||
|
content=mindmap_svg,
|
||||||
|
open=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mindmap_content = None
|
||||||
|
|
||||||
|
return mindmap_content
|
||||||
|
|
||||||
|
def show_citations_and_addons(self, answer, docs):
|
||||||
# show the evidence
|
# show the evidence
|
||||||
with_citation, without_citation = self.prepare_citations(answer, docs)
|
with_citation, without_citation = self.prepare_citations(answer, docs)
|
||||||
|
mindmap_output = self.prepare_mindmap(answer)
|
||||||
|
|
||||||
if not with_citation and not without_citation:
|
if not with_citation and not without_citation:
|
||||||
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
||||||
else:
|
else:
|
||||||
|
@ -611,6 +657,10 @@ class FullQAPipeline(BaseReasoning):
|
||||||
# clear previous info
|
# clear previous info
|
||||||
yield Document(channel="info", content=None)
|
yield Document(channel="info", content=None)
|
||||||
|
|
||||||
|
# yield mindmap output
|
||||||
|
if mindmap_output:
|
||||||
|
yield mindmap_output
|
||||||
|
|
||||||
# yield warning message
|
# yield warning message
|
||||||
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
|
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
|
||||||
yield Document(
|
yield Document(
|
||||||
|
@ -683,7 +733,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
if scoring_thread:
|
if scoring_thread:
|
||||||
scoring_thread.join()
|
scoring_thread.join()
|
||||||
|
|
||||||
yield from self.show_citations(answer, docs)
|
yield from self.show_citations_and_addons(answer, docs)
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
@ -716,6 +766,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
answer_pipeline.citation_pipeline.llm = llm
|
answer_pipeline.citation_pipeline.llm = llm
|
||||||
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
||||||
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||||
|
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
|
||||||
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
||||||
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
||||||
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
||||||
|
@ -764,6 +815,11 @@ class FullQAPipeline(BaseReasoning):
|
||||||
"value": True,
|
"value": True,
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
|
"create_mindmap": {
|
||||||
|
"name": "Create Mindmap",
|
||||||
|
"value": False,
|
||||||
|
"component": "checkbox",
|
||||||
|
},
|
||||||
"system_prompt": {
|
"system_prompt": {
|
||||||
"name": "System Prompt",
|
"name": "System Prompt",
|
||||||
"value": "This is a question answering system",
|
"value": "This is a question answering system",
|
||||||
|
|
113
libs/ktem/ktem/utils/plantuml.py
Normal file
113
libs/ktem/ktem/utils/plantuml.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import string
|
||||||
|
from zlib import compress
|
||||||
|
|
||||||
|
import httplib2
|
||||||
|
import six # type: ignore
|
||||||
|
|
||||||
|
if six.PY2:
|
||||||
|
from string import maketrans
|
||||||
|
else:
|
||||||
|
maketrans = bytes.maketrans
|
||||||
|
|
||||||
|
|
||||||
|
plantuml_alphabet = (
|
||||||
|
string.digits + string.ascii_uppercase + string.ascii_lowercase + "-_"
|
||||||
|
)
|
||||||
|
base64_alphabet = string.ascii_uppercase + string.ascii_lowercase + string.digits + "+/"
|
||||||
|
b64_to_plantuml = maketrans(
|
||||||
|
base64_alphabet.encode("utf-8"), plantuml_alphabet.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlantUMLError(Exception):
|
||||||
|
"""
|
||||||
|
Error in processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PlantUMLConnectionError(PlantUMLError):
|
||||||
|
"""
|
||||||
|
Error connecting or talking to PlantUML Server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PlantUMLHTTPError(PlantUMLConnectionError):
|
||||||
|
"""
|
||||||
|
Request to PlantUML server returned HTTP Error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, response, content, *args, **kwdargs):
|
||||||
|
self.response = response
|
||||||
|
self.content = content
|
||||||
|
message = "%d: %s" % (self.response.status, self.response.reason)
|
||||||
|
if not getattr(self, "message", None):
|
||||||
|
self.message = message
|
||||||
|
super(PlantUMLHTTPError, self).__init__(message, *args, **kwdargs)
|
||||||
|
|
||||||
|
|
||||||
|
def deflate_and_encode(plantuml_text):
|
||||||
|
"""zlib compress the plantuml text and encode it for the plantuml server."""
|
||||||
|
zlibbed_str = compress(plantuml_text.encode("utf-8"))
|
||||||
|
compressed_string = zlibbed_str[2:-4]
|
||||||
|
return (
|
||||||
|
base64.b64encode(compressed_string).translate(b64_to_plantuml).decode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlantUML(object):
|
||||||
|
"""Connection to a PlantUML server with optional authentication.
|
||||||
|
|
||||||
|
All parameters are optional.
|
||||||
|
|
||||||
|
:param str url: URL to the PlantUML server image CGI. defaults to
|
||||||
|
http://www.plantuml.com/plantuml/svg/
|
||||||
|
:param dict request_opts: Extra options to be passed off to the
|
||||||
|
httplib2.Http().request() call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url="http://www.plantuml.com/plantuml/svg/", request_opts={}):
|
||||||
|
self.HttpLib2Error = httplib2.HttpLib2Error
|
||||||
|
self.http = httplib2.Http()
|
||||||
|
|
||||||
|
self.url = url
|
||||||
|
self.request_opts = request_opts
|
||||||
|
|
||||||
|
def get_url(self, plantuml_text):
|
||||||
|
"""Return the server URL for the image.
|
||||||
|
You can use this URL in an IMG HTML tag.
|
||||||
|
|
||||||
|
:param str plantuml_text: The plantuml markup to render
|
||||||
|
:returns: the plantuml server image URL
|
||||||
|
"""
|
||||||
|
return self.url + deflate_and_encode(plantuml_text)
|
||||||
|
|
||||||
|
def process(self, plantuml_text):
|
||||||
|
"""Processes the plantuml text into the raw PNG image data.
|
||||||
|
|
||||||
|
:param str plantuml_text: The plantuml markup to render
|
||||||
|
:returns: the raw image data
|
||||||
|
"""
|
||||||
|
url = self.get_url(plantuml_text)
|
||||||
|
try:
|
||||||
|
response, content = self.http.request(url, **self.request_opts)
|
||||||
|
except self.HttpLib2Error as e:
|
||||||
|
raise PlantUMLConnectionError(e)
|
||||||
|
if response.status != 200:
|
||||||
|
raise PlantUMLHTTPError(response, content)
|
||||||
|
|
||||||
|
svg_content = content.decode("utf-8")
|
||||||
|
svg_content = svg_content.replace("<svg ", "<svg id='mindmap' ")
|
||||||
|
|
||||||
|
# wrap in fixed height div
|
||||||
|
svg_content = (
|
||||||
|
"<div id='mindmap-wrapper' "
|
||||||
|
"style='height: 400px; overflow: hidden;'>"
|
||||||
|
f"{svg_content}</div>"
|
||||||
|
)
|
||||||
|
|
||||||
|
return svg_content
|
Loading…
Reference in New Issue
Block a user