Simplify the BaseComponent inteface (#64)

This change remove `BaseComponent`'s:

- run_raw
- run_batch_raw
- run_document
- run_batch_document
- is_document
- is_batch

Each component is expected to support multiple types of inputs and a single type of output. Since we want the component to work out-of-the-box with both standardized and customized use cases, supporting multiple types of inputs are expected. At the same time, to reduce the complexity of understanding how to use a component, we restrict a component to only have a single output type.

To accommodate these changes, we also refactor some components to remove their run_raw, run_batch_raw... methods, and to decide the common output interface for those components.

Tests are updated accordingly.

Commit changes:

* Add kwargs to vector store's query
* Simplify the BaseComponent
* Update tests
* Remove support for Python 3.8 and 3.9
* Bump version 0.3.0
* Fix github PR caching still use old environment after bumping version

---------

Co-authored-by: ian <ian@cinnamon.is>
This commit is contained in:
Nguyen Trung Duc (john) 2023-11-13 15:10:18 +07:00 committed by GitHub
parent 6095526dc7
commit d79b3744cb
25 changed files with 280 additions and 458 deletions

View File

@ -16,7 +16,7 @@ jobs:
shell: ${{ matrix.shell }} shell: ${{ matrix.shell }}
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.10", "3.11"]
include: include:
- os: ubuntu-latest - os: ubuntu-latest
shell: bash shell: bash
@ -81,7 +81,7 @@ jobs:
steps.check-cache-hit.outputs.check != 'true' steps.check-cache-hit.outputs.check != 'true'
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e .[dev] pip install --ignore-installed -e .[dev]
- name: New dependencies cache for key ${{ steps.restore-dependencies.outputs.cache-primary-key }} - name: New dependencies cache for key ${{ steps.restore-dependencies.outputs.cache-primary-key }}
if: | if: |

View File

@ -22,4 +22,4 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "0.2.0" __version__ = "0.3.0"

View File

@ -1,70 +0,0 @@
from abc import abstractmethod
from theflow.base import Compose
class BaseComponent(Compose):
"""Base class for component
A component is a class that can be used to compose a pipeline. To use the
component, you should implement the following methods:
- run_raw: run on raw input
- run_batch_raw: run on batch of raw input
- run_document: run on document
- run_batch_document: run on batch of documents
- is_document: check if input is document
- is_batch: check if input is batch
"""
inflow = None
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
return self.__call__(self.inflow.flow())
@abstractmethod
def run_raw(self, *args, **kwargs):
...
@abstractmethod
def run_batch_raw(self, *args, **kwargs):
...
@abstractmethod
def run_document(self, *args, **kwargs):
...
@abstractmethod
def run_batch_document(self, *args, **kwargs):
...
@abstractmethod
def is_document(self, *args, **kwargs) -> bool:
...
@abstractmethod
def is_batch(self, *args, **kwargs) -> bool:
...
def run(self, *args, **kwargs):
"""Run the component."""
is_document = self.is_document(*args, **kwargs)
is_batch = self.is_batch(*args, **kwargs)
if is_document and is_batch:
return self.run_batch_document(*args, **kwargs)
elif is_document and not is_batch:
return self.run_document(*args, **kwargs)
elif not is_document and is_batch:
return self.run_batch_raw(*args, **kwargs)
else:
return self.run_raw(*args, **kwargs)

View File

@ -0,0 +1,3 @@
from .component import BaseComponent
__all__ = ["BaseComponent"]

View File

@ -0,0 +1,35 @@
from abc import abstractmethod
from theflow.base import Compose
class BaseComponent(Compose):
"""A component is a class that can be used to compose a pipeline
Benefits of component:
- Auto caching, logging
- Allow deployment
For each component, the spirit is:
- Tolerate multiple input types, e.g. str, Document, List[str], List[Document]
- Enforce single output type. Hence, the output type of a component should be
as generic as possible.
"""
inflow = None
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
return self.__call__(self.inflow.flow())
@abstractmethod
def run(self, *args, **kwargs):
"""Run the component."""
...

View File

@ -70,7 +70,7 @@ class SimpleLinearPipeline(BaseComponent):
prompt = self.prompt(**prompt_kwargs) prompt = self.prompt(**prompt_kwargs)
llm_output = self.llm(prompt.text, **llm_kwargs) llm_output = self.llm(prompt.text, **llm_kwargs)
if self.post_processor is not None: if self.post_processor is not None:
final_output = self.post_processor(llm_output, **post_processor_kwargs) final_output = self.post_processor(llm_output, **post_processor_kwargs)[0]
else: else:
final_output = llm_output final_output = llm_output
@ -143,7 +143,7 @@ class GatedLinearPipeline(SimpleLinearPipeline):
if condition_text is None: if condition_text is None:
raise ValueError("`condition_text` must be provided") raise ValueError("`condition_text` must be provided")
if self.condition(condition_text): if self.condition(condition_text)[0]:
return super().run( return super().run(
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
post_processor_kwargs=post_processor_kwargs, post_processor_kwargs=post_processor_kwargs,

View File

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import List, Type from typing import Type
from langchain.schema.embeddings import Embeddings as LCEmbeddings from langchain.schema.embeddings import Embeddings as LCEmbeddings
from theflow import Param from theflow import Param
@ -10,33 +12,11 @@ from ..documents.base import Document
class BaseEmbeddings(BaseComponent): class BaseEmbeddings(BaseComponent):
@abstractmethod @abstractmethod
def run_raw(self, text: str) -> List[float]: def run(
self, text: str | list[str] | Document | list[Document]
) -> list[list[float]]:
... ...
@abstractmethod
def run_batch_raw(self, text: List[str]) -> List[List[float]]:
...
@abstractmethod
def run_document(self, text: Document) -> List[float]:
...
@abstractmethod
def run_batch_document(self, text: List[Document]) -> List[List[float]]:
...
def is_document(self, text) -> bool:
if isinstance(text, Document):
return True
elif isinstance(text, List) and isinstance(text[0], Document):
return True
return False
def is_batch(self, text) -> bool:
if isinstance(text, list):
return True
return False
class LangchainEmbeddings(BaseEmbeddings): class LangchainEmbeddings(BaseEmbeddings):
_lc_class: Type[LCEmbeddings] _lc_class: Type[LCEmbeddings]
@ -64,14 +44,19 @@ class LangchainEmbeddings(BaseEmbeddings):
def agent(self): def agent(self):
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)
def run_raw(self, text: str) -> List[float]: def run(self, text) -> list[list[float]]:
return self.agent.embed_query(text) # type: ignore input_: list[str] = []
if not isinstance(text, list):
text = [text]
def run_batch_raw(self, text: List[str]) -> List[List[float]]: for item in text:
return self.agent.embed_documents(text) # type: ignore if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
def run_document(self, text: Document) -> List[float]: return self.agent.embed_documents(input_)
return self.agent.embed_query(text.text) # type: ignore
def run_batch_document(self, text: List[Document]) -> List[List[float]]:
return self.agent.embed_documents([each.text for each in text]) # type: ignore

View File

@ -1,13 +1,16 @@
from typing import List, Type, TypeVar from __future__ import annotations
from langchain.schema.language_model import BaseLanguageModel import logging
from typing import Type
from langchain.chat_models.base import BaseChatModel
from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.messages import BaseMessage, HumanMessage
from theflow.base import Param from theflow.base import Param
from ...base import BaseComponent from ...base import BaseComponent
from ..base import LLMInterface from ..base import LLMInterface
Message = TypeVar("Message", bound=BaseMessage) logger = logging.getLogger(__name__)
class ChatLLM(BaseComponent): class ChatLLM(BaseComponent):
@ -25,7 +28,7 @@ class ChatLLM(BaseComponent):
class LangchainChatLLM(ChatLLM): class LangchainChatLLM(ChatLLM):
_lc_class: Type[BaseLanguageModel] _lc_class: Type[BaseChatModel]
def __init__(self, **params): def __init__(self, **params):
if self._lc_class is None: if self._lc_class is None:
@ -41,60 +44,62 @@ class LangchainChatLLM(ChatLLM):
super().__init__(**params) super().__init__(**params)
@Param.auto(cache=False) @Param.auto(cache=False)
def agent(self) -> BaseLanguageModel: def agent(self) -> BaseChatModel:
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)
def run_raw(self, text: str, **kwargs) -> LLMInterface: def run(
message = HumanMessage(content=text) self, messages: str | BaseMessage | list[BaseMessage], **kwargs
return self.run_document([message], **kwargs) ) -> LLMInterface:
"""Generate response from messages
def run_batch_raw(self, text: List[str], **kwargs) -> List[LLMInterface]: Args:
inputs = [[HumanMessage(content=each)] for each in text] messages: history of messages to generate response from
return self.run_batch_document(inputs, **kwargs) **kwargs: additional arguments to pass to the langchain chat model
def run_document(self, text: List[Message], **kwargs) -> LLMInterface: Returns:
pred = self.agent.generate([text], **kwargs) # type: ignore LLMInterface: generated response
"""
input_: list[BaseMessage] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
pred = self.agent.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]] all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface( return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "", text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text, candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], completion_tokens=completion_tokens,
total_tokens=pred.llm_output["token_usage"]["total_tokens"], total_tokens=total_tokens,
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"], prompt_tokens=prompt_tokens,
logits=[], logits=[],
) )
def run_batch_document(
self, text: List[List[Message]], **kwargs
) -> List[LLMInterface]:
outputs = []
for each_text in text:
outputs.append(self.run_document(each_text, **kwargs))
return outputs
def is_document(self, text, **kwargs) -> bool:
if isinstance(text, str):
return False
elif isinstance(text, List) and isinstance(text[0], str):
return False
return True
def is_batch(self, text, **kwargs) -> bool:
if isinstance(text, str):
return False
elif isinstance(text, List):
if isinstance(text[0], BaseMessage):
return False
return True
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in self._lc_class.__fields__: if name in self._lc_class.__fields__:
self._kwargs[name] = value
setattr(self.agent, name, value) setattr(self.agent, name, value)
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattr__(self, name): def __getattr__(self, name):
if name in self._lc_class.__fields__: if name in self._lc_class.__fields__:
getattr(self.agent, name) return getattr(self.agent, name)
else:
super().__getattr__(name) return super().__getattr__(name) # type: ignore

