Provide type hints for pass-through Langchain and Llama-index objects (#95)

This commit is contained in:
Duc Nguyen (john)
2023-12-04 10:59:13 +07:00
committed by GitHub
parent e34b1e4c6d
commit 0ce3a8832f
34 changed files with 641 additions and 310 deletions

View File

@@ -54,7 +54,7 @@ class LangchainAgent(BaseAgent):
# reinit Langchain AgentExecutor
self.agent = initialize_agent(
langchain_plugins,
self.llm.agent,
self.llm._obj,
agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True,
verbose=True,

View File

@@ -21,7 +21,7 @@ class LLMTool(BaseTool):
"are confident in solving the problem "
"yourself. Input can be any instruction."
)
llm: BaseLLM = AzureChatOpenAI()
llm: BaseLLM = AzureChatOpenAI.withx()
args_schema: Optional[Type[BaseModel]] = LLMArgs
def _run_tool(self, query: AnyStr) -> str:

View File

@@ -1,4 +1,15 @@
from .base import BaseEmbeddings
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .langchain_based import (
AzureOpenAIEmbeddings,
CohereEmbdeddings,
HuggingFaceEmbeddings,
OpenAIEmbeddings,
)
__all__ = ["BaseEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings"]
__all__ = [
"BaseEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CohereEmbdeddings",
"HuggingFaceEmbeddings",
]

View File

@@ -1,10 +1,6 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Type
from langchain.schema.embeddings import Embeddings as LCEmbeddings
from theflow import Param
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
@@ -15,52 +11,3 @@ class BaseEmbeddings(BaseComponent):
self, text: str | list[str] | Document | list[Document]
) -> list[DocumentWithEmbedding]:
...
class LangchainEmbeddings(BaseEmbeddings):
_lc_class: Type[LCEmbeddings]
def __init__(self, **params):
if self._lc_class is None:
raise AttributeError(
"Should set _lc_class attribute to the LLM class from Langchain "
"if using LLM from Langchain"
)
self._kwargs: dict = {}
for param in list(params.keys()):
if param in self._lc_class.__fields__: # type: ignore
self._kwargs[param] = params.pop(param)
super().__init__(**params)
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
self._kwargs[name] = value
else:
super().__setattr__(name, value)
@Param.auto(cache=False)
def agent(self):
return self._lc_class(**self._kwargs)
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self.agent.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
]

View File

@@ -1,12 +0,0 @@
from langchain.embeddings import CohereEmbeddings as LCCohereEmbeddings
from kotaemon.embeddings.base import LangchainEmbeddings
class CohereEmbdeddings(LangchainEmbeddings):
"""Cohere embeddings.
This class wraps around the Langchain CohereEmbeddings class.
"""
_lc_class = LCCohereEmbeddings

View File

@@ -1,12 +0,0 @@
from langchain.embeddings import HuggingFaceBgeEmbeddings as LCHuggingFaceEmbeddings
from kotaemon.embeddings.base import LangchainEmbeddings
class HuggingFaceEmbeddings(LangchainEmbeddings):
"""HuggingFace embeddings
This class wraps around the Langchain HuggingFaceEmbeddings class
"""
_lc_class = LCHuggingFaceEmbeddings

View File

