Provide type hints for pass-through Langchain and Llama-index objects (#95)

This commit is contained in:
Duc Nguyen (john) 2023-12-04 10:59:13 +07:00 committed by GitHub
parent e34b1e4c6d
commit 0ce3a8832f
34 changed files with 641 additions and 310 deletions

View File

@ -54,7 +54,7 @@ class LangchainAgent(BaseAgent):
# reinit Langchain AgentExecutor # reinit Langchain AgentExecutor
self.agent = initialize_agent( self.agent = initialize_agent(
langchain_plugins, langchain_plugins,
self.llm.agent, self.llm._obj,
agent=self.AGENT_TYPE_MAP[self.agent_type], agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True, handle_parsing_errors=True,
verbose=True, verbose=True,

View File

@ -21,7 +21,7 @@ class LLMTool(BaseTool):
"are confident in solving the problem " "are confident in solving the problem "
"yourself. Input can be any instruction." "yourself. Input can be any instruction."
) )
llm: BaseLLM = AzureChatOpenAI() llm: BaseLLM = AzureChatOpenAI.withx()
args_schema: Optional[Type[BaseModel]] = LLMArgs args_schema: Optional[Type[BaseModel]] = LLMArgs
def _run_tool(self, query: AnyStr) -> str: def _run_tool(self, query: AnyStr) -> str:

View File

@ -1,4 +1,15 @@
from .base import BaseEmbeddings from .base import BaseEmbeddings
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from .langchain_based import (
AzureOpenAIEmbeddings,
CohereEmbdeddings,
HuggingFaceEmbeddings,
OpenAIEmbeddings,
)
__all__ = ["BaseEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings"] __all__ = [
"BaseEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CohereEmbdeddings",
"HuggingFaceEmbeddings",
]

View File

@ -1,10 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import Type
from langchain.schema.embeddings import Embeddings as LCEmbeddings
from theflow import Param
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
@ -15,52 +11,3 @@ class BaseEmbeddings(BaseComponent):
self, text: str | list[str] | Document | list[Document] self, text: str | list[str] | Document | list[Document]
) -> list[DocumentWithEmbedding]: ) -> list[DocumentWithEmbedding]:
... ...
class LangchainEmbeddings(BaseEmbeddings):
_lc_class: Type[LCEmbeddings]
def __init__(self, **params):
if self._lc_class is None:
raise AttributeError(
"Should set _lc_class attribute to the LLM class from Langchain "
"if using LLM from Langchain"
)
self._kwargs: dict = {}
for param in list(params.keys()):
if param in self._lc_class.__fields__: # type: ignore
self._kwargs[param] = params.pop(param)
super().__init__(**params)
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
self._kwargs[name] = value
else:
super().__setattr__(name, value)
@Param.auto(cache=False)
def agent(self):
return self._lc_class(**self._kwargs)
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self.agent.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
]

View File

@ -1,12 +0,0 @@
from langchain.embeddings import CohereEmbeddings as LCCohereEmbeddings
from kotaemon.embeddings.base import LangchainEmbeddings
class CohereEmbdeddings(LangchainEmbeddings):
"""Cohere embeddings.
This class wraps around the Langchain CohereEmbeddings class.
"""
_lc_class = LCCohereEmbeddings

View File

@ -1,12 +0,0 @@
from langchain.embeddings import HuggingFaceBgeEmbeddings as LCHuggingFaceEmbeddings
from kotaemon.embeddings.base import LangchainEmbeddings
class HuggingFaceEmbeddings(LangchainEmbeddings):
"""HuggingFace embeddings
This class wraps around the Langchain HuggingFaceEmbeddings class
"""
_lc_class = LCHuggingFaceEmbeddings

View File