View File

@ -1,18 +1,21 @@
from typing import List, Type import logging
from typing import Type
from langchain.schema.language_model import BaseLanguageModel from langchain.llms.base import BaseLLM
from theflow.base import Param from theflow.base import Param
from ...base import BaseComponent from ...base import BaseComponent
from ..base import LLMInterface from ..base import LLMInterface
logger = logging.getLogger(__name__)
class LLM(BaseComponent): class LLM(BaseComponent):
pass pass
class LangchainLLM(LLM): class LangchainLLM(LLM):
_lc_class: Type[BaseLanguageModel] _lc_class: Type[BaseLLM]
def __init__(self, **params): def __init__(self, **params):
if self._lc_class is None: if self._lc_class is None:
@ -31,38 +34,33 @@ class LangchainLLM(LLM):
def agent(self): def agent(self):
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)
def run_raw(self, text: str) -> LLMInterface: def run(self, text: str) -> LLMInterface:
pred = self.agent.generate([text]) pred = self.agent.generate([text])
all_text = [each.text for each in pred.generations[0]] all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface( return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "", text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text, candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], completion_tokens=completion_tokens,
total_tokens=pred.llm_output["token_usage"]["total_tokens"], total_tokens=total_tokens,
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"], prompt_tokens=prompt_tokens,
logits=[], logits=[],
) )
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
outputs = []
for each_text in text:
outputs.append(self.run_raw(each_text))
return outputs
def run_document(self, text: str) -> LLMInterface:
return self.run_raw(text)
def run_batch_document(self, text: List[str]) -> List[LLMInterface]:
return self.run_batch_raw(text)
def is_document(self, text) -> bool:
return False
def is_batch(self, text) -> bool:
return False if isinstance(text, str) else True
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in self._lc_class.__fields__: if name in self._lc_class.__fields__:
self._kwargs[name] = value
setattr(self.agent, name, value) setattr(self.agent, name, value)
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import List, Union
from theflow import Node, Param from theflow import Node, Param
@ -26,44 +27,34 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
doc_store: Param[BaseDocumentStore] = Param() doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
def run_raw(self, text: str) -> None: def run(self, text: str | list[str] | Document | list[Document]) -> None:
document = Document(text=text, id_=str(uuid.uuid4())) input_: list[Document] = []
self.run_batch_document([document]) if not isinstance(text, list):
text = [text]
def run_batch_raw(self, text: List[str]) -> None: for item in text:
documents = [Document(text=t, id_=str(uuid.uuid4())) for t in text] if isinstance(item, str):
self.run_batch_document(documents) input_.append(Document(text=item, id_=str(uuid.uuid4())))
elif isinstance(item, Document):
input_.append(item)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
def run_document(self, text: Document) -> None: embeddings = self.embedding(input_)
self.run_batch_document([text])
def run_batch_document(self, text: List[Document]) -> None:
embeddings = self.embedding(text)
self.vector_store.add( self.vector_store.add(
embeddings=embeddings, embeddings=embeddings,
ids=[t.id_ for t in text], ids=[t.id_ for t in input_],
) )
if self.doc_store: if self.doc_store:
self.doc_store.add(text) self.doc_store.add(input_)
def is_document(self, text) -> bool:
if isinstance(text, Document):
return True
elif isinstance(text, List) and isinstance(text[0], Document):
return True
return False
def is_batch(self, text) -> bool:
if isinstance(text, list):
return True
return False
def save( def save(
self, self,
path: Union[str, Path], path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME, vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME,
): ):
@ -80,7 +71,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
def load( def load(
self, self,
path: Union[str, Path], path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME, vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME,
): ):

