diff --git a/.gitignore b/.gitignore
index 5c91c3e..711a741 100644
--- a/.gitignore
+++ b/.gitignore
@@ -458,6 +458,7 @@ logs/
.gitsecret/keys/random_seed
!*.secret
.envrc
+.env
S.gpg-agent*
.vscode/settings.json
diff --git a/docs/development/create-a-component.md b/docs/development/create-a-component.md
index 029bbe9..f831259 100644
--- a/docs/development/create-a-component.md
+++ b/docs/development/create-a-component.md
@@ -22,7 +22,7 @@ The syntax of a component is as follow:
```python
from kotaemon.base import BaseComponent
-from kotaemon.llms import AzureChatOpenAI
+from kotaemon.llms import LCAzureChatOpenAI
from kotaemon.parsers import RegexExtractor
@@ -32,7 +32,7 @@ class FancyPipeline(BaseComponent):
param3: float
node1: BaseComponent # this is a node because of BaseComponent type annotation
- node2: AzureChatOpenAI # this is also a node because AzureChatOpenAI subclasses BaseComponent
+ node2: LCAzureChatOpenAI # this is also a node because LCAzureChatOpenAI subclasses BaseComponent
node3: RegexExtractor # this is also a node bceause RegexExtractor subclasses BaseComponent
def run(self, some_text: str):
@@ -45,7 +45,7 @@ class FancyPipeline(BaseComponent):
Then this component can be used as follow:
```python
-llm = AzureChatOpenAI(endpoint="some-endpont")
+llm = LCAzureChatOpenAI(endpoint="some-endpont")
extractor = RegexExtractor(pattern=["yes", "Yes"])
component = FancyPipeline(
diff --git a/docs/pages/app/customize-flows.md b/docs/pages/app/customize-flows.md
index 1277e34..3dd005e 100644
--- a/docs/pages/app/customize-flows.md
+++ b/docs/pages/app/customize-flows.md
@@ -193,7 +193,8 @@ information panel.
You can access users' collections of LLMs and embedding models with:
```python
-from ktem.components import llms, embeddings
+from ktem.components import embeddings
+from ktem.llms.manager import llms
llm = llms.get_default()
@@ -206,12 +207,12 @@ models they want to use through the settings.
```python
@classmethod
def get_user_settings(cls) -> dict:
- from ktem.components import llms
+ from ktem.llms.manager import llms
return {
"citation_llm": {
"name": "LLM for citation",
- "value": llms.get_lowest_cost_name(),
+ "value": llms.get_default(),
"component: "dropdown",
"choices": list(llms.options().keys()),
},
diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py
index 9acd39f..6936b2a 100644
--- a/libs/kotaemon/kotaemon/base/component.py
+++ b/libs/kotaemon/kotaemon/base/component.py
@@ -52,7 +52,7 @@ class BaseComponent(Function):
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
...
- async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
+ def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
...
@abstractmethod
diff --git a/libs/kotaemon/kotaemon/base/schema.py b/libs/kotaemon/kotaemon/base/schema.py
index 1d0e622..07fe9f5 100644
--- a/libs/kotaemon/kotaemon/base/schema.py
+++ b/libs/kotaemon/kotaemon/base/schema.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
from langchain.schema.messages import AIMessage as LCAIMessage
from langchain.schema.messages import HumanMessage as LCHumanMessage
@@ -10,6 +10,9 @@ from llama_index.schema import Document as BaseDocument
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
+ from openai.types.chat.chat_completion_message_param import (
+ ChatCompletionMessageParam,
+ )
IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon"
@@ -26,10 +29,15 @@ class Document(BaseDocument):
Attributes:
content: raw content of the document, can be anything
source: id of the source of the Document. Optional.
+ channel: the channel to show the document. Optional.:
+ - chat: show in chat message
+ - info: show in information panel
+ - debug: show in debug panel
"""
- content: Any
+ content: Any = None
source: Optional[str] = None
+ channel: Optional[Literal["chat", "info", "debug"]] = None
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:
@@ -87,17 +95,23 @@ class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError
+ def to_openai_format(self) -> "ChatCompletionMessageParam":
+ raise NotImplementedError
+
class SystemMessage(BaseMessage, LCSystemMessage):
- pass
+ def to_openai_format(self) -> "ChatCompletionMessageParam":
+ return {"role": "system", "content": self.content}
class AIMessage(BaseMessage, LCAIMessage):
- pass
+ def to_openai_format(self) -> "ChatCompletionMessageParam":
+ return {"role": "assistant", "content": self.content}
class HumanMessage(BaseMessage, LCHumanMessage):
- pass
+ def to_openai_format(self) -> "ChatCompletionMessageParam":
+ return {"role": "user", "content": self.content}
class RetrievedDocument(Document):
diff --git a/libs/kotaemon/kotaemon/indices/qa/text_based.py b/libs/kotaemon/kotaemon/indices/qa/text_based.py
index 5b1f6e3..e0b49be 100644
--- a/libs/kotaemon/kotaemon/indices/qa/text_based.py
+++ b/libs/kotaemon/kotaemon/indices/qa/text_based.py
@@ -1,7 +1,7 @@
import os
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
-from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
+from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate
from .citation import CitationPipeline
@@ -13,7 +13,7 @@ class CitationQAPipeline(BaseComponent):
'Answer the following question: "{question}". '
"The context is: \n{context}\nAnswer: "
)
- llm: BaseLLM = AzureChatOpenAI.withx(
+ llm: BaseLLM = LCAzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-07-01-preview",
diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py
index d7547a6..266e391 100644
--- a/libs/kotaemon/kotaemon/llms/__init__.py
+++ b/libs/kotaemon/kotaemon/llms/__init__.py
@@ -2,7 +2,15 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
-from .chats import AzureChatOpenAI, ChatLLM, ChatOpenAI, EndpointChatLLM, LlamaCppChat
+from .chats import (
+ AzureChatOpenAI,
+ ChatLLM,
+ ChatOpenAI,
+ EndpointChatLLM,
+ LCAzureChatOpenAI,
+ LCChatOpenAI,
+ LlamaCppChat,
+)
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
from .cot import ManualSequentialChainOfThought, Thought
from .linear import GatedLinearPipeline, SimpleLinearPipeline
@@ -17,8 +25,10 @@ __all__ = [
"HumanMessage",
"AIMessage",
"SystemMessage",
- "ChatOpenAI",
"AzureChatOpenAI",
+ "ChatOpenAI",
+ "LCAzureChatOpenAI",
+ "LCChatOpenAI",
"LlamaCppChat",
# completion-specific components
"LLM",
diff --git a/libs/kotaemon/kotaemon/llms/base.py b/libs/kotaemon/kotaemon/llms/base.py
index 6ef7afc..374d139 100644
--- a/libs/kotaemon/kotaemon/llms/base.py
+++ b/libs/kotaemon/kotaemon/llms/base.py
@@ -18,5 +18,8 @@ class BaseLLM(BaseComponent):
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
raise NotImplementedError
- async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
+ def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
raise NotImplementedError
+
+ def run(self, *args, **kwargs):
+ return self.invoke(*args, **kwargs)
diff --git a/libs/kotaemon/kotaemon/llms/branching.py b/libs/kotaemon/kotaemon/llms/branching.py
index a9cbbe8..ee49dc5 100644
--- a/libs/kotaemon/kotaemon/llms/branching.py
+++ b/libs/kotaemon/kotaemon/llms/branching.py
@@ -15,7 +15,7 @@ class SimpleBranchingPipeline(BaseComponent):
Example:
```python
from kotaemon.llms import (
- AzureChatOpenAI,
+ LCAzureChatOpenAI,
BasePromptComponent,
GatedLinearPipeline,
)
@@ -25,7 +25,7 @@ class SimpleBranchingPipeline(BaseComponent):
return x
pipeline = SimpleBranchingPipeline()
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
@@ -92,7 +92,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
Example:
```python
from kotaemon.llms import (
- AzureChatOpenAI,
+ LCAzureChatOpenAI,
BasePromptComponent,
GatedLinearPipeline,
)
@@ -102,7 +102,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
return x
pipeline = GatedBranchingPipeline()
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
@@ -157,7 +157,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
if __name__ == "__main__":
import dotenv
- from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
+ from kotaemon.llms import BasePromptComponent, LCAzureChatOpenAI
from kotaemon.parsers import RegexExtractor
def identity(x):
@@ -166,7 +166,7 @@ if __name__ == "__main__":
secrets = dotenv.dotenv_values(".env")
pipeline = GatedBranchingPipeline()
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py
index 5b50317..7fc1c40 100644
--- a/libs/kotaemon/kotaemon/llms/chats/__init__.py
+++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py
@@ -1,13 +1,17 @@
from .base import ChatLLM
from .endpoint_based import EndpointChatLLM
-from .langchain_based import AzureChatOpenAI, ChatOpenAI, LCChatMixin
+from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
from .llamacpp import LlamaCppChat
+from .openai import AzureChatOpenAI, ChatOpenAI
__all__ = [
+ "ChatOpenAI",
+ "AzureChatOpenAI",
"ChatLLM",
"EndpointChatLLM",
"ChatOpenAI",
- "AzureChatOpenAI",
+ "LCChatOpenAI",
+ "LCAzureChatOpenAI",
"LCChatMixin",
"LlamaCppChat",
]
diff --git a/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py b/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
index 170ec8b..5ab1835 100644
--- a/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
+++ b/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
@@ -5,6 +5,7 @@ from kotaemon.base import (
BaseMessage,
HumanMessage,
LLMInterface,
+ Param,
SystemMessage,
)
@@ -20,7 +21,9 @@ class EndpointChatLLM(ChatLLM):
endpoint_url (str): The url of a OpenAI API compatible endpoint.
"""
- endpoint_url: str
+ endpoint_url: str = Param(
+ help="URL of the OpenAI API compatible endpoint", required=True
+ )
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py
index 526eaf8..fca78dc 100644
--- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py
+++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py
@@ -165,7 +165,7 @@ class LCChatMixin:
raise ValueError(f"Invalid param {path}")
-class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
+class LCChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__(
self,
openai_api_base: str | None = None,
@@ -193,7 +193,7 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
return ChatOpenAI
-class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
+class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__(
self,
azure_endpoint: str | None = None,
diff --git a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py
index 62ee0ea..7b8bee4 100644
--- a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py
+++ b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Iterator, Optional, cast
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
@@ -12,13 +12,32 @@ if TYPE_CHECKING:
class LlamaCppChat(ChatLLM):
"""Wrapper around the llama-cpp-python's Llama model"""
- model_path: Optional[str] = None
- chat_format: Optional[str] = None
- lora_base: Optional[str] = None
- n_ctx: int = 512
- n_gpu_layers: int = 0
- use_mmap: bool = True
- vocab_only: bool = False
+ model_path: str = Param(
+ help="Path to the model file. This is required to load the model.",
+ required=True,
+ )
+ chat_format: str = Param(
+ help=(
+ "Chat format to use. Please refer to llama_cpp.llama_chat_format for a "
+ "list of supported formats. If blank, the chat format will be auto-"
+ "inferred."
+ ),
+ required=True,
+ )
+ lora_base: Optional[str] = Param(None, help="Path to the base Lora model")
+ n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
+ n_gpu_layers: Optional[int] = Param(
+ 0,
+ help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
+ )
+ use_mmap: Optional[bool] = Param(
+ True,
+ help=(),
+ )
+ vocab_only: Optional[bool] = Param(
+ False,
+ help=("If True, only the vocabulary is loaded. This is useful for debugging."),
+ )
_role_mapper: dict[str, str] = {
"human": "user",
@@ -60,9 +79,9 @@ class LlamaCppChat(ChatLLM):
vocab_only=self.vocab_only,
)
- def run(
- self, messages: str | BaseMessage | list[BaseMessage], **kwargs
- ) -> LLMInterface:
+ def prepare_message(
+ self, messages: str | BaseMessage | list[BaseMessage]
+ ) -> list[dict]:
input_: list[BaseMessage] = []
if isinstance(messages, str):
@@ -72,11 +91,19 @@ class LlamaCppChat(ChatLLM):
else:
input_ = messages
+ output_ = [
+ {"role": self._role_mapper[each.type], "content": each.content}
+ for each in input_
+ ]
+
+ return output_
+
+ def invoke(
+ self, messages: str | BaseMessage | list[BaseMessage], **kwargs
+ ) -> LLMInterface:
+
pred: "CCCR" = self.client_object.create_chat_completion(
- messages=[
- {"role": self._role_mapper[each.type], "content": each.content}
- for each in input_
- ], # type: ignore
+ messages=self.prepare_message(messages),
stream=False,
)
@@ -91,3 +118,19 @@ class LlamaCppChat(ChatLLM):
total_tokens=pred["usage"]["total_tokens"],
prompt_tokens=pred["usage"]["prompt_tokens"],
)
+
+ def stream(
+ self, messages: str | BaseMessage | list[BaseMessage], **kwargs
+ ) -> Iterator[LLMInterface]:
+ pred = self.client_object.create_chat_completion(
+ messages=self.prepare_message(messages),
+ stream=True,
+ )
+ for chunk in pred:
+ if not chunk["choices"]:
+ continue
+
+ if "content" not in chunk["choices"][0]["delta"]:
+ continue
+
+ yield LLMInterface(content=chunk["choices"][0]["delta"]["content"])
diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py
new file mode 100644
index 0000000..6f492c7
--- /dev/null
+++ b/libs/kotaemon/kotaemon/llms/chats/openai.py
@@ -0,0 +1,356 @@
+from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
+
+from theflow.utils.modules import import_dotted_string
+
+from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
+
+from .base import ChatLLM
+
+if TYPE_CHECKING:
+ from openai.types.chat.chat_completion_message_param import (
+ ChatCompletionMessageParam,
+ )
+
+
+class BaseChatOpenAI(ChatLLM):
+ """Base interface for OpenAI chat model, using the openai library
+
+ This class exposes the parameters in resources.Chat. To subclass this class:
+
+ - Implement the `prepare_client` method to return the OpenAI client
+ - Implement the `openai_response` method to return the OpenAI response
+ - Implement the params relate to the OpenAI client
+ """
+
+ _dependencies = ["openai"]
+ _capabilities = ["chat", "text"] # consider as mixin
+
+ api_key: str = Param(help="API key", required=True)
+ timeout: Optional[float] = Param(None, help="Timeout for the API request")
+ max_retries: Optional[int] = Param(
+ None, help="Maximum number of retries for the API request"
+ )
+
+ temperature: Optional[float] = Param(
+ None,
+ help=(
+ "Number between 0 and 2 that controls the randomness of the generated "
+ "tokens. Lower values make the model more deterministic, while higher "
+ "values make the model more random."
+ ),
+ )
+ max_tokens: Optional[int] = Param(
+ None,
+ help=(
+ "Maximum number of tokens to generate. The total length of input tokens "
+ "and generated tokens is limited by the model's context length."
+ ),
+ )
+ n: int = Param(
+ 1,
+ help=(
+ "Number of completions to generate. The API will generate n completion "
+ "for each prompt."
+ ),
+ )
+ stop: Optional[str | list[str]] = Param(
+ None,
+ help=(
+ "Stop sequence. If a stop sequence is detected, generation will stop "
+ "at that point. If not specified, generation will continue until the "
+ "maximum token length is reached."
+ ),
+ )
+ frequency_penalty: Optional[float] = Param(
+ None,
+ help=(
+ "Number between -2.0 and 2.0. Positive values penalize new tokens "
+ "based on their existing frequency in the text so far, decrearsing the "
+ "model's likelihood of repeating the same text."
+ ),
+ )
+ presence_penalty: Optional[float] = Param(
+ None,
+ help=(
+ "Number between -2.0 and 2.0. Positive values penalize new tokens "
+ "based on their existing presence in the text so far, decrearsing the "
+ "model's likelihood of repeating the same text."
+ ),
+ )
+ tool_choice: Optional[str] = Param(
+ None,
+ help=(
+ "Choice of tool to use for the completion. Available choices are: "
+ "auto, default."
+ ),
+ )
+ tools: Optional[list[str]] = Param(
+ None,
+ help="List of tools to use for the completion.",
+ )
+ logprobs: Optional[bool] = Param(
+ None,
+ help=(
+ "Include log probabilities on the logprobs most likely tokens, "
+ "as well as the chosen token."
+ ),
+ )
+ logit_bias: Optional[dict] = Param(
+ None,
+ help=(
+ "Dictionary of logit bias values to add to the logits of the tokens "
+ "in the vocabulary."
+ ),
+ )
+ top_logprobs: Optional[int] = Param(
+ None,
+ help=(
+ "An integer between 0 and 5 specifying the number of most likely tokens "
+ "to return at each token position, each with an associated log "
+ "probability. `logprobs` must also be set to `true` if this parameter "
+ "is used."
+ ),
+ )
+ top_p: Optional[float] = Param(
+ None,
+ help=(
+ "An alternative to sampling with temperature, called nucleus sampling, "
+ "where the model considers the results of the token with top_p "
+ "probability mass. So 0.1 means that only the tokens comprising the "
+ "top 10% probability mass are considered."
+ ),
+ )
+
+ @Param.auto(depends_on=["max_retries"])
+ def max_retries_(self):
+ if self.max_retries is None:
+ from openai._constants import DEFAULT_MAX_RETRIES
+
+ return DEFAULT_MAX_RETRIES
+ return self.max_retries
+
+ def prepare_message(
+ self, messages: str | BaseMessage | list[BaseMessage]
+ ) -> list["ChatCompletionMessageParam"]:
+ """Prepare the message into OpenAI format
+
+ Returns:
+ list[dict]: List of messages in OpenAI format
+ """
+ input_: list[BaseMessage] = []
+ output_: list["ChatCompletionMessageParam"] = []
+
+ if isinstance(messages, str):
+ input_ = [HumanMessage(content=messages)]
+ elif isinstance(messages, BaseMessage):
+ input_ = [messages]
+ else:
+ input_ = messages
+
+ for message in input_:
+ output_.append(message.to_openai_format())
+
+ return output_
+
+ def prepare_client(self, async_version: bool = False):
+ """Get the OpenAI client
+
+ Args:
+ async_version (bool): Whether to get the async version of the client
+ """
+ raise NotImplementedError
+
+ def openai_response(self, client, **kwargs):
+ """Get the openai response"""
+ raise NotImplementedError
+
+ def invoke(
+ self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
+ ) -> LLMInterface:
+ client = self.prepare_client(async_version=False)
+ input_messages = self.prepare_message(messages)
+ resp = self.openai_response(
+ client, messages=input_messages, stream=False, **kwargs
+ ).dict()
+
+ output = LLMInterface(
+ candidates=[_["message"]["content"] for _ in resp["choices"]],
+ content=resp["choices"][0]["message"]["content"],
+ total_tokens=resp["usage"]["total_tokens"],
+ prompt_tokens=resp["usage"]["prompt_tokens"],
+ completion_tokens=resp["usage"]["completion_tokens"],
+ messages=[
+ AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
+ ],
+ )
+
+ return output
+
+ async def ainvoke(
+ self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
+ ) -> LLMInterface:
+ client = self.prepare_client(async_version=True)
+ input_messages = self.prepare_message(messages)
+ resp = await self.openai_response(
+ client, messages=input_messages, stream=False, **kwargs
+ ).dict()
+
+ output = LLMInterface(
+ candidates=[_["message"]["content"] for _ in resp["choices"]],
+ content=resp["choices"][0]["message"]["content"],
+ total_tokens=resp["usage"]["total_tokens"],
+ prompt_tokens=resp["usage"]["prompt_tokens"],
+ completion_tokens=resp["usage"]["completion_tokens"],
+ messages=[
+ AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
+ ],
+ )
+
+ return output
+
+ def stream(
+ self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
+ ) -> Iterator[LLMInterface]:
+ client = self.prepare_client(async_version=False)
+ input_messages = self.prepare_message(messages)
+ resp = self.openai_response(
+ client, messages=input_messages, stream=True, **kwargs
+ )
+
+ for chunk in resp:
+ if not chunk.choices:
+ continue
+ if chunk.choices[0].delta.content is not None:
+ yield LLMInterface(content=chunk.choices[0].delta.content)
+
+ async def astream(
+ self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
+ ) -> AsyncGenerator[LLMInterface, None]:
+ client = self.prepare_client(async_version=True)
+ input_messages = self.prepare_message(messages)
+ resp = self.openai_response(
+ client, messages=input_messages, stream=True, **kwargs
+ )
+
+ async for chunk in resp:
+ if not chunk.choices:
+ continue
+ if chunk.choices[0].delta.content is not None:
+ yield LLMInterface(content=chunk.choices[0].delta.content)
+
+
+class ChatOpenAI(BaseChatOpenAI):
+ """OpenAI chat model"""
+
+ base_url: Optional[str] = Param(None, help="OpenAI base URL")
+ organization: Optional[str] = Param(None, help="OpenAI organization")
+ model: str = Param(help="OpenAI model", required=True)
+
+ def prepare_client(self, async_version: bool = False):
+ """Get the OpenAI client
+
+ Args:
+ async_version (bool): Whether to get the async version of the client
+ """
+ params = {
+ "api_key": self.api_key,
+ "organization": self.organization,
+ "base_url": self.base_url,
+ "timeout": self.timeout,
+ "max_retries": self.max_retries_,
+ }
+ if async_version:
+ from openai import AsyncOpenAI
+
+ return AsyncOpenAI(**params)
+
+ from openai import OpenAI
+
+ return OpenAI(**params)
+
+ def openai_response(self, client, **kwargs):
+ """Get the openai response"""
+ params = {
+ "model": self.model,
+ "temperature": self.temperature,
+ "max_tokens": self.max_tokens,
+ "n": self.n,
+ "stop": self.stop,
+ "frequency_penalty": self.frequency_penalty,
+ "presence_penalty": self.presence_penalty,
+ "tool_choice": self.tool_choice,
+ "tools": self.tools,
+ "logprobs": self.logprobs,
+ "logit_bias": self.logit_bias,
+ "top_logprobs": self.top_logprobs,
+ "top_p": self.top_p,
+ }
+ params.update(kwargs)
+
+ return client.chat.completions.create(**params)
+
+
+class AzureChatOpenAI(BaseChatOpenAI):
+ """OpenAI chat model provided by Microsoft Azure"""
+
+ azure_endpoint: str = Param(
+ help=(
+ "HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
+ "azure_deployment, and api_version parameters are used to construct "
+ "the full URL for the Azure OpenAI model."
+ )
+ )
+ azure_deployment: str = Param(help="Azure deployment name", required=True)
+ api_version: str = Param(help="Azure model version", required=True)
+ azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
+ azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
+
+ @Param.auto(depends_on=["azure_ad_token_provider"])
+ def azure_ad_token_provider_(self):
+ if isinstance(self.azure_ad_token_provider, str):
+ return import_dotted_string(self.azure_ad_token_provider, safe=False)
+
+ def prepare_client(self, async_version: bool = False):
+ """Get the OpenAI client
+
+ Args:
+ async_version (bool): Whether to get the async version of the client
+ """
+ params = {
+ "azure_endpoint": self.azure_endpoint,
+ "api_version": self.api_version,
+ "api_key": self.api_key,
+ "azure_ad_token": self.azure_ad_token,
+ "azure_ad_token_provider": self.azure_ad_token_provider_,
+ "timeout": self.timeout,
+ "max_retries": self.max_retries_,
+ }
+ if async_version:
+ from openai import AsyncAzureOpenAI
+
+ return AsyncAzureOpenAI(**params)
+
+ from openai import AzureOpenAI
+
+ return AzureOpenAI(**params)
+
+ def openai_response(self, client, **kwargs):
+ """Get the openai response"""
+ params = {
+ "model": self.azure_deployment,
+ "temperature": self.temperature,
+ "max_tokens": self.max_tokens,
+ "n": self.n,
+ "stop": self.stop,
+ "frequency_penalty": self.frequency_penalty,
+ "presence_penalty": self.presence_penalty,
+ "tool_choice": self.tool_choice,
+ "tools": self.tools,
+ "logprobs": self.logprobs,
+ "logit_bias": self.logit_bias,
+ "top_logprobs": self.top_logprobs,
+ "top_p": self.top_p,
+ }
+ params.update(kwargs)
+
+ return client.chat.completions.create(**params)
diff --git a/libs/kotaemon/kotaemon/llms/cot.py b/libs/kotaemon/kotaemon/llms/cot.py
index 7eaf5d1..a52f9bd 100644
--- a/libs/kotaemon/kotaemon/llms/cot.py
+++ b/libs/kotaemon/kotaemon/llms/cot.py
@@ -5,7 +5,7 @@ from theflow import Function, Node, Param
from kotaemon.base import BaseComponent, Document
-from .chats import AzureChatOpenAI
+from .chats import LCAzureChatOpenAI
from .completions import LLM
from .prompts import BasePromptComponent
@@ -25,7 +25,7 @@ class Thought(BaseComponent):
>> from kotaemon.pipelines.cot import Thought
>> thought = Thought(
prompt="How to {action} {object}?",
- llm=AzureChatOpenAI(...),
+ llm=LCAzureChatOpenAI(...),
post_process=lambda string: {"tutorial": string},
)
>> output = thought(action="install", object="python")
@@ -42,7 +42,7 @@ class Thought(BaseComponent):
This `Thought` allows chaining sequentially with the + operator. For example:
```python
- >> llm = AzureChatOpenAI(...)
+ >> llm = LCAzureChatOpenAI(...)
>> thought1 = Thought(
prompt="Word {word} in {language} is ",
llm=llm,
@@ -73,7 +73,7 @@ class Thought(BaseComponent):
" component is executed"
)
)
- llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
+ llm: LLM = Node(LCAzureChatOpenAI, help="The LLM model to execute the input prompt")
post_process: Function = Node(
help=(
"The function post-processor that post-processes LLM output prediction ."
@@ -117,7 +117,7 @@ class ManualSequentialChainOfThought(BaseComponent):
```pycon
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
- >>> llm = AzureChatOpenAI(...)
+ >>> llm = LCAzureChatOpenAI(...)
>>> thought1 = Thought(
>>> prompt="Word {word} in {language} is ",
>>> post_process=lambda string: {"translated": string},
diff --git a/libs/kotaemon/kotaemon/llms/linear.py b/libs/kotaemon/kotaemon/llms/linear.py
index ac8605a..4c61597 100644
--- a/libs/kotaemon/kotaemon/llms/linear.py
+++ b/libs/kotaemon/kotaemon/llms/linear.py
@@ -22,12 +22,12 @@ class SimpleLinearPipeline(BaseComponent):
Example Usage:
```python
- from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
+ from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
def identity(x):
return x
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
@@ -89,13 +89,13 @@ class GatedLinearPipeline(SimpleLinearPipeline):
Usage:
```{.py3 title="Example Usage"}
- from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
+ from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
from kotaemon.parsers import RegexExtractor
def identity(x):
return x
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml
index 73c3e8a..a02337e 100644
--- a/libs/kotaemon/pyproject.toml
+++ b/libs/kotaemon/pyproject.toml
@@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies
[project]
name = "kotaemon"
-version = "0.3.8"
+version = "0.3.9"
requires-python = ">= 3.10"
description = "Kotaemon core library for AI development."
dependencies = [
diff --git a/libs/kotaemon/tests/test_agent.py b/libs/kotaemon/tests/test_agent.py
index 0cc65fa..d489af9 100644
--- a/libs/kotaemon/tests/test_agent.py
+++ b/libs/kotaemon/tests/test_agent.py
@@ -13,7 +13,7 @@ from kotaemon.agents import (
RewooAgent,
WikipediaTool,
)
-from kotaemon.llms import AzureChatOpenAI
+from kotaemon.llms import LCAzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
REWOO_VALID_PLAN = (
@@ -112,7 +112,7 @@ _openai_chat_completion_responses_react_langchain_tool = [
@pytest.fixture
def llm():
- return AzureChatOpenAI(
+ return LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
diff --git a/libs/kotaemon/tests/test_composite.py b/libs/kotaemon/tests/test_composite.py
index 464a456..38e79bd 100644
--- a/libs/kotaemon/tests/test_composite.py
+++ b/libs/kotaemon/tests/test_composite.py
@@ -4,10 +4,10 @@ import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms import (
- AzureChatOpenAI,
BasePromptComponent,
GatedBranchingPipeline,
GatedLinearPipeline,
+ LCAzureChatOpenAI,
SimpleBranchingPipeline,
SimpleLinearPipeline,
)
@@ -40,7 +40,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
@pytest.fixture
def mock_llm():
- return AzureChatOpenAI(
+ return LCAzureChatOpenAI(
azure_endpoint="OPENAI_API_BASE",
openai_api_key="OPENAI_API_KEY",
openai_api_version="OPENAI_API_VERSION",
diff --git a/libs/kotaemon/tests/test_cot.py b/libs/kotaemon/tests/test_cot.py
index aef8a69..5fd1344 100644
--- a/libs/kotaemon/tests/test_cot.py
+++ b/libs/kotaemon/tests/test_cot.py
@@ -2,7 +2,7 @@ from unittest.mock import patch
from openai.types.chat.chat_completion import ChatCompletion
-from kotaemon.llms import AzureChatOpenAI
+from kotaemon.llms import LCAzureChatOpenAI
from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought
_openai_chat_completion_response = [
@@ -38,7 +38,7 @@ _openai_chat_completion_response = [
side_effect=_openai_chat_completion_response,
)
def test_cot_plus_operator(openai_completion):
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
@@ -70,7 +70,7 @@ def test_cot_plus_operator(openai_completion):
side_effect=_openai_chat_completion_response,
)
def test_cot_manual(openai_completion):
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
@@ -100,7 +100,7 @@ def test_cot_manual(openai_completion):
side_effect=_openai_chat_completion_response,
)
def test_cot_with_termination_callback(openai_completion):
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
diff --git a/libs/kotaemon/tests/test_llms_chat_models.py b/libs/kotaemon/tests/test_llms_chat_models.py
index a6a2a24..3758f76 100644
--- a/libs/kotaemon/tests/test_llms_chat_models.py
+++ b/libs/kotaemon/tests/test_llms_chat_models.py
@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage
-from kotaemon.llms import AzureChatOpenAI, LlamaCppChat
+from kotaemon.llms import LCAzureChatOpenAI, LlamaCppChat
try:
from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC
@@ -43,7 +43,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
side_effect=lambda *args, **kwargs: _openai_chat_completion_response,
)
def test_azureopenai_model(openai_completion):
- model = AzureChatOpenAI(
+ model = LCAzureChatOpenAI(
azure_endpoint="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",
diff --git a/libs/kotaemon/tests/test_reranking.py b/libs/kotaemon/tests/test_reranking.py
index d4f7be8..ee37d3c 100644
--- a/libs/kotaemon/tests/test_reranking.py
+++ b/libs/kotaemon/tests/test_reranking.py
@@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.base import Document
from kotaemon.indices.rankings import LLMReranking
-from kotaemon.llms import AzureChatOpenAI
+from kotaemon.llms import LCAzureChatOpenAI
_openai_chat_completion_responses = [
ChatCompletion.parse_obj(
@@ -41,7 +41,7 @@ _openai_chat_completion_responses = [
@pytest.fixture
def llm():
- return AzureChatOpenAI(
+ return LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py
index 33ba88f..2e26cff 100644
--- a/libs/ktem/flowsettings.py
+++ b/libs/ktem/flowsettings.py
@@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
):
if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""):
KH_LLMS["azure"] = {
- "def": {
+ "spec": {
"__type__": "kotaemon.llms.AzureChatOpenAI",
"temperature": 0,
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
- "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
+ "api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
or "2024-02-15-preview",
- "deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
- "request_timeout": 10,
- "stream": False,
+ "azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
+ "timeout": 20,
},
"default": False,
"accuracy": 5,
@@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
}
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
- "def": {
+ "spec": {
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
@@ -164,5 +163,11 @@ KH_INDICES = [
"name": "File",
"config": {},
"index_type": "ktem.index.file.FileIndex",
- }
+ },
+ {
+ "id": 2,
+ "name": "Sample",
+ "config": {},
+ "index_type": "ktem.index.file.FileIndex",
+ },
]
diff --git a/libs/ktem/ktem/components.py b/libs/ktem/ktem/components.py
index 6cfb2e3..182cb91 100644
--- a/libs/ktem/ktem/components.py
+++ b/libs/ktem/ktem/components.py
@@ -3,6 +3,7 @@
import logging
from functools import cache
from pathlib import Path
+from typing import Optional
from theflow.settings import settings
from theflow.utils.modules import deserialize
@@ -48,7 +49,7 @@ class ModelPool:
self._default: list[str] = []
for name, model in conf.items():
- self._models[name] = deserialize(model["def"], safe=False)
+ self._models[name] = deserialize(model["spec"], safe=False)
if model.get("default", False):
self._default.append(name)
@@ -58,11 +59,27 @@ class ModelPool:
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
def __getitem__(self, key: str) -> BaseComponent:
+ """Get model by name"""
return self._models[key]
def __setitem__(self, key: str, value: BaseComponent):
+ """Set model by name"""
self._models[key] = value
+ def __delitem__(self, key: str):
+ """Delete model by name"""
+ del self._models[key]
+
+ def __contains__(self, key: str) -> bool:
+ """Check if model exists"""
+ return key in self._models
+
+ def get(
+ self, key: str, default: Optional[BaseComponent] = None
+ ) -> Optional[BaseComponent]:
+ """Get model by name with default value"""
+ return self._models.get(key, default)
+
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
@@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})
-indices = ModelPool("Indices", {})
diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py
index b63d89c..13036f3 100644
--- a/libs/ktem/ktem/index/file/pipelines.py
+++ b/libs/ktem/ktem/index/file/pipelines.py
@@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
@classmethod
def get_user_settings(cls) -> dict:
- from ktem.components import llms
+ from ktem.llms.manager import llms
try:
- reranking_llm = llms.get_lowest_cost_name()
+ reranking_llm = llms.get_default_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
diff --git a/libs/ktem/ktem/llms/__init__.py b/libs/ktem/ktem/llms/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/libs/ktem/ktem/llms/db.py b/libs/ktem/ktem/llms/db.py
new file mode 100644
index 0000000..628ebb7
--- /dev/null
+++ b/libs/ktem/ktem/llms/db.py
@@ -0,0 +1,36 @@
+from typing import Type
+
+from ktem.db.engine import engine
+from sqlalchemy import JSON, Boolean, Column, String
+from sqlalchemy.orm import DeclarativeBase
+from theflow.settings import settings as flowsettings
+from theflow.utils.modules import import_dotted_string
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class BaseLLMTable(Base):
+ """Base table to store language model"""
+
+ __abstract__ = True
+
+ name = Column(String, primary_key=True, unique=True)
+ spec = Column(JSON, default={})
+ default = Column(Boolean, default=False)
+
+
+_base_llm: Type[BaseLLMTable] = (
+ import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False)
+ if hasattr(flowsettings, "KH_TABLE_LLM")
+ else BaseLLMTable
+)
+
+
+class LLMTable(_base_llm): # type: ignore
+ __tablename__ = "llm_table"
+
+
+if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
+ LLMTable.metadata.create_all(engine)
diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py
new file mode 100644
index 0000000..f9ad763
--- /dev/null
+++ b/libs/ktem/ktem/llms/manager.py
@@ -0,0 +1,191 @@
+from typing import Optional, Type
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+from theflow.settings import settings as flowsettings
+from theflow.utils.modules import deserialize
+
+from kotaemon.base import BaseComponent
+
+from .db import LLMTable, engine
+
+
+class LLMManager:
+ """Represent a pool of models"""
+
+ def __init__(self):
+ self._models: dict[str, BaseComponent] = {}
+ self._info: dict[str, dict] = {}
+ self._default: str = ""
+ self._vendors: list[Type] = []
+
+ if hasattr(flowsettings, "KH_LLMS"):
+ for name, model in flowsettings.KH_LLMS.items():
+ with Session(engine) as session:
+ stmt = select(LLMTable).where(LLMTable.name == name)
+ result = session.execute(stmt)
+ if not result.first():
+ item = LLMTable(
+ name=name,
+ spec=model["spec"],
+ default=model.get("default", False),
+ )
+ session.add(item)
+ session.commit()
+
+ self.load()
+ self.load_vendors()
+
+ def load(self):
+ """Load the model pool from database"""
+ self._models, self._info, self._defaut = {}, {}, ""
+ with Session(engine) as session:
+ stmt = select(LLMTable)
+ items = session.execute(stmt)
+
+ for (item,) in items:
+ self._models[item.name] = deserialize(item.spec, safe=False)
+ self._info[item.name] = {
+ "name": item.name,
+ "spec": item.spec,
+ "default": item.default,
+ }
+ if item.default:
+ self._default = item.name
+
+ def load_vendors(self):
+ from kotaemon.llms import (
+ AzureChatOpenAI,
+ ChatOpenAI,
+ EndpointChatLLM,
+ LlamaCppChat,
+ )
+
+ self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
+
+ def __getitem__(self, key: str) -> BaseComponent:
+ """Get model by name"""
+ return self._models[key]
+
+ def __contains__(self, key: str) -> bool:
+ """Check if model exists"""
+ return key in self._models
+
+ def get(
+ self, key: str, default: Optional[BaseComponent] = None
+ ) -> Optional[BaseComponent]:
+ """Get model by name with default value"""
+ return self._models.get(key, default)
+
+ def settings(self) -> dict:
+ """Present model pools option for gradio"""
+ return {
+ "label": "LLM",
+ "choices": list(self._models.keys()),
+ "value": self.get_default_name(),
+ }
+
+ def options(self) -> dict:
+ """Present a dict of models"""
+ return self._models
+
+ def get_random_name(self) -> str:
+ """Get the name of random model
+
+ Returns:
+ str: random model name in the pool
+ """
+ import random
+
+ if not self._models:
+ raise ValueError("No models in pool")
+
+ return random.choice(list(self._models.keys()))
+
+ def get_default_name(self) -> str:
+ """Get the name of default model
+
+ In case there is no default model, choose random model from pool. In
+ case there are multiple default models, choose random from them.
+
+ Returns:
+ str: model name
+ """
+ if not self._models:
+ raise ValueError("No models in pool")
+
+ if not self._default:
+ return self.get_random_name()
+
+ return self._default
+
+ def get_random(self) -> BaseComponent:
+ """Get random model"""
+ return self._models[self.get_random_name()]
+
+ def get_default(self) -> BaseComponent:
+ """Get default model
+
+ In case there is no default model, choose random model from pool. In
+ case there are multiple default models, choose random from them.
+
+ Returns:
+ BaseComponent: model
+ """
+ return self._models[self.get_default_name()]
+
+ def info(self) -> dict:
+ """List all models"""
+ return self._info
+
+ def add(self, name: str, spec: dict, default: bool):
+ """Add a new model to the pool"""
+ try:
+ with Session(engine) as session:
+ item = LLMTable(name=name, spec=spec, default=default)
+ session.add(item)
+ session.commit()
+ except Exception as e:
+ raise ValueError(f"Failed to add model {name}: {e}")
+
+ self.load()
+
+ def delete(self, name: str):
+ """Delete a model from the pool"""
+ try:
+ with Session(engine) as session:
+ item = session.query(LLMTable).filter_by(name=name).first()
+ session.delete(item)
+ session.commit()
+ except Exception as e:
+ raise ValueError(f"Failed to delete model {name}: {e}")
+
+ self.load()
+
+ def update(self, name: str, spec: dict, default: bool):
+ """Update a model in the pool"""
+ try:
+ with Session(engine) as session:
+
+ if default:
+ # turn all models to non-default
+ session.query(LLMTable).update({"default": False})
+ session.commit()
+
+ item = session.query(LLMTable).filter_by(name=name).first()
+ if not item:
+ raise ValueError(f"Model {name} not found")
+ item.spec = spec
+ item.default = default
+ session.commit()
+ except Exception as e:
+ raise ValueError(f"Failed to update model {name}: {e}")
+
+ self.load()
+
+ def vendors(self) -> dict:
+ """Return list of vendors"""
+ return {vendor.__qualname__: vendor for vendor in self._vendors}
+
+
+llms = LLMManager()
diff --git a/libs/ktem/ktem/llms/ui.py b/libs/ktem/ktem/llms/ui.py
new file mode 100644
index 0000000..dd8f2bd
--- /dev/null
+++ b/libs/ktem/ktem/llms/ui.py
@@ -0,0 +1,318 @@
+from copy import deepcopy
+
+import gradio as gr
+import pandas as pd
+import yaml
+from ktem.app import BasePage
+
+from .manager import llms
+
+
+def format_description(cls):
+ params = cls.describe()["params"]
+ params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
+ for key, value in params.items():
+ if isinstance(value["auto_callback"], str):
+ continue
+ params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
+ return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
+
+
+class LLMManagement(BasePage):
+ def __init__(self, app):
+ self._app = app
+ self.spec_desc_default = (
+ "# Spec description\n\nSelect an LLM to view the spec description."
+ )
+ self.on_building_ui()
+
+ def on_building_ui(self):
+ with gr.Tab(label="View"):
+ self.llm_list = gr.DataFrame(
+ headers=["name", "vendor", "default"],
+ interactive=False,
+ )
+
+ with gr.Column(visible=False) as self._selected_panel:
+ self.selected_llm_name = gr.Textbox(value="", visible=False)
+ with gr.Row():
+ with gr.Column():
+ self.edit_default = gr.Checkbox(
+ label="Set default",
+ info=(
+ "Set this LLM as default. If no default is set, a "
+ "random LLM will be used."
+ ),
+ )
+ self.edit_spec = gr.Textbox(
+ label="Specification",
+ info="Specification of the LLM in YAML format",
+ lines=10,
+ )
+
+ with gr.Row(visible=False) as self._selected_panel_btn:
+ with gr.Column():
+ self.btn_edit_save = gr.Button("Save", min_width=10)
+ with gr.Column():
+ self.btn_delete = gr.Button("Delete", min_width=10)
+ with gr.Row():
+ self.btn_delete_yes = gr.Button(
+ "Confirm delete",
+ variant="primary",
+ visible=False,
+ min_width=10,
+ )
+ self.btn_delete_no = gr.Button(
+ "Cancel", visible=False, min_width=10
+ )
+ with gr.Column():
+ self.btn_close = gr.Button("Close", min_width=10)
+
+ with gr.Column():
+ self.edit_spec_desc = gr.Markdown("# Spec description")
+
+ with gr.Tab(label="Add"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ self.name = gr.Textbox(
+ label="LLM name",
+ info=(
+ "Must be unique. The name will be used to identify the LLM."
+ ),
+ )
+ self.llm_choices = gr.Dropdown(
+ label="LLM vendors",
+ info=(
+ "Choose the vendor for the LLM. Each vendor has different "
+ "specification."
+ ),
+ )
+ self.spec = gr.Textbox(
+ label="Specification",
+ info="Specification of the LLM in YAML format",
+ )
+ self.default = gr.Checkbox(
+ label="Set default",
+ info=(
+ "Set this LLM as default. This default LLM will be used "
+ "by default across the application."
+ ),
+ )
+ self.btn_new = gr.Button("Create LLM")
+
+ with gr.Column(scale=3):
+ self.spec_desc = gr.Markdown(self.spec_desc_default)
+
+ def _on_app_created(self):
+ """Called when the app is created"""
+ self._app.app.load(
+ self.list_llms,
+ inputs=None,
+ outputs=[self.llm_list],
+ )
+ self._app.app.load(
+ lambda: gr.update(choices=list(llms.vendors().keys())),
+ outputs=[self.llm_choices],
+ )
+
+ def on_llm_vendor_change(self, vendor):
+ vendor = llms.vendors()[vendor]
+
+ required: dict = {}
+ desc = vendor.describe()
+ for key, value in desc["params"].items():
+ if value.get("required", False):
+ required[key] = None
+
+ return yaml.dump(required), format_description(vendor)
+
+ def on_register_events(self):
+ self.llm_choices.select(
+ self.on_llm_vendor_change,
+ inputs=[self.llm_choices],
+ outputs=[self.spec, self.spec_desc],
+ )
+ self.btn_new.click(
+ self.create_llm,
+ inputs=[self.name, self.llm_choices, self.spec, self.default],
+ outputs=None,
+ ).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then(
+ lambda: ("", None, "", False, self.spec_desc_default),
+ outputs=[
+ self.name,
+ self.llm_choices,
+ self.spec,
+ self.default,
+ self.spec_desc,
+ ],
+ )
+ self.llm_list.select(
+ self.select_llm,
+ inputs=self.llm_list,
+ outputs=[self.selected_llm_name],
+ show_progress="hidden",
+ )
+ self.selected_llm_name.change(
+ self.on_selected_llm_change,
+ inputs=[self.selected_llm_name],
+ outputs=[
+ self._selected_panel,
+ self._selected_panel_btn,
+ # delete section
+ self.btn_delete,
+ self.btn_delete_yes,
+ self.btn_delete_no,
+ # edit section
+ self.edit_spec,
+ self.edit_spec_desc,
+ self.edit_default,
+ ],
+ show_progress="hidden",
+ )
+ self.btn_delete.click(
+ self.on_btn_delete_click,
+ inputs=None,
+ outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
+ show_progress="hidden",
+ )
+ self.btn_delete_yes.click(
+ self.delete_llm,
+ inputs=[self.selected_llm_name],
+ outputs=[self.selected_llm_name],
+ show_progress="hidden",
+ ).then(
+ self.list_llms,
+ inputs=None,
+ outputs=[self.llm_list],
+ )
+ self.btn_delete_no.click(
+ lambda: (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ ),
+ inputs=None,
+ outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
+ show_progress="hidden",
+ )
+ self.btn_edit_save.click(
+ self.save_llm,
+ inputs=[
+ self.selected_llm_name,
+ self.edit_default,
+ self.edit_spec,
+ ],
+ show_progress="hidden",
+ ).then(
+ self.list_llms,
+ inputs=None,
+ outputs=[self.llm_list],
+ )
+ self.btn_close.click(
+ lambda: "",
+ outputs=[self.selected_llm_name],
+ )
+
+ def create_llm(self, name, choices, spec, default):
+ try:
+ spec = yaml.safe_load(spec)
+ spec["__type__"] = (
+ llms.vendors()[choices].__module__
+ + "."
+ + llms.vendors()[choices].__qualname__
+ )
+
+ llms.add(name, spec=spec, default=default)
+ gr.Info(f"LLM {name} created successfully")
+ except Exception as e:
+ gr.Error(f"Failed to create LLM {name}: {e}")
+
+ def list_llms(self):
+ """List the LLMs"""
+ items = []
+ for item in llms.info().values():
+ record = {}
+ record["name"] = item["name"]
+ record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
+ record["default"] = item["default"]
+ items.append(record)
+
+ if items:
+ llm_list = pd.DataFrame.from_records(items)
+ else:
+ llm_list = pd.DataFrame.from_records(
+ [{"name": "-", "vendor": "-", "default": "-"}]
+ )
+
+ return llm_list
+
+ def select_llm(self, llm_list, ev: gr.SelectData):
+ if ev.value == "-" and ev.index[0] == 0:
+ gr.Info("No LLM is loaded. Please add LLM first")
+ return ""
+
+ if not ev.selected:
+ return ""
+
+ return llm_list["name"][ev.index[0]]
+
+ def on_selected_llm_change(self, selected_llm_name):
+ if selected_llm_name == "":
+ _selected_panel = gr.update(visible=False)
+ _selected_panel_btn = gr.update(visible=False)
+ btn_delete = gr.update(visible=True)
+ btn_delete_yes = gr.update(visible=False)
+ btn_delete_no = gr.update(visible=False)
+ edit_spec = gr.update(value="")
+ edit_spec_desc = gr.update(value="")
+ edit_default = gr.update(value=False)
+ else:
+ _selected_panel = gr.update(visible=True)
+ _selected_panel_btn = gr.update(visible=True)
+ btn_delete = gr.update(visible=True)
+ btn_delete_yes = gr.update(visible=False)
+ btn_delete_no = gr.update(visible=False)
+
+ info = deepcopy(llms.info()[selected_llm_name])
+ vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
+ vendor = llms.vendors()[vendor_str]
+
+ edit_spec = yaml.dump(info["spec"])
+ edit_spec_desc = format_description(vendor)
+ edit_default = info["default"]
+
+ return (
+ _selected_panel,
+ _selected_panel_btn,
+ btn_delete,
+ btn_delete_yes,
+ btn_delete_no,
+ edit_spec,
+ edit_spec_desc,
+ edit_default,
+ )
+
+ def on_btn_delete_click(self):
+ btn_delete = gr.update(visible=False)
+ btn_delete_yes = gr.update(visible=True)
+ btn_delete_no = gr.update(visible=True)
+
+ return btn_delete, btn_delete_yes, btn_delete_no
+
+ def save_llm(self, selected_llm_name, default, spec):
+ try:
+ spec = yaml.safe_load(spec)
+ spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"]
+ llms.update(selected_llm_name, spec=spec, default=default)
+ gr.Info(f"LLM {selected_llm_name} saved successfully")
+ except Exception as e:
+ gr.Error(f"Failed to save LLM {selected_llm_name}: {e}")
+
+ def delete_llm(self, selected_llm_name):
+ try:
+ llms.delete(selected_llm_name)
+ except Exception as e:
+ gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}")
+ return selected_llm_name
+
+ return ""
diff --git a/libs/ktem/ktem/pages/admin/__init__.py b/libs/ktem/ktem/pages/admin/__init__.py
index 1cc58c7..b32d816 100644
--- a/libs/ktem/ktem/pages/admin/__init__.py
+++ b/libs/ktem/ktem/pages/admin/__init__.py
@@ -1,6 +1,7 @@
import gradio as gr
from ktem.app import BasePage
from ktem.db.models import User, engine
+from ktem.llms.ui import LLMManagement
from sqlmodel import Session, select
from .user import UserManagement
@@ -16,6 +17,9 @@ class AdminPage(BasePage):
with gr.Tab("User Management", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
+ with gr.Tab("LLM Management") as self.llm_management_tab:
+ self.llm_management = LLMManagement(self._app)
+
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py
index 04c78eb..3f01571 100644
--- a/libs/ktem/ktem/pages/chat/__init__.py
+++ b/libs/ktem/ktem/pages/chat/__init__.py
@@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
+from kotaemon.base import Document
+
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE
@@ -189,6 +191,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
+ self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
@@ -220,6 +223,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
+ self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
@@ -392,7 +396,7 @@ class ChatPage(BasePage):
return pipeline, reasoning_state
- async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
+ def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
@@ -403,52 +407,43 @@ class ChatPage(BasePage):
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
pipeline.set_output_queue(queue)
- asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
-
- len_ref = -1 # for logging purpose
msg_placeholder = getattr(
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
)
-
print(msg_placeholder)
- while True:
- try:
- response = queue.get_nowait()
- except Exception:
- state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
- yield chat_history + [
- (chat_input, text or msg_placeholder)
- ], refs, state
+ yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
+
+ len_ref = -1 # for logging purpose
+
+ for response in pipeline.stream(chat_input, conversation_id, chat_history):
+
+ if not isinstance(response, Document):
continue
- if response is None:
- queue.task_done()
- print("Chat completed")
- break
+ if response.channel is None:
+ continue
- if "output" in response:
- if response["output"] is None:
+ if response.channel == "chat":
+ if response.content is None:
text = ""
else:
- text += response["output"]
+ text += response.content
- if "evidence" in response:
- if response["evidence"] is None:
+ if response.channel == "info":
+ if response.content is None:
refs = ""
else:
- refs += response["evidence"]
+ refs += response.content
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
- state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
- yield chat_history + [(chat_input, text)], refs, state
+ state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
+ yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
- async def regen_fn(
- self, conversation_id, chat_history, settings, state, *selecteds
- ):
+ def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Regen function"""
if not chat_history:
gr.Warning("Empty chat")
@@ -456,12 +451,11 @@ class ChatPage(BasePage):
return
state["app"]["regen"] = True
- async for chat, refs, state in self.chat_fn(
+ for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
- else:
- state["app"]["regen"] = False
- yield chat_history, "", state
+
+ state["app"]["regen"] = False
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index 082c20f..3397250 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -4,10 +4,10 @@ import logging
import re
from collections import defaultdict
from functools import partial
+from typing import Generator
import tiktoken
-from ktem.components import llms
-from theflow.settings import settings as flowsettings
+from ktem.llms.manager import llms
from kotaemon.base import (
BaseComponent,
@@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent):
lang: the language of the answer. Currently support English and Japanese
"""
- llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
- vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
+ llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
+ vlm_endpoint: str = ""
citation_pipeline: CitationPipeline = Node(
- default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
+ default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
@@ -297,13 +297,95 @@ class AnswerWithContextPipeline(BaseComponent):
return answer
+ def stream( # type: ignore
+ self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
+ ) -> Generator[Document, None, Document]:
+ """Answer the question based on the evidence
-def extract_evidence_images(self, evidence: str):
- """Util function to extract and isolate images from context/evidence"""
- image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
- matches = re.findall(image_pattern, evidence)
- context = re.sub(image_pattern, "", evidence)
- return context, matches
+ In addition to the question and the evidence, this method also take into
+ account evidence_mode. The evidence_mode tells which kind of evidence is.
+ The kind of evidence affects:
+ 1. How the evidence is represented.
+ 2. The prompt to generate the answer.
+
+ By default, the evidence_mode is 0, which means the evidence is plain text with
+ no particular semantic representation. The evidence_mode can be:
+ 1. "table": There will be HTML markup telling that there is a table
+ within the evidence.
+ 2. "chatbot": There will be HTML markup telling that there is a chatbot.
+ This chatbot is a scenario, extracted from an Excel file, where each
+ row corresponds to an interaction.
+
+ Args:
+ question: the original question posed by user
+ evidence: the text that contain relevant information to answer the question
+ (determined by retrieval pipeline)
+ evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
+ """
+ if evidence_mode == EVIDENCE_MODE_TEXT:
+ prompt_template = PromptTemplate(self.qa_template)
+ elif evidence_mode == EVIDENCE_MODE_TABLE:
+ prompt_template = PromptTemplate(self.qa_table_template)
+ elif evidence_mode == EVIDENCE_MODE_FIGURE:
+ prompt_template = PromptTemplate(self.qa_figure_template)
+ else:
+ prompt_template = PromptTemplate(self.qa_chatbot_template)
+
+ images = []
+ if evidence_mode == EVIDENCE_MODE_FIGURE:
+ # isolate image from evidence
+ evidence, images = self.extract_evidence_images(evidence)
+ prompt = prompt_template.populate(
+ context=evidence,
+ question=question,
+ lang=self.lang,
+ )
+ else:
+ prompt = prompt_template.populate(
+ context=evidence,
+ question=question,
+ lang=self.lang,
+ )
+
+ output = ""
+ if evidence_mode == EVIDENCE_MODE_FIGURE:
+ for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
+ output += text
+ yield Document(channel="chat", content=text)
+ else:
+ messages = []
+ if self.system_prompt:
+ messages.append(SystemMessage(content=self.system_prompt))
+ messages.append(HumanMessage(content=prompt))
+
+ try:
+ # try streaming first
+ print("Trying LLM streaming")
+ for text in self.llm.stream(messages):
+ output += text.text
+ yield Document(channel="chat", content=text.text)
+ except NotImplementedError:
+ print("Streaming is not supported, falling back to normal processing")
+ output = self.llm(messages).text
+ yield Document(channel="chat", content=output)
+
+ # retrieve the citation
+ citation = None
+ if evidence and self.enable_citation:
+ citation = self.citation_pipeline.invoke(
+ context=evidence, question=question
+ )
+
+ answer = Document(text=output, metadata={"citation": citation})
+
+ return answer
+
+ def extract_evidence_images(self, evidence: str):
+ """Util function to extract and isolate images from context/evidence"""
+ image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
+ matches = re.findall(image_pattern, evidence)
+ context = re.sub(image_pattern, "", evidence)
+ return context, matches
class RewriteQuestionPipeline(BaseComponent):
@@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent):
lang: the language of the answer. Currently support English and Japanese
"""
- llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
+ llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
- async def run(self, question: str) -> Document: # type: ignore
+ def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
- output = ""
- for text in self.llm(messages):
- if "content" in text:
- output += text[1]
- self.report_output({"chat_input": text[1]})
- break
- await asyncio.sleep(0)
-
- return Document(text=output)
+ return self.llm(messages)
class FullQAPipeline(BaseReasoning):
@@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning):
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
- async def run( # type: ignore
+ async def ainvoke( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
import markdown
@@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning):
self.report_output(None)
return answer
+ def stream( # type: ignore
+ self, message: str, conv_id: str, history: list, **kwargs # type: ignore
+ ) -> Generator[Document, None, Document]:
+ import markdown
+
+ docs = []
+ doc_ids = []
+ if self.use_rewrite:
+ message = self.rewrite_pipeline(question=message).text
+
+ for retriever in self.retrievers:
+ for doc in retriever(text=message):
+ if doc.doc_id not in doc_ids:
+ docs.append(doc)
+ doc_ids.append(doc.doc_id)
+ for doc in docs:
+ # TODO: a better approach to show the information
+ text = markdown.markdown(
+ doc.text, extensions=["markdown.extensions.tables"]
+ )
+ yield Document(
+ content=(
+ ""
+ f"{doc.metadata['file_name']}
"
+ f"{text}"
+ "
"
+ ),
+ channel="info",
+ )
+
+ evidence_mode, evidence = self.evidence_pipeline(docs).content
+ answer = yield from self.answering_pipeline.stream(
+ question=message,
+ history=history,
+ evidence=evidence,
+ evidence_mode=evidence_mode,
+ conv_id=conv_id,
+ **kwargs,
+ )
+
+ # prepare citation
+ spans = defaultdict(list)
+ if answer.metadata["citation"] is not None:
+ for fact_with_evidence in answer.metadata["citation"].answer:
+ for quote in fact_with_evidence.substring_quote:
+ for doc in docs:
+ start_idx = doc.text.find(quote)
+ if start_idx == -1:
+ continue
+
+ end_idx = start_idx + len(quote)
+
+ current_idx = start_idx
+ if "|" not in doc.text[start_idx:end_idx]:
+ spans[doc.doc_id].append(
+ {"start": start_idx, "end": end_idx}
+ )
+ else:
+ while doc.text[current_idx:end_idx].find("|") != -1:
+ match_idx = doc.text[current_idx:end_idx].find("|")
+ spans[doc.doc_id].append(
+ {
+ "start": current_idx,
+ "end": current_idx + match_idx,
+ }
+ )
+ current_idx += match_idx + 2
+ if current_idx > end_idx:
+ break
+ break
+
+ id2docs = {doc.doc_id: doc for doc in docs}
+ lack_evidence = True
+ not_detected = set(id2docs.keys()) - set(spans.keys())
+ yield Document(channel="info", content=None)
+ for id, ss in spans.items():
+ if not ss:
+ not_detected.add(id)
+ continue
+ ss = sorted(ss, key=lambda x: x["start"])
+ text = id2docs[id].text[: ss[0]["start"]]
+ for idx, span in enumerate(ss):
+ text += (
+ "" + id2docs[id].text[span["start"] : span["end"]] + ""
+ )
+ if idx < len(ss) - 1:
+ text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
+ text += id2docs[id].text[ss[-1]["end"] :]
+ text_out = markdown.markdown(
+ text, extensions=["markdown.extensions.tables"]
+ )
+ yield Document(
+ content=(
+ ""
+ f"{id2docs[id].metadata['file_name']}
"
+ f"{text_out}"
+ "
"
+ ),
+ channel="info",
+ )
+ lack_evidence = False
+
+ if lack_evidence:
+ yield Document(channel="info", content="No evidence found.\n")
+
+ if not_detected:
+ yield Document(
+ channel="info",
+ content="Retrieved segments without matching evidence:\n",
+ )
+ for id in list(not_detected):
+ text_out = markdown.markdown(
+ id2docs[id].text, extensions=["markdown.extensions.tables"]
+ )
+ yield Document(
+ content=(
+ ""
+ f"{id2docs[id].metadata['file_name']}
"
+ f"{text_out}"
+ "
"
+ ),
+ channel="info",
+ )
+
+ return answer
+
@classmethod
def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline
@@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning):
_id = cls.get_info()["id"]
pipeline = FullQAPipeline(retrievers=retrievers)
- pipeline.answering_pipeline.llm = llms[
- settings[f"reasoning.options.{_id}.main_llm"]
- ]
- pipeline.answering_pipeline.citation_pipeline.llm = llms[
- settings[f"reasoning.options.{_id}.citation_llm"]
- ]
+ pipeline.answering_pipeline.llm = llms.get_default()
+ pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
+
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
@@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning):
f"reasoning.options.{_id}.qa_prompt"
]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
- pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
+ pipeline.rewrite_pipeline.llm = llms.get_default()
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
@@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning):
@classmethod
def get_user_settings(cls) -> dict:
- from ktem.components import llms
-
- try:
- citation_llm = llms.get_lowest_cost_name()
- citation_llm_choices = list(llms.options().keys())
- main_llm = llms.get_highest_accuracy_name()
- main_llm_choices = list(llms.options().keys())
- except Exception as e:
- logger.error(e)
- citation_llm = None
- citation_llm_choices = []
- main_llm = None
- main_llm_choices = []
-
return {
"highlight_citation": {
"name": "Highlight Citation",
"value": False,
"component": "checkbox",
},
- "citation_llm": {
- "name": "LLM for citation",
- "value": citation_llm,
- "component": "dropdown",
- "choices": citation_llm_choices,
- },
- "main_llm": {
- "name": "LLM for main generation",
- "value": main_llm,
- "component": "dropdown",
- "choices": main_llm_choices,
- },
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",
diff --git a/libs/ktem/ktem_tests/test_qa.py b/libs/ktem/ktem_tests/test_qa.py
index a3993ee..80ee68b 100644
--- a/libs/ktem/ktem_tests/test_qa.py
+++ b/libs/ktem/ktem_tests/test_qa.py
@@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline
from openai.resources.embeddings import Embeddings
from openai.types.chat.chat_completion import ChatCompletion
-from kotaemon.llms import AzureChatOpenAI
+from kotaemon.llms import LCAzureChatOpenAI
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
openai_embedding = json.load(f)
@@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
assert len(results) == 1
# create llm
- llm = AzureChatOpenAI(
+ llm = LCAzureChatOpenAI(
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",
diff --git a/libs/ktem/launch.py b/libs/ktem/launch.py
index 2ac7a1a..1f436c5 100644
--- a/libs/ktem/launch.py
+++ b/libs/ktem/launch.py
@@ -2,4 +2,4 @@ from ktem.main import App
app = App()
demo = app.make()
-demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
+demo.queue().launch(favicon_path=app._favicon)
diff --git a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py
index 1739ca8..db2fa0b 100644
--- a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py
+++ b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py
@@ -5,7 +5,7 @@ from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, la
from kotaemon.contribs.promptui.logs import ResultLog
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval
-from kotaemon.llms import AzureChatOpenAI
+from kotaemon.llms import LCAzureChatOpenAI
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
@@ -34,7 +34,7 @@ class QuestionAnsweringPipeline(BaseComponent):
]
retrieval_top_k: int = 1
- llm: AzureChatOpenAI = AzureChatOpenAI.withx(
+ llm: LCAzureChatOpenAI = LCAzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
openai_api_version="2023-03-15-preview",