@@ -0,0 +1,194 @@
from typing import Optional
from kotaemon.base import Document, DocumentWithEmbedding
from .base import BaseEmbeddings
class LCEmbeddingMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self._obj.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
]
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's OpenAI embedding, focusing on key parameters"""
def __init__(
self,
model: str = "text-embedding-ada-002",
openai_api_version: Optional[str] = None,
openai_api_base: Optional[str] = None,
openai_api_type: Optional[str] = None,
openai_api_key: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
model=model,
openai_api_version=openai_api_version,
openai_api_base=openai_api_base,
openai_api_type=openai_api_type,
openai_api_key=openai_api_key,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.emebddings.OpenAIEmbeddings
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's AzureOpenAI embedding, focusing on key parameters"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
deployment: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_version: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
deployment=deployment,
openai_api_version=openai_api_version,
openai_api_key=openai_api_key,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.embeddings.AzureOpenAIEmbeddings
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
def __init__(
self,
model: str = "embed-english-v2.0",
cohere_api_key: Optional[str] = None,
truncate: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
model=model,
cohere_api_key=cohere_api_key,
truncate=truncate,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.embeddings.CohereEmbeddings
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
def __init__(
self,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
**params,
):
super().__init__(
model_name=model_name,
**params,
)
def _get_lc_class(self):
import langchain.embeddings
return langchain.embeddings.HuggingFaceBgeEmbeddings

View File

@@ -1,21 +0,0 @@
from langchain import embeddings as lcembeddings
from .base import LangchainEmbeddings
class OpenAIEmbeddings(LangchainEmbeddings):
"""OpenAI embeddings.
This method is wrapped around the Langchain OpenAIEmbeddings class.
"""
_lc_class = lcembeddings.OpenAIEmbeddings
class AzureOpenAIEmbeddings(LangchainEmbeddings):
"""Azure OpenAI embeddings.
This method is wrapped around the Langchain AzureOpenAIEmbeddings class.
"""
_lc_class = lcembeddings.AzureOpenAIEmbeddings

View File

@@ -46,20 +46,48 @@ class LlamaIndexDocTransformerMixin:
"Please return the relevant LlamaIndex class in _get_li_class"
)
def __init__(self, *args, **kwargs):
_li_cls = self._get_li_class()
self._obj = _li_cls(*args, **kwargs)
def __init__(self, **params):
self._li_cls = self._get_li_class()
self._obj = self._li_cls(**params)
self._kwargs = params
super().__init__()
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
self._kwargs[name] = value
return setattr(self._obj, name, value)
def __getattr__(self, name: str) -> Any:
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def run(
self,
documents: list[Document],

View File

@@ -6,6 +6,14 @@ class BaseDocParser(DocTransformer):
class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
nodes: int = 5,
**params,
):
super().__init__(llm=llm, nodes=nodes, **params)
def _get_li_class(self):
from llama_index.extractors import TitleExtractor
@@ -13,6 +21,14 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
summaries: list[str] = ["self"],
**params,
):
super().__init__(llm=llm, summaries=summaries, **params)
def _get_li_class(self):
from llama_index.extractors import SummaryExtractor

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import abstractmethod
from ...base import BaseComponent, Document
from kotaemon.base import BaseComponent, Document
class BaseReranking(BaseComponent):

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import os
from ...base import Document
from kotaemon.base import Document
from .base import BaseReranking

View File

@@ -1,17 +1,13 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import Union
from langchain.output_parsers.boolean import BooleanOutputParser
from ...base import Document
from ...llms import PromptTemplate
from ...llms.chats.base import ChatLLM
from ...llms.completions.base import LLM
from .base import BaseReranking
from kotaemon.base import Document
from kotaemon.llms import BaseLLM, PromptTemplate
BaseLLM = Union[ChatLLM, LLM]
from .base import BaseReranking
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
return YES if the context is relevant to the question and NO if it isn't.

View File

@@ -8,6 +8,20 @@ class BaseSplitter(DocTransformer):
class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 20,
separator: str = " ",
**params,
):
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
**params,
)
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
@@ -15,6 +29,9 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(self, window_size: int = 3, **params):
super().__init__(window_size=window_size, **params)
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser

View File

@@ -154,8 +154,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
if __name__ == "__main__":
import dotenv
from kotaemon.llms import BasePromptComponent
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.parsers import RegexExtractor
def identity(x):

View File

@@ -1,4 +1,4 @@
from .base import BaseMessage, ChatLLM, HumanMessage
from .openai import AzureChatOpenAI
from .base import ChatLLM
from .langchain_based import AzureChatOpenAI
__all__ = ["ChatLLM", "AzureChatOpenAI", "BaseMessage", "HumanMessage"]
__all__ = ["ChatLLM", "AzureChatOpenAI"]

View File

@@ -1,12 +1,8 @@
from __future__ import annotations
import logging
from typing import Type
from langchain.chat_models.base import BaseChatModel
from theflow.base import Param
from kotaemon.base import BaseComponent, BaseMessage, HumanMessage, LLMInterface
from kotaemon.base import BaseComponent
logger = logging.getLogger(__name__)
@@ -23,83 +19,3 @@ class ChatLLM(BaseComponent):
text = self.inflow.flow().text
return self.__call__(text)
class LangchainChatLLM(ChatLLM):
_lc_class: Type[BaseChatModel]
def __init__(self, **params):
if self._lc_class is None:
raise AttributeError(
"Should set _lc_class attribute to the LLM class from Langchain "
"if using LLM from Langchain"
)
self._kwargs: dict = {}
for param in list(params.keys()):
if param in self._lc_class.__fields__:
self._kwargs[param] = params.pop(param)
super().__init__(**params)
@Param.auto(cache=False)
def agent(self) -> BaseChatModel:
return self._lc_class(**self._kwargs)
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_: list[BaseMessage] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
pred = self.agent.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
messages=all_messages,
logits=[],
)
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
self._kwargs[name] = value
setattr(self.agent, name, value)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._lc_class.__fields__:
return getattr(self.agent, name)
return super().__getattr__(name) # type: ignore

View File

