Upgrade the declarative pipeline for cleaner interface (#51)

This commit is contained in:
Nguyen Trung Duc (john)
2023-10-24 11:12:22 +07:00
committed by GitHub
parent aab982ddc4
commit 9035e25666
26 changed files with 365 additions and 169 deletions

View File

@@ -1,7 +1,7 @@
import tempfile
from typing import List
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings
@@ -11,33 +11,27 @@ from kotaemon.vectorstores import ChromaVectorStore
class Pipeline(BaseComponent):
vectorstore_path: str = str(tempfile.mkdtemp())
llm: Node[AzureOpenAI] = Node(
default=AzureOpenAI,
default_kwargs={
"openai_api_base": "https://test.openai.azure.com/",
"openai_api_key": "some-key",
"openai_api_version": "2023-03-15-preview",
"deployment_name": "gpt35turbo",
"temperature": 0,
"request_timeout": 60,
},
llm: AzureOpenAI = AzureOpenAI.withx(
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",
deployment_name="gpt35turbo",
temperature=0,
request_timeout=60,
)
@Node.decorate(depends_on=["vectorstore_path"])
def retrieving_pipeline(self):
vector_store = ChromaVectorStore(self.vectorstore_path)
embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
)
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store, embedding=embedding
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
RetrieveDocumentFromVectorStorePipeline.withx(
vector_store=_(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())),
embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
),
)
)
def run_raw(self, text: str) -> str:
matched_texts: List[str] = self.retrieving_pipeline(text)
return self.llm("\n".join(matched_texts)).text[0]
return self.llm("\n".join(matched_texts)).text