@ -0,0 +1,194 @@
from typing import Optional
from kotaemon.base import Document, DocumentWithEmbedding
from .base import BaseEmbeddings
class LCEmbeddingMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self._obj.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
]
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's OpenAI embedding, focusing on key parameters"""
def __init__(
self,
model: str = "text-embedding-ada-002",
openai_api_version: Optional[str] = None,
openai_api_base: Optional[str] = None,
openai_api_type: Optional[str] = None,
openai_api_key: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
model=model,
openai_api_version=openai_api_version,
openai_api_base=openai_api_base,
openai_api_type=openai_api_type,
openai_api_key=openai_api_key,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.emebddings.OpenAIEmbeddings
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's AzureOpenAI embedding, focusing on key parameters"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
deployment: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_version: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
deployment=deployment,
openai_api_version=openai_api_version,
openai_api_key=openai_api_key,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.embeddings.AzureOpenAIEmbeddings
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
def __init__(
self,
model: str = "embed-english-v2.0",
cohere_api_key: Optional[str] = None,
truncate: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
model=model,
cohere_api_key=cohere_api_key,
truncate=truncate,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.embeddings.CohereEmbeddings
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
def __init__(
self,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
**params,
):
super().__init__(
model_name=model_name,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.embeddings.HuggingFaceBgeEmbeddings

View File

@ -1,21 +0,0 @@
from langchain import embeddings as lcembeddings
from .base import LangchainEmbeddings
class OpenAIEmbeddings(LangchainEmbeddings):
"""OpenAI embeddings.
This method is wrapped around the Langchain OpenAIEmbeddings class.
"""
_lc_class = lcembeddings.OpenAIEmbeddings
class AzureOpenAIEmbeddings(LangchainEmbeddings):
"""Azure OpenAI embeddings.
This method is wrapped around the Langchain AzureOpenAIEmbeddings class.
"""
_lc_class = lcembeddings.AzureOpenAIEmbeddings

View File

@ -46,20 +46,48 @@ class LlamaIndexDocTransformerMixin:
"Please return the relevant LlamaIndex class in _get_li_class" "Please return the relevant LlamaIndex class in _get_li_class"
) )
def __init__(self, *args, **kwargs): def __init__(self, **params):
_li_cls = self._get_li_class() self._li_cls = self._get_li_class()
self._obj = _li_cls(*args, **kwargs) self._obj = self._li_cls(**params)
self._kwargs = params
super().__init__() super().__init__()
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords(): if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value) return super().__setattr__(name, value)
self._kwargs[name] = value
return setattr(self._obj, name, value) return setattr(self._obj, name, value)
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name) return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def run( def run(
self, self,
documents: list[Document], documents: list[Document],

View File

@ -6,6 +6,14 @@ class BaseDocParser(DocTransformer):
class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser): class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
nodes: int = 5,
**params,
):
super().__init__(llm=llm, nodes=nodes, **params)
def _get_li_class(self): def _get_li_class(self):
from llama_index.extractors import TitleExtractor from llama_index.extractors import TitleExtractor
@ -13,6 +21,14 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser): class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
summaries: list[str] = ["self"],
**params,
):
super().__init__(llm=llm, summaries=summaries, **params)
def _get_li_class(self): def _get_li_class(self):
from llama_index.extractors import SummaryExtractor from llama_index.extractors import SummaryExtractor

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from ...base import BaseComponent, Document from kotaemon.base import BaseComponent, Document
class BaseReranking(BaseComponent): class BaseReranking(BaseComponent):

View File

@ -2,7 +2,8 @@ from __future__ import annotations
import os import os
from ...base import Document from kotaemon.base import Document
from .base import BaseReranking from .base import BaseReranking

View File

@ -1,17 +1,13 @@
from __future__ import annotations from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Union
from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.boolean import BooleanOutputParser
from ...base import Document from kotaemon.base import Document
from ...llms import PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate
from ...llms.chats.base import ChatLLM
from ...llms.completions.base import LLM
from .base import BaseReranking
BaseLLM = Union[ChatLLM, LLM] from .base import BaseReranking
RERANK_PROMPT_TEMPLATE = """Given the following question and context, RERANK_PROMPT_TEMPLATE = """Given the following question and context,
return YES if the context is relevant to the question and NO if it isn't. return YES if the context is relevant to the question and NO if it isn't.

View File

@ -8,6 +8,20 @@ class BaseSplitter(DocTransformer):
class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 20,
separator: str = " ",
**params,
):
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
**params,
)
def _get_li_class(self): def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter from llama_index.text_splitter import TokenTextSplitter
@ -15,6 +29,9 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(self, window_size: int = 3, **params):
super().__init__(window_size=window_size, **params)
def _get_li_class(self): def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser from llama_index.node_parser import SentenceWindowNodeParser

View File

@ -154,8 +154,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
if __name__ == "__main__": if __name__ == "__main__":
import dotenv import dotenv
from kotaemon.llms import BasePromptComponent from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.parsers import RegexExtractor from kotaemon.parsers import RegexExtractor
def identity(x): def identity(x):

View File

@ -1,4 +1,4 @@
from .base import BaseMessage, ChatLLM, HumanMessage from .base import ChatLLM
from .openai import AzureChatOpenAI from .langchain_based import AzureChatOpenAI
__all__ = ["ChatLLM", "AzureChatOpenAI", "BaseMessage", "HumanMessage"] __all__ = ["ChatLLM", "AzureChatOpenAI"]

