Feat/add multimodal loader (#5)

* Add Adobe reader as the multimodal loader

* Allow FullQAPipeline to reasoning on figures

* fix: move the adobe import to avoid ImportError, notify users whenever they run the AdobeReader

---------

Co-authored-by: cin-albert <albert@cinnamon.is>
This commit is contained in:
ian_Cin 2024-04-03 14:52:40 +07:00 committed by GitHub
parent a3bf728400
commit e67a25c0bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 654 additions and 32 deletions

View File

@ -7,6 +7,7 @@ from kotaemon.base import BaseComponent, Document, Param
from kotaemon.indices.extractors import BaseDocParser
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from kotaemon.loaders import (
AdobeReader,
DirectoryReader,
MathpixPDFReader,
OCRReader,
@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent):
The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS`
"""
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal"
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
text_splitter: BaseSplitter = TokenSplitter.withx(
chunk_size=1024,
@ -61,6 +62,8 @@ class DocumentIngestor(BaseComponent):
pass # use default loader of llama-index which is pypdf
elif self.pdf_mode == "ocr":
file_extractors[".pdf"] = OCRReader()
elif self.pdf_mode == "multimodal":
file_extractors[".pdf"] = AdobeReader()
else:
file_extractors[".pdf"] = MathpixPDFReader()

View File

@ -1,3 +1,4 @@
from .adobe_loader import AdobeReader
from .base import AutoReader, BaseReader
from .composite_loader import DirectoryReader
from .docx_loader import DocxReader
@ -17,4 +18,5 @@ __all__ = [
"UnstructuredReader",
"DocxReader",
"HtmlReader",
"AdobeReader",
]

View File

@ -0,0 +1,187 @@
import logging
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
from decouple import config
from llama_index.readers.base import BaseReader
from kotaemon.base import Document
from .utils.adobe import (
generate_figure_captions,
load_json,
parse_figure_paths,
parse_table_paths,
request_adobe_service,
)
logger = logging.getLogger(__name__)
DEFAULT_VLM_ENDPOINT = (
"{0}openai/deployments/{1}/chat/completions?api-version={2}".format(
config("AZURE_OPENAI_ENDPOINT", default=""),
"gpt-4-vision",
config("OPENAI_API_VERSION", default=""),
)
)
class AdobeReader(BaseReader):
"""Read PDF using the Adobe's PDF Services.
Be able to extract text, table, and figure with high accuracy
Example:
```python
>> from kotaemon.loaders import AdobeReader
>> reader = AdobeReader()
>> documents = reader.load_data("path/to/pdf")
```
Args:
endpoint: URL to the Vision Language Model endpoint. If not provided,
will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT`
max_figures_to_caption: an int decides how many figured will be captioned.
The rest will be ignored (are indexed without captions).
"""
def __init__(
self,
vlm_endpoint: Optional[str] = None,
max_figures_to_caption: int = 100,
*args: Any,
**kwargs: Any,
) -> None:
"""Init params"""
super().__init__(*args)
self.table_regex = r"/Table(\[\d+\])?$"
self.figure_regex = r"/Figure(\[\d+\])?$"
self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT
self.max_figures_to_caption = max_figures_to_caption
def load_data(
self, file: Path, extra_info: Optional[Dict] = None, **kwargs
) -> List[Document]:
"""Load data by calling to the Adobe's API
Args:
file (Path): Path to the PDF file
Returns:
List[Document]: list of documents extracted from the PDF file,
includes 3 types: text, table, and image
"""
filename = file.name
filepath = str(Path(file).resolve())
output_path = request_adobe_service(file_path=str(file), output_path="")
results_path = os.path.join(output_path, "structuredData.json")
if not os.path.exists(results_path):
logger.exception("Fail to parse the document.")
return []
data = load_json(results_path)
texts = defaultdict(list)
tables = []
figures = []
elements = data["elements"]
for item_id, item in enumerate(elements):
page_number = item.get("Page", -1) + 1
item_path = item["Path"]
item_text = item.get("Text", "")
file_paths = [
Path(output_path) / path for path in item.get("filePaths", [])
]
prev_item = elements[item_id - 1]
title = prev_item.get("Text", "")
if re.search(self.table_regex, item_path):
table_content = parse_table_paths(file_paths)
if not table_content:
continue
table_caption = (
table_content.replace("|", "").replace("---", "")
+ f"\n(Table in Page {page_number}. {title})"
)
tables.append((page_number, table_content, table_caption))
elif re.search(self.figure_regex, item_path):
figure_caption = (
item_text + f"\n(Figure in Page {page_number}. {title})"
)
figure_content = parse_figure_paths(file_paths)
if not figure_content:
continue
figures.append([page_number, figure_content, figure_caption])
else:
if item_text and "Table" not in item_path and "Figure" not in item_path:
texts[page_number].append(item_text)
# get figure caption using GPT-4V
figure_captions = generate_figure_captions(
self.vlm_endpoint,
[item[1] for item in figures],
self.max_figures_to_caption,
)
for item, caption in zip(figures, figure_captions):
# update figure caption
item[2] += " " + caption
# Wrap elements with Document
documents = []
# join plain text elements
for page_number, txts in texts.items():
documents.append(
Document(
text="\n".join(txts),
metadata={
"page_label": page_number,
"file_name": filename,
"file_path": filepath,
},
)
)
# table elements
for page_number, table_content, table_caption in tables:
documents.append(
Document(
text=table_caption,
metadata={
"table_origin": table_content,
"type": "table",
"page_label": page_number,
"file_name": filename,
"file_path": filepath,
},
metadata_template="",
metadata_seperator="",
)
)
# figure elements
for page_number, figure_content, figure_caption in figures:
documents.append(
Document(
text=figure_caption,
metadata={
"image_origin": figure_content,
"type": "image",
"page_label": page_number,
"file_name": filename,
"file_path": filepath,
},
metadata_template="",
metadata_seperator="",
)
)
return documents

View File

@ -0,0 +1,248 @@
# need pip install pdfservices-sdk==2.3.0
import base64
import json
import logging
import os
import tempfile
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List, Union
import pandas as pd
from decouple import config
from kotaemon.loaders.utils.gpt4v import generate_gpt4v
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
def request_adobe_service(file_path: str, output_path: str = "") -> str:
"""Main function to call the adobe service, and unzip the results.
Args:
file_path (str): path to the pdf file
output_path (str): path to store the results
Returns:
output_path (str): path to the results
"""
try:
from adobe.pdfservices.operation.auth.credentials import Credentials
from adobe.pdfservices.operation.exception.exceptions import (
SdkException,
ServiceApiException,
ServiceUsageException,
)
from adobe.pdfservices.operation.execution_context import ExecutionContext
from adobe.pdfservices.operation.io.file_ref import FileRef
from adobe.pdfservices.operation.pdfops.extract_pdf_operation import (
ExtractPDFOperation,
)
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501
ExtractElementType,
)
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501
ExtractPDFOptions,
)
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501
ExtractRenditionsElementType,
)
except ImportError:
raise ImportError(
"pdfservices-sdk is not installed. "
"Please install it by running `pip install pdfservices-sdk"
"@git+https://github.com/niallcm/pdfservices-python-sdk.git"
"@bump-and-unfreeze-requirements`"
)
if not output_path:
output_path = tempfile.mkdtemp()
try:
# Initial setup, create credentials instance.
credentials = (
Credentials.service_principal_credentials_builder()
.with_client_id(config("PDF_SERVICES_CLIENT_ID", default=""))
.with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default=""))
.build()
)
# Create an ExecutionContext using credentials
# and create a new operation instance.
execution_context = ExecutionContext.create(credentials)
extract_pdf_operation = ExtractPDFOperation.create_new()
# Set operation input from a source file.
source = FileRef.create_from_local_file(file_path)
extract_pdf_operation.set_input(source)
# Build ExtractPDF options and set them into the operation
extract_pdf_options: ExtractPDFOptions = (
ExtractPDFOptions.builder()
.with_elements_to_extract(
[ExtractElementType.TEXT, ExtractElementType.TABLES]
)
.with_elements_to_extract_renditions(
[
ExtractRenditionsElementType.TABLES,
ExtractRenditionsElementType.FIGURES,
]
)
.build()
)
extract_pdf_operation.set_options(extract_pdf_options)
# Execute the operation.
result: FileRef = extract_pdf_operation.execute(execution_context)
# Save the result to the specified location.
zip_file_path = os.path.join(
output_path, "ExtractTextTableWithFigureTableRendition.zip"
)
result.save_as(zip_file_path)
# Open the ZIP file
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
# Extract all contents to the destination folder
zip_ref.extractall(output_path)
except (ServiceApiException, ServiceUsageException, SdkException):
logging.exception("Exception encountered while executing operation")
return output_path
def make_markdown_table(table_as_list: List[str]) -> str:
"""
Convert table from python list representation to markdown format.
The input list consists of rows of tables, the first row is the header.
Args:
table_as_list: list of table rows
Example: [["Name", "Age", "Height"],
["Jake", 20, 5'10],
["Mary", 21, 5'7]]
Returns:
markdown representation of the table
"""
markdown = "\n" + str("| ")
for e in table_as_list[0]:
to_add = " " + str(e) + str(" |")
markdown += to_add
markdown += "\n"
markdown += "| "
for i in range(len(table_as_list[0])):
markdown += str("--- | ")
markdown += "\n"
for entry in table_as_list[1:]:
markdown += str("| ")
for e in entry:
to_add = str(e) + str(" | ")
markdown += to_add
markdown += "\n"
return markdown + "\n"
def load_json(input_path: Union[str | Path]) -> dict:
"""Load json file"""
with open(input_path, "r") as fi:
data = json.load(fi)
return data
def load_excel(input_path: Union[str | Path]) -> str:
"""Load excel file and convert to markdown"""
df = pd.read_excel(input_path).fillna("")
# Convert dataframe to a list of rows
row_list = [df.columns.values.tolist()] + df.values.tolist()
for item_id, item in enumerate(row_list[0]):
if "Unnamed" in item:
row_list[0][item_id] = ""
for row in row_list:
for item_id, item in enumerate(row):
row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip()
markdown_str = make_markdown_table(row_list)
return markdown_str
def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]:
"""Convert image to base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def parse_table_paths(file_paths: List[Path]) -> str:
"""Read the table stored in an excel file given the file path"""
content = ""
for path in file_paths:
if path.suffix == ".xlsx":
content = load_excel(path)
break
return content
def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]:
"""Read and convert an image to base64 given the image path"""
content = ""
for path in file_paths:
if path.suffix == ".png":
base64_image = encode_image_base64(path)
content = f"data:image/png;base64,{base64_image}" # type: ignore
break
return content
def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str:
"""Summarize a single figure using GPT-4V"""
if figure:
output = generate_gpt4v(
endpoint=vlm_endpoint,
prompt="Provide a short 2 sentence summary of this image?",
images=figure,
)
if "sorry" in output.lower():
output = ""
else:
output = ""
return output
def generate_figure_captions(
vlm_endpoint: str, figures: List, max_figures_to_process: int
) -> List:
"""Summarize several figures using GPT-4V.
Args:
vlm_endpoint (str): endpoint to the vision language model service
figures (List): list of base64 images
max_figures_to_process (int): the maximum number of figures will be summarized,
the rest are ignored.
Returns:
results (List[str]): list of all figure captions and empty strings for
ignored figures.
"""
to_gen_figures = figures[:max_figures_to_process]
other_figures = figures[max_figures_to_process:]
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
lambda: generate_single_figure_caption(vlm_endpoint, figure)
)
for figure in to_gen_figures
]
results = [future.result() for future in futures]
return results + [""] * len(other_figures)

View File

@ -0,0 +1,96 @@
import json
from typing import Any, List
import requests
from decouple import config
def generate_gpt4v(
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
) -> str:
# OpenAI API Key
api_key = config("AZURE_OPENAI_API_KEY", default="")
headers = {"Content-Type": "application/json", "api-key": api_key}
if isinstance(images, str):
images = [images]
payload = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
]
+ [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images
],
}
],
"max_tokens": max_tokens,
}
try:
response = requests.post(endpoint, headers=headers, json=payload)
output = response.json()
output = output["choices"][0]["message"]["content"]
except Exception:
output = ""
return output
def stream_gpt4v(
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
) -> Any:
# OpenAI API Key
api_key = config("AZURE_OPENAI_API_KEY", default="")
headers = {"Content-Type": "application/json", "api-key": api_key}
if isinstance(images, str):
images = [images]
payload = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
]
+ [
{
"type": "image_url",
"image_url": {"url": image},
}
for image in images
],
}
],
"max_tokens": max_tokens,
"stream": True,
}
try:
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
assert response.status_code == 200, str(response.content)
output = ""
for line in response.iter_lines():
if line:
if line.startswith(b"\xef\xbb\xbf"):
line = line[9:]
else:
line = line[6:]
try:
if line == "[DONE]":
break
line = json.loads(line.decode("utf-8"))
except Exception:
break
if len(line["choices"]):
output += line["choices"][0]["delta"].get("content", "")
yield line["choices"][0]["delta"].get("content", "")
except Exception:
output = ""
return output

