diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py index ed00e5c..75f944e 100644 --- a/libs/kotaemon/kotaemon/indices/ingests/files.py +++ b/libs/kotaemon/kotaemon/indices/ingests/files.py @@ -7,6 +7,7 @@ from kotaemon.base import BaseComponent, Document, Param from kotaemon.indices.extractors import BaseDocParser from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from kotaemon.loaders import ( + AdobeReader, DirectoryReader, MathpixPDFReader, OCRReader, @@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent): The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS` """ - pdf_mode: str = "normal" # "normal", "mathpix", "ocr" + pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal" doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: []) text_splitter: BaseSplitter = TokenSplitter.withx( chunk_size=1024, @@ -61,6 +62,8 @@ class DocumentIngestor(BaseComponent): pass # use default loader of llama-index which is pypdf elif self.pdf_mode == "ocr": file_extractors[".pdf"] = OCRReader() + elif self.pdf_mode == "multimodal": + file_extractors[".pdf"] = AdobeReader() else: file_extractors[".pdf"] = MathpixPDFReader() diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py index d742b52..28cb5f3 100644 --- a/libs/kotaemon/kotaemon/loaders/__init__.py +++ b/libs/kotaemon/kotaemon/loaders/__init__.py @@ -1,3 +1,4 @@ +from .adobe_loader import AdobeReader from .base import AutoReader, BaseReader from .composite_loader import DirectoryReader from .docx_loader import DocxReader @@ -17,4 +18,5 @@ __all__ = [ "UnstructuredReader", "DocxReader", "HtmlReader", + "AdobeReader", ] diff --git a/libs/kotaemon/kotaemon/loaders/adobe_loader.py b/libs/kotaemon/kotaemon/loaders/adobe_loader.py new file mode 100644 index 0000000..dd8cbc9 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/adobe_loader.py @@ -0,0 +1,187 @@ +import logging +import os +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional + +from decouple import config +from llama_index.readers.base import BaseReader + +from kotaemon.base import Document + +from .utils.adobe import ( + generate_figure_captions, + load_json, + parse_figure_paths, + parse_table_paths, + request_adobe_service, +) + +logger = logging.getLogger(__name__) + +DEFAULT_VLM_ENDPOINT = ( + "{0}openai/deployments/{1}/chat/completions?api-version={2}".format( + config("AZURE_OPENAI_ENDPOINT", default=""), + "gpt-4-vision", + config("OPENAI_API_VERSION", default=""), + ) +) + + +class AdobeReader(BaseReader): + """Read PDF using the Adobe's PDF Services. + Be able to extract text, table, and figure with high accuracy + + Example: + ```python + >> from kotaemon.loaders import AdobeReader + >> reader = AdobeReader() + >> documents = reader.load_data("path/to/pdf") + ``` + Args: + endpoint: URL to the Vision Language Model endpoint. If not provided, + will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT` + + max_figures_to_caption: an int decides how many figured will be captioned. + The rest will be ignored (are indexed without captions). + """ + + def __init__( + self, + vlm_endpoint: Optional[str] = None, + max_figures_to_caption: int = 100, + *args: Any, + **kwargs: Any, + ) -> None: + """Init params""" + super().__init__(*args) + self.table_regex = r"/Table(\[\d+\])?$" + self.figure_regex = r"/Figure(\[\d+\])?$" + self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT + self.max_figures_to_caption = max_figures_to_caption + + def load_data( + self, file: Path, extra_info: Optional[Dict] = None, **kwargs + ) -> List[Document]: + """Load data by calling to the Adobe's API + + Args: + file (Path): Path to the PDF file + + Returns: + List[Document]: list of documents extracted from the PDF file, + includes 3 types: text, table, and image + + """ + + filename = file.name + filepath = str(Path(file).resolve()) + output_path = request_adobe_service(file_path=str(file), output_path="") + results_path = os.path.join(output_path, "structuredData.json") + + if not os.path.exists(results_path): + logger.exception("Fail to parse the document.") + return [] + + data = load_json(results_path) + + texts = defaultdict(list) + tables = [] + figures = [] + + elements = data["elements"] + for item_id, item in enumerate(elements): + page_number = item.get("Page", -1) + 1 + item_path = item["Path"] + item_text = item.get("Text", "") + + file_paths = [ + Path(output_path) / path for path in item.get("filePaths", []) + ] + prev_item = elements[item_id - 1] + title = prev_item.get("Text", "") + + if re.search(self.table_regex, item_path): + table_content = parse_table_paths(file_paths) + if not table_content: + continue + table_caption = ( + table_content.replace("|", "").replace("---", "") + + f"\n(Table in Page {page_number}. {title})" + ) + tables.append((page_number, table_content, table_caption)) + + elif re.search(self.figure_regex, item_path): + figure_caption = ( + item_text + f"\n(Figure in Page {page_number}. {title})" + ) + figure_content = parse_figure_paths(file_paths) + if not figure_content: + continue + figures.append([page_number, figure_content, figure_caption]) + + else: + if item_text and "Table" not in item_path and "Figure" not in item_path: + texts[page_number].append(item_text) + + # get figure caption using GPT-4V + figure_captions = generate_figure_captions( + self.vlm_endpoint, + [item[1] for item in figures], + self.max_figures_to_caption, + ) + for item, caption in zip(figures, figure_captions): + # update figure caption + item[2] += " " + caption + + # Wrap elements with Document + documents = [] + + # join plain text elements + for page_number, txts in texts.items(): + documents.append( + Document( + text="\n".join(txts), + metadata={ + "page_label": page_number, + "file_name": filename, + "file_path": filepath, + }, + ) + ) + + # table elements + for page_number, table_content, table_caption in tables: + documents.append( + Document( + text=table_caption, + metadata={ + "table_origin": table_content, + "type": "table", + "page_label": page_number, + "file_name": filename, + "file_path": filepath, + }, + metadata_template="", + metadata_seperator="", + ) + ) + + # figure elements + for page_number, figure_content, figure_caption in figures: + documents.append( + Document( + text=figure_caption, + metadata={ + "image_origin": figure_content, + "type": "image", + "page_label": page_number, + "file_name": filename, + "file_path": filepath, + }, + metadata_template="", + metadata_seperator="", + ) + ) + return documents diff --git a/libs/kotaemon/kotaemon/loaders/utils/adobe.py b/libs/kotaemon/kotaemon/loaders/utils/adobe.py new file mode 100644 index 0000000..a780c45 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/utils/adobe.py @@ -0,0 +1,248 @@ +# need pip install pdfservices-sdk==2.3.0 + +import base64 +import json +import logging +import os +import tempfile +import zipfile +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List, Union + +import pandas as pd +from decouple import config + +from kotaemon.loaders.utils.gpt4v import generate_gpt4v + +logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) + + +def request_adobe_service(file_path: str, output_path: str = "") -> str: + """Main function to call the adobe service, and unzip the results. + Args: + file_path (str): path to the pdf file + output_path (str): path to store the results + + Returns: + output_path (str): path to the results + + """ + try: + from adobe.pdfservices.operation.auth.credentials import Credentials + from adobe.pdfservices.operation.exception.exceptions import ( + SdkException, + ServiceApiException, + ServiceUsageException, + ) + from adobe.pdfservices.operation.execution_context import ExecutionContext + from adobe.pdfservices.operation.io.file_ref import FileRef + from adobe.pdfservices.operation.pdfops.extract_pdf_operation import ( + ExtractPDFOperation, + ) + from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501 + ExtractElementType, + ) + from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501 + ExtractPDFOptions, + ) + from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501 + ExtractRenditionsElementType, + ) + except ImportError: + raise ImportError( + "pdfservices-sdk is not installed. " + "Please install it by running `pip install pdfservices-sdk" + "@git+https://github.com/niallcm/pdfservices-python-sdk.git" + "@bump-and-unfreeze-requirements`" + ) + + if not output_path: + output_path = tempfile.mkdtemp() + + try: + # Initial setup, create credentials instance. + credentials = ( + Credentials.service_principal_credentials_builder() + .with_client_id(config("PDF_SERVICES_CLIENT_ID", default="")) + .with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default="")) + .build() + ) + + # Create an ExecutionContext using credentials + # and create a new operation instance. + execution_context = ExecutionContext.create(credentials) + extract_pdf_operation = ExtractPDFOperation.create_new() + + # Set operation input from a source file. + source = FileRef.create_from_local_file(file_path) + extract_pdf_operation.set_input(source) + + # Build ExtractPDF options and set them into the operation + extract_pdf_options: ExtractPDFOptions = ( + ExtractPDFOptions.builder() + .with_elements_to_extract( + [ExtractElementType.TEXT, ExtractElementType.TABLES] + ) + .with_elements_to_extract_renditions( + [ + ExtractRenditionsElementType.TABLES, + ExtractRenditionsElementType.FIGURES, + ] + ) + .build() + ) + extract_pdf_operation.set_options(extract_pdf_options) + + # Execute the operation. + result: FileRef = extract_pdf_operation.execute(execution_context) + + # Save the result to the specified location. + zip_file_path = os.path.join( + output_path, "ExtractTextTableWithFigureTableRendition.zip" + ) + result.save_as(zip_file_path) + # Open the ZIP file + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + # Extract all contents to the destination folder + zip_ref.extractall(output_path) + except (ServiceApiException, ServiceUsageException, SdkException): + logging.exception("Exception encountered while executing operation") + + return output_path + + +def make_markdown_table(table_as_list: List[str]) -> str: + """ + Convert table from python list representation to markdown format. + The input list consists of rows of tables, the first row is the header. + + Args: + table_as_list: list of table rows + Example: [["Name", "Age", "Height"], + ["Jake", 20, 5'10], + ["Mary", 21, 5'7]] + Returns: + markdown representation of the table + """ + markdown = "\n" + str("| ") + + for e in table_as_list[0]: + to_add = " " + str(e) + str(" |") + markdown += to_add + markdown += "\n" + + markdown += "| " + for i in range(len(table_as_list[0])): + markdown += str("--- | ") + markdown += "\n" + + for entry in table_as_list[1:]: + markdown += str("| ") + for e in entry: + to_add = str(e) + str(" | ") + markdown += to_add + markdown += "\n" + + return markdown + "\n" + + +def load_json(input_path: Union[str | Path]) -> dict: + """Load json file""" + with open(input_path, "r") as fi: + data = json.load(fi) + + return data + + +def load_excel(input_path: Union[str | Path]) -> str: + """Load excel file and convert to markdown""" + + df = pd.read_excel(input_path).fillna("") + # Convert dataframe to a list of rows + row_list = [df.columns.values.tolist()] + df.values.tolist() + + for item_id, item in enumerate(row_list[0]): + if "Unnamed" in item: + row_list[0][item_id] = "" + + for row in row_list: + for item_id, item in enumerate(row): + row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip() + + markdown_str = make_markdown_table(row_list) + return markdown_str + + +def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]: + """Convert image to base64""" + + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def parse_table_paths(file_paths: List[Path]) -> str: + """Read the table stored in an excel file given the file path""" + + content = "" + for path in file_paths: + if path.suffix == ".xlsx": + content = load_excel(path) + break + return content + + +def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]: + """Read and convert an image to base64 given the image path""" + + content = "" + for path in file_paths: + if path.suffix == ".png": + base64_image = encode_image_base64(path) + content = f"data:image/png;base64,{base64_image}" # type: ignore + break + return content + + +def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str: + """Summarize a single figure using GPT-4V""" + if figure: + output = generate_gpt4v( + endpoint=vlm_endpoint, + prompt="Provide a short 2 sentence summary of this image?", + images=figure, + ) + if "sorry" in output.lower(): + output = "" + else: + output = "" + return output + + +def generate_figure_captions( + vlm_endpoint: str, figures: List, max_figures_to_process: int +) -> List: + """Summarize several figures using GPT-4V. + Args: + vlm_endpoint (str): endpoint to the vision language model service + figures (List): list of base64 images + max_figures_to_process (int): the maximum number of figures will be summarized, + the rest are ignored. + + Returns: + results (List[str]): list of all figure captions and empty strings for + ignored figures. + """ + to_gen_figures = figures[:max_figures_to_process] + other_figures = figures[max_figures_to_process:] + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + lambda: generate_single_figure_caption(vlm_endpoint, figure) + ) + for figure in to_gen_figures + ] + + results = [future.result() for future in futures] + return results + [""] * len(other_figures) diff --git a/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py new file mode 100644 index 0000000..1e219d6 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py @@ -0,0 +1,96 @@ +import json +from typing import Any, List + +import requests +from decouple import config + + +def generate_gpt4v( + endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512 +) -> str: + # OpenAI API Key + api_key = config("AZURE_OPENAI_API_KEY", default="") + headers = {"Content-Type": "application/json", "api-key": api_key} + + if isinstance(images, str): + images = [images] + + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ] + + [ + { + "type": "image_url", + "image_url": {"url": image}, + } + for image in images + ], + } + ], + "max_tokens": max_tokens, + } + + try: + response = requests.post(endpoint, headers=headers, json=payload) + output = response.json() + output = output["choices"][0]["message"]["content"] + except Exception: + output = "" + return output + + +def stream_gpt4v( + endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512 +) -> Any: + # OpenAI API Key + api_key = config("AZURE_OPENAI_API_KEY", default="") + headers = {"Content-Type": "application/json", "api-key": api_key} + + if isinstance(images, str): + images = [images] + + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ] + + [ + { + "type": "image_url", + "image_url": {"url": image}, + } + for image in images + ], + } + ], + "max_tokens": max_tokens, + "stream": True, + } + try: + response = requests.post(endpoint, headers=headers, json=payload, stream=True) + assert response.status_code == 200, str(response.content) + output = "" + for line in response.iter_lines(): + if line: + if line.startswith(b"\xef\xbb\xbf"): + line = line[9:] + else: + line = line[6:] + try: + if line == "[DONE]": + break + line = json.loads(line.decode("utf-8")) + except Exception: + break + if len(line["choices"]): + output += line["choices"][0]["delta"].get("content", "") + yield line["choices"][0]["delta"].get("content", "") + except Exception: + output = "" + return output diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index e1e3028..73c3e8a 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -60,6 +60,7 @@ adv = [ "cohere", "elasticsearch", "llama-cpp-python", + "pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements", ] dev = [ "ipython", @@ -69,6 +70,7 @@ dev = [ "flake8", "sphinx", "coverage", + "python-decouple" ] all = ["kotaemon[adv,dev]"] diff --git a/libs/kotaemon/tests/_test_multimodal_reader.py b/libs/kotaemon/tests/_test_multimodal_reader.py new file mode 100644 index 0000000..b07786f --- /dev/null +++ b/libs/kotaemon/tests/_test_multimodal_reader.py @@ -0,0 +1,21 @@ +# TODO: This test is broken and should be rewritten +from pathlib import Path + +from kotaemon.loaders import AdobeReader + +# from dotenv import load_dotenv + + +input_file = Path(__file__).parent / "resources" / "multimodal.pdf" + +# load_dotenv() + + +def test_adobe_reader(): + reader = AdobeReader() + documents = reader.load_data(input_file) + table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"] + assert len(table_docs) == 2 + + figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"] + assert len(figure_docs) == 2 diff --git a/libs/kotaemon/tests/resources/multimodal.pdf b/libs/kotaemon/tests/resources/multimodal.pdf new file mode 100644 index 0000000..29c2bdc Binary files /dev/null and b/libs/kotaemon/tests/resources/multimodal.pdf differ diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py index a3589fe..33ba88f 100644 --- a/libs/ktem/flowsettings.py +++ b/libs/ktem/flowsettings.py @@ -124,6 +124,11 @@ if config("LOCAL_MODEL", default=""): KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"] +KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format( + config("AZURE_OPENAI_ENDPOINT", default=""), + config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"), + config("OPENAI_API_VERSION", default=""), +) SETTINGS_APP = { diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 1d813f5..68b3a4d 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -378,6 +378,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): ("PDF text parser", "normal"), ("Mathpix", "mathpix"), ("Advanced ocr", "ocr"), + ("Multimodal parser", "multimodal"), ], "component": "dropdown", }, diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 47f7c92..1627522 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -1,11 +1,14 @@ import asyncio +import html import logging +import re from collections import defaultdict from functools import partial import tiktoken from ktem.components import llms from ktem.reasoning.base import BaseReasoning +from theflow.settings import settings as flowsettings from kotaemon.base import ( BaseComponent, @@ -18,9 +21,15 @@ from kotaemon.base import ( from kotaemon.indices.qa.citation import CitationPipeline from kotaemon.indices.splitters import TokenSplitter from kotaemon.llms import ChatLLM, PromptTemplate +from kotaemon.loaders.utils.gpt4v import stream_gpt4v logger = logging.getLogger(__name__) +EVIDENCE_MODE_TEXT = 0 +EVIDENCE_MODE_TABLE = 1 +EVIDENCE_MODE_CHATBOT = 2 +EVIDENCE_MODE_FIGURE = 3 + class PrepareEvidencePipeline(BaseComponent): """Prepare the evidence text from the list of retrieved documents @@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent): def run(self, docs: list[RetrievedDocument]) -> Document: evidence = "" table_found = 0 - evidence_mode = 0 + evidence_mode = EVIDENCE_MODE_TEXT for _id, retrieved_item in enumerate(docs): retrieved_content = "" @@ -55,7 +64,7 @@ class PrepareEvidencePipeline(BaseComponent): if page: source += f" (Page {page})" if retrieved_item.metadata.get("type", "") == "table": - evidence_mode = 1 # table + evidence_mode = EVIDENCE_MODE_TABLE if table_found < 5: retrieved_content = retrieved_item.metadata.get("table_origin", "") if retrieved_content not in evidence: @@ -66,13 +75,23 @@ class PrepareEvidencePipeline(BaseComponent): + "\n
" ) elif retrieved_item.metadata.get("type", "") == "chatbot": - evidence_mode = 2 # chatbot + evidence_mode = EVIDENCE_MODE_CHATBOT retrieved_content = retrieved_item.metadata["window"] evidence += ( f"
Chatbot scenario from {filename} (Row {page})\n" + retrieved_content + "\n
" ) + elif retrieved_item.metadata.get("type", "") == "image": + evidence_mode = EVIDENCE_MODE_FIGURE + retrieved_content = retrieved_item.metadata.get("image_origin", "") + retrieved_caption = html.escape(retrieved_item.get_content()) + evidence += ( + f"
Figure from {source}\n" + + f"" + + "\n
" + ) else: if "window" in retrieved_item.metadata: retrieved_content = retrieved_item.metadata["window"] @@ -90,12 +109,13 @@ class PrepareEvidencePipeline(BaseComponent): print(retrieved_item.metadata) print("Score", retrieved_item.metadata.get("relevance_score", None)) - # trim context by trim_len - print("len (original)", len(evidence)) - if evidence: - texts = self.trim_func([Document(text=evidence)]) - evidence = texts[0].text - print("len (trimmed)", len(evidence)) + if evidence_mode != EVIDENCE_MODE_FIGURE: + # trim context by trim_len + print("len (original)", len(evidence)) + if evidence: + texts = self.trim_func([Document(text=evidence)]) + evidence = texts[0].text + print("len (trimmed)", len(evidence)) print(f"PrepareEvidence with input {docs}\nOutput: {evidence}\n") @@ -134,6 +154,16 @@ DEFAULT_QA_CHATBOT_PROMPT = ( "Answer:" ) +DEFAULT_QA_FIGURE_PROMPT = ( + "Use the given context: texts, tables, and figures below to answer the question. " + "If you don't know the answer, just say that you don't know. " + "Give answer in {lang}.\n\n" + "Context: \n" + "{context}\n" + "Question: {question}\n" + "Answer: " +) + class AnswerWithContextPipeline(BaseComponent): """Answer the question based on the evidence @@ -151,6 +181,7 @@ class AnswerWithContextPipeline(BaseComponent): """ llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy()) + vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT citation_pipeline: CitationPipeline = Node( default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost()) ) @@ -158,6 +189,7 @@ class AnswerWithContextPipeline(BaseComponent): qa_template: str = DEFAULT_QA_TEXT_PROMPT qa_table_template: str = DEFAULT_QA_TABLE_PROMPT qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT + qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT enable_citation: bool = False system_prompt: str = "" @@ -188,18 +220,30 @@ class AnswerWithContextPipeline(BaseComponent): (determined by retrieval pipeline) evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot """ - if evidence_mode == 0: + if evidence_mode == EVIDENCE_MODE_TEXT: prompt_template = PromptTemplate(self.qa_template) - elif evidence_mode == 1: + elif evidence_mode == EVIDENCE_MODE_TABLE: prompt_template = PromptTemplate(self.qa_table_template) + elif evidence_mode == EVIDENCE_MODE_FIGURE: + prompt_template = PromptTemplate(self.qa_figure_template) else: prompt_template = PromptTemplate(self.qa_chatbot_template) - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) + images = [] + if evidence_mode == EVIDENCE_MODE_FIGURE: + # isolate image from evidence + evidence, images = self.extract_evidence_images(evidence) + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + else: + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) citation_task = None if evidence and self.enable_citation: @@ -208,23 +252,29 @@ class AnswerWithContextPipeline(BaseComponent): ) print("Citation task created") - messages = [] - if self.system_prompt: - messages.append(SystemMessage(content=self.system_prompt)) - messages.append(HumanMessage(content=prompt)) - output = "" - try: - # try streaming first - print("Trying LLM streaming") - for text in self.llm.stream(messages): - output += text.text - self.report_output({"output": text.text}) + if evidence_mode == EVIDENCE_MODE_FIGURE: + for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): + output += text + self.report_output({"output": text}) await asyncio.sleep(0) - except NotImplementedError: - print("Streaming is not supported, falling back to normal processing") - output = self.llm(messages).text - self.report_output({"output": output}) + else: + messages = [] + if self.system_prompt: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(HumanMessage(content=prompt)) + + try: + # try streaming first + print("Trying LLM streaming") + for text in self.llm.stream(messages): + output += text.text + self.report_output({"output": text.text}) + await asyncio.sleep(0) + except NotImplementedError: + print("Streaming is not supported, falling back to normal processing") + output = self.llm(messages).text + self.report_output({"output": output}) # retrieve the citation print("Waiting for citation task") @@ -237,6 +287,13 @@ class AnswerWithContextPipeline(BaseComponent): return answer + def extract_evidence_images(self, evidence: str): + """Util function to extract and isolate images from context/evidence""" + image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" + matches = re.findall(image_pattern, evidence) + context = re.sub(image_pattern, "", evidence) + return context, matches + class FullQAPipeline(BaseReasoning): """Question answering pipeline. Handle from question to answer"""