View File

@ -1,12 +1,8 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Type
from langchain.chat_models.base import BaseChatModel from kotaemon.base import BaseComponent
from theflow.base import Param
from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,83 +19,3 @@ class ChatLLM(BaseComponent):
text = self.inflow.flow().text text = self.inflow.flow().text
return self.__call__(text) return self.__call__(text)
class LangchainChatLLM(ChatLLM):
_lc_class: Type[BaseChatModel]
def __init__(self, **params):
if self._lc_class is None:
raise AttributeError(
"Should set _lc_class attribute to the LLM class from Langchain "
"if using LLM from Langchain"
)
self._kwargs: dict = {}
for param in list(params.keys()):
if param in self._lc_class.__fields__:
self._kwargs[param] = params.pop(param)
super().__init__(**params)
@Param.auto(cache=False)
def agent(self) -> BaseChatModel:
return self._lc_class(**self._kwargs)
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_: list[BaseMessage] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
pred = self.agent.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
messages=all_messages,
logits=[],
)
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
self._kwargs[name] = value
setattr(self.agent, name, value)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._lc_class.__fields__:
return getattr(self.agent, name)
return super().__getattr__(name) # type: ignore

View File

@ -0,0 +1,149 @@
from __future__ import annotations
import logging
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
from .base import ChatLLM
logger = logging.getLogger(__name__)
class LCChatMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_: list[BaseMessage] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
pred = self._obj.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
messages=all_messages,
logits=[],
)
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class AzureChatOpenAI(LCChatMixin, ChatLLM):
def __init__(
self,
azure_endpoint: str | None = None,
openai_api_key: str | None = None,
openai_api_version: str = "",
deployment_name: str | None = None,
temperature: float = 0.7,
request_timeout: float | None = None,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
openai_api_key=openai_api_key,
openai_api_version=openai_api_version,
deployment_name=deployment_name,
temperature=temperature,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.chat_models
return langchain.chat_models.AzureChatOpenAI

View File

@ -1,7 +0,0 @@
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
from .base import LangchainChatLLM
class AzureChatOpenAI(LangchainChatLLM):
_lc_class = AzureChatOpenAILC

View File

@ -1,4 +1,4 @@
from .base import LLM from .base import LLM
from .openai import AzureOpenAI, OpenAI from .langchain_based import AzureOpenAI, OpenAI
__all__ = ["LLM", "OpenAI", "AzureOpenAI"] __all__ = ["LLM", "OpenAI", "AzureOpenAI"]

View File

@ -1,66 +1,5 @@
import logging from kotaemon.base import BaseComponent
from typing import Type
from langchain.llms.base import BaseLLM
from theflow.base import Param
from ...base import BaseComponent
from ...base.schema import LLMInterface
logger = logging.getLogger(__name__)
class LLM(BaseComponent): class LLM(BaseComponent):
pass pass
class LangchainLLM(LLM):
_lc_class: Type[BaseLLM]
def __init__(self, **params):
if self._lc_class is None:
raise AttributeError(
"Should set _lc_class attribute to the LLM class from Langchain "
"if using LLM from Langchain"
)
self._kwargs: dict = {}
for param in list(params.keys()):
if param in self._lc_class.__fields__:
self._kwargs[param] = params.pop(param)
super().__init__(**params)
@Param.auto(cache=False)
def agent(self):
return self._lc_class(**self._kwargs)
def run(self, text: str) -> LLMInterface:
pred = self.agent.generate([text])
all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
logits=[],
)
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
self._kwargs[name] = value
setattr(self.agent, name, value)
else:
super().__setattr__(name, value)

View File

