Enforce all IO objects to be subclassed from Document (#88)

* enforce Document as IO

* Separate rerankers, splitters and extractors (#85)

* partially refractor importing

* add text to embedding outputs

---------

Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin 2023-11-27 16:35:09 +07:00 committed by GitHub
parent 2186c5558f
commit 8e0779a22d
13 changed files with 108 additions and 59 deletions

View File

@ -1,4 +1,25 @@
from .component import BaseComponent
from .schema import Document
from .schema import (
AIMessage,
BaseMessage,
Document,
DocumentWithEmbedding,
ExtractorOutput,
HumanMessage,
LLMInterface,
RetrievedDocument,
SystemMessage,
)
__all__ = ["BaseComponent", "Document"]
__all__ = [
"BaseComponent",
"Document",
"DocumentWithEmbedding",
"BaseMessage",
"SystemMessage",
"AIMessage",
"HumanMessage",
"RetrievedDocument",
"LLMInterface",
"ExtractorOutput",
]

View File

@ -2,6 +2,8 @@ from abc import abstractmethod
from theflow.base import Function
from kotaemon.base.schema import Document
class BaseComponent(Function):
"""A component is a class that can be used to compose a pipeline
@ -30,7 +32,6 @@ class BaseComponent(Function):
return self.__call__(self.inflow.flow())
@abstractmethod
def run(self, *args, **kwargs):
# enforce output type to be compatible with Document
def run(self, *args, **kwargs) -> Document | list[Document] | None:
"""Run the component."""
...

View File

@ -32,6 +32,8 @@ class Document(BaseDocument):
kwargs["content"] = kwargs["text"]
elif kwargs.get("embedding", None) is not None:
kwargs["content"] = kwargs["embedding"]
# default text indicating this document only contains embedding
kwargs["text"] = "<EMBEDDING>"
elif isinstance(content, Document):
kwargs = content.dict()
else:
@ -65,6 +67,17 @@ class Document(BaseDocument):
return str(self.content)
class DocumentWithEmbedding(Document):
"""Subclass of Document which must contains embedding
Use this if you want to enforce component's IOs to must contain embedding.
"""
def __init__(self, embedding: list[float], *args, **kwargs):
kwargs["embedding"] = embedding
super().__init__(*args, **kwargs)
class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError

View File

@ -3,11 +3,8 @@ from typing import List, Optional
from theflow import SessionFunction
from kotaemon.base.schema import AIMessage, SystemMessage
from ..base import BaseComponent
from ..base.schema import LLMInterface
from ..llms.chats.base import BaseMessage, HumanMessage
from kotaemon.base import BaseComponent, LLMInterface
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
class BaseChatBot(BaseComponent):

View File

@ -5,8 +5,9 @@ from typing import Any, Dict, Optional, Type, Union
import yaml
from ...base import BaseComponent
from ...chatbot import BaseChatBot
from kotaemon.base import BaseComponent
from kotaemon.chatbot import BaseChatBot
from .base import DEFAULT_COMPONENT_BY_TYPES

View File

@ -6,14 +6,14 @@ from typing import Type
from langchain.schema.embeddings import Embeddings as LCEmbeddings
from theflow import Param
from ..base import BaseComponent, Document
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
class BaseEmbeddings(BaseComponent):
@abstractmethod
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[list[float]]:
) -> list[DocumentWithEmbedding]:
...
@ -43,7 +43,7 @@ class LangchainEmbeddings(BaseEmbeddings):
def agent(self):
return self._lc_class(**self._kwargs)
def run(self, text) -> list[list[float]]:
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
@ -58,4 +58,9 @@ class LangchainEmbeddings(BaseEmbeddings):
f"Invalid input type {type(item)}, should be str or Document"
)
return self.agent.embed_documents(input_)
embeddings = self.agent.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
]

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Sequence, Type
from typing import Any, Type
from llama_index.node_parser.interface import NodeParser
@ -20,9 +20,9 @@ class DocTransformer(BaseComponent):
@abstractmethod
def run(
self,
documents: Sequence[Document],
documents: list[Document],
**kwargs,
) -> Sequence[Document]:
) -> list[Document]:
...
@ -62,9 +62,9 @@ class LlamaIndexMixin:
def run(
self,
documents: Sequence[Document],
documents: list[Document],
**kwargs,
) -> Sequence[Document]:
) -> list[Document]:
"""Run Llama-index node parser and convert the output to Document from
kotaemon
"""

View File

@ -6,10 +6,7 @@ from typing import Type
from langchain.chat_models.base import BaseChatModel
from theflow.base import Param
from kotaemon.base.schema import BaseMessage, HumanMessage
from ...base import BaseComponent
from ...base.schema import LLMInterface
from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface
logger = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ from typing import Callable, List
from theflow import Function, Node, Param
from kotaemon.base import BaseComponent
from kotaemon.base import BaseComponent, Document
from kotaemon.llms import LLM, BasePromptComponent
from kotaemon.llms.chats.openai import AzureChatOpenAI
@ -65,27 +65,33 @@ class Thought(BaseComponent):
"""
prompt: str = Param(
help="The prompt template string. This prompt template has Python-like "
help=(
"The prompt template string. This prompt template has Python-like "
"variable placeholders, that then will be subsituted with real values when "
"this component is executed"
)
)
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
post_process: Function = Node(
help="The function post-processor that post-processes LLM output prediction ."
help=(
"The function post-processor that post-processes LLM output prediction ."
"It should take a string as input (this is the LLM output text) and return "
"a dictionary, where the key should"
)
)
@Node.auto(depends_on="prompt")
def prompt_template(self):
"""Automatically wrap around param prompt. Can ignore"""
return BasePromptComponent(self.prompt)
def run(self, **kwargs) -> dict:
def run(self, **kwargs) -> Document:
"""Run the chain of thought"""
prompt = self.prompt_template(**kwargs).text
response = self.llm(prompt).text
return self.post_process(response)
response = self.post_process(response)
return Document(response)
def get_variables(self) -> List[str]:
return []
@ -146,7 +152,7 @@ class ManualSequentialChainOfThought(BaseComponent):
help="Callback on terminate condition. Default to always return False",
)
def run(self, **kwargs) -> dict:
def run(self, **kwargs) -> Document:
"""Run the manual chain of thought"""
inputs = deepcopy(kwargs)
@ -156,11 +162,11 @@ class ManualSequentialChainOfThought(BaseComponent):
self._prepare_child(thought, f"thought{idx}")
output = thought(**inputs)
inputs.update(output)
inputs.update(output.content)
if self.terminate(inputs):
break
return inputs
return Document(inputs)
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
return ManualSequentialChainOfThought(

View File

@ -3,12 +3,10 @@ from __future__ import annotations
from pathlib import Path
from typing import Optional, Sequence
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.indices.rankings import BaseReranking
from ..base import BaseComponent
from ..base.schema import Document, RetrievedDocument
from ..embeddings import BaseEmbeddings
from ..storages import BaseDocumentStore, BaseVectorStore
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
@ -45,7 +43,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
"retrieve the documents"
)
emb: list[float] = self.embedding(text)[0]
emb: list[float] = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
docs = self.doc_store.get(ids)
result = [

View File

@ -6,7 +6,7 @@ from llama_index.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.types import VectorStore as LIVectorStore
from llama_index.vector_stores.types import VectorStoreQuery
from ...base import Document
from kotaemon.base import Document, DocumentWithEmbedding
class BaseVectorStore(ABC):
@ -17,7 +17,7 @@ class BaseVectorStore(ABC):
@abstractmethod
def add(
self,
embeddings: List[List[float]],
embeddings: List[List[float]] | List[DocumentWithEmbedding],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
@ -104,11 +104,16 @@ class LlamaIndexVectorStore(BaseVectorStore):
def add(
self,
embeddings: List[List[float]],
embeddings: List[List[float]] | List[DocumentWithEmbedding],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
):
nodes = [Document(embedding=embedding) for embedding in embeddings]
if isinstance(embeddings[0], list):
nodes = [
DocumentWithEmbedding(embedding=embedding) for embedding in embeddings
]
else:
nodes = embeddings # type: ignore
if metadatas is not None:
for node, metadata in zip(nodes, metadatas):
node.metadata = metadata
@ -119,10 +124,10 @@ class LlamaIndexVectorStore(BaseVectorStore):
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
}
return self._client.add(nodes=nodes) # type: ignore
return self._client.add(nodes=nodes)
def add_from_docs(self, docs: List[Document]):
return self._client.add(nodes=docs) # type: ignore
return self._client.add(nodes=docs)
def delete(self, ids: List[str], **kwargs):
for id_ in ids:

View File

@ -56,7 +56,7 @@ def test_cot_plus_operator(openai_completion):
)
thought = thought1 + thought2
output = thought(word="hello", language="French")
assert output == {
assert output.content == {
"word": "hello",
"language": "French",
"translated": "Bonjour",
@ -86,7 +86,7 @@ def test_cot_manual(openai_completion):
)
thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
output = thought(word="hello", language="French")
assert output == {
assert output.content == {
"word": "hello",
"language": "French",
"translated": "Bonjour",
@ -120,7 +120,7 @@ def test_cot_with_termination_callback(openai_completion):
terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False,
)
output = thought(word="hallo", language="French")
assert output == {
assert output.content == {
"word": "hallo",
"language": "French",
"translated": "Bonjour",

View File

@ -2,6 +2,7 @@ import json
from pathlib import Path
from unittest.mock import patch
from kotaemon.base import Document
from kotaemon.embeddings.cohere import CohereEmbdeddings
from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
@ -26,8 +27,9 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
)
output = model("Hello world")
assert isinstance(output, list)
assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
assert isinstance(output[0], Document)
assert isinstance(output[0].embedding, list)
assert isinstance(output[0].embedding[0], float)
openai_embedding_call.assert_called()
@ -44,8 +46,9 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
)
output = model(["Hello world", "Goodbye world"])
assert isinstance(output, list)
assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
assert isinstance(output[0], Document)
assert isinstance(output[0].embedding, list)
assert isinstance(output[0].embedding[0], float)
openai_embedding_call.assert_called()
@ -68,8 +71,9 @@ def test_huggingface_embddings(
output = model("Hello World")
assert isinstance(output, list)
assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
assert isinstance(output[0], Document)
assert isinstance(output[0].embedding, list)
assert isinstance(output[0].embedding[0], float)
sentence_transformers_init.assert_called()
langchain_huggingface_embedding_call.assert_called()
@ -85,6 +89,7 @@ def test_cohere_embeddings(langchain_cohere_embedding_call):
output = model("Hello World")
assert isinstance(output, list)
assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
assert isinstance(output[0], Document)
assert isinstance(output[0].embedding, list)
assert isinstance(output[0].embedding[0], float)
langchain_cohere_embedding_call.assert_called()