feat: add mindmap visualization (#405) bump:minor
This commit is contained in:
parent
4764b0e82a
commit
e6fa1af404
|
@ -55,6 +55,8 @@ class BaseApp:
|
|||
"PDFJS_PREBUILT_DIR",
|
||||
pdf_js_dist_dir,
|
||||
)
|
||||
with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi:
|
||||
self._svg_js = fi.read()
|
||||
|
||||
self._favicon = str(dir_assets / "img" / "favicon.svg")
|
||||
|
||||
|
@ -172,6 +174,9 @@ class BaseApp:
|
|||
"<script type='module' "
|
||||
"src='https://cdn.skypack.dev/pdfjs-viewer-element'>"
|
||||
"</script>"
|
||||
"<script>"
|
||||
f"{self._svg_js}"
|
||||
"</script>"
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
|
|
3
libs/ktem/ktem/assets/js/svg-pan-zoom.min.js
vendored
Normal file
3
libs/ktem/ktem/assets/js/svg-pan-zoom.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
|
@ -39,6 +39,29 @@ function() {
|
|||
for (var i = 0; i < links.length; i++) {
|
||||
links[i].onclick = openModal;
|
||||
}
|
||||
|
||||
var mindmap_el = document.getElementById('mindmap');
|
||||
if (mindmap_el) {
|
||||
var output = svgPanZoom(mindmap_el);
|
||||
}
|
||||
|
||||
var link = document.getElementById("mindmap-toggle");
|
||||
if (link) {
|
||||
link.onclick = function(event) {
|
||||
event.preventDefault(); // Prevent the default link behavior
|
||||
var div = document.getElementById("mindmap-wrapper");
|
||||
if (div) {
|
||||
var currentHeight = div.style.height;
|
||||
if (currentHeight === '400px') {
|
||||
var contentHeight = div.scrollHeight;
|
||||
div.style.height = contentHeight + 'px';
|
||||
} else {
|
||||
div.style.height = '400px'
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return [links.length]
|
||||
}
|
||||
"""
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from .decompose_question import DecomposeQuestionPipeline
|
||||
from .fewshot_rewrite_question import FewshotRewriteQuestionPipeline
|
||||
from .mindmap import CreateMindmapPipeline
|
||||
from .rewrite_question import RewriteQuestionPipeline
|
||||
|
||||
__all__ = [
|
||||
"DecomposeQuestionPipeline",
|
||||
"FewshotRewriteQuestionPipeline",
|
||||
"RewriteQuestionPipeline",
|
||||
"CreateMindmapPipeline",
|
||||
]
|
||||
|
|
52
libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py
Normal file
52
libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import logging
|
||||
|
||||
from ktem.llms.manager import llms
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateMindmapPipeline(BaseComponent):
|
||||
"""Create a mindmap from the question and context"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
From now on you will behave as "MapGPT" and, for every text the user will submit, you are going to create a PlantUML mind map file for the inputted text to best describe main ideas. Format it as a code and remember that the mind map should be in the same language as the inputted context. You don't have to provide a general example for the mind map format before the user inputs the text.
|
||||
""" # noqa: E501
|
||||
MINDMAP_PROMPT_TEMPLATE = """
|
||||
Question:
|
||||
{question}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Generate a sample PlantUML mindmap for based on the provided question and context above. Only includes context relevant to the question to produce the mindmap.
|
||||
|
||||
Use the template like this:
|
||||
|
||||
@startmindmap
|
||||
* Title
|
||||
** Item A
|
||||
*** Item B
|
||||
**** Item C
|
||||
*** Item D
|
||||
@endmindmap
|
||||
""" # noqa: E501
|
||||
prompt_template: str = MINDMAP_PROMPT_TEMPLATE
|
||||
|
||||
def run(self, question: str, context: str) -> Document: # type: ignore
|
||||
prompt_template = PromptTemplate(self.prompt_template)
|
||||
prompt = prompt_template.populate(
|
||||
question=question,
|
||||
context=context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.SYSTEM_PROMPT),
|
||||
HumanMessage(content=prompt),
|
||||
]
|
||||
|
||||
return self.llm(messages)
|
|
@ -10,9 +10,11 @@ import numpy as np
|
|||
import tiktoken
|
||||
from ktem.llms.manager import llms
|
||||
from ktem.reasoning.prompt_optimization import (
|
||||
CreateMindmapPipeline,
|
||||
DecomposeQuestionPipeline,
|
||||
RewriteQuestionPipeline,
|
||||
)
|
||||
from ktem.utils.plantuml import PlantUML
|
||||
from ktem.utils.render import Render
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
|
@ -227,6 +229,9 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
|
||||
)
|
||||
create_mindmap_pipeline: CreateMindmapPipeline = Node(
|
||||
default_callback=lambda _: CreateMindmapPipeline(llm=llms.get_default())
|
||||
)
|
||||
|
||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||
|
@ -234,6 +239,8 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
||||
|
||||
enable_citation: bool = False
|
||||
enable_mindmap: bool = False
|
||||
|
||||
system_prompt: str = ""
|
||||
lang: str = "English" # support English and Japanese
|
||||
n_last_interactions: int = 5
|
||||
|
@ -325,17 +332,28 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
|
||||
# retrieve the citation
|
||||
citation = None
|
||||
mindmap = None
|
||||
|
||||
def citation_call():
|
||||
nonlocal citation
|
||||
citation = self.citation_pipeline(context=evidence, question=question)
|
||||
|
||||
if evidence and self.enable_citation:
|
||||
# execute function call in thread
|
||||
citation_thread = threading.Thread(target=citation_call)
|
||||
citation_thread.start()
|
||||
else:
|
||||
citation_thread = None
|
||||
def mindmap_call():
|
||||
nonlocal mindmap
|
||||
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
|
||||
|
||||
citation_thread = None
|
||||
mindmap_thread = None
|
||||
|
||||
# execute function call in thread
|
||||
if evidence:
|
||||
if self.enable_citation:
|
||||
citation_thread = threading.Thread(target=citation_call)
|
||||
citation_thread.start()
|
||||
|
||||
if self.enable_mindmap:
|
||||
mindmap_thread = threading.Thread(target=mindmap_call)
|
||||
mindmap_thread.start()
|
||||
|
||||
output = ""
|
||||
logprobs = []
|
||||
|
@ -386,10 +404,12 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
|
||||
if citation_thread:
|
||||
citation_thread.join(timeout=CITATION_TIMEOUT)
|
||||
if mindmap_thread:
|
||||
mindmap_thread.join(timeout=CITATION_TIMEOUT)
|
||||
|
||||
answer = Document(
|
||||
text=output,
|
||||
metadata={"citation": citation, "qa_score": qa_score},
|
||||
metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score},
|
||||
)
|
||||
|
||||
return answer
|
||||
|
@ -597,9 +617,35 @@ class FullQAPipeline(BaseReasoning):
|
|||
)
|
||||
return with_citation, without_citation
|
||||
|
||||
def show_citations(self, answer, docs):
|
||||
def prepare_mindmap(self, answer) -> Document | None:
|
||||
mindmap = answer.metadata["mindmap"]
|
||||
if mindmap:
|
||||
mindmap_text = mindmap.text
|
||||
uml_renderer = PlantUML()
|
||||
mindmap_svg = uml_renderer.process(mindmap_text)
|
||||
|
||||
mindmap_content = Document(
|
||||
channel="info",
|
||||
content=Render.collapsible(
|
||||
header="""
|
||||
<i>Mindmap</i>
|
||||
<a href="#" id='mindmap-toggle'">
|
||||
[Expand]
|
||||
</a>""",
|
||||
content=mindmap_svg,
|
||||
open=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
mindmap_content = None
|
||||
|
||||
return mindmap_content
|
||||
|
||||
def show_citations_and_addons(self, answer, docs):
|
||||
# show the evidence
|
||||
with_citation, without_citation = self.prepare_citations(answer, docs)
|
||||
mindmap_output = self.prepare_mindmap(answer)
|
||||
|
||||
if not with_citation and not without_citation:
|
||||
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
|
||||
else:
|
||||
|
@ -611,6 +657,10 @@ class FullQAPipeline(BaseReasoning):
|
|||
# clear previous info
|
||||
yield Document(channel="info", content=None)
|
||||
|
||||
# yield mindmap output
|
||||
if mindmap_output:
|
||||
yield mindmap_output
|
||||
|
||||
# yield warning message
|
||||
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
|
||||
yield Document(
|
||||
|
@ -683,7 +733,7 @@ class FullQAPipeline(BaseReasoning):
|
|||
if scoring_thread:
|
||||
scoring_thread.join()
|
||||
|
||||
yield from self.show_citations(answer, docs)
|
||||
yield from self.show_citations_and_addons(answer, docs)
|
||||
|
||||
return answer
|
||||
|
||||
|
@ -716,6 +766,7 @@ class FullQAPipeline(BaseReasoning):
|
|||
answer_pipeline.citation_pipeline.llm = llm
|
||||
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
||||
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
|
||||
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
||||
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
|
||||
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
||||
|
@ -764,6 +815,11 @@ class FullQAPipeline(BaseReasoning):
|
|||
"value": True,
|
||||
"component": "checkbox",
|
||||
},
|
||||
"create_mindmap": {
|
||||
"name": "Create Mindmap",
|
||||
"value": False,
|
||||
"component": "checkbox",
|
||||
},
|
||||
"system_prompt": {
|
||||
"name": "System Prompt",
|
||||
"value": "This is a question answering system",
|
||||
|
|
113
libs/ktem/ktem/utils/plantuml.py
Normal file
113
libs/ktem/ktem/utils/plantuml.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import base64
|
||||
import string
|
||||
from zlib import compress
|
||||
|
||||
import httplib2
|
||||
import six # type: ignore
|
||||
|
||||
if six.PY2:
|
||||
from string import maketrans
|
||||
else:
|
||||
maketrans = bytes.maketrans
|
||||
|
||||
|
||||
plantuml_alphabet = (
|
||||
string.digits + string.ascii_uppercase + string.ascii_lowercase + "-_"
|
||||
)
|
||||
base64_alphabet = string.ascii_uppercase + string.ascii_lowercase + string.digits + "+/"
|
||||
b64_to_plantuml = maketrans(
|
||||
base64_alphabet.encode("utf-8"), plantuml_alphabet.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class PlantUMLError(Exception):
|
||||
"""
|
||||
Error in processing.
|
||||
"""
|
||||
|
||||
|
||||
class PlantUMLConnectionError(PlantUMLError):
|
||||
"""
|
||||
Error connecting or talking to PlantUML Server.
|
||||
"""
|
||||
|
||||
|
||||
class PlantUMLHTTPError(PlantUMLConnectionError):
|
||||
"""
|
||||
Request to PlantUML server returned HTTP Error.
|
||||
"""
|
||||
|
||||
def __init__(self, response, content, *args, **kwdargs):
|
||||
self.response = response
|
||||
self.content = content
|
||||
message = "%d: %s" % (self.response.status, self.response.reason)
|
||||
if not getattr(self, "message", None):
|
||||
self.message = message
|
||||
super(PlantUMLHTTPError, self).__init__(message, *args, **kwdargs)
|
||||
|
||||
|
||||
def deflate_and_encode(plantuml_text):
|
||||
"""zlib compress the plantuml text and encode it for the plantuml server."""
|
||||
zlibbed_str = compress(plantuml_text.encode("utf-8"))
|
||||
compressed_string = zlibbed_str[2:-4]
|
||||
return (
|
||||
base64.b64encode(compressed_string).translate(b64_to_plantuml).decode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class PlantUML(object):
|
||||
"""Connection to a PlantUML server with optional authentication.
|
||||
|
||||
All parameters are optional.
|
||||
|
||||
:param str url: URL to the PlantUML server image CGI. defaults to
|
||||
http://www.plantuml.com/plantuml/svg/
|
||||
:param dict request_opts: Extra options to be passed off to the
|
||||
httplib2.Http().request() call.
|
||||
"""
|
||||
|
||||
def __init__(self, url="http://www.plantuml.com/plantuml/svg/", request_opts={}):
|
||||
self.HttpLib2Error = httplib2.HttpLib2Error
|
||||
self.http = httplib2.Http()
|
||||
|
||||
self.url = url
|
||||
self.request_opts = request_opts
|
||||
|
||||
def get_url(self, plantuml_text):
|
||||
"""Return the server URL for the image.
|
||||
You can use this URL in an IMG HTML tag.
|
||||
|
||||
:param str plantuml_text: The plantuml markup to render
|
||||
:returns: the plantuml server image URL
|
||||
"""
|
||||
return self.url + deflate_and_encode(plantuml_text)
|
||||
|
||||
def process(self, plantuml_text):
|
||||
"""Processes the plantuml text into the raw PNG image data.
|
||||
|
||||
:param str plantuml_text: The plantuml markup to render
|
||||
:returns: the raw image data
|
||||
"""
|
||||
url = self.get_url(plantuml_text)
|
||||
try:
|
||||
response, content = self.http.request(url, **self.request_opts)
|
||||
except self.HttpLib2Error as e:
|
||||
raise PlantUMLConnectionError(e)
|
||||
if response.status != 200:
|
||||
raise PlantUMLHTTPError(response, content)
|
||||
|
||||
svg_content = content.decode("utf-8")
|
||||
svg_content = svg_content.replace("<svg ", "<svg id='mindmap' ")
|
||||
|
||||
# wrap in fixed height div
|
||||
svg_content = (
|
||||
"<div id='mindmap-wrapper' "
|
||||
"style='height: 400px; overflow: hidden;'>"
|
||||
f"{svg_content}</div>"
|
||||
)
|
||||
|
||||
return svg_content
|
Loading…
Reference in New Issue
Block a user