View File

@ -1,6 +1,6 @@
from abc import abstractmethod from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import List, Union
from theflow import Node, Param from theflow import Node, Param
@ -14,31 +14,7 @@ VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore" DOC_STORE_FNAME = "docstore"
class BaseRetrieval(BaseComponent): class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
"""Define the base interface of a retrieval pipeline"""
@abstractmethod
def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]:
...
@abstractmethod
def run_batch_raw(
self, text: List[str], top_k: int = 1
) -> List[List[RetrievedDocument]]:
...
@abstractmethod
def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]:
...
@abstractmethod
def run_batch_document(
self, text: List[Document], top_k: int = 1
) -> List[List[RetrievedDocument]]:
...
class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
"""Retrieve list of documents from vector store""" """Retrieve list of documents from vector store"""
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
@ -46,53 +22,33 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]: def run(self, text: str | Document, top_k: int = 1) -> list[RetrievedDocument]:
return self.run_batch_raw([text], top_k=top_k)[0] """Retrieve a list of documents from vector store
def run_batch_raw( Args:
self, text: List[str], top_k: int = 1 text: the text to retrieve similar documents
) -> List[List[RetrievedDocument]]:
Returns:
list[RetrievedDocument]: list of retrieved documents
"""
if self.doc_store is None: if self.doc_store is None:
raise ValueError( raise ValueError(
"doc_store is not provided. Please provide a doc_store to " "doc_store is not provided. Please provide a doc_store to "
"retrieve the documents" "retrieve the documents"
) )
result = [] emb: list[float] = self.embedding(text)[0]
for each_text in text: _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
emb = self.embedding(each_text) docs = self.doc_store.get(ids)
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k) result = [
docs = self.doc_store.get(ids) RetrievedDocument(**doc.to_dict(), score=score)
each_result = [ for doc, score in zip(docs, scores)
RetrievedDocument(**doc.to_dict(), score=score) ]
for doc, score in zip(docs, scores)
]
result.append(each_result)
return result return result
def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]:
return self.run_raw(text.text, top_k)
def run_batch_document(
self, text: List[Document], top_k: int = 1
) -> List[List[RetrievedDocument]]:
return self.run_batch_raw(text=[t.text for t in text], top_k=top_k)
def is_document(self, text, *args, **kwargs) -> bool:
if isinstance(text, Document):
return True
elif isinstance(text, List) and isinstance(text[0], Document):
return True
return False
def is_batch(self, text, *args, **kwargs) -> bool:
if isinstance(text, list):
return True
return False
def save( def save(
self, self,
path: Union[str, Path], path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME, vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME,
): ):
@ -109,7 +65,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
def load( def load(
self, self,
path: Union[str, Path], path: str | Path,
vectorstore_fname: str = VECTOR_STORE_FNAME, vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME, docstore_fname: str = DOC_STORE_FNAME,
): ):