@@ -0,0 +1,149 @@
from __future__ import annotations
import logging
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
from .base import ChatLLM
logger = logging.getLogger(__name__)
class LCChatMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_: list[BaseMessage] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
pred = self._obj.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
messages=all_messages,
logits=[],
)
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class AzureChatOpenAI(LCChatMixin, ChatLLM):
def __init__(
self,
azure_endpoint: str | None = None,
openai_api_key: str | None = None,
openai_api_version: str = "",
deployment_name: str | None = None,
temperature: float = 0.7,
request_timeout: float | None = None,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
openai_api_key=openai_api_key,
openai_api_version=openai_api_version,
deployment_name=deployment_name,
temperature=temperature,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
import langchain.chat_models
return langchain.chat_models.AzureChatOpenAI

View File

@@ -1,7 +0,0 @@
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
from .base import LangchainChatLLM
class AzureChatOpenAI(LangchainChatLLM):
_lc_class = AzureChatOpenAILC

View File

@@ -1,4 +1,4 @@
from .base import LLM
from .openai import AzureOpenAI, OpenAI
from .langchain_based import AzureOpenAI, OpenAI
__all__ = ["LLM", "OpenAI", "AzureOpenAI"]

View File

@@ -1,66 +1,5 @@
import logging
from typing import Type
from langchain.llms.base import BaseLLM
from theflow.base import Param
from ...base import BaseComponent
from ...base.schema import LLMInterface
logger = logging.getLogger(__name__)
from kotaemon.base import BaseComponent
class LLM(BaseComponent):
pass
class LangchainLLM(LLM):
_lc_class: Type[BaseLLM]
def __init__(self, **params):
if self._lc_class is None:
raise AttributeError(
"Should set _lc_class attribute to the LLM class from Langchain "
"if using LLM from Langchain"
)
self._kwargs: dict = {}
for param in list(params.keys()):
if param in self._lc_class.__fields__:
self._kwargs[param] = params.pop(param)
super().__init__(**params)
@Param.auto(cache=False)
def agent(self):
return self._lc_class(**self._kwargs)
def run(self, text: str) -> LLMInterface:
pred = self.agent.generate([text])
all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
logits=[],
)
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
self._kwargs[name] = value
setattr(self.agent, name, value)
else:
super().__setattr__(name, value)

View File

@@ -0,0 +1,185 @@
import logging
from typing import Optional
from kotaemon.base import LLMInterface
from .base import LLM
logger = logging.getLogger(__name__)
class LCCompletionMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(self, text: str) -> LLMInterface:
pred = self._obj.generate([text])
all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
logits=[],
)
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class OpenAI(LCCompletionMixin, LLM):
"""Wrapper around Langchain's OpenAI class, focusing on key parameters"""
def __init__(
self,
openai_api_key: Optional[str] = None,
openai_api_base: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_token: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
best_of: int = 1,
request_timeout: Optional[float] = None,
max_retries: int = 2,
streaming: bool = False,
**params,
):
super().__init__(
openai_api_key=openai_api_key,
openai_api_base=openai_api_base,
model_name=model_name,
temperature=temperature,
max_token=max_token,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
best_of=best_of,
request_timeout=request_timeout,
max_retries=max_retries,
streaming=streaming,
**params,
)
def _get_lc_class(self):
import langchain.llms as langchain_llms
return langchain_llms.OpenAI
class AzureOpenAI(LCCompletionMixin, LLM):
"""Wrapper around Langchain's AzureOpenAI class, focusing on key parameters"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
deployment_name: Optional[str] = None,
openai_api_version: str = "",
openai_api_key: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_token: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
best_of: int = 1,
request_timeout: Optional[float] = None,
max_retries: int = 2,
streaming: bool = False,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
deployment_name=deployment_name,
openai_api_version=openai_api_version,
openai_api_key=openai_api_key,
model_name=model_name,
temperature=temperature,
max_token=max_token,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
best_of=best_of,
request_timeout=request_timeout,
max_retries=max_retries,
streaming=streaming,
**params,
)
def _get_lc_class(self):
import langchain.llms as langchain_llms
return langchain_llms.AzureOpenAI

View File

@@ -1,15 +0,0 @@
import langchain.llms as langchain_llms
from .base import LangchainLLM
class OpenAI(LangchainLLM):
"""Wrapper around Langchain's OpenAI class"""
_lc_class = langchain_llms.OpenAI
class AzureOpenAI(LangchainLLM):
"""Wrapper around Langchain's AzureOpenAI class"""
_lc_class = langchain_llms.AzureOpenAI

View File

@@ -21,8 +21,7 @@ class SimpleLinearPipeline(BaseComponent):
post-processor component or function.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.llms import BasePromptComponent
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
def identity(x):
return x
@@ -87,8 +86,7 @@ class GatedLinearPipeline(SimpleLinearPipeline):
condition.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.llms import BasePromptComponent
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.parsers import RegexExtractor
def identity(x):