Enforce all IO objects to be subclassed from Document (#88)
* enforce Document as IO * Separate rerankers, splitters and extractors (#85) * partially refractor importing * add text to embedding outputs --------- Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
parent
2186c5558f
commit
8e0779a22d
|
@ -1,4 +1,25 @@
|
||||||
from .component import BaseComponent
|
from .component import BaseComponent
|
||||||
from .schema import Document
|
from .schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
Document,
|
||||||
|
DocumentWithEmbedding,
|
||||||
|
ExtractorOutput,
|
||||||
|
HumanMessage,
|
||||||
|
LLMInterface,
|
||||||
|
RetrievedDocument,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["BaseComponent", "Document"]
|
__all__ = [
|
||||||
|
"BaseComponent",
|
||||||
|
"Document",
|
||||||
|
"DocumentWithEmbedding",
|
||||||
|
"BaseMessage",
|
||||||
|
"SystemMessage",
|
||||||
|
"AIMessage",
|
||||||
|
"HumanMessage",
|
||||||
|
"RetrievedDocument",
|
||||||
|
"LLMInterface",
|
||||||
|
"ExtractorOutput",
|
||||||
|
]
|
||||||
|
|
|
@ -2,6 +2,8 @@ from abc import abstractmethod
|
||||||
|
|
||||||
from theflow.base import Function
|
from theflow.base import Function
|
||||||
|
|
||||||
|
from kotaemon.base.schema import Document
|
||||||
|
|
||||||
|
|
||||||
class BaseComponent(Function):
|
class BaseComponent(Function):
|
||||||
"""A component is a class that can be used to compose a pipeline
|
"""A component is a class that can be used to compose a pipeline
|
||||||
|
@ -30,7 +32,6 @@ class BaseComponent(Function):
|
||||||
return self.__call__(self.inflow.flow())
|
return self.__call__(self.inflow.flow())
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||||
# enforce output type to be compatible with Document
|
|
||||||
"""Run the component."""
|
"""Run the component."""
|
||||||
...
|
...
|
||||||
|
|
|
@ -32,6 +32,8 @@ class Document(BaseDocument):
|
||||||
kwargs["content"] = kwargs["text"]
|
kwargs["content"] = kwargs["text"]
|
||||||
elif kwargs.get("embedding", None) is not None:
|
elif kwargs.get("embedding", None) is not None:
|
||||||
kwargs["content"] = kwargs["embedding"]
|
kwargs["content"] = kwargs["embedding"]
|
||||||
|
# default text indicating this document only contains embedding
|
||||||
|
kwargs["text"] = "<EMBEDDING>"
|
||||||
elif isinstance(content, Document):
|
elif isinstance(content, Document):
|
||||||
kwargs = content.dict()
|
kwargs = content.dict()
|
||||||
else:
|
else:
|
||||||
|
@ -65,6 +67,17 @@ class Document(BaseDocument):
|
||||||
return str(self.content)
|
return str(self.content)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentWithEmbedding(Document):
|
||||||
|
"""Subclass of Document which must contains embedding
|
||||||
|
|
||||||
|
Use this if you want to enforce component's IOs to must contain embedding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embedding: list[float], *args, **kwargs):
|
||||||
|
kwargs["embedding"] = embedding
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(Document):
|
class BaseMessage(Document):
|
||||||
def __add__(self, other: Any):
|
def __add__(self, other: Any):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -3,11 +3,8 @@ from typing import List, Optional
|
||||||
|
|
||||||
from theflow import SessionFunction
|
from theflow import SessionFunction
|
||||||
|
|
||||||
from kotaemon.base.schema import AIMessage, SystemMessage
|
from kotaemon.base import BaseComponent, LLMInterface
|
||||||
|
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from ..base import BaseComponent
|
|
||||||
from ..base.schema import LLMInterface
|
|
||||||
from ..llms.chats.base import BaseMessage, HumanMessage
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatBot(BaseComponent):
|
class BaseChatBot(BaseComponent):
|
||||||
|
|
|
@ -5,8 +5,9 @@ from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from ...base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from ...chatbot import BaseChatBot
|
from kotaemon.chatbot import BaseChatBot
|
||||||
|
|
||||||
from .base import DEFAULT_COMPONENT_BY_TYPES
|
from .base import DEFAULT_COMPONENT_BY_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,14 +6,14 @@ from typing import Type
|
||||||
from langchain.schema.embeddings import Embeddings as LCEmbeddings
|
from langchain.schema.embeddings import Embeddings as LCEmbeddings
|
||||||
from theflow import Param
|
from theflow import Param
|
||||||
|
|
||||||
from ..base import BaseComponent, Document
|
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbeddings(BaseComponent):
|
class BaseEmbeddings(BaseComponent):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(
|
def run(
|
||||||
self, text: str | list[str] | Document | list[Document]
|
self, text: str | list[str] | Document | list[Document]
|
||||||
) -> list[list[float]]:
|
) -> list[DocumentWithEmbedding]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ class LangchainEmbeddings(BaseEmbeddings):
|
||||||
def agent(self):
|
def agent(self):
|
||||||
return self._lc_class(**self._kwargs)
|
return self._lc_class(**self._kwargs)
|
||||||
|
|
||||||
def run(self, text) -> list[list[float]]:
|
def run(self, text):
|
||||||
input_: list[str] = []
|
input_: list[str] = []
|
||||||
if not isinstance(text, list):
|
if not isinstance(text, list):
|
||||||
text = [text]
|
text = [text]
|
||||||
|
@ -58,4 +58,9 @@ class LangchainEmbeddings(BaseEmbeddings):
|
||||||
f"Invalid input type {type(item)}, should be str or Document"
|
f"Invalid input type {type(item)}, should be str or Document"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.agent.embed_documents(input_)
|
embeddings = self.agent.embed_documents(input_)
|
||||||
|
|
||||||
|
return [
|
||||||
|
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
|
||||||
|
for each_text, each_embedding in zip(input_, embeddings)
|
||||||
|
]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Sequence, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
from llama_index.node_parser.interface import NodeParser
|
from llama_index.node_parser.interface import NodeParser
|
||||||
|
|
||||||
|
@ -20,9 +20,9 @@ class DocTransformer(BaseComponent):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: list[Document],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Sequence[Document]:
|
) -> list[Document]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,9 +62,9 @@ class LlamaIndexMixin:
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: list[Document],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Sequence[Document]:
|
) -> list[Document]:
|
||||||
"""Run Llama-index node parser and convert the output to Document from
|
"""Run Llama-index node parser and convert the output to Document from
|
||||||
kotaemon
|
kotaemon
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,10 +6,7 @@ from typing import Type
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from theflow.base import Param
|
from theflow.base import Param
|
||||||
|
|
||||||
from kotaemon.base.schema import BaseMessage, HumanMessage
|
from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface
|
||||||
|
|
||||||
from ...base import BaseComponent
|
|
||||||
from ...base.schema import LLMInterface
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Callable, List
|
||||||
|
|
||||||
from theflow import Function, Node, Param
|
from theflow import Function, Node, Param
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent, Document
|
||||||
from kotaemon.llms import LLM, BasePromptComponent
|
from kotaemon.llms import LLM, BasePromptComponent
|
||||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
@ -65,15 +65,19 @@ class Thought(BaseComponent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt: str = Param(
|
prompt: str = Param(
|
||||||
help="The prompt template string. This prompt template has Python-like "
|
help=(
|
||||||
"variable placeholders, that then will be subsituted with real values when "
|
"The prompt template string. This prompt template has Python-like "
|
||||||
"this component is executed"
|
"variable placeholders, that then will be subsituted with real values when "
|
||||||
|
"this component is executed"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||||
post_process: Function = Node(
|
post_process: Function = Node(
|
||||||
help="The function post-processor that post-processes LLM output prediction ."
|
help=(
|
||||||
"It should take a string as input (this is the LLM output text) and return "
|
"The function post-processor that post-processes LLM output prediction ."
|
||||||
"a dictionary, where the key should"
|
"It should take a string as input (this is the LLM output text) and return "
|
||||||
|
"a dictionary, where the key should"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@Node.auto(depends_on="prompt")
|
@Node.auto(depends_on="prompt")
|
||||||
|
@ -81,11 +85,13 @@ class Thought(BaseComponent):
|
||||||
"""Automatically wrap around param prompt. Can ignore"""
|
"""Automatically wrap around param prompt. Can ignore"""
|
||||||
return BasePromptComponent(self.prompt)
|
return BasePromptComponent(self.prompt)
|
||||||
|
|
||||||
def run(self, **kwargs) -> dict:
|
def run(self, **kwargs) -> Document:
|
||||||
"""Run the chain of thought"""
|
"""Run the chain of thought"""
|
||||||
prompt = self.prompt_template(**kwargs).text
|
prompt = self.prompt_template(**kwargs).text
|
||||||
response = self.llm(prompt).text
|
response = self.llm(prompt).text
|
||||||
return self.post_process(response)
|
response = self.post_process(response)
|
||||||
|
|
||||||
|
return Document(response)
|
||||||
|
|
||||||
def get_variables(self) -> List[str]:
|
def get_variables(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
@ -146,7 +152,7 @@ class ManualSequentialChainOfThought(BaseComponent):
|
||||||
help="Callback on terminate condition. Default to always return False",
|
help="Callback on terminate condition. Default to always return False",
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, **kwargs) -> dict:
|
def run(self, **kwargs) -> Document:
|
||||||
"""Run the manual chain of thought"""
|
"""Run the manual chain of thought"""
|
||||||
|
|
||||||
inputs = deepcopy(kwargs)
|
inputs = deepcopy(kwargs)
|
||||||
|
@ -156,11 +162,11 @@ class ManualSequentialChainOfThought(BaseComponent):
|
||||||
self._prepare_child(thought, f"thought{idx}")
|
self._prepare_child(thought, f"thought{idx}")
|
||||||
|
|
||||||
output = thought(**inputs)
|
output = thought(**inputs)
|
||||||
inputs.update(output)
|
inputs.update(output.content)
|
||||||
if self.terminate(inputs):
|
if self.terminate(inputs):
|
||||||
break
|
break
|
||||||
|
|
||||||
return inputs
|
return Document(inputs)
|
||||||
|
|
||||||
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
|
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
|
||||||
return ManualSequentialChainOfThought(
|
return ManualSequentialChainOfThought(
|
||||||
|
|
|
@ -3,12 +3,10 @@ from __future__ import annotations
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||||
|
from kotaemon.embeddings import BaseEmbeddings
|
||||||
from kotaemon.indices.rankings import BaseReranking
|
from kotaemon.indices.rankings import BaseReranking
|
||||||
|
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||||
from ..base import BaseComponent
|
|
||||||
from ..base.schema import Document, RetrievedDocument
|
|
||||||
from ..embeddings import BaseEmbeddings
|
|
||||||
from ..storages import BaseDocumentStore, BaseVectorStore
|
|
||||||
|
|
||||||
VECTOR_STORE_FNAME = "vectorstore"
|
VECTOR_STORE_FNAME = "vectorstore"
|
||||||
DOC_STORE_FNAME = "docstore"
|
DOC_STORE_FNAME = "docstore"
|
||||||
|
@ -45,7 +43,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
||||||
"retrieve the documents"
|
"retrieve the documents"
|
||||||
)
|
)
|
||||||
|
|
||||||
emb: list[float] = self.embedding(text)[0]
|
emb: list[float] = self.embedding(text)[0].embedding
|
||||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
||||||
docs = self.doc_store.get(ids)
|
docs = self.doc_store.get(ids)
|
||||||
result = [
|
result = [
|
||||||
|
|
|
@ -6,7 +6,7 @@ from llama_index.vector_stores.types import BasePydanticVectorStore
|
||||||
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
||||||
from llama_index.vector_stores.types import VectorStoreQuery
|
from llama_index.vector_stores.types import VectorStoreQuery
|
||||||
|
|
||||||
from ...base import Document
|
from kotaemon.base import Document, DocumentWithEmbedding
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorStore(ABC):
|
class BaseVectorStore(ABC):
|
||||||
|
@ -17,7 +17,7 @@ class BaseVectorStore(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: List[List[float]],
|
embeddings: List[List[float]] | List[DocumentWithEmbedding],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -104,11 +104,16 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: List[List[float]],
|
embeddings: List[List[float]] | List[DocumentWithEmbedding],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
nodes = [Document(embedding=embedding) for embedding in embeddings]
|
if isinstance(embeddings[0], list):
|
||||||
|
nodes = [
|
||||||
|
DocumentWithEmbedding(embedding=embedding) for embedding in embeddings
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
nodes = embeddings # type: ignore
|
||||||
if metadatas is not None:
|
if metadatas is not None:
|
||||||
for node, metadata in zip(nodes, metadatas):
|
for node, metadata in zip(nodes, metadatas):
|
||||||
node.metadata = metadata
|
node.metadata = metadata
|
||||||
|
@ -119,10 +124,10 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
|
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return self._client.add(nodes=nodes) # type: ignore
|
return self._client.add(nodes=nodes)
|
||||||
|
|
||||||
def add_from_docs(self, docs: List[Document]):
|
def add_from_docs(self, docs: List[Document]):
|
||||||
return self._client.add(nodes=docs) # type: ignore
|
return self._client.add(nodes=docs)
|
||||||
|
|
||||||
def delete(self, ids: List[str], **kwargs):
|
def delete(self, ids: List[str], **kwargs):
|
||||||
for id_ in ids:
|
for id_ in ids:
|
||||||
|
|
|
@ -56,7 +56,7 @@ def test_cot_plus_operator(openai_completion):
|
||||||
)
|
)
|
||||||
thought = thought1 + thought2
|
thought = thought1 + thought2
|
||||||
output = thought(word="hello", language="French")
|
output = thought(word="hello", language="French")
|
||||||
assert output == {
|
assert output.content == {
|
||||||
"word": "hello",
|
"word": "hello",
|
||||||
"language": "French",
|
"language": "French",
|
||||||
"translated": "Bonjour",
|
"translated": "Bonjour",
|
||||||
|
@ -86,7 +86,7 @@ def test_cot_manual(openai_completion):
|
||||||
)
|
)
|
||||||
thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
|
thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
|
||||||
output = thought(word="hello", language="French")
|
output = thought(word="hello", language="French")
|
||||||
assert output == {
|
assert output.content == {
|
||||||
"word": "hello",
|
"word": "hello",
|
||||||
"language": "French",
|
"language": "French",
|
||||||
"translated": "Bonjour",
|
"translated": "Bonjour",
|
||||||
|
@ -120,7 +120,7 @@ def test_cot_with_termination_callback(openai_completion):
|
||||||
terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False,
|
terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False,
|
||||||
)
|
)
|
||||||
output = thought(word="hallo", language="French")
|
output = thought(word="hallo", language="French")
|
||||||
assert output == {
|
assert output.content == {
|
||||||
"word": "hallo",
|
"word": "hallo",
|
||||||
"language": "French",
|
"language": "French",
|
||||||
"translated": "Bonjour",
|
"translated": "Bonjour",
|
||||||
|
|
|
@ -2,6 +2,7 @@ import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
from kotaemon.embeddings.cohere import CohereEmbdeddings
|
from kotaemon.embeddings.cohere import CohereEmbdeddings
|
||||||
from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings
|
from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||||
|
@ -26,8 +27,9 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
|
||||||
)
|
)
|
||||||
output = model("Hello world")
|
output = model("Hello world")
|
||||||
assert isinstance(output, list)
|
assert isinstance(output, list)
|
||||||
assert isinstance(output[0], list)
|
assert isinstance(output[0], Document)
|
||||||
assert isinstance(output[0][0], float)
|
assert isinstance(output[0].embedding, list)
|
||||||
|
assert isinstance(output[0].embedding[0], float)
|
||||||
openai_embedding_call.assert_called()
|
openai_embedding_call.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,8 +46,9 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
|
||||||
)
|
)
|
||||||
output = model(["Hello world", "Goodbye world"])
|
output = model(["Hello world", "Goodbye world"])
|
||||||
assert isinstance(output, list)
|
assert isinstance(output, list)
|
||||||
assert isinstance(output[0], list)
|
assert isinstance(output[0], Document)
|
||||||
assert isinstance(output[0][0], float)
|
assert isinstance(output[0].embedding, list)
|
||||||
|
assert isinstance(output[0].embedding[0], float)
|
||||||
openai_embedding_call.assert_called()
|
openai_embedding_call.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,8 +71,9 @@ def test_huggingface_embddings(
|
||||||
|
|
||||||
output = model("Hello World")
|
output = model("Hello World")
|
||||||
assert isinstance(output, list)
|
assert isinstance(output, list)
|
||||||
assert isinstance(output[0], list)
|
assert isinstance(output[0], Document)
|
||||||
assert isinstance(output[0][0], float)
|
assert isinstance(output[0].embedding, list)
|
||||||
|
assert isinstance(output[0].embedding[0], float)
|
||||||
sentence_transformers_init.assert_called()
|
sentence_transformers_init.assert_called()
|
||||||
langchain_huggingface_embedding_call.assert_called()
|
langchain_huggingface_embedding_call.assert_called()
|
||||||
|
|
||||||
|
@ -85,6 +89,7 @@ def test_cohere_embeddings(langchain_cohere_embedding_call):
|
||||||
|
|
||||||
output = model("Hello World")
|
output = model("Hello World")
|
||||||
assert isinstance(output, list)
|
assert isinstance(output, list)
|
||||||
assert isinstance(output[0], list)
|
assert isinstance(output[0], Document)
|
||||||
assert isinstance(output[0][0], float)
|
assert isinstance(output[0].embedding, list)
|
||||||
|
assert isinstance(output[0].embedding[0], float)
|
||||||
langchain_cohere_embedding_call.assert_called()
|
langchain_cohere_embedding_call.assert_called()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user