From 8e0779a22dd1d04a1d532dccfacf89113ed52010 Mon Sep 17 00:00:00 2001 From: ian_Cin Date: Mon, 27 Nov 2023 16:35:09 +0700 Subject: [PATCH] Enforce all IO objects to be subclassed from Document (#88) * enforce Document as IO * Separate rerankers, splitters and extractors (#85) * partially refractor importing * add text to embedding outputs --------- Co-authored-by: Nguyen Trung Duc (john) --- knowledgehub/base/__init__.py | 25 ++++++++++++++++-- knowledgehub/base/component.py | 5 ++-- knowledgehub/base/schema.py | 13 ++++++++++ knowledgehub/chatbot/base.py | 7 ++--- knowledgehub/contribs/promptui/config.py | 5 ++-- knowledgehub/embeddings/base.py | 13 +++++++--- knowledgehub/indices/base.py | 10 ++++---- knowledgehub/llms/chats/base.py | 5 +--- knowledgehub/pipelines/cot.py | 30 +++++++++++++--------- knowledgehub/pipelines/retrieving.py | 10 +++----- knowledgehub/storages/vectorstores/base.py | 17 +++++++----- tests/test_cot.py | 6 ++--- tests/test_embedding_models.py | 21 +++++++++------ 13 files changed, 108 insertions(+), 59 deletions(-) diff --git a/knowledgehub/base/__init__.py b/knowledgehub/base/__init__.py index 7c600cc..11e9b76 100644 --- a/knowledgehub/base/__init__.py +++ b/knowledgehub/base/__init__.py @@ -1,4 +1,25 @@ from .component import BaseComponent -from .schema import Document +from .schema import ( + AIMessage, + BaseMessage, + Document, + DocumentWithEmbedding, + ExtractorOutput, + HumanMessage, + LLMInterface, + RetrievedDocument, + SystemMessage, +) -__all__ = ["BaseComponent", "Document"] +__all__ = [ + "BaseComponent", + "Document", + "DocumentWithEmbedding", + "BaseMessage", + "SystemMessage", + "AIMessage", + "HumanMessage", + "RetrievedDocument", + "LLMInterface", + "ExtractorOutput", +] diff --git a/knowledgehub/base/component.py b/knowledgehub/base/component.py index 5ae1fde..71da362 100644 --- a/knowledgehub/base/component.py +++ b/knowledgehub/base/component.py @@ -2,6 +2,8 @@ from abc import abstractmethod from theflow.base import Function +from kotaemon.base.schema import Document + class BaseComponent(Function): """A component is a class that can be used to compose a pipeline @@ -30,7 +32,6 @@ class BaseComponent(Function): return self.__call__(self.inflow.flow()) @abstractmethod - def run(self, *args, **kwargs): - # enforce output type to be compatible with Document + def run(self, *args, **kwargs) -> Document | list[Document] | None: """Run the component.""" ... diff --git a/knowledgehub/base/schema.py b/knowledgehub/base/schema.py index 9a4ced7..a767e7f 100644 --- a/knowledgehub/base/schema.py +++ b/knowledgehub/base/schema.py @@ -32,6 +32,8 @@ class Document(BaseDocument): kwargs["content"] = kwargs["text"] elif kwargs.get("embedding", None) is not None: kwargs["content"] = kwargs["embedding"] + # default text indicating this document only contains embedding + kwargs["text"] = "" elif isinstance(content, Document): kwargs = content.dict() else: @@ -65,6 +67,17 @@ class Document(BaseDocument): return str(self.content) +class DocumentWithEmbedding(Document): + """Subclass of Document which must contains embedding + + Use this if you want to enforce component's IOs to must contain embedding. + """ + + def __init__(self, embedding: list[float], *args, **kwargs): + kwargs["embedding"] = embedding + super().__init__(*args, **kwargs) + + class BaseMessage(Document): def __add__(self, other: Any): raise NotImplementedError diff --git a/knowledgehub/chatbot/base.py b/knowledgehub/chatbot/base.py index 3c305c0..b6a3baf 100644 --- a/knowledgehub/chatbot/base.py +++ b/knowledgehub/chatbot/base.py @@ -3,11 +3,8 @@ from typing import List, Optional from theflow import SessionFunction -from kotaemon.base.schema import AIMessage, SystemMessage - -from ..base import BaseComponent -from ..base.schema import LLMInterface -from ..llms.chats.base import BaseMessage, HumanMessage +from kotaemon.base import BaseComponent, LLMInterface +from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage class BaseChatBot(BaseComponent): diff --git a/knowledgehub/contribs/promptui/config.py b/knowledgehub/contribs/promptui/config.py index ac95024..478f859 100644 --- a/knowledgehub/contribs/promptui/config.py +++ b/knowledgehub/contribs/promptui/config.py @@ -5,8 +5,9 @@ from typing import Any, Dict, Optional, Type, Union import yaml -from ...base import BaseComponent -from ...chatbot import BaseChatBot +from kotaemon.base import BaseComponent +from kotaemon.chatbot import BaseChatBot + from .base import DEFAULT_COMPONENT_BY_TYPES diff --git a/knowledgehub/embeddings/base.py b/knowledgehub/embeddings/base.py index 688eb1c..632a480 100644 --- a/knowledgehub/embeddings/base.py +++ b/knowledgehub/embeddings/base.py @@ -6,14 +6,14 @@ from typing import Type from langchain.schema.embeddings import Embeddings as LCEmbeddings from theflow import Param -from ..base import BaseComponent, Document +from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding class BaseEmbeddings(BaseComponent): @abstractmethod def run( self, text: str | list[str] | Document | list[Document] - ) -> list[list[float]]: + ) -> list[DocumentWithEmbedding]: ... @@ -43,7 +43,7 @@ class LangchainEmbeddings(BaseEmbeddings): def agent(self): return self._lc_class(**self._kwargs) - def run(self, text) -> list[list[float]]: + def run(self, text): input_: list[str] = [] if not isinstance(text, list): text = [text] @@ -58,4 +58,9 @@ class LangchainEmbeddings(BaseEmbeddings): f"Invalid input type {type(item)}, should be str or Document" ) - return self.agent.embed_documents(input_) + embeddings = self.agent.embed_documents(input_) + + return [ + DocumentWithEmbedding(text=each_text, embedding=each_embedding) + for each_text, each_embedding in zip(input_, embeddings) + ] diff --git a/knowledgehub/indices/base.py b/knowledgehub/indices/base.py index dfdf9aa..02cfbc7 100644 --- a/knowledgehub/indices/base.py +++ b/knowledgehub/indices/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Sequence, Type +from typing import Any, Type from llama_index.node_parser.interface import NodeParser @@ -20,9 +20,9 @@ class DocTransformer(BaseComponent): @abstractmethod def run( self, - documents: Sequence[Document], + documents: list[Document], **kwargs, - ) -> Sequence[Document]: + ) -> list[Document]: ... @@ -62,9 +62,9 @@ class LlamaIndexMixin: def run( self, - documents: Sequence[Document], + documents: list[Document], **kwargs, - ) -> Sequence[Document]: + ) -> list[Document]: """Run Llama-index node parser and convert the output to Document from kotaemon """ diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index 7fdcd62..0804dff 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -6,10 +6,7 @@ from typing import Type from langchain.chat_models.base import BaseChatModel from theflow.base import Param -from kotaemon.base.schema import BaseMessage, HumanMessage - -from ...base import BaseComponent -from ...base.schema import LLMInterface +from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface logger = logging.getLogger(__name__) diff --git a/knowledgehub/pipelines/cot.py b/knowledgehub/pipelines/cot.py index 4ef4416..2a39679 100644 --- a/knowledgehub/pipelines/cot.py +++ b/knowledgehub/pipelines/cot.py @@ -3,7 +3,7 @@ from typing import Callable, List from theflow import Function, Node, Param -from kotaemon.base import BaseComponent +from kotaemon.base import BaseComponent, Document from kotaemon.llms import LLM, BasePromptComponent from kotaemon.llms.chats.openai import AzureChatOpenAI @@ -65,15 +65,19 @@ class Thought(BaseComponent): """ prompt: str = Param( - help="The prompt template string. This prompt template has Python-like " - "variable placeholders, that then will be subsituted with real values when " - "this component is executed" + help=( + "The prompt template string. This prompt template has Python-like " + "variable placeholders, that then will be subsituted with real values when " + "this component is executed" + ) ) llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt") post_process: Function = Node( - help="The function post-processor that post-processes LLM output prediction ." - "It should take a string as input (this is the LLM output text) and return " - "a dictionary, where the key should" + help=( + "The function post-processor that post-processes LLM output prediction ." + "It should take a string as input (this is the LLM output text) and return " + "a dictionary, where the key should" + ) ) @Node.auto(depends_on="prompt") @@ -81,11 +85,13 @@ class Thought(BaseComponent): """Automatically wrap around param prompt. Can ignore""" return BasePromptComponent(self.prompt) - def run(self, **kwargs) -> dict: + def run(self, **kwargs) -> Document: """Run the chain of thought""" prompt = self.prompt_template(**kwargs).text response = self.llm(prompt).text - return self.post_process(response) + response = self.post_process(response) + + return Document(response) def get_variables(self) -> List[str]: return [] @@ -146,7 +152,7 @@ class ManualSequentialChainOfThought(BaseComponent): help="Callback on terminate condition. Default to always return False", ) - def run(self, **kwargs) -> dict: + def run(self, **kwargs) -> Document: """Run the manual chain of thought""" inputs = deepcopy(kwargs) @@ -156,11 +162,11 @@ class ManualSequentialChainOfThought(BaseComponent): self._prepare_child(thought, f"thought{idx}") output = thought(**inputs) - inputs.update(output) + inputs.update(output.content) if self.terminate(inputs): break - return inputs + return Document(inputs) def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought": return ManualSequentialChainOfThought( diff --git a/knowledgehub/pipelines/retrieving.py b/knowledgehub/pipelines/retrieving.py index 1af6854..941ff01 100644 --- a/knowledgehub/pipelines/retrieving.py +++ b/knowledgehub/pipelines/retrieving.py @@ -3,12 +3,10 @@ from __future__ import annotations from pathlib import Path from typing import Optional, Sequence +from kotaemon.base import BaseComponent, Document, RetrievedDocument +from kotaemon.embeddings import BaseEmbeddings from kotaemon.indices.rankings import BaseReranking - -from ..base import BaseComponent -from ..base.schema import Document, RetrievedDocument -from ..embeddings import BaseEmbeddings -from ..storages import BaseDocumentStore, BaseVectorStore +from kotaemon.storages import BaseDocumentStore, BaseVectorStore VECTOR_STORE_FNAME = "vectorstore" DOC_STORE_FNAME = "docstore" @@ -45,7 +43,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent): "retrieve the documents" ) - emb: list[float] = self.embedding(text)[0] + emb: list[float] = self.embedding(text)[0].embedding _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k) docs = self.doc_store.get(ids) result = [ diff --git a/knowledgehub/storages/vectorstores/base.py b/knowledgehub/storages/vectorstores/base.py index 1ddaef0..2213b85 100644 --- a/knowledgehub/storages/vectorstores/base.py +++ b/knowledgehub/storages/vectorstores/base.py @@ -6,7 +6,7 @@ from llama_index.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.types import VectorStore as LIVectorStore from llama_index.vector_stores.types import VectorStoreQuery -from ...base import Document +from kotaemon.base import Document, DocumentWithEmbedding class BaseVectorStore(ABC): @@ -17,7 +17,7 @@ class BaseVectorStore(ABC): @abstractmethod def add( self, - embeddings: List[List[float]], + embeddings: List[List[float]] | List[DocumentWithEmbedding], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, ) -> List[str]: @@ -104,11 +104,16 @@ class LlamaIndexVectorStore(BaseVectorStore): def add( self, - embeddings: List[List[float]], + embeddings: List[List[float]] | List[DocumentWithEmbedding], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, ): - nodes = [Document(embedding=embedding) for embedding in embeddings] + if isinstance(embeddings[0], list): + nodes = [ + DocumentWithEmbedding(embedding=embedding) for embedding in embeddings + ] + else: + nodes = embeddings # type: ignore if metadatas is not None: for node, metadata in zip(nodes, metadatas): node.metadata = metadata @@ -119,10 +124,10 @@ class LlamaIndexVectorStore(BaseVectorStore): NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id) } - return self._client.add(nodes=nodes) # type: ignore + return self._client.add(nodes=nodes) def add_from_docs(self, docs: List[Document]): - return self._client.add(nodes=docs) # type: ignore + return self._client.add(nodes=docs) def delete(self, ids: List[str], **kwargs): for id_ in ids: diff --git a/tests/test_cot.py b/tests/test_cot.py index 6583732..304d481 100644 --- a/tests/test_cot.py +++ b/tests/test_cot.py @@ -56,7 +56,7 @@ def test_cot_plus_operator(openai_completion): ) thought = thought1 + thought2 output = thought(word="hello", language="French") - assert output == { + assert output.content == { "word": "hello", "language": "French", "translated": "Bonjour", @@ -86,7 +86,7 @@ def test_cot_manual(openai_completion): ) thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm) output = thought(word="hello", language="French") - assert output == { + assert output.content == { "word": "hello", "language": "French", "translated": "Bonjour", @@ -120,7 +120,7 @@ def test_cot_with_termination_callback(openai_completion): terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False, ) output = thought(word="hallo", language="French") - assert output == { + assert output.content == { "word": "hallo", "language": "French", "translated": "Bonjour", diff --git a/tests/test_embedding_models.py b/tests/test_embedding_models.py index 2b29538..d530964 100644 --- a/tests/test_embedding_models.py +++ b/tests/test_embedding_models.py @@ -2,6 +2,7 @@ import json from pathlib import Path from unittest.mock import patch +from kotaemon.base import Document from kotaemon.embeddings.cohere import CohereEmbdeddings from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings from kotaemon.embeddings.openai import AzureOpenAIEmbeddings @@ -26,8 +27,9 @@ def test_azureopenai_embeddings_raw(openai_embedding_call): ) output = model("Hello world") assert isinstance(output, list) - assert isinstance(output[0], list) - assert isinstance(output[0][0], float) + assert isinstance(output[0], Document) + assert isinstance(output[0].embedding, list) + assert isinstance(output[0].embedding[0], float) openai_embedding_call.assert_called() @@ -44,8 +46,9 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call): ) output = model(["Hello world", "Goodbye world"]) assert isinstance(output, list) - assert isinstance(output[0], list) - assert isinstance(output[0][0], float) + assert isinstance(output[0], Document) + assert isinstance(output[0].embedding, list) + assert isinstance(output[0].embedding[0], float) openai_embedding_call.assert_called() @@ -68,8 +71,9 @@ def test_huggingface_embddings( output = model("Hello World") assert isinstance(output, list) - assert isinstance(output[0], list) - assert isinstance(output[0][0], float) + assert isinstance(output[0], Document) + assert isinstance(output[0].embedding, list) + assert isinstance(output[0].embedding[0], float) sentence_transformers_init.assert_called() langchain_huggingface_embedding_call.assert_called() @@ -85,6 +89,7 @@ def test_cohere_embeddings(langchain_cohere_embedding_call): output = model("Hello World") assert isinstance(output, list) - assert isinstance(output[0], list) - assert isinstance(output[0][0], float) + assert isinstance(output[0], Document) + assert isinstance(output[0].embedding, list) + assert isinstance(output[0].embedding[0], float) langchain_cohere_embedding_call.assert_called()