Update retrieving + agent pipeline (#71)

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2023-11-14 16:40:13 +07:00
committed by GitHub
parent 693ed39de4
commit 640962e916
8 changed files with 65 additions and 21 deletions

View File

@@ -7,14 +7,14 @@ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.base.schema import RetrievedDocument
from kotaemon.docstores import InMemoryDocumentStore
from kotaemon.docstores import BaseDocumentStore, InMemoryDocumentStore
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool
from kotaemon.vectorstores import InMemoryVectorStore
from kotaemon.vectorstores import BaseVectorStore, InMemoryVectorStore
from .utils import file_names_to_collection_name
@@ -29,7 +29,7 @@ class QuestionAnsweringPipeline(BaseComponent):
file_name_list: List[str]
"""List of filename, incombination with storage_path to
create persistent path of vectorstore"""
prompt_template: PromptTemplate = PromptTemplate(
qa_prompt_template: PromptTemplate = PromptTemplate(
'Answer the following question: "{question}". '
"The context is: \n{context}\nAnswer: "
)
@@ -43,8 +43,8 @@ class QuestionAnsweringPipeline(BaseComponent):
request_timeout=60,
)
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
@@ -53,12 +53,21 @@ class QuestionAnsweringPipeline(BaseComponent):
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
@Node.default()
@Node.auto(
depends_on=[
"vector_store",
"doc_store",
"embedding",
"file_name_list",
"retrieval_top_k",
]
)
def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store,
doc_store=self.doc_store,
embedding=self.embedding,
top_k=self.retrieval_top_k,
)
# load persistent from selected path
collection_name = file_names_to_collection_name(self.file_name_list)
@@ -81,7 +90,7 @@ class QuestionAnsweringPipeline(BaseComponent):
self.log_progress(".context", context=context)
# generate the answer
prompt = self.prompt_template.populate(
prompt = self.qa_prompt_template.populate(
context=context,
question=question,
)