From a203fc0f7c854c33af6f95b946cc3265b8cb4bcc Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Sat, 6 Apr 2024 11:53:17 +0700 Subject: [PATCH] Allow users to add LLM within the UI (#6) * Rename AzureChatOpenAI to LCAzureChatOpenAI * Provide vanilla ChatOpenAI and AzureChatOpenAI * Remove the highest accuracy, lowest cost criteria These criteria are unnecessary. The users, not pipeline creators, should choose which LLM to use. Furthermore, it's cumbersome to input this information, really degrades user experience. * Remove the LLM selection in simple reasoning pipeline * Provide a dedicated stream method to generate the output * Return placeholder message to chat if the text is empty --- .gitignore | 1 + docs/development/create-a-component.md | 6 +- docs/pages/app/customize-flows.md | 7 +- libs/kotaemon/kotaemon/base/component.py | 2 +- libs/kotaemon/kotaemon/base/schema.py | 24 +- .../kotaemon/indices/qa/text_based.py | 4 +- libs/kotaemon/kotaemon/llms/__init__.py | 14 +- libs/kotaemon/kotaemon/llms/base.py | 5 +- libs/kotaemon/kotaemon/llms/branching.py | 12 +- libs/kotaemon/kotaemon/llms/chats/__init__.py | 8 +- .../kotaemon/llms/chats/endpoint_based.py | 5 +- .../kotaemon/llms/chats/langchain_based.py | 4 +- libs/kotaemon/kotaemon/llms/chats/llamacpp.py | 73 +++- libs/kotaemon/kotaemon/llms/chats/openai.py | 356 ++++++++++++++++++ libs/kotaemon/kotaemon/llms/cot.py | 10 +- libs/kotaemon/kotaemon/llms/linear.py | 8 +- libs/kotaemon/pyproject.toml | 2 +- libs/kotaemon/tests/test_agent.py | 4 +- libs/kotaemon/tests/test_composite.py | 4 +- libs/kotaemon/tests/test_cot.py | 8 +- libs/kotaemon/tests/test_llms_chat_models.py | 4 +- libs/kotaemon/tests/test_reranking.py | 4 +- libs/ktem/flowsettings.py | 19 +- libs/ktem/ktem/components.py | 20 +- libs/ktem/ktem/index/file/pipelines.py | 4 +- libs/ktem/ktem/llms/__init__.py | 0 libs/ktem/ktem/llms/db.py | 36 ++ libs/ktem/ktem/llms/manager.py | 191 ++++++++++ libs/ktem/ktem/llms/ui.py | 318 ++++++++++++++++ libs/ktem/ktem/pages/admin/__init__.py | 4 + libs/ktem/ktem/pages/chat/__init__.py | 58 ++- libs/ktem/ktem/reasoning/simple.py | 283 +++++++++++--- libs/ktem/ktem_tests/test_qa.py | 4 +- libs/ktem/launch.py | 2 +- .../{{cookiecutter.project_name}}/pipeline.py | 4 +- 35 files changed, 1339 insertions(+), 169 deletions(-) create mode 100644 libs/kotaemon/kotaemon/llms/chats/openai.py create mode 100644 libs/ktem/ktem/llms/__init__.py create mode 100644 libs/ktem/ktem/llms/db.py create mode 100644 libs/ktem/ktem/llms/manager.py create mode 100644 libs/ktem/ktem/llms/ui.py diff --git a/.gitignore b/.gitignore index 5c91c3e..711a741 100644 --- a/.gitignore +++ b/.gitignore @@ -458,6 +458,7 @@ logs/ .gitsecret/keys/random_seed !*.secret .envrc +.env S.gpg-agent* .vscode/settings.json diff --git a/docs/development/create-a-component.md b/docs/development/create-a-component.md index 029bbe9..f831259 100644 --- a/docs/development/create-a-component.md +++ b/docs/development/create-a-component.md @@ -22,7 +22,7 @@ The syntax of a component is as follow: ```python from kotaemon.base import BaseComponent -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI from kotaemon.parsers import RegexExtractor @@ -32,7 +32,7 @@ class FancyPipeline(BaseComponent): param3: float node1: BaseComponent # this is a node because of BaseComponent type annotation - node2: AzureChatOpenAI # this is also a node because AzureChatOpenAI subclasses BaseComponent + node2: LCAzureChatOpenAI # this is also a node because LCAzureChatOpenAI subclasses BaseComponent node3: RegexExtractor # this is also a node bceause RegexExtractor subclasses BaseComponent def run(self, some_text: str): @@ -45,7 +45,7 @@ class FancyPipeline(BaseComponent): Then this component can be used as follow: ```python -llm = AzureChatOpenAI(endpoint="some-endpont") +llm = LCAzureChatOpenAI(endpoint="some-endpont") extractor = RegexExtractor(pattern=["yes", "Yes"]) component = FancyPipeline( diff --git a/docs/pages/app/customize-flows.md b/docs/pages/app/customize-flows.md index 1277e34..3dd005e 100644 --- a/docs/pages/app/customize-flows.md +++ b/docs/pages/app/customize-flows.md @@ -193,7 +193,8 @@ information panel. You can access users' collections of LLMs and embedding models with: ```python -from ktem.components import llms, embeddings +from ktem.components import embeddings +from ktem.llms.manager import llms llm = llms.get_default() @@ -206,12 +207,12 @@ models they want to use through the settings. ```python @classmethod def get_user_settings(cls) -> dict: - from ktem.components import llms + from ktem.llms.manager import llms return { "citation_llm": { "name": "LLM for citation", - "value": llms.get_lowest_cost_name(), + "value": llms.get_default(), "component: "dropdown", "choices": list(llms.options().keys()), }, diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 9acd39f..6936b2a 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -52,7 +52,7 @@ class BaseComponent(Function): def stream(self, *args, **kwargs) -> Iterator[Document] | None: ... - async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None: + def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None: ... @abstractmethod diff --git a/libs/kotaemon/kotaemon/base/schema.py b/libs/kotaemon/kotaemon/base/schema.py index 1d0e622..07fe9f5 100644 --- a/libs/kotaemon/kotaemon/base/schema.py +++ b/libs/kotaemon/kotaemon/base/schema.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar from langchain.schema.messages import AIMessage as LCAIMessage from langchain.schema.messages import HumanMessage as LCHumanMessage @@ -10,6 +10,9 @@ from llama_index.schema import Document as BaseDocument if TYPE_CHECKING: from haystack.schema import Document as HaystackDocument + from openai.types.chat.chat_completion_message_param import ( + ChatCompletionMessageParam, + ) IO_Type = TypeVar("IO_Type", "Document", str) SAMPLE_TEXT = "A sample Document from kotaemon" @@ -26,10 +29,15 @@ class Document(BaseDocument): Attributes: content: raw content of the document, can be anything source: id of the source of the Document. Optional. + channel: the channel to show the document. Optional.: + - chat: show in chat message + - info: show in information panel + - debug: show in debug panel """ - content: Any + content: Any = None source: Optional[str] = None + channel: Optional[Literal["chat", "info", "debug"]] = None def __init__(self, content: Optional[Any] = None, *args, **kwargs): if content is None: @@ -87,17 +95,23 @@ class BaseMessage(Document): def __add__(self, other: Any): raise NotImplementedError + def to_openai_format(self) -> "ChatCompletionMessageParam": + raise NotImplementedError + class SystemMessage(BaseMessage, LCSystemMessage): - pass + def to_openai_format(self) -> "ChatCompletionMessageParam": + return {"role": "system", "content": self.content} class AIMessage(BaseMessage, LCAIMessage): - pass + def to_openai_format(self) -> "ChatCompletionMessageParam": + return {"role": "assistant", "content": self.content} class HumanMessage(BaseMessage, LCHumanMessage): - pass + def to_openai_format(self) -> "ChatCompletionMessageParam": + return {"role": "user", "content": self.content} class RetrievedDocument(Document): diff --git a/libs/kotaemon/kotaemon/indices/qa/text_based.py b/libs/kotaemon/kotaemon/indices/qa/text_based.py index 5b1f6e3..e0b49be 100644 --- a/libs/kotaemon/kotaemon/indices/qa/text_based.py +++ b/libs/kotaemon/kotaemon/indices/qa/text_based.py @@ -1,7 +1,7 @@ import os from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument -from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate +from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate from .citation import CitationPipeline @@ -13,7 +13,7 @@ class CitationQAPipeline(BaseComponent): 'Answer the following question: "{question}". ' "The context is: \n{context}\nAnswer: " ) - llm: BaseLLM = AzureChatOpenAI.withx( + llm: BaseLLM = LCAzureChatOpenAI.withx( azure_endpoint="https://bleh-dummy.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_version="2023-07-01-preview", diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py index d7547a6..266e391 100644 --- a/libs/kotaemon/kotaemon/llms/__init__.py +++ b/libs/kotaemon/kotaemon/llms/__init__.py @@ -2,7 +2,15 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes from .base import BaseLLM from .branching import GatedBranchingPipeline, SimpleBranchingPipeline -from .chats import AzureChatOpenAI, ChatLLM, ChatOpenAI, EndpointChatLLM, LlamaCppChat +from .chats import ( + AzureChatOpenAI, + ChatLLM, + ChatOpenAI, + EndpointChatLLM, + LCAzureChatOpenAI, + LCChatOpenAI, + LlamaCppChat, +) from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI from .cot import ManualSequentialChainOfThought, Thought from .linear import GatedLinearPipeline, SimpleLinearPipeline @@ -17,8 +25,10 @@ __all__ = [ "HumanMessage", "AIMessage", "SystemMessage", - "ChatOpenAI", "AzureChatOpenAI", + "ChatOpenAI", + "LCAzureChatOpenAI", + "LCChatOpenAI", "LlamaCppChat", # completion-specific components "LLM", diff --git a/libs/kotaemon/kotaemon/llms/base.py b/libs/kotaemon/kotaemon/llms/base.py index 6ef7afc..374d139 100644 --- a/libs/kotaemon/kotaemon/llms/base.py +++ b/libs/kotaemon/kotaemon/llms/base.py @@ -18,5 +18,8 @@ class BaseLLM(BaseComponent): def stream(self, *args, **kwargs) -> Iterator[LLMInterface]: raise NotImplementedError - async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]: + def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]: raise NotImplementedError + + def run(self, *args, **kwargs): + return self.invoke(*args, **kwargs) diff --git a/libs/kotaemon/kotaemon/llms/branching.py b/libs/kotaemon/kotaemon/llms/branching.py index a9cbbe8..ee49dc5 100644 --- a/libs/kotaemon/kotaemon/llms/branching.py +++ b/libs/kotaemon/kotaemon/llms/branching.py @@ -15,7 +15,7 @@ class SimpleBranchingPipeline(BaseComponent): Example: ```python from kotaemon.llms import ( - AzureChatOpenAI, + LCAzureChatOpenAI, BasePromptComponent, GatedLinearPipeline, ) @@ -25,7 +25,7 @@ class SimpleBranchingPipeline(BaseComponent): return x pipeline = SimpleBranchingPipeline() - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", @@ -92,7 +92,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): Example: ```python from kotaemon.llms import ( - AzureChatOpenAI, + LCAzureChatOpenAI, BasePromptComponent, GatedLinearPipeline, ) @@ -102,7 +102,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): return x pipeline = GatedBranchingPipeline() - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", @@ -157,7 +157,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): if __name__ == "__main__": import dotenv - from kotaemon.llms import AzureChatOpenAI, BasePromptComponent + from kotaemon.llms import BasePromptComponent, LCAzureChatOpenAI from kotaemon.parsers import RegexExtractor def identity(x): @@ -166,7 +166,7 @@ if __name__ == "__main__": secrets = dotenv.dotenv_values(".env") pipeline = GatedBranchingPipeline() - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base=secrets.get("OPENAI_API_BASE", ""), openai_api_key=secrets.get("OPENAI_API_KEY", ""), openai_api_version=secrets.get("OPENAI_API_VERSION", ""), diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py index 5b50317..7fc1c40 100644 --- a/libs/kotaemon/kotaemon/llms/chats/__init__.py +++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py @@ -1,13 +1,17 @@ from .base import ChatLLM from .endpoint_based import EndpointChatLLM -from .langchain_based import AzureChatOpenAI, ChatOpenAI, LCChatMixin +from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI from .llamacpp import LlamaCppChat +from .openai import AzureChatOpenAI, ChatOpenAI __all__ = [ + "ChatOpenAI", + "AzureChatOpenAI", "ChatLLM", "EndpointChatLLM", "ChatOpenAI", - "AzureChatOpenAI", + "LCChatOpenAI", + "LCAzureChatOpenAI", "LCChatMixin", "LlamaCppChat", ] diff --git a/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py b/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py index 170ec8b..5ab1835 100644 --- a/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py @@ -5,6 +5,7 @@ from kotaemon.base import ( BaseMessage, HumanMessage, LLMInterface, + Param, SystemMessage, ) @@ -20,7 +21,9 @@ class EndpointChatLLM(ChatLLM): endpoint_url (str): The url of a OpenAI API compatible endpoint. """ - endpoint_url: str + endpoint_url: str = Param( + help="URL of the OpenAI API compatible endpoint", required=True + ) def run( self, messages: str | BaseMessage | list[BaseMessage], **kwargs diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index 526eaf8..fca78dc 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -165,7 +165,7 @@ class LCChatMixin: raise ValueError(f"Invalid param {path}") -class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore +class LCChatOpenAI(LCChatMixin, ChatLLM): # type: ignore def __init__( self, openai_api_base: str | None = None, @@ -193,7 +193,7 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore return ChatOpenAI -class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore +class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore def __init__( self, azure_endpoint: str | None = None, diff --git a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py index 62ee0ea..7b8bee4 100644 --- a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py +++ b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Iterator, Optional, cast from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param @@ -12,13 +12,32 @@ if TYPE_CHECKING: class LlamaCppChat(ChatLLM): """Wrapper around the llama-cpp-python's Llama model""" - model_path: Optional[str] = None - chat_format: Optional[str] = None - lora_base: Optional[str] = None - n_ctx: int = 512 - n_gpu_layers: int = 0 - use_mmap: bool = True - vocab_only: bool = False + model_path: str = Param( + help="Path to the model file. This is required to load the model.", + required=True, + ) + chat_format: str = Param( + help=( + "Chat format to use. Please refer to llama_cpp.llama_chat_format for a " + "list of supported formats. If blank, the chat format will be auto-" + "inferred." + ), + required=True, + ) + lora_base: Optional[str] = Param(None, help="Path to the base Lora model") + n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model") + n_gpu_layers: Optional[int] = Param( + 0, + help=("Number of layers to offload to GPU. If -1, all layers are offloaded"), + ) + use_mmap: Optional[bool] = Param( + True, + help=(), + ) + vocab_only: Optional[bool] = Param( + False, + help=("If True, only the vocabulary is loaded. This is useful for debugging."), + ) _role_mapper: dict[str, str] = { "human": "user", @@ -60,9 +79,9 @@ class LlamaCppChat(ChatLLM): vocab_only=self.vocab_only, ) - def run( - self, messages: str | BaseMessage | list[BaseMessage], **kwargs - ) -> LLMInterface: + def prepare_message( + self, messages: str | BaseMessage | list[BaseMessage] + ) -> list[dict]: input_: list[BaseMessage] = [] if isinstance(messages, str): @@ -72,11 +91,19 @@ class LlamaCppChat(ChatLLM): else: input_ = messages + output_ = [ + {"role": self._role_mapper[each.type], "content": each.content} + for each in input_ + ] + + return output_ + + def invoke( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + pred: "CCCR" = self.client_object.create_chat_completion( - messages=[ - {"role": self._role_mapper[each.type], "content": each.content} - for each in input_ - ], # type: ignore + messages=self.prepare_message(messages), stream=False, ) @@ -91,3 +118,19 @@ class LlamaCppChat(ChatLLM): total_tokens=pred["usage"]["total_tokens"], prompt_tokens=pred["usage"]["prompt_tokens"], ) + + def stream( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> Iterator[LLMInterface]: + pred = self.client_object.create_chat_completion( + messages=self.prepare_message(messages), + stream=True, + ) + for chunk in pred: + if not chunk["choices"]: + continue + + if "content" not in chunk["choices"][0]["delta"]: + continue + + yield LLMInterface(content=chunk["choices"][0]["delta"]["content"]) diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py new file mode 100644 index 0000000..6f492c7 --- /dev/null +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -0,0 +1,356 @@ +from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional + +from theflow.utils.modules import import_dotted_string + +from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param + +from .base import ChatLLM + +if TYPE_CHECKING: + from openai.types.chat.chat_completion_message_param import ( + ChatCompletionMessageParam, + ) + + +class BaseChatOpenAI(ChatLLM): + """Base interface for OpenAI chat model, using the openai library + + This class exposes the parameters in resources.Chat. To subclass this class: + + - Implement the `prepare_client` method to return the OpenAI client + - Implement the `openai_response` method to return the OpenAI response + - Implement the params relate to the OpenAI client + """ + + _dependencies = ["openai"] + _capabilities = ["chat", "text"] # consider as mixin + + api_key: str = Param(help="API key", required=True) + timeout: Optional[float] = Param(None, help="Timeout for the API request") + max_retries: Optional[int] = Param( + None, help="Maximum number of retries for the API request" + ) + + temperature: Optional[float] = Param( + None, + help=( + "Number between 0 and 2 that controls the randomness of the generated " + "tokens. Lower values make the model more deterministic, while higher " + "values make the model more random." + ), + ) + max_tokens: Optional[int] = Param( + None, + help=( + "Maximum number of tokens to generate. The total length of input tokens " + "and generated tokens is limited by the model's context length." + ), + ) + n: int = Param( + 1, + help=( + "Number of completions to generate. The API will generate n completion " + "for each prompt." + ), + ) + stop: Optional[str | list[str]] = Param( + None, + help=( + "Stop sequence. If a stop sequence is detected, generation will stop " + "at that point. If not specified, generation will continue until the " + "maximum token length is reached." + ), + ) + frequency_penalty: Optional[float] = Param( + None, + help=( + "Number between -2.0 and 2.0. Positive values penalize new tokens " + "based on their existing frequency in the text so far, decrearsing the " + "model's likelihood of repeating the same text." + ), + ) + presence_penalty: Optional[float] = Param( + None, + help=( + "Number between -2.0 and 2.0. Positive values penalize new tokens " + "based on their existing presence in the text so far, decrearsing the " + "model's likelihood of repeating the same text." + ), + ) + tool_choice: Optional[str] = Param( + None, + help=( + "Choice of tool to use for the completion. Available choices are: " + "auto, default." + ), + ) + tools: Optional[list[str]] = Param( + None, + help="List of tools to use for the completion.", + ) + logprobs: Optional[bool] = Param( + None, + help=( + "Include log probabilities on the logprobs most likely tokens, " + "as well as the chosen token." + ), + ) + logit_bias: Optional[dict] = Param( + None, + help=( + "Dictionary of logit bias values to add to the logits of the tokens " + "in the vocabulary." + ), + ) + top_logprobs: Optional[int] = Param( + None, + help=( + "An integer between 0 and 5 specifying the number of most likely tokens " + "to return at each token position, each with an associated log " + "probability. `logprobs` must also be set to `true` if this parameter " + "is used." + ), + ) + top_p: Optional[float] = Param( + None, + help=( + "An alternative to sampling with temperature, called nucleus sampling, " + "where the model considers the results of the token with top_p " + "probability mass. So 0.1 means that only the tokens comprising the " + "top 10% probability mass are considered." + ), + ) + + @Param.auto(depends_on=["max_retries"]) + def max_retries_(self): + if self.max_retries is None: + from openai._constants import DEFAULT_MAX_RETRIES + + return DEFAULT_MAX_RETRIES + return self.max_retries + + def prepare_message( + self, messages: str | BaseMessage | list[BaseMessage] + ) -> list["ChatCompletionMessageParam"]: + """Prepare the message into OpenAI format + + Returns: + list[dict]: List of messages in OpenAI format + """ + input_: list[BaseMessage] = [] + output_: list["ChatCompletionMessageParam"] = [] + + if isinstance(messages, str): + input_ = [HumanMessage(content=messages)] + elif isinstance(messages, BaseMessage): + input_ = [messages] + else: + input_ = messages + + for message in input_: + output_.append(message.to_openai_format()) + + return output_ + + def prepare_client(self, async_version: bool = False): + """Get the OpenAI client + + Args: + async_version (bool): Whether to get the async version of the client + """ + raise NotImplementedError + + def openai_response(self, client, **kwargs): + """Get the openai response""" + raise NotImplementedError + + def invoke( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> LLMInterface: + client = self.prepare_client(async_version=False) + input_messages = self.prepare_message(messages) + resp = self.openai_response( + client, messages=input_messages, stream=False, **kwargs + ).dict() + + output = LLMInterface( + candidates=[_["message"]["content"] for _ in resp["choices"]], + content=resp["choices"][0]["message"]["content"], + total_tokens=resp["usage"]["total_tokens"], + prompt_tokens=resp["usage"]["prompt_tokens"], + completion_tokens=resp["usage"]["completion_tokens"], + messages=[ + AIMessage(content=_["message"]["content"]) for _ in resp["choices"] + ], + ) + + return output + + async def ainvoke( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> LLMInterface: + client = self.prepare_client(async_version=True) + input_messages = self.prepare_message(messages) + resp = await self.openai_response( + client, messages=input_messages, stream=False, **kwargs + ).dict() + + output = LLMInterface( + candidates=[_["message"]["content"] for _ in resp["choices"]], + content=resp["choices"][0]["message"]["content"], + total_tokens=resp["usage"]["total_tokens"], + prompt_tokens=resp["usage"]["prompt_tokens"], + completion_tokens=resp["usage"]["completion_tokens"], + messages=[ + AIMessage(content=_["message"]["content"]) for _ in resp["choices"] + ], + ) + + return output + + def stream( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> Iterator[LLMInterface]: + client = self.prepare_client(async_version=False) + input_messages = self.prepare_message(messages) + resp = self.openai_response( + client, messages=input_messages, stream=True, **kwargs + ) + + for chunk in resp: + if not chunk.choices: + continue + if chunk.choices[0].delta.content is not None: + yield LLMInterface(content=chunk.choices[0].delta.content) + + async def astream( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> AsyncGenerator[LLMInterface, None]: + client = self.prepare_client(async_version=True) + input_messages = self.prepare_message(messages) + resp = self.openai_response( + client, messages=input_messages, stream=True, **kwargs + ) + + async for chunk in resp: + if not chunk.choices: + continue + if chunk.choices[0].delta.content is not None: + yield LLMInterface(content=chunk.choices[0].delta.content) + + +class ChatOpenAI(BaseChatOpenAI): + """OpenAI chat model""" + + base_url: Optional[str] = Param(None, help="OpenAI base URL") + organization: Optional[str] = Param(None, help="OpenAI organization") + model: str = Param(help="OpenAI model", required=True) + + def prepare_client(self, async_version: bool = False): + """Get the OpenAI client + + Args: + async_version (bool): Whether to get the async version of the client + """ + params = { + "api_key": self.api_key, + "organization": self.organization, + "base_url": self.base_url, + "timeout": self.timeout, + "max_retries": self.max_retries_, + } + if async_version: + from openai import AsyncOpenAI + + return AsyncOpenAI(**params) + + from openai import OpenAI + + return OpenAI(**params) + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": self.n, + "stop": self.stop, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "tool_choice": self.tool_choice, + "tools": self.tools, + "logprobs": self.logprobs, + "logit_bias": self.logit_bias, + "top_logprobs": self.top_logprobs, + "top_p": self.top_p, + } + params.update(kwargs) + + return client.chat.completions.create(**params) + + +class AzureChatOpenAI(BaseChatOpenAI): + """OpenAI chat model provided by Microsoft Azure""" + + azure_endpoint: str = Param( + help=( + "HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, " + "azure_deployment, and api_version parameters are used to construct " + "the full URL for the Azure OpenAI model." + ) + ) + azure_deployment: str = Param(help="Azure deployment name", required=True) + api_version: str = Param(help="Azure model version", required=True) + azure_ad_token: Optional[str] = Param(None, help="Azure AD token") + azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider") + + @Param.auto(depends_on=["azure_ad_token_provider"]) + def azure_ad_token_provider_(self): + if isinstance(self.azure_ad_token_provider, str): + return import_dotted_string(self.azure_ad_token_provider, safe=False) + + def prepare_client(self, async_version: bool = False): + """Get the OpenAI client + + Args: + async_version (bool): Whether to get the async version of the client + """ + params = { + "azure_endpoint": self.azure_endpoint, + "api_version": self.api_version, + "api_key": self.api_key, + "azure_ad_token": self.azure_ad_token, + "azure_ad_token_provider": self.azure_ad_token_provider_, + "timeout": self.timeout, + "max_retries": self.max_retries_, + } + if async_version: + from openai import AsyncAzureOpenAI + + return AsyncAzureOpenAI(**params) + + from openai import AzureOpenAI + + return AzureOpenAI(**params) + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = { + "model": self.azure_deployment, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": self.n, + "stop": self.stop, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "tool_choice": self.tool_choice, + "tools": self.tools, + "logprobs": self.logprobs, + "logit_bias": self.logit_bias, + "top_logprobs": self.top_logprobs, + "top_p": self.top_p, + } + params.update(kwargs) + + return client.chat.completions.create(**params) diff --git a/libs/kotaemon/kotaemon/llms/cot.py b/libs/kotaemon/kotaemon/llms/cot.py index 7eaf5d1..a52f9bd 100644 --- a/libs/kotaemon/kotaemon/llms/cot.py +++ b/libs/kotaemon/kotaemon/llms/cot.py @@ -5,7 +5,7 @@ from theflow import Function, Node, Param from kotaemon.base import BaseComponent, Document -from .chats import AzureChatOpenAI +from .chats import LCAzureChatOpenAI from .completions import LLM from .prompts import BasePromptComponent @@ -25,7 +25,7 @@ class Thought(BaseComponent): >> from kotaemon.pipelines.cot import Thought >> thought = Thought( prompt="How to {action} {object}?", - llm=AzureChatOpenAI(...), + llm=LCAzureChatOpenAI(...), post_process=lambda string: {"tutorial": string}, ) >> output = thought(action="install", object="python") @@ -42,7 +42,7 @@ class Thought(BaseComponent): This `Thought` allows chaining sequentially with the + operator. For example: ```python - >> llm = AzureChatOpenAI(...) + >> llm = LCAzureChatOpenAI(...) >> thought1 = Thought( prompt="Word {word} in {language} is ", llm=llm, @@ -73,7 +73,7 @@ class Thought(BaseComponent): " component is executed" ) ) - llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt") + llm: LLM = Node(LCAzureChatOpenAI, help="The LLM model to execute the input prompt") post_process: Function = Node( help=( "The function post-processor that post-processes LLM output prediction ." @@ -117,7 +117,7 @@ class ManualSequentialChainOfThought(BaseComponent): ```pycon >>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought - >>> llm = AzureChatOpenAI(...) + >>> llm = LCAzureChatOpenAI(...) >>> thought1 = Thought( >>> prompt="Word {word} in {language} is ", >>> post_process=lambda string: {"translated": string}, diff --git a/libs/kotaemon/kotaemon/llms/linear.py b/libs/kotaemon/kotaemon/llms/linear.py index ac8605a..4c61597 100644 --- a/libs/kotaemon/kotaemon/llms/linear.py +++ b/libs/kotaemon/kotaemon/llms/linear.py @@ -22,12 +22,12 @@ class SimpleLinearPipeline(BaseComponent): Example Usage: ```python - from kotaemon.llms import AzureChatOpenAI, BasePromptComponent + from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent def identity(x): return x - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", @@ -89,13 +89,13 @@ class GatedLinearPipeline(SimpleLinearPipeline): Usage: ```{.py3 title="Example Usage"} - from kotaemon.llms import AzureChatOpenAI, BasePromptComponent + from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent from kotaemon.parsers import RegexExtractor def identity(x): return x - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index 73c3e8a..a02337e 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"] # metadata and dependencies [project] name = "kotaemon" -version = "0.3.8" +version = "0.3.9" requires-python = ">= 3.10" description = "Kotaemon core library for AI development." dependencies = [ diff --git a/libs/kotaemon/tests/test_agent.py b/libs/kotaemon/tests/test_agent.py index 0cc65fa..d489af9 100644 --- a/libs/kotaemon/tests/test_agent.py +++ b/libs/kotaemon/tests/test_agent.py @@ -13,7 +13,7 @@ from kotaemon.agents import ( RewooAgent, WikipediaTool, ) -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" REWOO_VALID_PLAN = ( @@ -112,7 +112,7 @@ _openai_chat_completion_responses_react_langchain_tool = [ @pytest.fixture def llm(): - return AzureChatOpenAI( + return LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", diff --git a/libs/kotaemon/tests/test_composite.py b/libs/kotaemon/tests/test_composite.py index 464a456..38e79bd 100644 --- a/libs/kotaemon/tests/test_composite.py +++ b/libs/kotaemon/tests/test_composite.py @@ -4,10 +4,10 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion from kotaemon.llms import ( - AzureChatOpenAI, BasePromptComponent, GatedBranchingPipeline, GatedLinearPipeline, + LCAzureChatOpenAI, SimpleBranchingPipeline, SimpleLinearPipeline, ) @@ -40,7 +40,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj( @pytest.fixture def mock_llm(): - return AzureChatOpenAI( + return LCAzureChatOpenAI( azure_endpoint="OPENAI_API_BASE", openai_api_key="OPENAI_API_KEY", openai_api_version="OPENAI_API_VERSION", diff --git a/libs/kotaemon/tests/test_cot.py b/libs/kotaemon/tests/test_cot.py index aef8a69..5fd1344 100644 --- a/libs/kotaemon/tests/test_cot.py +++ b/libs/kotaemon/tests/test_cot.py @@ -2,7 +2,7 @@ from unittest.mock import patch from openai.types.chat.chat_completion import ChatCompletion -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought _openai_chat_completion_response = [ @@ -38,7 +38,7 @@ _openai_chat_completion_response = [ side_effect=_openai_chat_completion_response, ) def test_cot_plus_operator(openai_completion): - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", @@ -70,7 +70,7 @@ def test_cot_plus_operator(openai_completion): side_effect=_openai_chat_completion_response, ) def test_cot_manual(openai_completion): - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", @@ -100,7 +100,7 @@ def test_cot_manual(openai_completion): side_effect=_openai_chat_completion_response, ) def test_cot_with_termination_callback(openai_completion): - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", diff --git a/libs/kotaemon/tests/test_llms_chat_models.py b/libs/kotaemon/tests/test_llms_chat_models.py index a6a2a24..3758f76 100644 --- a/libs/kotaemon/tests/test_llms_chat_models.py +++ b/libs/kotaemon/tests/test_llms_chat_models.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage -from kotaemon.llms import AzureChatOpenAI, LlamaCppChat +from kotaemon.llms import LCAzureChatOpenAI, LlamaCppChat try: from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC @@ -43,7 +43,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj( side_effect=lambda *args, **kwargs: _openai_chat_completion_response, ) def test_azureopenai_model(openai_completion): - model = AzureChatOpenAI( + model = LCAzureChatOpenAI( azure_endpoint="https://test.openai.azure.com/", openai_api_key="some-key", openai_api_version="2023-03-15-preview", diff --git a/libs/kotaemon/tests/test_reranking.py b/libs/kotaemon/tests/test_reranking.py index d4f7be8..ee37d3c 100644 --- a/libs/kotaemon/tests/test_reranking.py +++ b/libs/kotaemon/tests/test_reranking.py @@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion from kotaemon.base import Document from kotaemon.indices.rankings import LLMReranking -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI _openai_chat_completion_responses = [ ChatCompletion.parse_obj( @@ -41,7 +41,7 @@ _openai_chat_completion_responses = [ @pytest.fixture def llm(): - return AzureChatOpenAI( + return LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py index 33ba88f..2e26cff 100644 --- a/libs/ktem/flowsettings.py +++ b/libs/ktem/flowsettings.py @@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config( ): if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""): KH_LLMS["azure"] = { - "def": { + "spec": { "__type__": "kotaemon.llms.AzureChatOpenAI", "temperature": 0, "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""), - "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""), + "api_key": config("AZURE_OPENAI_API_KEY", default=""), "api_version": config("OPENAI_API_VERSION", default="") or "2024-02-15-preview", - "deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""), - "request_timeout": 10, - "stream": False, + "azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""), + "timeout": 20, }, "default": False, "accuracy": 5, @@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config( } if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""): KH_EMBEDDINGS["azure"] = { - "def": { + "spec": { "__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings", "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""), "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""), @@ -164,5 +163,11 @@ KH_INDICES = [ "name": "File", "config": {}, "index_type": "ktem.index.file.FileIndex", - } + }, + { + "id": 2, + "name": "Sample", + "config": {}, + "index_type": "ktem.index.file.FileIndex", + }, ] diff --git a/libs/ktem/ktem/components.py b/libs/ktem/ktem/components.py index 6cfb2e3..182cb91 100644 --- a/libs/ktem/ktem/components.py +++ b/libs/ktem/ktem/components.py @@ -3,6 +3,7 @@ import logging from functools import cache from pathlib import Path +from typing import Optional from theflow.settings import settings from theflow.utils.modules import deserialize @@ -48,7 +49,7 @@ class ModelPool: self._default: list[str] = [] for name, model in conf.items(): - self._models[name] = deserialize(model["def"], safe=False) + self._models[name] = deserialize(model["spec"], safe=False) if model.get("default", False): self._default.append(name) @@ -58,11 +59,27 @@ class ModelPool: self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf")))) def __getitem__(self, key: str) -> BaseComponent: + """Get model by name""" return self._models[key] def __setitem__(self, key: str, value: BaseComponent): + """Set model by name""" self._models[key] = value + def __delitem__(self, key: str): + """Delete model by name""" + del self._models[key] + + def __contains__(self, key: str) -> bool: + """Check if model exists""" + return key in self._models + + def get( + self, key: str, default: Optional[BaseComponent] = None + ) -> Optional[BaseComponent]: + """Get model by name with default value""" + return self._models.get(key, default) + def settings(self) -> dict: """Present model pools option for gradio""" return { @@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS) embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS) reasonings: dict = {} tools = ModelPool("Tools", {}) -indices = ModelPool("Indices", {}) diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index b63d89c..13036f3 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): @classmethod def get_user_settings(cls) -> dict: - from ktem.components import llms + from ktem.llms.manager import llms try: - reranking_llm = llms.get_lowest_cost_name() + reranking_llm = llms.get_default_name() reranking_llm_choices = list(llms.options().keys()) except Exception as e: logger.error(e) diff --git a/libs/ktem/ktem/llms/__init__.py b/libs/ktem/ktem/llms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/ktem/ktem/llms/db.py b/libs/ktem/ktem/llms/db.py new file mode 100644 index 0000000..628ebb7 --- /dev/null +++ b/libs/ktem/ktem/llms/db.py @@ -0,0 +1,36 @@ +from typing import Type + +from ktem.db.engine import engine +from sqlalchemy import JSON, Boolean, Column, String +from sqlalchemy.orm import DeclarativeBase +from theflow.settings import settings as flowsettings +from theflow.utils.modules import import_dotted_string + + +class Base(DeclarativeBase): + pass + + +class BaseLLMTable(Base): + """Base table to store language model""" + + __abstract__ = True + + name = Column(String, primary_key=True, unique=True) + spec = Column(JSON, default={}) + default = Column(Boolean, default=False) + + +_base_llm: Type[BaseLLMTable] = ( + import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False) + if hasattr(flowsettings, "KH_TABLE_LLM") + else BaseLLMTable +) + + +class LLMTable(_base_llm): # type: ignore + __tablename__ = "llm_table" + + +if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False): + LLMTable.metadata.create_all(engine) diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py new file mode 100644 index 0000000..f9ad763 --- /dev/null +++ b/libs/ktem/ktem/llms/manager.py @@ -0,0 +1,191 @@ +from typing import Optional, Type + +from sqlalchemy import select +from sqlalchemy.orm import Session +from theflow.settings import settings as flowsettings +from theflow.utils.modules import deserialize + +from kotaemon.base import BaseComponent + +from .db import LLMTable, engine + + +class LLMManager: + """Represent a pool of models""" + + def __init__(self): + self._models: dict[str, BaseComponent] = {} + self._info: dict[str, dict] = {} + self._default: str = "" + self._vendors: list[Type] = [] + + if hasattr(flowsettings, "KH_LLMS"): + for name, model in flowsettings.KH_LLMS.items(): + with Session(engine) as session: + stmt = select(LLMTable).where(LLMTable.name == name) + result = session.execute(stmt) + if not result.first(): + item = LLMTable( + name=name, + spec=model["spec"], + default=model.get("default", False), + ) + session.add(item) + session.commit() + + self.load() + self.load_vendors() + + def load(self): + """Load the model pool from database""" + self._models, self._info, self._defaut = {}, {}, "" + with Session(engine) as session: + stmt = select(LLMTable) + items = session.execute(stmt) + + for (item,) in items: + self._models[item.name] = deserialize(item.spec, safe=False) + self._info[item.name] = { + "name": item.name, + "spec": item.spec, + "default": item.default, + } + if item.default: + self._default = item.name + + def load_vendors(self): + from kotaemon.llms import ( + AzureChatOpenAI, + ChatOpenAI, + EndpointChatLLM, + LlamaCppChat, + ) + + self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM] + + def __getitem__(self, key: str) -> BaseComponent: + """Get model by name""" + return self._models[key] + + def __contains__(self, key: str) -> bool: + """Check if model exists""" + return key in self._models + + def get( + self, key: str, default: Optional[BaseComponent] = None + ) -> Optional[BaseComponent]: + """Get model by name with default value""" + return self._models.get(key, default) + + def settings(self) -> dict: + """Present model pools option for gradio""" + return { + "label": "LLM", + "choices": list(self._models.keys()), + "value": self.get_default_name(), + } + + def options(self) -> dict: + """Present a dict of models""" + return self._models + + def get_random_name(self) -> str: + """Get the name of random model + + Returns: + str: random model name in the pool + """ + import random + + if not self._models: + raise ValueError("No models in pool") + + return random.choice(list(self._models.keys())) + + def get_default_name(self) -> str: + """Get the name of default model + + In case there is no default model, choose random model from pool. In + case there are multiple default models, choose random from them. + + Returns: + str: model name + """ + if not self._models: + raise ValueError("No models in pool") + + if not self._default: + return self.get_random_name() + + return self._default + + def get_random(self) -> BaseComponent: + """Get random model""" + return self._models[self.get_random_name()] + + def get_default(self) -> BaseComponent: + """Get default model + + In case there is no default model, choose random model from pool. In + case there are multiple default models, choose random from them. + + Returns: + BaseComponent: model + """ + return self._models[self.get_default_name()] + + def info(self) -> dict: + """List all models""" + return self._info + + def add(self, name: str, spec: dict, default: bool): + """Add a new model to the pool""" + try: + with Session(engine) as session: + item = LLMTable(name=name, spec=spec, default=default) + session.add(item) + session.commit() + except Exception as e: + raise ValueError(f"Failed to add model {name}: {e}") + + self.load() + + def delete(self, name: str): + """Delete a model from the pool""" + try: + with Session(engine) as session: + item = session.query(LLMTable).filter_by(name=name).first() + session.delete(item) + session.commit() + except Exception as e: + raise ValueError(f"Failed to delete model {name}: {e}") + + self.load() + + def update(self, name: str, spec: dict, default: bool): + """Update a model in the pool""" + try: + with Session(engine) as session: + + if default: + # turn all models to non-default + session.query(LLMTable).update({"default": False}) + session.commit() + + item = session.query(LLMTable).filter_by(name=name).first() + if not item: + raise ValueError(f"Model {name} not found") + item.spec = spec + item.default = default + session.commit() + except Exception as e: + raise ValueError(f"Failed to update model {name}: {e}") + + self.load() + + def vendors(self) -> dict: + """Return list of vendors""" + return {vendor.__qualname__: vendor for vendor in self._vendors} + + +llms = LLMManager() diff --git a/libs/ktem/ktem/llms/ui.py b/libs/ktem/ktem/llms/ui.py new file mode 100644 index 0000000..dd8f2bd --- /dev/null +++ b/libs/ktem/ktem/llms/ui.py @@ -0,0 +1,318 @@ +from copy import deepcopy + +import gradio as gr +import pandas as pd +import yaml +from ktem.app import BasePage + +from .manager import llms + + +def format_description(cls): + params = cls.describe()["params"] + params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"] + for key, value in params.items(): + if isinstance(value["auto_callback"], str): + continue + params_lines.append(f"| {key} | {value['type']} | {value['help']} |") + return f"{cls.__doc__}\n\n" + "\n".join(params_lines) + + +class LLMManagement(BasePage): + def __init__(self, app): + self._app = app + self.spec_desc_default = ( + "# Spec description\n\nSelect an LLM to view the spec description." + ) + self.on_building_ui() + + def on_building_ui(self): + with gr.Tab(label="View"): + self.llm_list = gr.DataFrame( + headers=["name", "vendor", "default"], + interactive=False, + ) + + with gr.Column(visible=False) as self._selected_panel: + self.selected_llm_name = gr.Textbox(value="", visible=False) + with gr.Row(): + with gr.Column(): + self.edit_default = gr.Checkbox( + label="Set default", + info=( + "Set this LLM as default. If no default is set, a " + "random LLM will be used." + ), + ) + self.edit_spec = gr.Textbox( + label="Specification", + info="Specification of the LLM in YAML format", + lines=10, + ) + + with gr.Row(visible=False) as self._selected_panel_btn: + with gr.Column(): + self.btn_edit_save = gr.Button("Save", min_width=10) + with gr.Column(): + self.btn_delete = gr.Button("Delete", min_width=10) + with gr.Row(): + self.btn_delete_yes = gr.Button( + "Confirm delete", + variant="primary", + visible=False, + min_width=10, + ) + self.btn_delete_no = gr.Button( + "Cancel", visible=False, min_width=10 + ) + with gr.Column(): + self.btn_close = gr.Button("Close", min_width=10) + + with gr.Column(): + self.edit_spec_desc = gr.Markdown("# Spec description") + + with gr.Tab(label="Add"): + with gr.Row(): + with gr.Column(scale=2): + self.name = gr.Textbox( + label="LLM name", + info=( + "Must be unique. The name will be used to identify the LLM." + ), + ) + self.llm_choices = gr.Dropdown( + label="LLM vendors", + info=( + "Choose the vendor for the LLM. Each vendor has different " + "specification." + ), + ) + self.spec = gr.Textbox( + label="Specification", + info="Specification of the LLM in YAML format", + ) + self.default = gr.Checkbox( + label="Set default", + info=( + "Set this LLM as default. This default LLM will be used " + "by default across the application." + ), + ) + self.btn_new = gr.Button("Create LLM") + + with gr.Column(scale=3): + self.spec_desc = gr.Markdown(self.spec_desc_default) + + def _on_app_created(self): + """Called when the app is created""" + self._app.app.load( + self.list_llms, + inputs=None, + outputs=[self.llm_list], + ) + self._app.app.load( + lambda: gr.update(choices=list(llms.vendors().keys())), + outputs=[self.llm_choices], + ) + + def on_llm_vendor_change(self, vendor): + vendor = llms.vendors()[vendor] + + required: dict = {} + desc = vendor.describe() + for key, value in desc["params"].items(): + if value.get("required", False): + required[key] = None + + return yaml.dump(required), format_description(vendor) + + def on_register_events(self): + self.llm_choices.select( + self.on_llm_vendor_change, + inputs=[self.llm_choices], + outputs=[self.spec, self.spec_desc], + ) + self.btn_new.click( + self.create_llm, + inputs=[self.name, self.llm_choices, self.spec, self.default], + outputs=None, + ).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then( + lambda: ("", None, "", False, self.spec_desc_default), + outputs=[ + self.name, + self.llm_choices, + self.spec, + self.default, + self.spec_desc, + ], + ) + self.llm_list.select( + self.select_llm, + inputs=self.llm_list, + outputs=[self.selected_llm_name], + show_progress="hidden", + ) + self.selected_llm_name.change( + self.on_selected_llm_change, + inputs=[self.selected_llm_name], + outputs=[ + self._selected_panel, + self._selected_panel_btn, + # delete section + self.btn_delete, + self.btn_delete_yes, + self.btn_delete_no, + # edit section + self.edit_spec, + self.edit_spec_desc, + self.edit_default, + ], + show_progress="hidden", + ) + self.btn_delete.click( + self.on_btn_delete_click, + inputs=None, + outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no], + show_progress="hidden", + ) + self.btn_delete_yes.click( + self.delete_llm, + inputs=[self.selected_llm_name], + outputs=[self.selected_llm_name], + show_progress="hidden", + ).then( + self.list_llms, + inputs=None, + outputs=[self.llm_list], + ) + self.btn_delete_no.click( + lambda: ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + ), + inputs=None, + outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no], + show_progress="hidden", + ) + self.btn_edit_save.click( + self.save_llm, + inputs=[ + self.selected_llm_name, + self.edit_default, + self.edit_spec, + ], + show_progress="hidden", + ).then( + self.list_llms, + inputs=None, + outputs=[self.llm_list], + ) + self.btn_close.click( + lambda: "", + outputs=[self.selected_llm_name], + ) + + def create_llm(self, name, choices, spec, default): + try: + spec = yaml.safe_load(spec) + spec["__type__"] = ( + llms.vendors()[choices].__module__ + + "." + + llms.vendors()[choices].__qualname__ + ) + + llms.add(name, spec=spec, default=default) + gr.Info(f"LLM {name} created successfully") + except Exception as e: + gr.Error(f"Failed to create LLM {name}: {e}") + + def list_llms(self): + """List the LLMs""" + items = [] + for item in llms.info().values(): + record = {} + record["name"] = item["name"] + record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1] + record["default"] = item["default"] + items.append(record) + + if items: + llm_list = pd.DataFrame.from_records(items) + else: + llm_list = pd.DataFrame.from_records( + [{"name": "-", "vendor": "-", "default": "-"}] + ) + + return llm_list + + def select_llm(self, llm_list, ev: gr.SelectData): + if ev.value == "-" and ev.index[0] == 0: + gr.Info("No LLM is loaded. Please add LLM first") + return "" + + if not ev.selected: + return "" + + return llm_list["name"][ev.index[0]] + + def on_selected_llm_change(self, selected_llm_name): + if selected_llm_name == "": + _selected_panel = gr.update(visible=False) + _selected_panel_btn = gr.update(visible=False) + btn_delete = gr.update(visible=True) + btn_delete_yes = gr.update(visible=False) + btn_delete_no = gr.update(visible=False) + edit_spec = gr.update(value="") + edit_spec_desc = gr.update(value="") + edit_default = gr.update(value=False) + else: + _selected_panel = gr.update(visible=True) + _selected_panel_btn = gr.update(visible=True) + btn_delete = gr.update(visible=True) + btn_delete_yes = gr.update(visible=False) + btn_delete_no = gr.update(visible=False) + + info = deepcopy(llms.info()[selected_llm_name]) + vendor_str = info["spec"].pop("__type__", "-").split(".")[-1] + vendor = llms.vendors()[vendor_str] + + edit_spec = yaml.dump(info["spec"]) + edit_spec_desc = format_description(vendor) + edit_default = info["default"] + + return ( + _selected_panel, + _selected_panel_btn, + btn_delete, + btn_delete_yes, + btn_delete_no, + edit_spec, + edit_spec_desc, + edit_default, + ) + + def on_btn_delete_click(self): + btn_delete = gr.update(visible=False) + btn_delete_yes = gr.update(visible=True) + btn_delete_no = gr.update(visible=True) + + return btn_delete, btn_delete_yes, btn_delete_no + + def save_llm(self, selected_llm_name, default, spec): + try: + spec = yaml.safe_load(spec) + spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"] + llms.update(selected_llm_name, spec=spec, default=default) + gr.Info(f"LLM {selected_llm_name} saved successfully") + except Exception as e: + gr.Error(f"Failed to save LLM {selected_llm_name}: {e}") + + def delete_llm(self, selected_llm_name): + try: + llms.delete(selected_llm_name) + except Exception as e: + gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}") + return selected_llm_name + + return "" diff --git a/libs/ktem/ktem/pages/admin/__init__.py b/libs/ktem/ktem/pages/admin/__init__.py index 1cc58c7..b32d816 100644 --- a/libs/ktem/ktem/pages/admin/__init__.py +++ b/libs/ktem/ktem/pages/admin/__init__.py @@ -1,6 +1,7 @@ import gradio as gr from ktem.app import BasePage from ktem.db.models import User, engine +from ktem.llms.ui import LLMManagement from sqlmodel import Session, select from .user import UserManagement @@ -16,6 +17,9 @@ class AdminPage(BasePage): with gr.Tab("User Management", visible=False) as self.user_management_tab: self.user_management = UserManagement(self._app) + with gr.Tab("LLM Management") as self.llm_management_tab: + self.llm_management = LLMManagement(self._app) + def on_subscribe_public_events(self): if self._app.f_user_management: self._app.subscribe_event( diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 04c78eb..3f01571 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine from sqlmodel import Session, select from theflow.settings import settings as flowsettings +from kotaemon.base import Document + from .chat_panel import ChatPanel from .chat_suggestion import ChatSuggestion from .common import STATE @@ -189,6 +191,7 @@ class ChatPage(BasePage): self.chat_control.conversation_rn, self.chat_panel.chatbot, self.info_panel, + self.chat_state, ] + self._indices_input, show_progress="hidden", @@ -220,6 +223,7 @@ class ChatPage(BasePage): self.chat_control.conversation_rn, self.chat_panel.chatbot, self.info_panel, + self.chat_state, ] + self._indices_input, show_progress="hidden", @@ -392,7 +396,7 @@ class ChatPage(BasePage): return pipeline, reasoning_state - async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): + def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): """Chat function""" chat_input = chat_history[-1][0] chat_history = chat_history[:-1] @@ -403,52 +407,43 @@ class ChatPage(BasePage): pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds) pipeline.set_output_queue(queue) - asyncio.create_task(pipeline(chat_input, conversation_id, chat_history)) text, refs = "", "" - - len_ref = -1 # for logging purpose msg_placeholder = getattr( flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..." ) - print(msg_placeholder) - while True: - try: - response = queue.get_nowait() - except Exception: - state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] - yield chat_history + [ - (chat_input, text or msg_placeholder) - ], refs, state + yield chat_history + [(chat_input, text or msg_placeholder)], refs, state + + len_ref = -1 # for logging purpose + + for response in pipeline.stream(chat_input, conversation_id, chat_history): + + if not isinstance(response, Document): continue - if response is None: - queue.task_done() - print("Chat completed") - break + if response.channel is None: + continue - if "output" in response: - if response["output"] is None: + if response.channel == "chat": + if response.content is None: text = "" else: - text += response["output"] + text += response.content - if "evidence" in response: - if response["evidence"] is None: + if response.channel == "info": + if response.content is None: refs = "" else: - refs += response["evidence"] + refs += response.content if len(refs) > len_ref: print(f"Len refs: {len(refs)}") len_ref = len(refs) - state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] - yield chat_history + [(chat_input, text)], refs, state + state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + yield chat_history + [(chat_input, text or msg_placeholder)], refs, state - async def regen_fn( - self, conversation_id, chat_history, settings, state, *selecteds - ): + def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds): """Regen function""" if not chat_history: gr.Warning("Empty chat") @@ -456,12 +451,11 @@ class ChatPage(BasePage): return state["app"]["regen"] = True - async for chat, refs, state in self.chat_fn( + for chat, refs, state in self.chat_fn( conversation_id, chat_history, settings, state, *selecteds ): new_state = deepcopy(state) new_state["app"]["regen"] = False yield chat, refs, new_state - else: - state["app"]["regen"] = False - yield chat_history, "", state + + state["app"]["regen"] = False diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 082c20f..3397250 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -4,10 +4,10 @@ import logging import re from collections import defaultdict from functools import partial +from typing import Generator import tiktoken -from ktem.components import llms -from theflow.settings import settings as flowsettings +from ktem.llms.manager import llms from kotaemon.base import ( BaseComponent, @@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent): lang: the language of the answer. Currently support English and Japanese """ - llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy()) - vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) + vlm_endpoint: str = "" citation_pipeline: CitationPipeline = Node( - default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost()) + default_callback=lambda _: CitationPipeline(llm=llms.get_default()) ) qa_template: str = DEFAULT_QA_TEXT_PROMPT @@ -297,13 +297,95 @@ class AnswerWithContextPipeline(BaseComponent): return answer + def stream( # type: ignore + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Generator[Document, None, Document]: + """Answer the question based on the evidence -def extract_evidence_images(self, evidence: str): - """Util function to extract and isolate images from context/evidence""" - image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" - matches = re.findall(image_pattern, evidence) - context = re.sub(image_pattern, "", evidence) - return context, matches + In addition to the question and the evidence, this method also take into + account evidence_mode. The evidence_mode tells which kind of evidence is. + The kind of evidence affects: + 1. How the evidence is represented. + 2. The prompt to generate the answer. + + By default, the evidence_mode is 0, which means the evidence is plain text with + no particular semantic representation. The evidence_mode can be: + 1. "table": There will be HTML markup telling that there is a table + within the evidence. + 2. "chatbot": There will be HTML markup telling that there is a chatbot. + This chatbot is a scenario, extracted from an Excel file, where each + row corresponds to an interaction. + + Args: + question: the original question posed by user + evidence: the text that contain relevant information to answer the question + (determined by retrieval pipeline) + evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot + """ + if evidence_mode == EVIDENCE_MODE_TEXT: + prompt_template = PromptTemplate(self.qa_template) + elif evidence_mode == EVIDENCE_MODE_TABLE: + prompt_template = PromptTemplate(self.qa_table_template) + elif evidence_mode == EVIDENCE_MODE_FIGURE: + prompt_template = PromptTemplate(self.qa_figure_template) + else: + prompt_template = PromptTemplate(self.qa_chatbot_template) + + images = [] + if evidence_mode == EVIDENCE_MODE_FIGURE: + # isolate image from evidence + evidence, images = self.extract_evidence_images(evidence) + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + else: + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + + output = "" + if evidence_mode == EVIDENCE_MODE_FIGURE: + for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): + output += text + yield Document(channel="chat", content=text) + else: + messages = [] + if self.system_prompt: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(HumanMessage(content=prompt)) + + try: + # try streaming first + print("Trying LLM streaming") + for text in self.llm.stream(messages): + output += text.text + yield Document(channel="chat", content=text.text) + except NotImplementedError: + print("Streaming is not supported, falling back to normal processing") + output = self.llm(messages).text + yield Document(channel="chat", content=output) + + # retrieve the citation + citation = None + if evidence and self.enable_citation: + citation = self.citation_pipeline.invoke( + context=evidence, question=question + ) + + answer = Document(text=output, metadata={"citation": citation}) + + return answer + + def extract_evidence_images(self, evidence: str): + """Util function to extract and isolate images from context/evidence""" + image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" + matches = re.findall(image_pattern, evidence) + context = re.sub(image_pattern, "", evidence) + return context, matches class RewriteQuestionPipeline(BaseComponent): @@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent): lang: the language of the answer. Currently support English and Japanese """ - llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost()) + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) rewrite_template: str = DEFAULT_REWRITE_PROMPT lang: str = "English" - async def run(self, question: str) -> Document: # type: ignore + def run(self, question: str) -> Document: # type: ignore prompt_template = PromptTemplate(self.rewrite_template) prompt = prompt_template.populate(question=question, lang=self.lang) messages = [ SystemMessage(content="You are a helpful assistant"), HumanMessage(content=prompt), ] - output = "" - for text in self.llm(messages): - if "content" in text: - output += text[1] - self.report_output({"chat_input": text[1]}) - break - await asyncio.sleep(0) - - return Document(text=output) + return self.llm(messages) class FullQAPipeline(BaseReasoning): @@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning): rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() use_rewrite: bool = False - async def run( # type: ignore + async def ainvoke( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Document: # type: ignore import markdown @@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning): self.report_output(None) return answer + def stream( # type: ignore + self, message: str, conv_id: str, history: list, **kwargs # type: ignore + ) -> Generator[Document, None, Document]: + import markdown + + docs = [] + doc_ids = [] + if self.use_rewrite: + message = self.rewrite_pipeline(question=message).text + + for retriever in self.retrievers: + for doc in retriever(text=message): + if doc.doc_id not in doc_ids: + docs.append(doc) + doc_ids.append(doc.doc_id) + for doc in docs: + # TODO: a better approach to show the information + text = markdown.markdown( + doc.text, extensions=["markdown.extensions.tables"] + ) + yield Document( + content=( + "
" + f"{doc.metadata['file_name']}" + f"{text}" + "

" + ), + channel="info", + ) + + evidence_mode, evidence = self.evidence_pipeline(docs).content + answer = yield from self.answering_pipeline.stream( + question=message, + history=history, + evidence=evidence, + evidence_mode=evidence_mode, + conv_id=conv_id, + **kwargs, + ) + + # prepare citation + spans = defaultdict(list) + if answer.metadata["citation"] is not None: + for fact_with_evidence in answer.metadata["citation"].answer: + for quote in fact_with_evidence.substring_quote: + for doc in docs: + start_idx = doc.text.find(quote) + if start_idx == -1: + continue + + end_idx = start_idx + len(quote) + + current_idx = start_idx + if "|" not in doc.text[start_idx:end_idx]: + spans[doc.doc_id].append( + {"start": start_idx, "end": end_idx} + ) + else: + while doc.text[current_idx:end_idx].find("|") != -1: + match_idx = doc.text[current_idx:end_idx].find("|") + spans[doc.doc_id].append( + { + "start": current_idx, + "end": current_idx + match_idx, + } + ) + current_idx += match_idx + 2 + if current_idx > end_idx: + break + break + + id2docs = {doc.doc_id: doc for doc in docs} + lack_evidence = True + not_detected = set(id2docs.keys()) - set(spans.keys()) + yield Document(channel="info", content=None) + for id, ss in spans.items(): + if not ss: + not_detected.add(id) + continue + ss = sorted(ss, key=lambda x: x["start"]) + text = id2docs[id].text[: ss[0]["start"]] + for idx, span in enumerate(ss): + text += ( + "" + id2docs[id].text[span["start"] : span["end"]] + "" + ) + if idx < len(ss) - 1: + text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] + text += id2docs[id].text[ss[-1]["end"] :] + text_out = markdown.markdown( + text, extensions=["markdown.extensions.tables"] + ) + yield Document( + content=( + "
" + f"{id2docs[id].metadata['file_name']}" + f"{text_out}" + "

" + ), + channel="info", + ) + lack_evidence = False + + if lack_evidence: + yield Document(channel="info", content="No evidence found.\n") + + if not_detected: + yield Document( + channel="info", + content="Retrieved segments without matching evidence:\n", + ) + for id in list(not_detected): + text_out = markdown.markdown( + id2docs[id].text, extensions=["markdown.extensions.tables"] + ) + yield Document( + content=( + "
" + f"{id2docs[id].metadata['file_name']}" + f"{text_out}" + "

" + ), + channel="info", + ) + + return answer + @classmethod def get_pipeline(cls, settings, states, retrievers): """Get the reasoning pipeline @@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning): _id = cls.get_info()["id"] pipeline = FullQAPipeline(retrievers=retrievers) - pipeline.answering_pipeline.llm = llms[ - settings[f"reasoning.options.{_id}.main_llm"] - ] - pipeline.answering_pipeline.citation_pipeline.llm = llms[ - settings[f"reasoning.options.{_id}.citation_llm"] - ] + pipeline.answering_pipeline.llm = llms.get_default() + pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default() + pipeline.answering_pipeline.enable_citation = settings[ f"reasoning.options.{_id}.highlight_citation" ] @@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning): f"reasoning.options.{_id}.qa_prompt" ] pipeline.use_rewrite = states.get("app", {}).get("regen", False) - pipeline.rewrite_pipeline.llm = llms.get_lowest_cost() + pipeline.rewrite_pipeline.llm = llms.get_default() pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( settings["reasoning.lang"], "English" ) @@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning): @classmethod def get_user_settings(cls) -> dict: - from ktem.components import llms - - try: - citation_llm = llms.get_lowest_cost_name() - citation_llm_choices = list(llms.options().keys()) - main_llm = llms.get_highest_accuracy_name() - main_llm_choices = list(llms.options().keys()) - except Exception as e: - logger.error(e) - citation_llm = None - citation_llm_choices = [] - main_llm = None - main_llm_choices = [] - return { "highlight_citation": { "name": "Highlight Citation", "value": False, "component": "checkbox", }, - "citation_llm": { - "name": "LLM for citation", - "value": citation_llm, - "component": "dropdown", - "choices": citation_llm_choices, - }, - "main_llm": { - "name": "LLM for main generation", - "value": main_llm, - "component": "dropdown", - "choices": main_llm_choices, - }, "system_prompt": { "name": "System Prompt", "value": "This is a question answering system", diff --git a/libs/ktem/ktem_tests/test_qa.py b/libs/ktem/ktem_tests/test_qa.py index a3993ee..80ee68b 100644 --- a/libs/ktem/ktem_tests/test_qa.py +++ b/libs/ktem/ktem_tests/test_qa.py @@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline from openai.resources.embeddings import Embeddings from openai.types.chat.chat_completion import ChatCompletion -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f: openai_embedding = json.load(f) @@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path): assert len(results) == 1 # create llm - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="https://test.openai.azure.com/", openai_api_key="some-key", openai_api_version="2023-03-15-preview", diff --git a/libs/ktem/launch.py b/libs/ktem/launch.py index 2ac7a1a..1f436c5 100644 --- a/libs/ktem/launch.py +++ b/libs/ktem/launch.py @@ -2,4 +2,4 @@ from ktem.main import App app = App() demo = app.make() -demo.queue().launch(favicon_path=app._favicon, inbrowser=True) +demo.queue().launch(favicon_path=app._favicon) diff --git a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py index 1739ca8..db2fa0b 100644 --- a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py +++ b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py @@ -5,7 +5,7 @@ from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, la from kotaemon.contribs.promptui.logs import ResultLog from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices import VectorIndexing, VectorRetrieval -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore @@ -34,7 +34,7 @@ class QuestionAnsweringPipeline(BaseComponent): ] retrieval_top_k: int = 1 - llm: AzureChatOpenAI = AzureChatOpenAI.withx( + llm: LCAzureChatOpenAI = LCAzureChatOpenAI.withx( azure_endpoint="https://bleh-dummy-2.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"), openai_api_version="2023-03-15-preview",