Feat/add multimodal loader (#5)
* Add Adobe reader as the multimodal loader * Allow FullQAPipeline to reasoning on figures * fix: move the adobe import to avoid ImportError, notify users whenever they run the AdobeReader --------- Co-authored-by: cin-albert <albert@cinnamon.is>
This commit is contained in:
@@ -124,6 +124,11 @@ if config("LOCAL_MODEL", default=""):
|
||||
|
||||
|
||||
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
||||
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||
config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
|
||||
config("OPENAI_API_VERSION", default=""),
|
||||
)
|
||||
|
||||
|
||||
SETTINGS_APP = {
|
||||
|
@@ -378,6 +378,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||
("PDF text parser", "normal"),
|
||||
("Mathpix", "mathpix"),
|
||||
("Advanced ocr", "ocr"),
|
||||
("Multimodal parser", "multimodal"),
|
||||
],
|
||||
"component": "dropdown",
|
||||
},
|
||||
|
@@ -1,11 +1,14 @@
|
||||
import asyncio
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
import tiktoken
|
||||
from ktem.components import llms
|
||||
from ktem.reasoning.base import BaseReasoning
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import (
|
||||
BaseComponent,
|
||||
@@ -18,9 +21,15 @@ from kotaemon.base import (
|
||||
from kotaemon.indices.qa.citation import CitationPipeline
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVIDENCE_MODE_TEXT = 0
|
||||
EVIDENCE_MODE_TABLE = 1
|
||||
EVIDENCE_MODE_CHATBOT = 2
|
||||
EVIDENCE_MODE_FIGURE = 3
|
||||
|
||||
|
||||
class PrepareEvidencePipeline(BaseComponent):
|
||||
"""Prepare the evidence text from the list of retrieved documents
|
||||
@@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||
def run(self, docs: list[RetrievedDocument]) -> Document:
|
||||
evidence = ""
|
||||
table_found = 0
|
||||
evidence_mode = 0
|
||||
evidence_mode = EVIDENCE_MODE_TEXT
|
||||
|
||||
for _id, retrieved_item in enumerate(docs):
|
||||
retrieved_content = ""
|
||||
@@ -55,7 +64,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||
if page:
|
||||
source += f" (Page {page})"
|
||||
if retrieved_item.metadata.get("type", "") == "table":
|
||||
evidence_mode = 1 # table
|
||||
evidence_mode = EVIDENCE_MODE_TABLE
|
||||
if table_found < 5:
|
||||
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
||||
if retrieved_content not in evidence:
|
||||
@@ -66,13 +75,23 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||
evidence_mode = 2 # chatbot
|
||||
evidence_mode = EVIDENCE_MODE_CHATBOT
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
evidence += (
|
||||
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||
+ retrieved_content
|
||||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "image":
|
||||
evidence_mode = EVIDENCE_MODE_FIGURE
|
||||
retrieved_content = retrieved_item.metadata.get("image_origin", "")
|
||||
retrieved_caption = html.escape(retrieved_item.get_content())
|
||||
evidence += (
|
||||
f"<br><b>Figure from {source}</b>\n"
|
||||
+ f"<img width='85%' src='{retrieved_content}' "
|
||||
+ f"alt='{retrieved_caption}'/>"
|
||||
+ "\n<br>"
|
||||
)
|
||||
else:
|
||||
if "window" in retrieved_item.metadata:
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
@@ -90,12 +109,13 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||
print(retrieved_item.metadata)
|
||||
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
||||
|
||||
# trim context by trim_len
|
||||
print("len (original)", len(evidence))
|
||||
if evidence:
|
||||
texts = self.trim_func([Document(text=evidence)])
|
||||
evidence = texts[0].text
|
||||
print("len (trimmed)", len(evidence))
|
||||
if evidence_mode != EVIDENCE_MODE_FIGURE:
|
||||
# trim context by trim_len
|
||||
print("len (original)", len(evidence))
|
||||
if evidence:
|
||||
texts = self.trim_func([Document(text=evidence)])
|
||||
evidence = texts[0].text
|
||||
print("len (trimmed)", len(evidence))
|
||||
|
||||
print(f"PrepareEvidence with input {docs}\nOutput: {evidence}\n")
|
||||
|
||||
@@ -134,6 +154,16 @@ DEFAULT_QA_CHATBOT_PROMPT = (
|
||||
"Answer:"
|
||||
)
|
||||
|
||||
DEFAULT_QA_FIGURE_PROMPT = (
|
||||
"Use the given context: texts, tables, and figures below to answer the question. "
|
||||
"If you don't know the answer, just say that you don't know. "
|
||||
"Give answer in {lang}.\n\n"
|
||||
"Context: \n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer: "
|
||||
)
|
||||
|
||||
|
||||
class AnswerWithContextPipeline(BaseComponent):
|
||||
"""Answer the question based on the evidence
|
||||
@@ -151,6 +181,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
||||
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
||||
)
|
||||
@@ -158,6 +189,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
||||
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
||||
|
||||
enable_citation: bool = False
|
||||
system_prompt: str = ""
|
||||
@@ -188,18 +220,30 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
(determined by retrieval pipeline)
|
||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||
"""
|
||||
if evidence_mode == 0:
|
||||
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||
prompt_template = PromptTemplate(self.qa_template)
|
||||
elif evidence_mode == 1:
|
||||
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||
prompt_template = PromptTemplate(self.qa_table_template)
|
||||
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||
else:
|
||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
images = []
|
||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
# isolate image from evidence
|
||||
evidence, images = self.extract_evidence_images(evidence)
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
|
||||
citation_task = None
|
||||
if evidence and self.enable_citation:
|
||||
@@ -208,23 +252,29 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
)
|
||||
print("Citation task created")
|
||||
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
output = ""
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for text in self.llm.stream(messages):
|
||||
output += text.text
|
||||
self.report_output({"output": text.text})
|
||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
||||
output += text
|
||||
self.report_output({"output": text})
|
||||
await asyncio.sleep(0)
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal processing")
|
||||
output = self.llm(messages).text
|
||||
self.report_output({"output": output})
|
||||
else:
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for text in self.llm.stream(messages):
|
||||
output += text.text
|
||||
self.report_output({"output": text.text})
|
||||
await asyncio.sleep(0)
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal processing")
|
||||
output = self.llm(messages).text
|
||||
self.report_output({"output": output})
|
||||
|
||||
# retrieve the citation
|
||||
print("Waiting for citation task")
|
||||
@@ -237,6 +287,13 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
|
||||
return answer
|
||||
|
||||
def extract_evidence_images(self, evidence: str):
|
||||
"""Util function to extract and isolate images from context/evidence"""
|
||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||
matches = re.findall(image_pattern, evidence)
|
||||
context = re.sub(image_pattern, "", evidence)
|
||||
return context, matches
|
||||
|
||||
|
||||
class FullQAPipeline(BaseReasoning):
|
||||
"""Question answering pipeline. Handle from question to answer"""
|
||||
|
Reference in New Issue
Block a user