Allow users to add LLM within the UI (#6)
* Rename AzureChatOpenAI to LCAzureChatOpenAI * Provide vanilla ChatOpenAI and AzureChatOpenAI * Remove the highest accuracy, lowest cost criteria These criteria are unnecessary. The users, not pipeline creators, should choose which LLM to use. Furthermore, it's cumbersome to input this information, really degrades user experience. * Remove the LLM selection in simple reasoning pipeline * Provide a dedicated stream method to generate the output * Return placeholder message to chat if the text is empty
This commit is contained in:
committed by
GitHub
parent
e187e23dd1
commit
a203fc0f7c
@@ -52,7 +52,7 @@ class BaseComponent(Function):
|
||||
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
|
||||
...
|
||||
|
||||
async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
|
||||
def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
|
||||
|
||||
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||
@@ -10,6 +10,9 @@ from llama_index.schema import Document as BaseDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
from openai.types.chat.chat_completion_message_param import (
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
IO_Type = TypeVar("IO_Type", "Document", str)
|
||||
SAMPLE_TEXT = "A sample Document from kotaemon"
|
||||
@@ -26,10 +29,15 @@ class Document(BaseDocument):
|
||||
Attributes:
|
||||
content: raw content of the document, can be anything
|
||||
source: id of the source of the Document. Optional.
|
||||
channel: the channel to show the document. Optional.:
|
||||
- chat: show in chat message
|
||||
- info: show in information panel
|
||||
- debug: show in debug panel
|
||||
"""
|
||||
|
||||
content: Any
|
||||
content: Any = None
|
||||
source: Optional[str] = None
|
||||
channel: Optional[Literal["chat", "info", "debug"]] = None
|
||||
|
||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||
if content is None:
|
||||
@@ -87,17 +95,23 @@ class BaseMessage(Document):
|
||||
def __add__(self, other: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage, LCSystemMessage):
|
||||
pass
|
||||
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||
return {"role": "system", "content": self.content}
|
||||
|
||||
|
||||
class AIMessage(BaseMessage, LCAIMessage):
|
||||
pass
|
||||
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||
return {"role": "assistant", "content": self.content}
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage, LCHumanMessage):
|
||||
pass
|
||||
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||
return {"role": "user", "content": self.content}
|
||||
|
||||
|
||||
class RetrievedDocument(Document):
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
||||
from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
|
||||
@@ -13,7 +13,7 @@ class CitationQAPipeline(BaseComponent):
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI.withx(
|
||||
llm: BaseLLM = LCAzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-07-01-preview",
|
||||
|
@@ -2,7 +2,15 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
|
||||
|
||||
from .base import BaseLLM
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, ChatLLM, ChatOpenAI, EndpointChatLLM, LlamaCppChat
|
||||
from .chats import (
|
||||
AzureChatOpenAI,
|
||||
ChatLLM,
|
||||
ChatOpenAI,
|
||||
EndpointChatLLM,
|
||||
LCAzureChatOpenAI,
|
||||
LCChatOpenAI,
|
||||
LlamaCppChat,
|
||||
)
|
||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||
from .cot import ManualSequentialChainOfThought, Thought
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
@@ -17,8 +25,10 @@ __all__ = [
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
"SystemMessage",
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"ChatOpenAI",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatOpenAI",
|
||||
"LlamaCppChat",
|
||||
# completion-specific components
|
||||
"LLM",
|
||||
|
@@ -18,5 +18,8 @@ class BaseLLM(BaseComponent):
|
||||
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
|
||||
def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return self.invoke(*args, **kwargs)
|
||||
|
@@ -15,7 +15,7 @@ class SimpleBranchingPipeline(BaseComponent):
|
||||
Example:
|
||||
```python
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
LCAzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedLinearPipeline,
|
||||
)
|
||||
@@ -25,7 +25,7 @@ class SimpleBranchingPipeline(BaseComponent):
|
||||
return x
|
||||
|
||||
pipeline = SimpleBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
@@ -92,7 +92,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||
Example:
|
||||
```python
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
LCAzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedLinearPipeline,
|
||||
)
|
||||
@@ -102,7 +102,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||
return x
|
||||
|
||||
pipeline = GatedBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
@@ -157,7 +157,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||
if __name__ == "__main__":
|
||||
import dotenv
|
||||
|
||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
||||
from kotaemon.llms import BasePromptComponent, LCAzureChatOpenAI
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
@@ -166,7 +166,7 @@ if __name__ == "__main__":
|
||||
secrets = dotenv.dotenv_values(".env")
|
||||
|
||||
pipeline = GatedBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
|
||||
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
|
||||
|
@@ -1,13 +1,17 @@
|
||||
from .base import ChatLLM
|
||||
from .endpoint_based import EndpointChatLLM
|
||||
from .langchain_based import AzureChatOpenAI, ChatOpenAI, LCChatMixin
|
||||
from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
|
||||
from .llamacpp import LlamaCppChat
|
||||
from .openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
__all__ = [
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"LCChatOpenAI",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatMixin",
|
||||
"LlamaCppChat",
|
||||
]
|
||||
|
@@ -5,6 +5,7 @@ from kotaemon.base import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
Param,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
@@ -20,7 +21,9 @@ class EndpointChatLLM(ChatLLM):
|
||||
endpoint_url (str): The url of a OpenAI API compatible endpoint.
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
endpoint_url: str = Param(
|
||||
help="URL of the OpenAI API compatible endpoint", required=True
|
||||
)
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
|
@@ -165,7 +165,7 @@ class LCChatMixin:
|
||||
raise ValueError(f"Invalid param {path}")
|
||||
|
||||
|
||||
class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
class LCChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
openai_api_base: str | None = None,
|
||||
@@ -193,7 +193,7 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
return ChatOpenAI
|
||||
|
||||
|
||||
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: str | None = None,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Iterator, Optional, cast
|
||||
|
||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
|
||||
|
||||
@@ -12,13 +12,32 @@ if TYPE_CHECKING:
|
||||
class LlamaCppChat(ChatLLM):
|
||||
"""Wrapper around the llama-cpp-python's Llama model"""
|
||||
|
||||
model_path: Optional[str] = None
|
||||
chat_format: Optional[str] = None
|
||||
lora_base: Optional[str] = None
|
||||
n_ctx: int = 512
|
||||
n_gpu_layers: int = 0
|
||||
use_mmap: bool = True
|
||||
vocab_only: bool = False
|
||||
model_path: str = Param(
|
||||
help="Path to the model file. This is required to load the model.",
|
||||
required=True,
|
||||
)
|
||||
chat_format: str = Param(
|
||||
help=(
|
||||
"Chat format to use. Please refer to llama_cpp.llama_chat_format for a "
|
||||
"list of supported formats. If blank, the chat format will be auto-"
|
||||
"inferred."
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
lora_base: Optional[str] = Param(None, help="Path to the base Lora model")
|
||||
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
|
||||
n_gpu_layers: Optional[int] = Param(
|
||||
0,
|
||||
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
|
||||
)
|
||||
use_mmap: Optional[bool] = Param(
|
||||
True,
|
||||
help=(),
|
||||
)
|
||||
vocab_only: Optional[bool] = Param(
|
||||
False,
|
||||
help=("If True, only the vocabulary is loaded. This is useful for debugging."),
|
||||
)
|
||||
|
||||
_role_mapper: dict[str, str] = {
|
||||
"human": "user",
|
||||
@@ -60,9 +79,9 @@ class LlamaCppChat(ChatLLM):
|
||||
vocab_only=self.vocab_only,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
def prepare_message(
|
||||
self, messages: str | BaseMessage | list[BaseMessage]
|
||||
) -> list[dict]:
|
||||
input_: list[BaseMessage] = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
@@ -72,11 +91,19 @@ class LlamaCppChat(ChatLLM):
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
output_ = [
|
||||
{"role": self._role_mapper[each.type], "content": each.content}
|
||||
for each in input_
|
||||
]
|
||||
|
||||
return output_
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
|
||||
pred: "CCCR" = self.client_object.create_chat_completion(
|
||||
messages=[
|
||||
{"role": self._role_mapper[each.type], "content": each.content}
|
||||
for each in input_
|
||||
], # type: ignore
|
||||
messages=self.prepare_message(messages),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@@ -91,3 +118,19 @@ class LlamaCppChat(ChatLLM):
|
||||
total_tokens=pred["usage"]["total_tokens"],
|
||||
prompt_tokens=pred["usage"]["prompt_tokens"],
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> Iterator[LLMInterface]:
|
||||
pred = self.client_object.create_chat_completion(
|
||||
messages=self.prepare_message(messages),
|
||||
stream=True,
|
||||
)
|
||||
for chunk in pred:
|
||||
if not chunk["choices"]:
|
||||
continue
|
||||
|
||||
if "content" not in chunk["choices"][0]["delta"]:
|
||||
continue
|
||||
|
||||
yield LLMInterface(content=chunk["choices"][0]["delta"]["content"])
|
||||
|
356
libs/kotaemon/kotaemon/llms/chats/openai.py
Normal file
356
libs/kotaemon/kotaemon/llms/chats/openai.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
|
||||
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
|
||||
|
||||
from .base import ChatLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat.chat_completion_message_param import (
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
|
||||
class BaseChatOpenAI(ChatLLM):
|
||||
"""Base interface for OpenAI chat model, using the openai library
|
||||
|
||||
This class exposes the parameters in resources.Chat. To subclass this class:
|
||||
|
||||
- Implement the `prepare_client` method to return the OpenAI client
|
||||
- Implement the `openai_response` method to return the OpenAI response
|
||||
- Implement the params relate to the OpenAI client
|
||||
"""
|
||||
|
||||
_dependencies = ["openai"]
|
||||
_capabilities = ["chat", "text"] # consider as mixin
|
||||
|
||||
api_key: str = Param(help="API key", required=True)
|
||||
timeout: Optional[float] = Param(None, help="Timeout for the API request")
|
||||
max_retries: Optional[int] = Param(
|
||||
None, help="Maximum number of retries for the API request"
|
||||
)
|
||||
|
||||
temperature: Optional[float] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Number between 0 and 2 that controls the randomness of the generated "
|
||||
"tokens. Lower values make the model more deterministic, while higher "
|
||||
"values make the model more random."
|
||||
),
|
||||
)
|
||||
max_tokens: Optional[int] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Maximum number of tokens to generate. The total length of input tokens "
|
||||
"and generated tokens is limited by the model's context length."
|
||||
),
|
||||
)
|
||||
n: int = Param(
|
||||
1,
|
||||
help=(
|
||||
"Number of completions to generate. The API will generate n completion "
|
||||
"for each prompt."
|
||||
),
|
||||
)
|
||||
stop: Optional[str | list[str]] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Stop sequence. If a stop sequence is detected, generation will stop "
|
||||
"at that point. If not specified, generation will continue until the "
|
||||
"maximum token length is reached."
|
||||
),
|
||||
)
|
||||
frequency_penalty: Optional[float] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||
"based on their existing frequency in the text so far, decrearsing the "
|
||||
"model's likelihood of repeating the same text."
|
||||
),
|
||||
)
|
||||
presence_penalty: Optional[float] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||
"based on their existing presence in the text so far, decrearsing the "
|
||||
"model's likelihood of repeating the same text."
|
||||
),
|
||||
)
|
||||
tool_choice: Optional[str] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Choice of tool to use for the completion. Available choices are: "
|
||||
"auto, default."
|
||||
),
|
||||
)
|
||||
tools: Optional[list[str]] = Param(
|
||||
None,
|
||||
help="List of tools to use for the completion.",
|
||||
)
|
||||
logprobs: Optional[bool] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Include log probabilities on the logprobs most likely tokens, "
|
||||
"as well as the chosen token."
|
||||
),
|
||||
)
|
||||
logit_bias: Optional[dict] = Param(
|
||||
None,
|
||||
help=(
|
||||
"Dictionary of logit bias values to add to the logits of the tokens "
|
||||
"in the vocabulary."
|
||||
),
|
||||
)
|
||||
top_logprobs: Optional[int] = Param(
|
||||
None,
|
||||
help=(
|
||||
"An integer between 0 and 5 specifying the number of most likely tokens "
|
||||
"to return at each token position, each with an associated log "
|
||||
"probability. `logprobs` must also be set to `true` if this parameter "
|
||||
"is used."
|
||||
),
|
||||
)
|
||||
top_p: Optional[float] = Param(
|
||||
None,
|
||||
help=(
|
||||
"An alternative to sampling with temperature, called nucleus sampling, "
|
||||
"where the model considers the results of the token with top_p "
|
||||
"probability mass. So 0.1 means that only the tokens comprising the "
|
||||
"top 10% probability mass are considered."
|
||||
),
|
||||
)
|
||||
|
||||
@Param.auto(depends_on=["max_retries"])
|
||||
def max_retries_(self):
|
||||
if self.max_retries is None:
|
||||
from openai._constants import DEFAULT_MAX_RETRIES
|
||||
|
||||
return DEFAULT_MAX_RETRIES
|
||||
return self.max_retries
|
||||
|
||||
def prepare_message(
|
||||
self, messages: str | BaseMessage | list[BaseMessage]
|
||||
) -> list["ChatCompletionMessageParam"]:
|
||||
"""Prepare the message into OpenAI format
|
||||
|
||||
Returns:
|
||||
list[dict]: List of messages in OpenAI format
|
||||
"""
|
||||
input_: list[BaseMessage] = []
|
||||
output_: list["ChatCompletionMessageParam"] = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
input_ = [HumanMessage(content=messages)]
|
||||
elif isinstance(messages, BaseMessage):
|
||||
input_ = [messages]
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
for message in input_:
|
||||
output_.append(message.to_openai_format())
|
||||
|
||||
return output_
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
Args:
|
||||
async_version (bool): Whether to get the async version of the client
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
raise NotImplementedError
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
) -> LLMInterface:
|
||||
client = self.prepare_client(async_version=False)
|
||||
input_messages = self.prepare_message(messages)
|
||||
resp = self.openai_response(
|
||||
client, messages=input_messages, stream=False, **kwargs
|
||||
).dict()
|
||||
|
||||
output = LLMInterface(
|
||||
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"],
|
||||
total_tokens=resp["usage"]["total_tokens"],
|
||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||
completion_tokens=resp["usage"]["completion_tokens"],
|
||||
messages=[
|
||||
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
||||
],
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
) -> LLMInterface:
|
||||
client = self.prepare_client(async_version=True)
|
||||
input_messages = self.prepare_message(messages)
|
||||
resp = await self.openai_response(
|
||||
client, messages=input_messages, stream=False, **kwargs
|
||||
).dict()
|
||||
|
||||
output = LLMInterface(
|
||||
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"],
|
||||
total_tokens=resp["usage"]["total_tokens"],
|
||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||
completion_tokens=resp["usage"]["completion_tokens"],
|
||||
messages=[
|
||||
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
||||
],
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def stream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
) -> Iterator[LLMInterface]:
|
||||
client = self.prepare_client(async_version=False)
|
||||
input_messages = self.prepare_message(messages)
|
||||
resp = self.openai_response(
|
||||
client, messages=input_messages, stream=True, **kwargs
|
||||
)
|
||||
|
||||
for chunk in resp:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield LLMInterface(content=chunk.choices[0].delta.content)
|
||||
|
||||
async def astream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
) -> AsyncGenerator[LLMInterface, None]:
|
||||
client = self.prepare_client(async_version=True)
|
||||
input_messages = self.prepare_message(messages)
|
||||
resp = self.openai_response(
|
||||
client, messages=input_messages, stream=True, **kwargs
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield LLMInterface(content=chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI):
|
||||
"""OpenAI chat model"""
|
||||
|
||||
base_url: Optional[str] = Param(None, help="OpenAI base URL")
|
||||
organization: Optional[str] = Param(None, help="OpenAI organization")
|
||||
model: str = Param(help="OpenAI model", required=True)
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
Args:
|
||||
async_version (bool): Whether to get the async version of the client
|
||||
"""
|
||||
params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries_,
|
||||
}
|
||||
if async_version:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
return AsyncOpenAI(**params)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(**params)
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"tool_choice": self.tool_choice,
|
||||
"tools": self.tools,
|
||||
"logprobs": self.logprobs,
|
||||
"logit_bias": self.logit_bias,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
return client.chat.completions.create(**params)
|
||||
|
||||
|
||||
class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""OpenAI chat model provided by Microsoft Azure"""
|
||||
|
||||
azure_endpoint: str = Param(
|
||||
help=(
|
||||
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
|
||||
"azure_deployment, and api_version parameters are used to construct "
|
||||
"the full URL for the Azure OpenAI model."
|
||||
)
|
||||
)
|
||||
azure_deployment: str = Param(help="Azure deployment name", required=True)
|
||||
api_version: str = Param(help="Azure model version", required=True)
|
||||
azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
|
||||
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
|
||||
|
||||
@Param.auto(depends_on=["azure_ad_token_provider"])
|
||||
def azure_ad_token_provider_(self):
|
||||
if isinstance(self.azure_ad_token_provider, str):
|
||||
return import_dotted_string(self.azure_ad_token_provider, safe=False)
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
Args:
|
||||
async_version (bool): Whether to get the async version of the client
|
||||
"""
|
||||
params = {
|
||||
"azure_endpoint": self.azure_endpoint,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"azure_ad_token": self.azure_ad_token,
|
||||
"azure_ad_token_provider": self.azure_ad_token_provider_,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries_,
|
||||
}
|
||||
if async_version:
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
return AsyncAzureOpenAI(**params)
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
return AzureOpenAI(**params)
|
||||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params = {
|
||||
"model": self.azure_deployment,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"tool_choice": self.tool_choice,
|
||||
"tools": self.tools,
|
||||
"logprobs": self.logprobs,
|
||||
"logit_bias": self.logit_bias,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
return client.chat.completions.create(**params)
|
@@ -5,7 +5,7 @@ from theflow import Function, Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
|
||||
from .chats import AzureChatOpenAI
|
||||
from .chats import LCAzureChatOpenAI
|
||||
from .completions import LLM
|
||||
from .prompts import BasePromptComponent
|
||||
|
||||
@@ -25,7 +25,7 @@ class Thought(BaseComponent):
|
||||
>> from kotaemon.pipelines.cot import Thought
|
||||
>> thought = Thought(
|
||||
prompt="How to {action} {object}?",
|
||||
llm=AzureChatOpenAI(...),
|
||||
llm=LCAzureChatOpenAI(...),
|
||||
post_process=lambda string: {"tutorial": string},
|
||||
)
|
||||
>> output = thought(action="install", object="python")
|
||||
@@ -42,7 +42,7 @@ class Thought(BaseComponent):
|
||||
This `Thought` allows chaining sequentially with the + operator. For example:
|
||||
|
||||
```python
|
||||
>> llm = AzureChatOpenAI(...)
|
||||
>> llm = LCAzureChatOpenAI(...)
|
||||
>> thought1 = Thought(
|
||||
prompt="Word {word} in {language} is ",
|
||||
llm=llm,
|
||||
@@ -73,7 +73,7 @@ class Thought(BaseComponent):
|
||||
" component is executed"
|
||||
)
|
||||
)
|
||||
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||
llm: LLM = Node(LCAzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||
post_process: Function = Node(
|
||||
help=(
|
||||
"The function post-processor that post-processes LLM output prediction ."
|
||||
@@ -117,7 +117,7 @@ class ManualSequentialChainOfThought(BaseComponent):
|
||||
|
||||
```pycon
|
||||
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
|
||||
>>> llm = AzureChatOpenAI(...)
|
||||
>>> llm = LCAzureChatOpenAI(...)
|
||||
>>> thought1 = Thought(
|
||||
>>> prompt="Word {word} in {language} is ",
|
||||
>>> post_process=lambda string: {"translated": string},
|
||||
|
@@ -22,12 +22,12 @@ class SimpleLinearPipeline(BaseComponent):
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
||||
from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
@@ -89,13 +89,13 @@ class GatedLinearPipeline(SimpleLinearPipeline):
|
||||
|
||||
Usage:
|
||||
```{.py3 title="Example Usage"}
|
||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
||||
from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
|
@@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
||||
# metadata and dependencies
|
||||
[project]
|
||||
name = "kotaemon"
|
||||
version = "0.3.8"
|
||||
version = "0.3.9"
|
||||
requires-python = ">= 3.10"
|
||||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
|
@@ -13,7 +13,7 @@ from kotaemon.agents import (
|
||||
RewooAgent,
|
||||
WikipediaTool,
|
||||
)
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from kotaemon.llms import LCAzureChatOpenAI
|
||||
|
||||
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
||||
REWOO_VALID_PLAN = (
|
||||
@@ -112,7 +112,7 @@ _openai_chat_completion_responses_react_langchain_tool = [
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
return AzureChatOpenAI(
|
||||
return LCAzureChatOpenAI(
|
||||
azure_endpoint="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
|
@@ -4,10 +4,10 @@ import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedBranchingPipeline,
|
||||
GatedLinearPipeline,
|
||||
LCAzureChatOpenAI,
|
||||
SimpleBranchingPipeline,
|
||||
SimpleLinearPipeline,
|
||||
)
|
||||
@@ -40,7 +40,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
return AzureChatOpenAI(
|
||||
return LCAzureChatOpenAI(
|
||||
azure_endpoint="OPENAI_API_BASE",
|
||||
openai_api_key="OPENAI_API_KEY",
|
||||
openai_api_version="OPENAI_API_VERSION",
|
||||
|
@@ -2,7 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from kotaemon.llms import LCAzureChatOpenAI
|
||||
from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought
|
||||
|
||||
_openai_chat_completion_response = [
|
||||
@@ -38,7 +38,7 @@ _openai_chat_completion_response = [
|
||||
side_effect=_openai_chat_completion_response,
|
||||
)
|
||||
def test_cot_plus_operator(openai_completion):
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
azure_endpoint="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
@@ -70,7 +70,7 @@ def test_cot_plus_operator(openai_completion):
|
||||
side_effect=_openai_chat_completion_response,
|
||||
)
|
||||
def test_cot_manual(openai_completion):
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
azure_endpoint="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
@@ -100,7 +100,7 @@ def test_cot_manual(openai_completion):
|
||||
side_effect=_openai_chat_completion_response,
|
||||
)
|
||||
def test_cot_with_termination_callback(openai_completion):
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
azure_endpoint="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
|
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage
|
||||
from kotaemon.llms import AzureChatOpenAI, LlamaCppChat
|
||||
from kotaemon.llms import LCAzureChatOpenAI, LlamaCppChat
|
||||
|
||||
try:
|
||||
from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC
|
||||
@@ -43,7 +43,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||
side_effect=lambda *args, **kwargs: _openai_chat_completion_response,
|
||||
)
|
||||
def test_azureopenai_model(openai_completion):
|
||||
model = AzureChatOpenAI(
|
||||
model = LCAzureChatOpenAI(
|
||||
azure_endpoint="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
|
@@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.indices.rankings import LLMReranking
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from kotaemon.llms import LCAzureChatOpenAI
|
||||
|
||||
_openai_chat_completion_responses = [
|
||||
ChatCompletion.parse_obj(
|
||||
@@ -41,7 +41,7 @@ _openai_chat_completion_responses = [
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
return AzureChatOpenAI(
|
||||
return LCAzureChatOpenAI(
|
||||
azure_endpoint="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
|
Reference in New Issue
Block a user