feat: add mindmap visualization (#405) bump:minor

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-10-17 14:35:28 +07:00 committed by GitHub
parent 4764b0e82a
commit e6fa1af404
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 263 additions and 9 deletions

View File

@ -55,6 +55,8 @@ class BaseApp:
"PDFJS_PREBUILT_DIR",
pdf_js_dist_dir,
)
with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi:
self._svg_js = fi.read()
self._favicon = str(dir_assets / "img" / "favicon.svg")
@ -172,6 +174,9 @@ class BaseApp:
"<script type='module' "
"src='https://cdn.skypack.dev/pdfjs-viewer-element'>"
"</script>"
"<script>"
f"{self._svg_js}"
"</script>"
)
with gr.Blocks(

File diff suppressed because one or more lines are too long

View File

@ -39,6 +39,29 @@ function() {
for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal;
}
var mindmap_el = document.getElementById('mindmap');
if (mindmap_el) {
var output = svgPanZoom(mindmap_el);
}
var link = document.getElementById("mindmap-toggle");
if (link) {
link.onclick = function(event) {
event.preventDefault(); // Prevent the default link behavior
var div = document.getElementById("mindmap-wrapper");
if (div) {
var currentHeight = div.style.height;
if (currentHeight === '400px') {
var contentHeight = div.scrollHeight;
div.style.height = contentHeight + 'px';
} else {
div.style.height = '400px'
}
}
};
}
return [links.length]
}
"""

View File

@ -1,9 +1,11 @@
from .decompose_question import DecomposeQuestionPipeline
from .fewshot_rewrite_question import FewshotRewriteQuestionPipeline
from .mindmap import CreateMindmapPipeline
from .rewrite_question import RewriteQuestionPipeline
__all__ = [
"DecomposeQuestionPipeline",
"FewshotRewriteQuestionPipeline",
"RewriteQuestionPipeline",
"CreateMindmapPipeline",
]

View File

@ -0,0 +1,52 @@
import logging
from ktem.llms.manager import llms
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
class CreateMindmapPipeline(BaseComponent):
"""Create a mindmap from the question and context"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
SYSTEM_PROMPT = """
From now on you will behave as "MapGPT" and, for every text the user will submit, you are going to create a PlantUML mind map file for the inputted text to best describe main ideas. Format it as a code and remember that the mind map should be in the same language as the inputted context. You don't have to provide a general example for the mind map format before the user inputs the text.
""" # noqa: E501
MINDMAP_PROMPT_TEMPLATE = """
Question:
{question}
Context:
{context}
Generate a sample PlantUML mindmap for based on the provided question and context above. Only includes context relevant to the question to produce the mindmap.
Use the template like this:
@startmindmap
* Title
** Item A
*** Item B
**** Item C
*** Item D
@endmindmap
""" # noqa: E501
prompt_template: str = MINDMAP_PROMPT_TEMPLATE
def run(self, question: str, context: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.prompt_template)
prompt = prompt_template.populate(
question=question,
context=context,
)
messages = [
SystemMessage(content=self.SYSTEM_PROMPT),
HumanMessage(content=prompt),
]
return self.llm(messages)

View File

