Update retrieving + agent pipeline (#71)
This commit is contained in:
parent
693ed39de4
commit
640962e916
|
@ -1,5 +1,8 @@
|
|||
from typing import Any, List, Sequence, Type
|
||||
|
||||
from llama_index.node_parser import (
|
||||
SentenceWindowNodeParser as LISentenceWindowNodeParser,
|
||||
)
|
||||
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
|
||||
from llama_index.node_parser.interface import NodeParser
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
|
@ -61,3 +64,7 @@ class SimpleNodeParser(LINodeParser):
|
|||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SentenceWindowNodeParser(LINodeParser):
|
||||
_parser_class = LISentenceWindowNodeParser
|
||||
|
|
|
@ -4,7 +4,7 @@ from kotaemon.llms import PromptTemplate
|
|||
|
||||
zero_shot_react_prompt = PromptTemplate(
|
||||
template="""Answer the following questions as best you can. You have access to the following tools:
|
||||
{tool_description}.
|
||||
{tool_description}
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
|
|
|
@ -4,7 +4,7 @@ from ....base import BaseComponent
|
|||
from ....llms import PromptTemplate
|
||||
from ..base import BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import zero_shot_planner_prompt
|
||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
||||
|
||||
|
||||
class Planner(BaseComponent):
|
||||
|
@ -56,7 +56,7 @@ class Planner(BaseComponent):
|
|||
)
|
||||
else:
|
||||
if self.examples is not None:
|
||||
return zero_shot_planner_prompt.populate(
|
||||
return few_shot_planner_prompt.populate(
|
||||
tool_description=worker_desctription,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
|
|
|
@ -7,19 +7,20 @@ from theflow import Node
|
|||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
from kotaemon.docstores import BaseDocumentStore, InMemoryDocumentStore
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.loaders import (
|
||||
AutoReader,
|
||||
DirectoryReader,
|
||||
MathpixPDFReader,
|
||||
OCRReader,
|
||||
PandasExcelReader,
|
||||
)
|
||||
from kotaemon.parsers.splitter import SimpleNodeParser
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.vectorstores import InMemoryVectorStore
|
||||
from kotaemon.vectorstores import BaseVectorStore, InMemoryVectorStore
|
||||
|
||||
from .qa import AgentQAPipeline, QuestionAnsweringPipeline
|
||||
from .utils import file_names_to_collection_name
|
||||
|
@ -33,12 +34,11 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
|
||||
# Expose variables for users to switch in prompt ui
|
||||
storage_path: Path = Path("./storage")
|
||||
reader_name: str = "normal" # "normal" or "mathpix"
|
||||
reader_name: str = "normal" # "normal", "mathpix" or "ocr"
|
||||
chunk_size: int = 1024
|
||||
chunk_overlap: int = 256
|
||||
file_name_list: List[str] = list()
|
||||
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
|
||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
||||
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
|
@ -54,6 +54,8 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
}
|
||||
if self.reader_name == "normal":
|
||||
file_extractor[".pdf"] = AutoReader("UnstructuredReader")
|
||||
elif self.reader_name == "ocr":
|
||||
file_extractor[".pdf"] = OCRReader()
|
||||
else:
|
||||
file_extractor[".pdf"] = MathpixPDFReader()
|
||||
main_reader = DirectoryReader(
|
||||
|
@ -105,11 +107,12 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
else:
|
||||
self.indexing_vector_pipeline.load(file_storage_path)
|
||||
|
||||
def to_retrieving_pipeline(self):
|
||||
def to_retrieving_pipeline(self, top_k=3):
|
||||
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
top_k=top_k,
|
||||
)
|
||||
return retrieving_pipeline
|
||||
|
||||
|
@ -118,7 +121,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
storage_path=self.storage_path,
|
||||
file_name_list=self.file_name_list,
|
||||
vector_store=self.vector_store,
|
||||
doc_score=self.doc_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
llm=llm,
|
||||
**kwargs
|
||||
|
@ -130,7 +133,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
storage_path=self.storage_path,
|
||||
file_name_list=self.file_name_list,
|
||||
vector_store=self.vector_store,
|
||||
doc_score=self.doc_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
agent=agent,
|
||||
**kwargs
|
||||
|
|
|
@ -7,14 +7,14 @@ from theflow.utils.modules import ObjectInitDeclaration as _
|
|||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import RetrievedDocument
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
from kotaemon.docstores import BaseDocumentStore, InMemoryDocumentStore
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.pipelines.tools import ComponentTool
|
||||
from kotaemon.vectorstores import InMemoryVectorStore
|
||||
from kotaemon.vectorstores import BaseVectorStore, InMemoryVectorStore
|
||||
|
||||
from .utils import file_names_to_collection_name
|
||||
|
||||
|
@ -29,7 +29,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
file_name_list: List[str]
|
||||
"""List of filename, incombination with storage_path to
|
||||
create persistent path of vectorstore"""
|
||||
prompt_template: PromptTemplate = PromptTemplate(
|
||||
qa_prompt_template: PromptTemplate = PromptTemplate(
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
|
@ -43,8 +43,8 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
request_timeout=60,
|
||||
)
|
||||
|
||||
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
|
||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
||||
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
|
@ -53,12 +53,21 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
)
|
||||
|
||||
@Node.default()
|
||||
@Node.auto(
|
||||
depends_on=[
|
||||
"vector_store",
|
||||
"doc_store",
|
||||
"embedding",
|
||||
"file_name_list",
|
||||
"retrieval_top_k",
|
||||
]
|
||||
)
|
||||
def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
|
||||
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
top_k=self.retrieval_top_k,
|
||||
)
|
||||
# load persistent from selected path
|
||||
collection_name = file_names_to_collection_name(self.file_name_list)
|
||||
|
@ -81,7 +90,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
self.log_progress(".context", context=context)
|
||||
|
||||
# generate the answer
|
||||
prompt = self.prompt_template.populate(
|
||||
prompt = self.qa_prompt_template.populate(
|
||||
context=context,
|
||||
question=question,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
|
@ -20,17 +21,24 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
|||
vector_store: Param[BaseVectorStore] = Param()
|
||||
doc_store: Param[BaseDocumentStore] = Param()
|
||||
embedding: Node[BaseEmbeddings] = Node()
|
||||
top_k: int = 1
|
||||
# TODO: refer to llama_index's storage as well
|
||||
|
||||
def run(self, text: str | Document, top_k: int = 1) -> list[RetrievedDocument]:
|
||||
def run(
|
||||
self, text: str | Document, top_k: Optional[int] = None
|
||||
) -> list[RetrievedDocument]:
|
||||
"""Retrieve a list of documents from vector store
|
||||
|
||||
Args:
|
||||
text: the text to retrieve similar documents
|
||||
top_k: number of top similar documents to return
|
||||
|
||||
Returns:
|
||||
list[RetrievedDocument]: list of retrieved documents
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
if self.doc_store is None:
|
||||
raise ValueError(
|
||||
"doc_store is not provided. Please provide a doc_store to "
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import AnyStr, Optional, Type
|
||||
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool
|
||||
|
@ -33,3 +34,18 @@ class GoogleSearchTool(BaseTool):
|
|||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SerpTool(BaseTool):
|
||||
name = "google_search"
|
||||
description = (
|
||||
"Worker that searches results from Google. Useful when you need to find short "
|
||||
"and succinct answers about a specific topic. Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
tool = SerpAPIWrapper()
|
||||
evidence = tool.run(query)
|
||||
|
||||
return evidence
|
||||
|
|
|
@ -53,6 +53,7 @@ class WikipediaTool(BaseTool):
|
|||
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
||||
"Useful when you need to get holistic knowledge about people, "
|
||||
"places, companies, historical events, or other subjects. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = WikipediaArgs
|
||||
doc_store: Any = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user