View File

@ -92,7 +92,7 @@ class BaseTool(BaseComponent):
"""Convert this tool to Langchain format to use with its agent""" """Convert this tool to Langchain format to use with its agent"""
return LCTool(name=self.name, description=self.description, func=self.run) return LCTool(name=self.name, description=self.description, func=self.run)
def run_raw( def run(
self, self,
tool_input: Union[str, Dict], tool_input: Union[str, Dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
@ -110,23 +110,6 @@ class BaseTool(BaseComponent):
else: else:
return observation return observation
def run_document(self, *args, **kwargs):
pass
def run_batch_raw(self, *args, **kwargs):
pass
def run_batch_document(self, *args, **kwargs):
pass
def is_document(self, *args, **kwargs) -> bool:
"""Tool does not support processing document"""
return False
def is_batch(self, *args, **kwargs) -> bool:
"""Tool does not support processing batch"""
return False
@classmethod @classmethod
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool": def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
"""Wrapper for Langchain Tool""" """Wrapper for Langchain Tool"""

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import re import re
from typing import Callable, Dict, List, Union from typing import Callable
from theflow import Param from theflow import Param
@ -12,7 +14,7 @@ class ExtractorOutput(Document):
Represents the output of an extractor. Represents the output of an extractor.
""" """
matches: List[str] matches: list[str]
class RegexExtractor(BaseComponent): class RegexExtractor(BaseComponent):
@ -28,18 +30,18 @@ class RegexExtractor(BaseComponent):
class Config: class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False} middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: List[str] pattern: list[str]
output_map: Union[Dict[str, str], Callable[[str], str]] = Param( output_map: dict[str, str] | Callable[[str], str] = Param(
default_callback=lambda *_: {} default_callback=lambda *_: {}
) )
def __init__(self, pattern: Union[str, List[str]], **kwargs): def __init__(self, pattern: str | list[str], **kwargs):
if isinstance(pattern, str): if isinstance(pattern, str):
pattern = [pattern] pattern = [pattern]
super().__init__(pattern=pattern, **kwargs) super().__init__(pattern=pattern, **kwargs)
@staticmethod @staticmethod
def run_raw_static(pattern: str, text: str) -> List[str]: def run_raw_static(pattern: str, text: str) -> list[str]:
""" """
Finds all non-overlapping occurrences of a pattern in a string. Finds all non-overlapping occurrences of a pattern in a string.
@ -86,9 +88,9 @@ class RegexExtractor(BaseComponent):
Returns: Returns:
ExtractorOutput: The processed output as a list of ExtractorOutput. ExtractorOutput: The processed output as a list of ExtractorOutput.
""" """
output = sum( output: list[str] = sum(
[self.run_raw_static(p, text) for p in self.pattern], [] [self.run_raw_static(p, text) for p in self.pattern], []
) # type: List[str] )
output = [self.map_output(text, self.output_map) for text in output] output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput( return ExtractorOutput(
@ -97,100 +99,48 @@ class RegexExtractor(BaseComponent):
metadata={"origin": "RegexExtractor"}, metadata={"origin": "RegexExtractor"},
) )
def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]: def run(
""" self, text: str | list[str] | Document | list[Document]
Runs a batch of raw text inputs through the `run_raw()` method and returns the ) -> list[ExtractorOutput]:
output for each input. """Match the input against a pattern and return the output for each input
Parameters: Parameters:
text_batch (List[str]): A list of raw text inputs to process. text: contains the input string to be processed
Returns: Returns:
List[ExtractorOutput]: A list containing the output for each input in the A list contains the output ExtractorOutput for each input
batch.
"""
batch_output = [self.run_raw(each_text) for each_text in text_batch]
return batch_output
def run_document(self, document: Document) -> ExtractorOutput:
"""
Run the document through the regex extractor and return an extracted document.
Args:
document (Document): The input document.
Returns:
ExtractorOutput: The extracted content.
"""
return self.run_raw(document.text)
def run_batch_document(
self, document_batch: List[Document]
) -> List[ExtractorOutput]:
"""
Runs a batch of documents through the `run_document` function and returns the
output for each document.
Parameters:
document_batch (List[Document]): A list of Document objects representing the
batch of documents to process.
Returns:
List[ExtractorOutput]: A list contains the output ExtractorOutput for each
input Document in the batch.
Example: Example:
document1 = Document(...) document1 = Document(...)
document2 = Document(...) document2 = Document(...)
document_batch = [document1, document2] document_batch = [document1, document2]
batch_output = self.run_batch_document(document_batch) batch_output = self(document_batch)
# batch_output will be [output1_document1, output1_document2] # batch_output will be [output1_document1, output1_document2]
""" """
# TODO: this conversion seems common
input_: list[str] = []
if not isinstance(text, list):
text = [text]
batch_output = [ for item in text:
self.run_document(each_document) for each_document in document_batch if isinstance(item, str):
] input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
return batch_output output = []
for each_input in input_:
output.append(self.run_raw(each_input))
def is_document(self, text) -> bool: return output
"""
Check if the given text is an instance of the Document class.
Args:
text: The text to check.
Returns:
bool: True if the text is an instance of Document, False otherwise.
"""
if isinstance(text, Document):
return True
return False
def is_batch(self, text) -> bool:
"""
Check if the given text is a batch of documents.
Parameters:
text (List): The text to be checked.
Returns:
bool: True if the text is a batch of documents, False otherwise.
"""
if not isinstance(text, List):
return False
if len(set(self.is_document(each_text) for each_text in text)) <= 1:
return True
return False
class FirstMatchRegexExtractor(RegexExtractor): class FirstMatchRegexExtractor(RegexExtractor):
pattern: List[str] pattern: list[str]
def run_raw(self, text: str) -> ExtractorOutput: def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern: for p in self.pattern:

View File

@ -174,23 +174,5 @@ class BasePromptComponent(BaseComponent):
text = self.template.populate(**prepared_kwargs) text = self.template.populate(**prepared_kwargs)
return Document(text=text, metadata={"origin": "PromptComponent"}) return Document(text=text, metadata={"origin": "PromptComponent"})
def run_raw(self, *args, **kwargs):
pass
def run_batch_raw(self, *args, **kwargs):
pass
def run_document(self, *args, **kwargs):
pass
def run_batch_document(self, *args, **kwargs):
pass
def is_document(self, *args, **kwargs):
pass
def is_batch(self, *args, **kwargs):
pass
def flow(self): def flow(self):
return self.__call__() return self.__call__()

View File

View File

@ -59,6 +59,7 @@ class BaseVectorStore(ABC):
embedding: List[float], embedding: List[float],
top_k: int = 1, top_k: int = 1,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
**kwargs,
) -> Tuple[List[List[float]], List[float], List[str]]: ) -> Tuple[List[List[float]], List[float], List[str]]:
"""Return the top k most similar vector embeddings """Return the top k most similar vector embeddings

View File

@ -65,7 +65,7 @@ setuptools.setup(
], ],
}, },
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]}, entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
python_requires=">=3.8", python_requires=">=3.10",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",