View File

@ -60,6 +60,7 @@ adv = [
"cohere",
"elasticsearch",
"llama-cpp-python",
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
]
dev = [
"ipython",
@ -69,6 +70,7 @@ dev = [
"flake8",
"sphinx",
"coverage",
"python-decouple"
]
all = ["kotaemon[adv,dev]"]

View File

@ -0,0 +1,21 @@
# TODO: This test is broken and should be rewritten
from pathlib import Path
from kotaemon.loaders import AdobeReader
# from dotenv import load_dotenv
input_file = Path(__file__).parent / "resources" / "multimodal.pdf"
# load_dotenv()
def test_adobe_reader():
reader = AdobeReader()
documents = reader.load_data(input_file)
table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
assert len(table_docs) == 2
figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"]
assert len(figure_docs) == 2

Binary file not shown.

View File

@ -124,6 +124,11 @@ if config("LOCAL_MODEL", default=""):
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
config("AZURE_OPENAI_ENDPOINT", default=""),
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
config("OPENAI_API_VERSION", default=""),
)
SETTINGS_APP = {

View File

@ -378,6 +378,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
("PDF text parser", "normal"),
("Mathpix", "mathpix"),
("Advanced ocr", "ocr"),
("Multimodal parser", "multimodal"),
],
"component": "dropdown",
},