@ -10,9 +10,11 @@ import numpy as np
import tiktoken
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization import (
CreateMindmapPipeline,
DecomposeQuestionPipeline,
RewriteQuestionPipeline,
)
from ktem.utils.plantuml import PlantUML
from ktem.utils.render import Render
from theflow.settings import settings as flowsettings
@ -227,6 +229,9 @@ class AnswerWithContextPipeline(BaseComponent):
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
create_mindmap_pipeline: CreateMindmapPipeline = Node(
default_callback=lambda _: CreateMindmapPipeline(llm=llms.get_default())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
@ -234,6 +239,8 @@ class AnswerWithContextPipeline(BaseComponent):
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
enable_citation: bool = False
enable_mindmap: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese
n_last_interactions: int = 5
@ -325,17 +332,28 @@ class AnswerWithContextPipeline(BaseComponent):
# retrieve the citation
citation = None
mindmap = None
def citation_call():
nonlocal citation
citation = self.citation_pipeline(context=evidence, question=question)
if evidence and self.enable_citation:
# execute function call in thread
citation_thread = threading.Thread(target=citation_call)
citation_thread.start()
else:
citation_thread = None
def mindmap_call():
nonlocal mindmap
mindmap = self.create_mindmap_pipeline(context=evidence, question=question)
citation_thread = None
mindmap_thread = None
# execute function call in thread
if evidence:
if self.enable_citation:
citation_thread = threading.Thread(target=citation_call)
citation_thread.start()
if self.enable_mindmap:
mindmap_thread = threading.Thread(target=mindmap_call)
mindmap_thread.start()
output = ""
logprobs = []
@ -386,10 +404,12 @@ class AnswerWithContextPipeline(BaseComponent):
if citation_thread:
citation_thread.join(timeout=CITATION_TIMEOUT)
if mindmap_thread:
mindmap_thread.join(timeout=CITATION_TIMEOUT)
answer = Document(
text=output,
metadata={"citation": citation, "qa_score": qa_score},
metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score},
)
return answer
@ -597,9 +617,35 @@ class FullQAPipeline(BaseReasoning):
)
return with_citation, without_citation
def show_citations(self, answer, docs):
def prepare_mindmap(self, answer) -> Document | None:
mindmap = answer.metadata["mindmap"]
if mindmap:
mindmap_text = mindmap.text
uml_renderer = PlantUML()
mindmap_svg = uml_renderer.process(mindmap_text)
mindmap_content = Document(
channel="info",
content=Render.collapsible(
header="""
<i>Mindmap</i>
<a href="#" id='mindmap-toggle'">
[Expand]
</a>""",
content=mindmap_svg,
open=True,
),
)
else:
mindmap_content = None
return mindmap_content
def show_citations_and_addons(self, answer, docs):
# show the evidence
with_citation, without_citation = self.prepare_citations(answer, docs)
mindmap_output = self.prepare_mindmap(answer)
if not with_citation and not without_citation:
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
else:
@ -611,6 +657,10 @@ class FullQAPipeline(BaseReasoning):
# clear previous info
yield Document(channel="info", content=None)
# yield mindmap output
if mindmap_output:
yield mindmap_output
# yield warning message
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
yield Document(
@ -683,7 +733,7 @@ class FullQAPipeline(BaseReasoning):
if scoring_thread:
scoring_thread.join()
yield from self.show_citations(answer, docs)
yield from self.show_citations_and_addons(answer, docs)
return answer
@ -716,6 +766,7 @@ class FullQAPipeline(BaseReasoning):
answer_pipeline.citation_pipeline.llm = llm
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
@ -764,6 +815,11 @@ class FullQAPipeline(BaseReasoning):
"value": True,
"component": "checkbox",
},
"create_mindmap": {
"name": "Create Mindmap",
"value": False,
"component": "checkbox",
},
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",

View File

@ -0,0 +1,113 @@
#!/usr/bin/env python
from __future__ import print_function
import base64
import string
from zlib import compress
import httplib2
import six # type: ignore
if six.PY2:
from string import maketrans
else:
maketrans = bytes.maketrans
plantuml_alphabet = (
string.digits + string.ascii_uppercase + string.ascii_lowercase + "-_"
)
base64_alphabet = string.ascii_uppercase + string.ascii_lowercase + string.digits + "+/"
b64_to_plantuml = maketrans(
base64_alphabet.encode("utf-8"), plantuml_alphabet.encode("utf-8")
)
class PlantUMLError(Exception):
"""
Error in processing.
"""
class PlantUMLConnectionError(PlantUMLError):
"""
Error connecting or talking to PlantUML Server.
"""
class PlantUMLHTTPError(PlantUMLConnectionError):
"""
Request to PlantUML server returned HTTP Error.
"""
def __init__(self, response, content, *args, **kwdargs):
self.response = response
self.content = content
message = "%d: %s" % (self.response.status, self.response.reason)
if not getattr(self, "message", None):
self.message = message
super(PlantUMLHTTPError, self).__init__(message, *args, **kwdargs)
def deflate_and_encode(plantuml_text):
"""zlib compress the plantuml text and encode it for the plantuml server."""
zlibbed_str = compress(plantuml_text.encode("utf-8"))
compressed_string = zlibbed_str[2:-4]
return (
base64.b64encode(compressed_string).translate(b64_to_plantuml).decode("utf-8")
)
class PlantUML(object):
"""Connection to a PlantUML server with optional authentication.
All parameters are optional.
:param str url: URL to the PlantUML server image CGI. defaults to
http://www.plantuml.com/plantuml/svg/
:param dict request_opts: Extra options to be passed off to the
httplib2.Http().request() call.
"""
def __init__(self, url="http://www.plantuml.com/plantuml/svg/", request_opts={}):
self.HttpLib2Error = httplib2.HttpLib2Error
self.http = httplib2.Http()
self.url = url
self.request_opts = request_opts
def get_url(self, plantuml_text):
"""Return the server URL for the image.
You can use this URL in an IMG HTML tag.
:param str plantuml_text: The plantuml markup to render
:returns: the plantuml server image URL
"""
return self.url + deflate_and_encode(plantuml_text)
def process(self, plantuml_text):
"""Processes the plantuml text into the raw PNG image data.
:param str plantuml_text: The plantuml markup to render
:returns: the raw image data
"""
url = self.get_url(plantuml_text)
try:
response, content = self.http.request(url, **self.request_opts)
except self.HttpLib2Error as e:
raise PlantUMLConnectionError(e)
if response.status != 200:
raise PlantUMLHTTPError(response, content)
svg_content = content.decode("utf-8")
svg_content = svg_content.replace("<svg ", "<svg id='mindmap' ")
# wrap in fixed height div
svg_content = (
"<div id='mindmap-wrapper' "
"style='height: 400px; overflow: hidden;'>"
f"{svg_content}</div>"
)
return svg_content