View File

@ -1,4 +1,7 @@
from copy import deepcopy
import pytest import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.composite import ( from kotaemon.composite import (
GatedBranchingPipeline, GatedBranchingPipeline,
@ -10,6 +13,29 @@ from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent from kotaemon.prompt.base import BasePromptComponent
_openai_chat_completion_response = ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": "This is a test 123",
"finish_reason": "length",
"logprobs": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
@pytest.fixture @pytest.fixture
def mock_llm(): def mock_llm():
@ -19,7 +45,6 @@ def mock_llm():
openai_api_version="OPENAI_API_VERSION", openai_api_version="OPENAI_API_VERSION",
deployment_name="dummy-q2-gpt35", deployment_name="dummy-q2-gpt35",
temperature=0, temperature=0,
request_timeout=600,
) )
@ -61,11 +86,12 @@ def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_process
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline): def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch.object( openai_mocker = mocker.patch(
AzureChatOpenAI, "run", return_value="This is a test 123" "openai.resources.chat.completions.Completions.create",
return_value=_openai_chat_completion_response,
) )
result = mock_simple_linear_pipeline.run(value="abc") result = mock_simple_linear_pipeline(value="abc")
assert result.text == "123" assert result.text == "123"
assert openai_mocker.call_count == 1 assert openai_mocker.call_count == 1
@ -74,11 +100,12 @@ def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
def test_gated_linear_pipeline_run_positive( def test_gated_linear_pipeline_run_positive(
mocker, mock_gated_linear_pipeline_positive mocker, mock_gated_linear_pipeline_positive
): ):
openai_mocker = mocker.patch.object( openai_mocker = mocker.patch(
AzureChatOpenAI, "run", return_value="This is a test 123." "openai.resources.chat.completions.Completions.create",
return_value=_openai_chat_completion_response,
) )
result = mock_gated_linear_pipeline_positive.run( result = mock_gated_linear_pipeline_positive(
value="abc", condition_text="positive condition" value="abc", condition_text="positive condition"
) )
@ -89,11 +116,12 @@ def test_gated_linear_pipeline_run_positive(
def test_gated_linear_pipeline_run_negative( def test_gated_linear_pipeline_run_negative(
mocker, mock_gated_linear_pipeline_positive mocker, mock_gated_linear_pipeline_positive
): ):
openai_mocker = mocker.patch.object( openai_mocker = mocker.patch(
AzureChatOpenAI, "run", return_value="This is a test 123." "openai.resources.chat.completions.Completions.create",
return_value=_openai_chat_completion_response,
) )
result = mock_gated_linear_pipeline_positive.run( result = mock_gated_linear_pipeline_positive(
value="abc", condition_text="negative condition" value="abc", condition_text="negative condition"
) )
@ -102,14 +130,14 @@ def test_gated_linear_pipeline_run_negative(
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline): def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch.object( response0: ChatCompletion = _openai_chat_completion_response
AzureChatOpenAI, response1: ChatCompletion = deepcopy(_openai_chat_completion_response)
"run", response1.choices[0].message.content = "a quick brown fox"
side_effect=[ response2: ChatCompletion = deepcopy(_openai_chat_completion_response)
"This is a test 123.", response2.choices[0].message.content = "jumps over the lazy dog 456"
"a quick brown fox", openai_mocker = mocker.patch(
"jumps over the lazy dog 456", "openai.resources.chat.completions.Completions.create",
], side_effect=[response0, response1, response2],
) )
pipeline = SimpleBranchingPipeline() pipeline = SimpleBranchingPipeline()
for _ in range(3): for _ in range(3):
@ -126,8 +154,11 @@ def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
def test_simple_gated_branching_pipeline_run( def test_simple_gated_branching_pipeline_run(
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
): ):
openai_mocker = mocker.patch.object( response0: ChatCompletion = deepcopy(_openai_chat_completion_response)
AzureChatOpenAI, "run", return_value="a quick brown fox" response0.choices[0].message.content = "a quick brown fox"
openai_mocker = mocker.patch(
"openai.resources.chat.completions.Completions.create",
return_value=response0,
) )
pipeline = GatedBranchingPipeline() pipeline = GatedBranchingPipeline()

View File

@ -26,7 +26,8 @@ def test_azureopenai_embeddings_raw(openai_embedding_call):
) )
output = model("Hello world") output = model("Hello world")
assert isinstance(output, list) assert isinstance(output, list)
assert isinstance(output[0], float) assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
openai_embedding_call.assert_called() openai_embedding_call.assert_called()
@ -53,8 +54,8 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
side_effect=lambda *args, **kwargs: None, side_effect=lambda *args, **kwargs: None,
) )
@patch( @patch(
"langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_query", "langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_documents",
side_effect=lambda *args, **kwargs: [1.0, 2.1, 3.2], side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
) )
def test_huggingface_embddings( def test_huggingface_embddings(
langchain_huggingface_embedding_call, sentence_transformers_init langchain_huggingface_embedding_call, sentence_transformers_init
@ -67,21 +68,23 @@ def test_huggingface_embddings(
output = model("Hello World") output = model("Hello World")
assert isinstance(output, list) assert isinstance(output, list)
assert isinstance(output[0], float) assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
sentence_transformers_init.assert_called() sentence_transformers_init.assert_called()
langchain_huggingface_embedding_call.assert_called() langchain_huggingface_embedding_call.assert_called()
@patch( @patch(
"langchain.embeddings.cohere.CohereEmbeddings.embed_query", "langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
side_effect=lambda *args, **kwargs: [1.0, 2.1, 3.2], side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
) )
def test_cohere_embddings(langchain_cohere_embedding_call): def test_cohere_embeddings(langchain_cohere_embedding_call):
model = CohereEmbdeddings( model = CohereEmbdeddings(
model="embed-english-light-v2.0", cohere_api_key="my-api-key" model="embed-english-light-v2.0", cohere_api_key="my-api-key"
) )
output = model("Hello World") output = model("Hello World")
assert isinstance(output, list) assert isinstance(output, list)
assert isinstance(output[0], float) assert isinstance(output[0], list)
assert isinstance(output[0][0], float)
langchain_cohere_embedding_call.assert_called() langchain_cohere_embedding_call.assert_called()

View File

@ -60,7 +60,8 @@ def test_retrieving(mock_openai_embedding, tmp_path):
) )
index_pipeline(text=Document(text="Hello world")) index_pipeline(text=Document(text="Hello world"))
output = retrieval_pipeline(text=["Hello world", "Hello world"]) output = retrieval_pipeline(text="Hello world")
output1 = retrieval_pipeline(text="Hello world")
assert len(output) == 2, "Expect 2 results" assert len(output) == 1, "Expect 1 results"
assert output[0] == output[1], "Expect identical results" assert output == output1, "Expect identical results"

View File

@ -54,12 +54,6 @@ def test_azureopenai_model(openai_completion):
), "Output for single text is not LLMInterface" ), "Output for single text is not LLMInterface"
openai_completion.assert_called() openai_completion.assert_called()
# test for list[str] input - batch mode
output = model(["hello world"])
assert isinstance(output, list), "Output for batch string is not a list"
assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface"
openai_completion.assert_called()
# test for list[message] input - stream mode # test for list[message] input - stream mode
messages = [ messages = [
SystemMessage(content="You are a philosohper"), SystemMessage(content="You are a philosohper"),
@ -73,9 +67,3 @@ def test_azureopenai_model(openai_completion):
output, LLMInterface output, LLMInterface
), "Output for single text is not LLMInterface" ), "Output for single text is not LLMInterface"
openai_completion.assert_called() openai_completion.assert_called()
# test for list[list[message]] input - batch mode
output = model([messages])
assert isinstance(output, list), "Output for batch string is not a list"
assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface"
openai_completion.assert_called()