View File

@ -1,11 +1,14 @@
import asyncio
import html
import logging
import re
from collections import defaultdict
from functools import partial
import tiktoken
from ktem.components import llms
from ktem.reasoning.base import BaseReasoning
from theflow.settings import settings as flowsettings
from kotaemon.base import (
BaseComponent,
@ -18,9 +21,15 @@ from kotaemon.base import (
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
logger = logging.getLogger(__name__)
EVIDENCE_MODE_TEXT = 0
EVIDENCE_MODE_TABLE = 1
EVIDENCE_MODE_CHATBOT = 2
EVIDENCE_MODE_FIGURE = 3
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent):
def run(self, docs: list[RetrievedDocument]) -> Document:
evidence = ""
table_found = 0
evidence_mode = 0
evidence_mode = EVIDENCE_MODE_TEXT
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
@ -55,7 +64,7 @@ class PrepareEvidencePipeline(BaseComponent):
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
evidence_mode = 1 # table
evidence_mode = EVIDENCE_MODE_TABLE
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
@ -66,13 +75,23 @@ class PrepareEvidencePipeline(BaseComponent):
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
evidence_mode = 2 # chatbot
evidence_mode = EVIDENCE_MODE_CHATBOT
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
+ retrieved_content
+ "\n<br>"
)
elif retrieved_item.metadata.get("type", "") == "image":
evidence_mode = EVIDENCE_MODE_FIGURE
retrieved_content = retrieved_item.metadata.get("image_origin", "")
retrieved_caption = html.escape(retrieved_item.get_content())
evidence += (
f"<br><b>Figure from {source}</b>\n"
+ f"<img width='85%' src='{retrieved_content}' "
+ f"alt='{retrieved_caption}'/>"
+ "\n<br>"
)
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
@ -90,12 +109,13 @@ class PrepareEvidencePipeline(BaseComponent):
print(retrieved_item.metadata)
print("Score", retrieved_item.metadata.get("relevance_score", None))
# trim context by trim_len
print("len (original)", len(evidence))
if evidence:
texts = self.trim_func([Document(text=evidence)])
evidence = texts[0].text
print("len (trimmed)", len(evidence))
if evidence_mode != EVIDENCE_MODE_FIGURE:
# trim context by trim_len
print("len (original)", len(evidence))
if evidence:
texts = self.trim_func([Document(text=evidence)])
evidence = texts[0].text
print("len (trimmed)", len(evidence))
print(f"PrepareEvidence with input {docs}\nOutput: {evidence}\n")
@ -134,6 +154,16 @@ DEFAULT_QA_CHATBOT_PROMPT = (
"Answer:"
)
DEFAULT_QA_FIGURE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Give answer in {lang}.\n\n"
"Context: \n"
"{context}\n"
"Question: {question}\n"
"Answer: "
)
class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
@ -151,6 +181,7 @@ class AnswerWithContextPipeline(BaseComponent):
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
)
@ -158,6 +189,7 @@ class AnswerWithContextPipeline(BaseComponent):
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
enable_citation: bool = False
system_prompt: str = ""
@ -188,18 +220,30 @@ class AnswerWithContextPipeline(BaseComponent):
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == 0:
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == 1:
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
citation_task = None
if evidence and self.enable_citation:
@ -208,23 +252,29 @@ class AnswerWithContextPipeline(BaseComponent):
)
print("Citation task created")
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
output = ""
try:
# try streaming first
print("Trying LLM streaming")
for text in self.llm.stream(messages):
output += text.text
self.report_output({"output": text.text})
if evidence_mode == EVIDENCE_MODE_FIGURE:
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
output += text
self.report_output({"output": text})
await asyncio.sleep(0)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
self.report_output({"output": output})
else:
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
try:
# try streaming first
print("Trying LLM streaming")
for text in self.llm.stream(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
self.report_output({"output": output})
# retrieve the citation
print("Waiting for citation task")
@ -237,6 +287,13 @@ class AnswerWithContextPipeline(BaseComponent):
return answer
def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
matches = re.findall(image_pattern, evidence)
context = re.sub(image_pattern, "", evidence)
return context, matches
class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer"""