Refactor the index component and update the MVP insurance accordingly (#90)
Refactor the `kotaemon/pipelines` module to `kotaemon/indices`. Create the VectorIndex. Note: currently I place `qa` to be inside `kotaemon/indices` since at the moment we only have `qa` in RAG. At the same time, I think `qa` can be an independent module in `kotaemon/qa`. Since this can be changed later, I still go at the 1st option for now to observe if we can change it later.
This commit is contained in:
committed by
GitHub
parent
8e3a1d193f
commit
e34b1e4c6d
7
knowledgehub/indices/qa/__init__.py
Normal file
7
knowledgehub/indices/qa/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .citation import CitationPipeline
|
||||
from .text_based import CitationQAPipeline
|
||||
|
||||
__all__ = [
|
||||
"CitationPipeline",
|
||||
"CitationQAPipeline",
|
||||
]
|
106
knowledgehub/indices/qa/citation.py
Normal file
106
knowledgehub/indices/qa/citation.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||
from kotaemon.llms import BaseLLM
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing a single statement.
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CitationPipeline(BaseComponent):
|
||||
"""Citation pipeline to extract cited evidences from source
|
||||
(based on input question)"""
|
||||
|
||||
llm: BaseLLM
|
||||
|
||||
def run(
|
||||
self,
|
||||
context: str,
|
||||
question: str,
|
||||
) -> QuestionAnswer:
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
llm_kwargs = {
|
||||
"functions": [function],
|
||||
"function_call": {"name": function["name"]},
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a world class algorithm to answer "
|
||||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessage(content=context),
|
||||
HumanMessage(content=f"Question: {question}"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Tips: Make sure to cite your sources, "
|
||||
"and use the exact words from the context."
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
llm_output = self.llm(messages, **llm_kwargs)
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
return output
|
62
knowledgehub/indices/qa/text_based.py
Normal file
62
knowledgehub/indices/qa/text_based.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
|
||||
|
||||
class CitationQAPipeline(BaseComponent):
|
||||
"""Answering question from a text corpus with citation"""
|
||||
|
||||
qa_prompt_template: PromptTemplate = PromptTemplate(
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-07-01-preview",
|
||||
deployment_name="dummy-q2-16k",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
def _format_doc_text(self, text: str) -> str:
|
||||
"""Format the text of each document"""
|
||||
return text.replace("\n", " ")
|
||||
|
||||
def _format_retrieved_context(self, documents: list[RetrievedDocument]) -> str:
|
||||
"""Format the texts between all documents"""
|
||||
matched_texts: list[str] = [
|
||||
self._format_doc_text(doc.text) for doc in documents
|
||||
]
|
||||
return "\n\n".join(matched_texts)
|
||||
|
||||
def run(
|
||||
self,
|
||||
question: str,
|
||||
documents: list[RetrievedDocument],
|
||||
use_citation: bool = False,
|
||||
**kwargs
|
||||
) -> Document:
|
||||
# retrieve relevant documents as context
|
||||
context = self._format_retrieved_context(documents)
|
||||
self.log_progress(".context", context=context)
|
||||
|
||||
# generate the answer
|
||||
prompt = self.qa_prompt_template.populate(
|
||||
context=context,
|
||||
question=question,
|
||||
)
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
# run citation pipeline
|
||||
citation_pipeline = CitationPipeline(llm=self.llm)
|
||||
citation = citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
answer = Document(text=answer_text, metadata={"citation": citation})
|
||||
return answer
|
Reference in New Issue
Block a user