@ -0,0 +1,185 @@
import logging
from typing import Optional
from kotaemon.base import LLMInterface
from .base import LLM
logger = logging.getLogger(__name__)
class LCCompletionMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(self, text: str) -> LLMInterface:
pred = self._obj.generate([text])
all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
logits=[],
)
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class OpenAI(LCCompletionMixin, LLM):
"""Wrapper around Langchain's OpenAI class, focusing on key parameters"""
def __init__(
self,
openai_api_key: Optional[str] = None,
openai_api_base: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_token: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
best_of: int = 1,
request_timeout: Optional[float] = None,
max_retries: int = 2,
streaming: bool = False,
**params,
):
super().__init__(
openai_api_key=openai_api_key,
openai_api_base=openai_api_base,
model_name=model_name,
temperature=temperature,
max_token=max_token,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
best_of=best_of,
request_timeout=request_timeout,
max_retries=max_retries,
streaming=streaming,
**params,
)
def _get_lc_class(self):
import langchain.llms as langchain_llms
return langchain_llms.OpenAI
class AzureOpenAI(LCCompletionMixin, LLM):
"""Wrapper around Langchain's AzureOpenAI class, focusing on key parameters"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
deployment_name: Optional[str] = None,
openai_api_version: str = "",
openai_api_key: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_token: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
best_of: int = 1,
request_timeout: Optional[float] = None,
max_retries: int = 2,
streaming: bool = False,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
deployment_name=deployment_name,
openai_api_version=openai_api_version,
openai_api_key=openai_api_key,
model_name=model_name,
temperature=temperature,
max_token=max_token,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
best_of=best_of,
request_timeout=request_timeout,
max_retries=max_retries,
streaming=streaming,
**params,
)
def _get_lc_class(self):
import langchain.llms as langchain_llms
return langchain_llms.AzureOpenAI

View File

@ -1,15 +0,0 @@
import langchain.llms as langchain_llms
from .base import LangchainLLM
class OpenAI(LangchainLLM):
"""Wrapper around Langchain's OpenAI class"""
_lc_class = langchain_llms.OpenAI
class AzureOpenAI(LangchainLLM):
"""Wrapper around Langchain's AzureOpenAI class"""
_lc_class = langchain_llms.AzureOpenAI

View File

@ -21,8 +21,7 @@ class SimpleLinearPipeline(BaseComponent):
post-processor component or function. post-processor component or function.
Example Usage: Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.llms import BasePromptComponent
def identity(x): def identity(x):
return x return x
@ -87,8 +86,7 @@ class GatedLinearPipeline(SimpleLinearPipeline):
condition. condition.
Example Usage: Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.llms import BasePromptComponent
from kotaemon.parsers import RegexExtractor from kotaemon.parsers import RegexExtractor
def identity(x): def identity(x):

View File

@ -6,7 +6,7 @@ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms.completions.openai import AzureOpenAI from kotaemon.llms import AzureOpenAI
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore

View File

@ -6,7 +6,7 @@ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorRetrieval from kotaemon.indices import VectorRetrieval
from kotaemon.llms.completions.openai import AzureOpenAI from kotaemon.llms import AzureOpenAI
from kotaemon.storages import ChromaVectorStore from kotaemon.storages import ChromaVectorStore

View File

@ -8,7 +8,7 @@ from kotaemon.agents.langchain import LangchainAgent
from kotaemon.agents.react import ReactAgent from kotaemon.agents.react import ReactAgent
from kotaemon.agents.rewoo import RewooAgent from kotaemon.agents.rewoo import RewooAgent
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
@ -195,7 +195,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
langchain_plugins = [tool.to_langchain_format() for tool in plugins] langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent( agent = initialize_agent(
langchain_plugins, langchain_plugins,
llm.agent, llm._obj,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True, verbose=True,
) )

View File

@ -5,7 +5,7 @@ import pytest
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.indices.qa import CitationPipeline from kotaemon.indices.qa import CitationPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI
function_output = '{\n "question": "What is the provided _example_ benefits?",\n "answer": [\n {\n "fact": "特約死亡保険金: 被保険者がこの特約の保険期間中に死亡したときに支払います。",\n "substring_quote": ["特約死亡保険金"]\n },\n {\n "fact": "特約特定疾病保険金: 被保険者がこの特約の保険期間中に特定の疾病(悪性新生物(がん)、急性心筋梗塞または脳卒中)により所定の状態に該当したときに支払います。",\n "substring_quote": ["特約特定疾病保険金"]\n },\n {\n "fact": "特約障害保険金: 被保険者がこの特約の保険期間中に傷害もしくは疾病により所定の身体障害の状態に該当したとき、または不慮の事故により所定の身体障害の状態に該当したときに支払います。",\n "substring_quote": ["特約障害保険金"]\n },\n {\n "fact": "特約介護保険金: 被保険者がこの特約の保険期間中に傷害または疾病により所定の要介護状態に該当したときに支払います。",\n "substring_quote": ["特約介護保険金"]\n }\n ]\n}' function_output = '{\n "question": "What is the provided _example_ benefits?",\n "answer": [\n {\n "fact": "特約死亡保険金: 被保険者がこの特約の保険期間中に死亡したときに支払います。",\n "substring_quote": ["特約死亡保険金"]\n },\n {\n "fact": "特約特定疾病保険金: 被保険者がこの特約の保険期間中に特定の疾病(悪性新生物(がん)、急性心筋梗塞または脳卒中)により所定の状態に該当したときに支払います。",\n "substring_quote": ["特約特定疾病保険金"]\n },\n {\n "fact": "特約障害保険金: 被保険者がこの特約の保険期間中に傷害もしくは疾病により所定の身体障害の状態に該当したとき、または不慮の事故により所定の身体障害の状態に該当したときに支払います。",\n "substring_quote": ["特約障害保険金"]\n },\n {\n "fact": "特約介護保険金: 被保険者がこの特約の保険期間中に傷害または疾病により所定の要介護状態に該当したときに支払います。",\n "substring_quote": ["特約介護保険金"]\n }\n ]\n}'

