Update retrieving + agent pipeline (#71)

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-11-14 16:40:13 +07:00 committed by GitHub
parent 693ed39de4
commit 640962e916
8 changed files with 65 additions and 21 deletions

View File

@ -1,5 +1,8 @@
from typing import Any, List, Sequence, Type from typing import Any, List, Sequence, Type
from llama_index.node_parser import (
SentenceWindowNodeParser as LISentenceWindowNodeParser,
)
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.interface import NodeParser
from llama_index.text_splitter import TokenTextSplitter from llama_index.text_splitter import TokenTextSplitter
@ -61,3 +64,7 @@ class SimpleNodeParser(LINodeParser):
chunk_size=chunk_size, chunk_overlap=chunk_overlap chunk_size=chunk_size, chunk_overlap=chunk_overlap
) )
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class SentenceWindowNodeParser(LINodeParser):
_parser_class = LISentenceWindowNodeParser

View File

@ -4,7 +4,7 @@ from kotaemon.llms import PromptTemplate
zero_shot_react_prompt = PromptTemplate( zero_shot_react_prompt = PromptTemplate(
template="""Answer the following questions as best you can. You have access to the following tools: template="""Answer the following questions as best you can. You have access to the following tools:
{tool_description}. {tool_description}
Use the following format: Use the following format:
Question: the input question you must answer Question: the input question you must answer

View File

@ -4,7 +4,7 @@ from ....base import BaseComponent
from ....llms import PromptTemplate from ....llms import PromptTemplate
from ..base import BaseLLM, BaseTool from ..base import BaseLLM, BaseTool
from ..output.base import BaseScratchPad from ..output.base import BaseScratchPad
from .prompt import zero_shot_planner_prompt from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
class Planner(BaseComponent): class Planner(BaseComponent):
@ -56,7 +56,7 @@ class Planner(BaseComponent):
) )
else: else:
if self.examples is not None: if self.examples is not None:
return zero_shot_planner_prompt.populate( return few_shot_planner_prompt.populate(
tool_description=worker_desctription, tool_description=worker_desctription,
fewshot=fewshot, fewshot=fewshot,
task=instruction, task=instruction,

View File

@ -7,19 +7,20 @@ from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore from kotaemon.docstores import BaseDocumentStore, InMemoryDocumentStore
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.loaders import ( from kotaemon.loaders import (
AutoReader, AutoReader,
DirectoryReader, DirectoryReader,
MathpixPDFReader, MathpixPDFReader,
OCRReader,
PandasExcelReader, PandasExcelReader,
) )
from kotaemon.parsers.splitter import SimpleNodeParser from kotaemon.parsers.splitter import SimpleNodeParser
from kotaemon.pipelines.agents import BaseAgent from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.vectorstores import InMemoryVectorStore from kotaemon.vectorstores import BaseVectorStore, InMemoryVectorStore
from .qa import AgentQAPipeline, QuestionAnsweringPipeline from .qa import AgentQAPipeline, QuestionAnsweringPipeline
from .utils import file_names_to_collection_name from .utils import file_names_to_collection_name
@ -33,12 +34,11 @@ class ReaderIndexingPipeline(BaseComponent):
# Expose variables for users to switch in prompt ui # Expose variables for users to switch in prompt ui
storage_path: Path = Path("./storage") storage_path: Path = Path("./storage")
reader_name: str = "normal" # "normal" or "mathpix" reader_name: str = "normal" # "normal", "mathpix" or "ocr"
chunk_size: int = 1024 chunk_size: int = 1024
chunk_overlap: int = 256 chunk_overlap: int = 256
file_name_list: List[str] = list() vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore) doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
@ -54,6 +54,8 @@ class ReaderIndexingPipeline(BaseComponent):
} }
if self.reader_name == "normal": if self.reader_name == "normal":
file_extractor[".pdf"] = AutoReader("UnstructuredReader") file_extractor[".pdf"] = AutoReader("UnstructuredReader")
elif self.reader_name == "ocr":
file_extractor[".pdf"] = OCRReader()
else: else:
file_extractor[".pdf"] = MathpixPDFReader() file_extractor[".pdf"] = MathpixPDFReader()
main_reader = DirectoryReader( main_reader = DirectoryReader(
@ -105,11 +107,12 @@ class ReaderIndexingPipeline(BaseComponent):
else: else:
self.indexing_vector_pipeline.load(file_storage_path) self.indexing_vector_pipeline.load(file_storage_path)
def to_retrieving_pipeline(self): def to_retrieving_pipeline(self, top_k=3):
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline( retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store, vector_store=self.vector_store,
doc_store=self.doc_store, doc_store=self.doc_store,
embedding=self.embedding, embedding=self.embedding,
top_k=top_k,
) )
return retrieving_pipeline return retrieving_pipeline
@ -118,7 +121,7 @@ class ReaderIndexingPipeline(BaseComponent):
storage_path=self.storage_path, storage_path=self.storage_path,
file_name_list=self.file_name_list, file_name_list=self.file_name_list,
vector_store=self.vector_store, vector_store=self.vector_store,
doc_score=self.doc_store, doc_store=self.doc_store,
embedding=self.embedding, embedding=self.embedding,
llm=llm, llm=llm,
**kwargs **kwargs
@ -130,7 +133,7 @@ class ReaderIndexingPipeline(BaseComponent):
storage_path=self.storage_path, storage_path=self.storage_path,
file_name_list=self.file_name_list, file_name_list=self.file_name_list,
vector_store=self.vector_store, vector_store=self.vector_store,
doc_score=self.doc_store, doc_store=self.doc_store,
embedding=self.embedding, embedding=self.embedding,
agent=agent, agent=agent,
**kwargs **kwargs

View File

@ -7,14 +7,14 @@ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.base.schema import RetrievedDocument from kotaemon.base.schema import RetrievedDocument
from kotaemon.docstores import InMemoryDocumentStore from kotaemon.docstores import BaseDocumentStore, InMemoryDocumentStore
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool from kotaemon.pipelines.tools import ComponentTool
from kotaemon.vectorstores import InMemoryVectorStore from kotaemon.vectorstores import BaseVectorStore, InMemoryVectorStore
from .utils import file_names_to_collection_name from .utils import file_names_to_collection_name
@ -29,7 +29,7 @@ class QuestionAnsweringPipeline(BaseComponent):
file_name_list: List[str] file_name_list: List[str]
"""List of filename, incombination with storage_path to """List of filename, incombination with storage_path to
create persistent path of vectorstore""" create persistent path of vectorstore"""
prompt_template: PromptTemplate = PromptTemplate( qa_prompt_template: PromptTemplate = PromptTemplate(
'Answer the following question: "{question}". ' 'Answer the following question: "{question}". '
"The context is: \n{context}\nAnswer: " "The context is: \n{context}\nAnswer: "
) )
@ -43,8 +43,8 @@ class QuestionAnsweringPipeline(BaseComponent):
request_timeout=60, request_timeout=60,
) )
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore) vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore) doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
@ -53,12 +53,21 @@ class QuestionAnsweringPipeline(BaseComponent):
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
) )
@Node.default() @Node.auto(
depends_on=[
"vector_store",
"doc_store",
"embedding",
"file_name_list",
"retrieval_top_k",
]
)
def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline: def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline( retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store, vector_store=self.vector_store,
doc_store=self.doc_store, doc_store=self.doc_store,
embedding=self.embedding, embedding=self.embedding,
top_k=self.retrieval_top_k,
) )
# load persistent from selected path # load persistent from selected path
collection_name = file_names_to_collection_name(self.file_name_list) collection_name = file_names_to_collection_name(self.file_name_list)
@ -81,7 +90,7 @@ class QuestionAnsweringPipeline(BaseComponent):
self.log_progress(".context", context=context) self.log_progress(".context", context=context)
# generate the answer # generate the answer
prompt = self.prompt_template.populate( prompt = self.qa_prompt_template.populate(
context=context, context=context,
question=question, question=question,
) )

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Optional
from theflow import Node, Param from theflow import Node, Param
@ -20,17 +21,24 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
doc_store: Param[BaseDocumentStore] = Param() doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
top_k: int = 1
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
def run(self, text: str | Document, top_k: int = 1) -> list[RetrievedDocument]: def run(
self, text: str | Document, top_k: Optional[int] = None
) -> list[RetrievedDocument]:
"""Retrieve a list of documents from vector store """Retrieve a list of documents from vector store
Args: Args:
text: the text to retrieve similar documents text: the text to retrieve similar documents
top_k: number of top similar documents to return
Returns: Returns:
list[RetrievedDocument]: list of retrieved documents list[RetrievedDocument]: list of retrieved documents
""" """
if top_k is None:
top_k = self.top_k
if self.doc_store is None: if self.doc_store is None:
raise ValueError( raise ValueError(
"doc_store is not provided. Please provide a doc_store to " "doc_store is not provided. Please provide a doc_store to "

View File

@ -1,5 +1,6 @@
from typing import AnyStr, Optional, Type from typing import AnyStr, Optional, Type
from langchain.utilities import SerpAPIWrapper
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from .base import BaseTool from .base import BaseTool
@ -33,3 +34,18 @@ class GoogleSearchTool(BaseTool):
) )
return output return output
class SerpTool(BaseTool):
name = "google_search"
description = (
"Worker that searches results from Google. Useful when you need to find short "
"and succinct answers about a specific topic. Input should be a search query."
)
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
def _run_tool(self, query: AnyStr) -> str:
tool = SerpAPIWrapper()
evidence = tool.run(query)
return evidence

View File

@ -52,7 +52,8 @@ class WikipediaTool(BaseTool):
description = ( description = (
"Search engine from Wikipedia, retrieving relevant wiki page. " "Search engine from Wikipedia, retrieving relevant wiki page. "
"Useful when you need to get holistic knowledge about people, " "Useful when you need to get holistic knowledge about people, "
"places, companies, historical events, or other subjects." "places, companies, historical events, or other subjects. "
"Input should be a search query."
) )
args_schema: Optional[Type[BaseModel]] = WikipediaArgs args_schema: Optional[Type[BaseModel]] = WikipediaArgs
doc_store: Any = None doc_store: Any = None