Enforce all IO objects to be subclassed from Document (#88)
* enforce Document as IO * Separate rerankers, splitters and extractors (#85) * partially refractor importing * add text to embedding outputs --------- Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
parent
2186c5558f
commit
8e0779a22d
|
@ -1,4 +1,25 @@
|
|||
from .component import BaseComponent
|
||||
from .schema import Document
|
||||
from .schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
Document,
|
||||
DocumentWithEmbedding,
|
||||
ExtractorOutput,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
RetrievedDocument,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
__all__ = ["BaseComponent", "Document"]
|
||||
__all__ = [
|
||||
"BaseComponent",
|
||||
"Document",
|
||||
"DocumentWithEmbedding",
|
||||
"BaseMessage",
|
||||
"SystemMessage",
|
||||
"AIMessage",
|
||||
"HumanMessage",
|
||||
"RetrievedDocument",
|
||||
"LLMInterface",
|
||||
"ExtractorOutput",
|
||||
]
|
||||
|
|
|
@ -2,6 +2,8 @@ from abc import abstractmethod
|
|||
|
||||
from theflow.base import Function
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
|
||||
|
||||
class BaseComponent(Function):
|
||||
"""A component is a class that can be used to compose a pipeline
|
||||
|
@ -30,7 +32,6 @@ class BaseComponent(Function):
|
|||
return self.__call__(self.inflow.flow())
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
# enforce output type to be compatible with Document
|
||||
def run(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||
"""Run the component."""
|
||||
...
|
||||
|
|
|
@ -32,6 +32,8 @@ class Document(BaseDocument):
|
|||
kwargs["content"] = kwargs["text"]
|
||||
elif kwargs.get("embedding", None) is not None:
|
||||
kwargs["content"] = kwargs["embedding"]
|
||||
# default text indicating this document only contains embedding
|
||||
kwargs["text"] = "<EMBEDDING>"
|
||||
elif isinstance(content, Document):
|
||||
kwargs = content.dict()
|
||||
else:
|
||||
|
@ -65,6 +67,17 @@ class Document(BaseDocument):
|
|||
return str(self.content)
|
||||
|
||||
|
||||
class DocumentWithEmbedding(Document):
|
||||
"""Subclass of Document which must contains embedding
|
||||
|
||||
Use this if you want to enforce component's IOs to must contain embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding: list[float], *args, **kwargs):
|
||||
kwargs["embedding"] = embedding
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class BaseMessage(Document):
|
||||
def __add__(self, other: Any):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -3,11 +3,8 @@ from typing import List, Optional
|
|||
|
||||
from theflow import SessionFunction
|
||||
|
||||
from kotaemon.base.schema import AIMessage, SystemMessage
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import LLMInterface
|
||||
from ..llms.chats.base import BaseMessage, HumanMessage
|
||||
from kotaemon.base import BaseComponent, LLMInterface
|
||||
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
class BaseChatBot(BaseComponent):
|
||||
|
|
|
@ -5,8 +5,9 @@ from typing import Any, Dict, Optional, Type, Union
|
|||
|
||||
import yaml
|
||||
|
||||
from ...base import BaseComponent
|
||||
from ...chatbot import BaseChatBot
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.chatbot import BaseChatBot
|
||||
|
||||
from .base import DEFAULT_COMPONENT_BY_TYPES
|
||||
|
||||
|
||||
|
|
|
@ -6,14 +6,14 @@ from typing import Type
|
|||
from langchain.schema.embeddings import Embeddings as LCEmbeddings
|
||||
from theflow import Param
|
||||
|
||||
from ..base import BaseComponent, Document
|
||||
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
|
||||
|
||||
|
||||
class BaseEmbeddings(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
) -> list[list[float]]:
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
...
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ class LangchainEmbeddings(BaseEmbeddings):
|
|||
def agent(self):
|
||||
return self._lc_class(**self._kwargs)
|
||||
|
||||
def run(self, text) -> list[list[float]]:
|
||||
def run(self, text):
|
||||
input_: list[str] = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
@ -58,4 +58,9 @@ class LangchainEmbeddings(BaseEmbeddings):
|
|||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
return self.agent.embed_documents(input_)
|
||||
embeddings = self.agent.embed_documents(input_)
|
||||
|
||||
return [
|
||||
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
|
||||
for each_text, each_embedding in zip(input_, embeddings)
|
||||
]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Sequence, Type
|
||||
from typing import Any, Type
|
||||
|
||||
from llama_index.node_parser.interface import NodeParser
|
||||
|
||||
|
@ -20,9 +20,9 @@ class DocTransformer(BaseComponent):
|
|||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
documents: list[Document],
|
||||
**kwargs,
|
||||
) -> Sequence[Document]:
|
||||
) -> list[Document]:
|
||||
...
|
||||
|
||||
|
||||
|
@ -62,9 +62,9 @@ class LlamaIndexMixin:
|
|||
|
||||
def run(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
documents: list[Document],
|
||||
**kwargs,
|
||||
) -> Sequence[Document]:
|
||||
) -> list[Document]:
|
||||
"""Run Llama-index node parser and convert the output to Document from
|
||||
kotaemon
|
||||
"""
|
||||
|
|
|
@ -6,10 +6,7 @@ from typing import Type
|
|||
from langchain.chat_models.base import BaseChatModel
|
||||
from theflow.base import Param
|
||||
|
||||
from kotaemon.base.schema import BaseMessage, HumanMessage
|
||||
|
||||
from ...base import BaseComponent
|
||||
from ...base.schema import LLMInterface
|
||||
from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Callable, List
|
|||
|
||||
from theflow import Function, Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
from kotaemon.llms import LLM, BasePromptComponent
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
|
||||
|
@ -65,27 +65,33 @@ class Thought(BaseComponent):
|
|||
"""
|
||||
|
||||
prompt: str = Param(
|
||||
help="The prompt template string. This prompt template has Python-like "
|
||||
help=(
|
||||
"The prompt template string. This prompt template has Python-like "
|
||||
"variable placeholders, that then will be subsituted with real values when "
|
||||
"this component is executed"
|
||||
)
|
||||
)
|
||||
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||
post_process: Function = Node(
|
||||
help="The function post-processor that post-processes LLM output prediction ."
|
||||
help=(
|
||||
"The function post-processor that post-processes LLM output prediction ."
|
||||
"It should take a string as input (this is the LLM output text) and return "
|
||||
"a dictionary, where the key should"
|
||||
)
|
||||
)
|
||||
|
||||
@Node.auto(depends_on="prompt")
|
||||
def prompt_template(self):
|
||||
"""Automatically wrap around param prompt. Can ignore"""
|
||||
return BasePromptComponent(self.prompt)
|
||||
|
||||
def run(self, **kwargs) -> dict:
|
||||
def run(self, **kwargs) -> Document:
|
||||
"""Run the chain of thought"""
|
||||
prompt = self.prompt_template(**kwargs).text
|
||||
response = self.llm(prompt).text
|
||||
return self.post_process(response)
|
||||
response = self.post_process(response)
|
||||
|
||||
return Document(response)
|
||||
|
||||
def get_variables(self) -> List[str]:
|
||||
return []
|
||||
|
@ -146,7 +152,7 @@ class ManualSequentialChainOfThought(BaseComponent):
|
|||
help="Callback on terminate condition. Default to always return False",
|
||||
)
|
||||
|
||||
def run(self, **kwargs) -> dict:
|
||||
def run(self, **kwargs) -> Document:
|
||||
"""Run the manual chain of thought"""
|
||||
|
||||
inputs = deepcopy(kwargs)
|
||||
|
@ -156,11 +162,11 @@ class ManualSequentialChainOfThought(BaseComponent):
|
|||
self._prepare_child(thought, f"thought{idx}")
|
||||
|
||||
output = thought(**inputs)
|
||||
inputs.update(output)
|
||||
inputs.update(output.content)
|
||||
if self.terminate(inputs):
|
||||
break
|
||||
|
||||
return inputs
|
||||
return Document(inputs)
|
||||
|
||||
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
|
||||
return ManualSequentialChainOfThought(
|
||||
|
|
|
@ -3,12 +3,10 @@ from __future__ import annotations
|
|||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.embeddings import BaseEmbeddings
|
||||
from kotaemon.indices.rankings import BaseReranking
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import Document, RetrievedDocument
|
||||
from ..embeddings import BaseEmbeddings
|
||||
from ..storages import BaseDocumentStore, BaseVectorStore
|
||||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
|
@ -45,7 +43,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
|||
"retrieve the documents"
|
||||
)
|
||||
|
||||
emb: list[float] = self.embedding(text)[0]
|
||||
emb: list[float] = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
|
|
|
@ -6,7 +6,7 @@ from llama_index.vector_stores.types import BasePydanticVectorStore
|
|||
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
||||
from llama_index.vector_stores.types import VectorStoreQuery
|
||||
|
||||
from ...base import Document
|
||||
from kotaemon.base import Document, DocumentWithEmbedding
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
|
@ -17,7 +17,7 @@ class BaseVectorStore(ABC):
|
|||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
embeddings: List[List[float]] | List[DocumentWithEmbedding],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
|
@ -104,11 +104,16 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
|||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
embeddings: List[List[float]] | List[DocumentWithEmbedding],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
):
|
||||
nodes = [Document(embedding=embedding) for embedding in embeddings]
|
||||
if isinstance(embeddings[0], list):
|
||||
nodes = [
|
||||
DocumentWithEmbedding(embedding=embedding) for embedding in embeddings
|
||||
]
|
||||
else:
|
||||
nodes = embeddings # type: ignore
|
||||
if metadatas is not None:
|
||||
for node, metadata in zip(nodes, metadatas):
|
||||
node.metadata = metadata
|
||||
|
@ -119,10 +124,10 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
|||
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
|
||||
}
|
||||
|
||||
return self._client.add(nodes=nodes) # type: ignore
|
||||
return self._client.add(nodes=nodes)
|
||||
|
||||
def add_from_docs(self, docs: List[Document]):
|
||||
return self._client.add(nodes=docs) # type: ignore
|
||||
return self._client.add(nodes=docs)
|
||||
|
||||
def delete(self, ids: List[str], **kwargs):
|
||||
for id_ in ids:
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_cot_plus_operator(openai_completion):
|
|||
)
|
||||
thought = thought1 + thought2
|
||||
output = thought(word="hello", language="French")
|
||||
assert output == {
|
||||
assert output.content == {
|
||||
"word": "hello",
|
||||
"language": "French",
|
||||
"translated": "Bonjour",
|
||||
|
@ -86,7 +86,7 @@ def test_cot_manual(openai_completion):
|
|||
)
|
||||
thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
|
||||
output = thought(word="hello", language="French")
|
||||
assert output == {
|
||||
assert output.content == {
|
||||
"word": "hello",
|
||||
"language": "French",
|
||||
"translated": "Bonjour",
|
||||
|
@ -120,7 +120,7 @@ def test_cot_with_termination_callback(openai_completion):
|
|||
terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False,
|
||||
)
|
||||
output = thought(word="hallo", language="French")
|
||||
assert output == {
|
||||
assert output.content == {
|
||||
"word": "hallo",
|
||||
"language": "French",
|
||||
"translated": "Bonjour",
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.embeddings.cohere import CohereEmbdeddings
|
||||
from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||
|
@ -26,8 +27,9 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
|
|||
)
|
||||
output = model("Hello world")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], list)
|
||||
assert isinstance(output[0][0], float)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
|
@ -44,8 +46,9 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
|
|||
)
|
||||
output = model(["Hello world", "Goodbye world"])
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], list)
|
||||
assert isinstance(output[0][0], float)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
|
@ -68,8 +71,9 @@ def test_huggingface_embddings(
|
|||
|
||||
output = model("Hello World")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], list)
|
||||
assert isinstance(output[0][0], float)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
sentence_transformers_init.assert_called()
|
||||
langchain_huggingface_embedding_call.assert_called()
|
||||
|
||||
|
@ -85,6 +89,7 @@ def test_cohere_embeddings(langchain_cohere_embedding_call):
|
|||
|
||||
output = model("Hello World")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], list)
|
||||
assert isinstance(output[0][0], float)
|
||||
assert isinstance(output[0], Document)
|
||||
assert isinstance(output[0].embedding, list)
|
||||
assert isinstance(output[0].embedding[0], float)
|
||||
langchain_cohere_embedding_call.assert_called()
|
||||
|
|
Loading…
Reference in New Issue
Block a user