View File

@ -44,11 +44,6 @@ def test_azureopenai_model(openai_completion):
model.agent, AzureOpenAILC model.agent, AzureOpenAILC
), "Agent not wrapped in Langchain's AzureOpenAI" ), "Agent not wrapped in Langchain's AzureOpenAI"
output = model(["hello world"])
assert isinstance(output, list), "Output for batch is not a list"
assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface"
openai_completion.assert_called()
output = model("hello world") output = model("hello world")
assert isinstance( assert isinstance(
output, LLMInterface output, LLMInterface
@ -72,11 +67,6 @@ def test_openai_model(openai_completion):
model.agent, OpenAILC model.agent, OpenAILC
), "Agent is not wrapped in Langchain's OpenAI" ), "Agent is not wrapped in Langchain's OpenAI"
output = model(["hello world"])
assert isinstance(output, list), "Output for batch is not a list"
assert isinstance(output[0], LLMInterface), "Output for text is not LLMInterface"
openai_completion.assert_called()
output = model("hello world") output = model("hello world")
assert isinstance( assert isinstance(
output, LLMInterface output, LLMInterface

View File

@ -13,23 +13,13 @@ def regex_extractor():
def test_run_document(regex_extractor): def test_run_document(regex_extractor):
document = Document(text="This is a test. 1 2 3") document = Document(text="This is a test. 1 2 3")
extracted_document = regex_extractor(document) extracted_document = regex_extractor(document)[0]
assert extracted_document.text == "One" assert extracted_document.text == "One"
assert extracted_document.matches == ["One", "Two", "Three"] assert extracted_document.matches == ["One", "Two", "Three"]
def test_is_document(regex_extractor):
assert regex_extractor.is_document(Document(text="Test"))
assert not regex_extractor.is_document("Test")
def test_is_batch(regex_extractor):
assert regex_extractor.is_batch([Document(text="Test")])
assert not regex_extractor.is_batch(Document(text="Test"))
def test_run_raw(regex_extractor): def test_run_raw(regex_extractor):
output = regex_extractor("This is a test. 123") output = regex_extractor("This is a test. 123")[0]
assert output.text == "123" assert output.text == "123"
assert output.matches == ["123"] assert output.matches == ["123"]

View File

@ -54,7 +54,7 @@ def test_run():
result = prompt() result = prompt()
assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One" assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One']"
def test_set_method(): def test_set_method():