View File

@ -3,9 +3,11 @@ from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.embeddings.cohere import CohereEmbdeddings from kotaemon.embeddings import (
from kotaemon.embeddings.huggingface import HuggingFaceEmbeddings AzureOpenAIEmbeddings,
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings CohereEmbdeddings,
HuggingFaceEmbeddings,
)
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f: with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f:
openai_embedding_batch = json.load(f) openai_embedding_batch = json.load(f)
@ -60,7 +62,7 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
"langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_documents", "langchain.embeddings.huggingface.HuggingFaceBgeEmbeddings.embed_documents",
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
) )
def test_huggingface_embddings( def test_huggingface_embeddings(
langchain_huggingface_embedding_call, sentence_transformers_init langchain_huggingface_embedding_call, sentence_transformers_init
): ):
model = HuggingFaceEmbeddings( model = HuggingFaceEmbeddings(

View File

@ -6,7 +6,7 @@ import pytest
from openai.resources.embeddings import Embeddings from openai.resources.embeddings import Embeddings
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore

View File

@ -9,7 +9,7 @@ from kotaemon.base.schema import (
LLMInterface, LLMInterface,
SystemMessage, SystemMessage,
) )
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI
_openai_chat_completion_response = ChatCompletion.parse_obj( _openai_chat_completion_response = ChatCompletion.parse_obj(
{ {
@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion):
temperature=0, temperature=0,
) )
assert isinstance( assert isinstance(
model.agent, AzureChatOpenAILC model._obj, AzureChatOpenAILC
), "Agent not wrapped in Langchain's AzureChatOpenAI" ), "Agent not wrapped in Langchain's AzureChatOpenAI"
# test for str input - stream mode # test for str input - stream mode

View File

@ -5,7 +5,7 @@ from langchain.llms import OpenAI as OpenAILC
from openai.types.completion import Completion from openai.types.completion import Completion
from kotaemon.base.schema import LLMInterface from kotaemon.base.schema import LLMInterface
from kotaemon.llms.completions.openai import AzureOpenAI, OpenAI from kotaemon.llms import AzureOpenAI, OpenAI
_openai_completion_response = Completion.parse_obj( _openai_completion_response = Completion.parse_obj(
{ {
@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion):
request_timeout=60, request_timeout=60,
) )
assert isinstance( assert isinstance(
model.agent, AzureOpenAILC model._obj, AzureOpenAILC
), "Agent not wrapped in Langchain's AzureOpenAI" ), "Agent not wrapped in Langchain's AzureOpenAI"
output = model("hello world") output = model("hello world")
@ -64,7 +64,7 @@ def test_openai_model(openai_completion):
request_timeout=60, request_timeout=60,
) )
assert isinstance( assert isinstance(
model.agent, OpenAILC model._obj, OpenAILC
), "Agent is not wrapped in Langchain's OpenAI" ), "Agent is not wrapped in Langchain's OpenAI"
output = model("hello world") output = model("hello world")

View File

@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.indices.rankings import LLMReranking from kotaemon.indices.rankings import LLMReranking
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI
_openai_chat_completion_responses = [ _openai_chat_completion_responses = [
ChatCompletion.parse_obj( ChatCompletion.parse_obj(

View File

@ -6,7 +6,7 @@ from openai.resources.embeddings import Embeddings
from kotaemon.agents.tools import ComponentTool, GoogleSearchTool, WikipediaTool from kotaemon.agents.tools import ComponentTool, GoogleSearchTool, WikipediaTool
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices.vectorindex import VectorIndexing, VectorRetrieval from kotaemon.indices.vectorindex import VectorIndexing, VectorRetrieval
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore