Allow users to add LLM within the UI (#6)
* Rename AzureChatOpenAI to LCAzureChatOpenAI * Provide vanilla ChatOpenAI and AzureChatOpenAI * Remove the highest accuracy, lowest cost criteria These criteria are unnecessary. The users, not pipeline creators, should choose which LLM to use. Furthermore, it's cumbersome to input this information, really degrades user experience. * Remove the LLM selection in simple reasoning pipeline * Provide a dedicated stream method to generate the output * Return placeholder message to chat if the text is empty
This commit is contained in:
parent
e187e23dd1
commit
a203fc0f7c
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -458,6 +458,7 @@ logs/
|
||||||
.gitsecret/keys/random_seed
|
.gitsecret/keys/random_seed
|
||||||
!*.secret
|
!*.secret
|
||||||
.envrc
|
.envrc
|
||||||
|
.env
|
||||||
|
|
||||||
S.gpg-agent*
|
S.gpg-agent*
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
|
|
@ -22,7 +22,7 @@ The syntax of a component is as follow:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import LCAzureChatOpenAI
|
||||||
from kotaemon.parsers import RegexExtractor
|
from kotaemon.parsers import RegexExtractor
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class FancyPipeline(BaseComponent):
|
||||||
param3: float
|
param3: float
|
||||||
|
|
||||||
node1: BaseComponent # this is a node because of BaseComponent type annotation
|
node1: BaseComponent # this is a node because of BaseComponent type annotation
|
||||||
node2: AzureChatOpenAI # this is also a node because AzureChatOpenAI subclasses BaseComponent
|
node2: LCAzureChatOpenAI # this is also a node because LCAzureChatOpenAI subclasses BaseComponent
|
||||||
node3: RegexExtractor # this is also a node bceause RegexExtractor subclasses BaseComponent
|
node3: RegexExtractor # this is also a node bceause RegexExtractor subclasses BaseComponent
|
||||||
|
|
||||||
def run(self, some_text: str):
|
def run(self, some_text: str):
|
||||||
|
@ -45,7 +45,7 @@ class FancyPipeline(BaseComponent):
|
||||||
Then this component can be used as follow:
|
Then this component can be used as follow:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
llm = AzureChatOpenAI(endpoint="some-endpont")
|
llm = LCAzureChatOpenAI(endpoint="some-endpont")
|
||||||
extractor = RegexExtractor(pattern=["yes", "Yes"])
|
extractor = RegexExtractor(pattern=["yes", "Yes"])
|
||||||
|
|
||||||
component = FancyPipeline(
|
component = FancyPipeline(
|
||||||
|
|
|
@ -193,7 +193,8 @@ information panel.
|
||||||
You can access users' collections of LLMs and embedding models with:
|
You can access users' collections of LLMs and embedding models with:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ktem.components import llms, embeddings
|
from ktem.components import embeddings
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
|
|
||||||
llm = llms.get_default()
|
llm = llms.get_default()
|
||||||
|
@ -206,12 +207,12 @@ models they want to use through the settings.
|
||||||
```python
|
```python
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
from ktem.components import llms
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"citation_llm": {
|
"citation_llm": {
|
||||||
"name": "LLM for citation",
|
"name": "LLM for citation",
|
||||||
"value": llms.get_lowest_cost_name(),
|
"value": llms.get_default(),
|
||||||
"component: "dropdown",
|
"component: "dropdown",
|
||||||
"choices": list(llms.options().keys()),
|
"choices": list(llms.options().keys()),
|
||||||
},
|
},
|
||||||
|
|
|
@ -52,7 +52,7 @@ class BaseComponent(Function):
|
||||||
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
|
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
|
def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
|
||||||
|
|
||||||
from langchain.schema.messages import AIMessage as LCAIMessage
|
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||||
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||||
|
@ -10,6 +10,9 @@ from llama_index.schema import Document as BaseDocument
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from haystack.schema import Document as HaystackDocument
|
from haystack.schema import Document as HaystackDocument
|
||||||
|
from openai.types.chat.chat_completion_message_param import (
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
IO_Type = TypeVar("IO_Type", "Document", str)
|
IO_Type = TypeVar("IO_Type", "Document", str)
|
||||||
SAMPLE_TEXT = "A sample Document from kotaemon"
|
SAMPLE_TEXT = "A sample Document from kotaemon"
|
||||||
|
@ -26,10 +29,15 @@ class Document(BaseDocument):
|
||||||
Attributes:
|
Attributes:
|
||||||
content: raw content of the document, can be anything
|
content: raw content of the document, can be anything
|
||||||
source: id of the source of the Document. Optional.
|
source: id of the source of the Document. Optional.
|
||||||
|
channel: the channel to show the document. Optional.:
|
||||||
|
- chat: show in chat message
|
||||||
|
- info: show in information panel
|
||||||
|
- debug: show in debug panel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Any
|
content: Any = None
|
||||||
source: Optional[str] = None
|
source: Optional[str] = None
|
||||||
|
channel: Optional[Literal["chat", "info", "debug"]] = None
|
||||||
|
|
||||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||||
if content is None:
|
if content is None:
|
||||||
|
@ -87,17 +95,23 @@ class BaseMessage(Document):
|
||||||
def __add__(self, other: Any):
|
def __add__(self, other: Any):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage, LCSystemMessage):
|
class SystemMessage(BaseMessage, LCSystemMessage):
|
||||||
pass
|
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||||
|
return {"role": "system", "content": self.content}
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage, LCAIMessage):
|
class AIMessage(BaseMessage, LCAIMessage):
|
||||||
pass
|
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||||
|
return {"role": "assistant", "content": self.content}
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage, LCHumanMessage):
|
class HumanMessage(BaseMessage, LCHumanMessage):
|
||||||
pass
|
def to_openai_format(self) -> "ChatCompletionMessageParam":
|
||||||
|
return {"role": "user", "content": self.content}
|
||||||
|
|
||||||
|
|
||||||
class RetrievedDocument(Document):
|
class RetrievedDocument(Document):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
||||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate
|
||||||
|
|
||||||
from .citation import CitationPipeline
|
from .citation import CitationPipeline
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ class CitationQAPipeline(BaseComponent):
|
||||||
'Answer the following question: "{question}". '
|
'Answer the following question: "{question}". '
|
||||||
"The context is: \n{context}\nAnswer: "
|
"The context is: \n{context}\nAnswer: "
|
||||||
)
|
)
|
||||||
llm: BaseLLM = AzureChatOpenAI.withx(
|
llm: BaseLLM = LCAzureChatOpenAI.withx(
|
||||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
openai_api_version="2023-07-01-preview",
|
openai_api_version="2023-07-01-preview",
|
||||||
|
|
|
@ -2,7 +2,15 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||||
from .chats import AzureChatOpenAI, ChatLLM, ChatOpenAI, EndpointChatLLM, LlamaCppChat
|
from .chats import (
|
||||||
|
AzureChatOpenAI,
|
||||||
|
ChatLLM,
|
||||||
|
ChatOpenAI,
|
||||||
|
EndpointChatLLM,
|
||||||
|
LCAzureChatOpenAI,
|
||||||
|
LCChatOpenAI,
|
||||||
|
LlamaCppChat,
|
||||||
|
)
|
||||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||||
from .cot import ManualSequentialChainOfThought, Thought
|
from .cot import ManualSequentialChainOfThought, Thought
|
||||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||||
|
@ -17,8 +25,10 @@ __all__ = [
|
||||||
"HumanMessage",
|
"HumanMessage",
|
||||||
"AIMessage",
|
"AIMessage",
|
||||||
"SystemMessage",
|
"SystemMessage",
|
||||||
"ChatOpenAI",
|
|
||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
|
"ChatOpenAI",
|
||||||
|
"LCAzureChatOpenAI",
|
||||||
|
"LCChatOpenAI",
|
||||||
"LlamaCppChat",
|
"LlamaCppChat",
|
||||||
# completion-specific components
|
# completion-specific components
|
||||||
"LLM",
|
"LLM",
|
||||||
|
|
|
@ -18,5 +18,8 @@ class BaseLLM(BaseComponent):
|
||||||
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
|
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
|
def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
return self.invoke(*args, **kwargs)
|
||||||
|
|
|
@ -15,7 +15,7 @@ class SimpleBranchingPipeline(BaseComponent):
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from kotaemon.llms import (
|
from kotaemon.llms import (
|
||||||
AzureChatOpenAI,
|
LCAzureChatOpenAI,
|
||||||
BasePromptComponent,
|
BasePromptComponent,
|
||||||
GatedLinearPipeline,
|
GatedLinearPipeline,
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ class SimpleBranchingPipeline(BaseComponent):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
pipeline = SimpleBranchingPipeline()
|
pipeline = SimpleBranchingPipeline()
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
openai_api_base="your openai api base",
|
openai_api_base="your openai api base",
|
||||||
openai_api_key="your openai api key",
|
openai_api_key="your openai api key",
|
||||||
openai_api_version="your openai api version",
|
openai_api_version="your openai api version",
|
||||||
|
@ -92,7 +92,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from kotaemon.llms import (
|
from kotaemon.llms import (
|
||||||
AzureChatOpenAI,
|
LCAzureChatOpenAI,
|
||||||
BasePromptComponent,
|
BasePromptComponent,
|
||||||
GatedLinearPipeline,
|
GatedLinearPipeline,
|
||||||
)
|
)
|
||||||
|
@ -102,7 +102,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
pipeline = GatedBranchingPipeline()
|
pipeline = GatedBranchingPipeline()
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
openai_api_base="your openai api base",
|
openai_api_base="your openai api base",
|
||||||
openai_api_key="your openai api key",
|
openai_api_key="your openai api key",
|
||||||
openai_api_version="your openai api version",
|
openai_api_version="your openai api version",
|
||||||
|
@ -157,7 +157,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
from kotaemon.llms import BasePromptComponent, LCAzureChatOpenAI
|
||||||
from kotaemon.parsers import RegexExtractor
|
from kotaemon.parsers import RegexExtractor
|
||||||
|
|
||||||
def identity(x):
|
def identity(x):
|
||||||
|
@ -166,7 +166,7 @@ if __name__ == "__main__":
|
||||||
secrets = dotenv.dotenv_values(".env")
|
secrets = dotenv.dotenv_values(".env")
|
||||||
|
|
||||||
pipeline = GatedBranchingPipeline()
|
pipeline = GatedBranchingPipeline()
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
|
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
|
||||||
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
|
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
|
||||||
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
|
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
from .base import ChatLLM
|
from .base import ChatLLM
|
||||||
from .endpoint_based import EndpointChatLLM
|
from .endpoint_based import EndpointChatLLM
|
||||||
from .langchain_based import AzureChatOpenAI, ChatOpenAI, LCChatMixin
|
from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
|
||||||
from .llamacpp import LlamaCppChat
|
from .llamacpp import LlamaCppChat
|
||||||
|
from .openai import AzureChatOpenAI, ChatOpenAI
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ChatOpenAI",
|
||||||
|
"AzureChatOpenAI",
|
||||||
"ChatLLM",
|
"ChatLLM",
|
||||||
"EndpointChatLLM",
|
"EndpointChatLLM",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"AzureChatOpenAI",
|
"LCChatOpenAI",
|
||||||
|
"LCAzureChatOpenAI",
|
||||||
"LCChatMixin",
|
"LCChatMixin",
|
||||||
"LlamaCppChat",
|
"LlamaCppChat",
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,6 +5,7 @@ from kotaemon.base import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
|
Param,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +21,9 @@ class EndpointChatLLM(ChatLLM):
|
||||||
endpoint_url (str): The url of a OpenAI API compatible endpoint.
|
endpoint_url (str): The url of a OpenAI API compatible endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
endpoint_url: str
|
endpoint_url: str = Param(
|
||||||
|
help="URL of the OpenAI API compatible endpoint", required=True
|
||||||
|
)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
|
|
@ -165,7 +165,7 @@ class LCChatMixin:
|
||||||
raise ValueError(f"Invalid param {path}")
|
raise ValueError(f"Invalid param {path}")
|
||||||
|
|
||||||
|
|
||||||
class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
class LCChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
openai_api_base: str | None = None,
|
openai_api_base: str | None = None,
|
||||||
|
@ -193,7 +193,7 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
azure_endpoint: str | None = None,
|
azure_endpoint: str | None = None,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import TYPE_CHECKING, Optional, cast
|
from typing import TYPE_CHECKING, Iterator, Optional, cast
|
||||||
|
|
||||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
|
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
|
||||||
|
|
||||||
|
@ -12,13 +12,32 @@ if TYPE_CHECKING:
|
||||||
class LlamaCppChat(ChatLLM):
|
class LlamaCppChat(ChatLLM):
|
||||||
"""Wrapper around the llama-cpp-python's Llama model"""
|
"""Wrapper around the llama-cpp-python's Llama model"""
|
||||||
|
|
||||||
model_path: Optional[str] = None
|
model_path: str = Param(
|
||||||
chat_format: Optional[str] = None
|
help="Path to the model file. This is required to load the model.",
|
||||||
lora_base: Optional[str] = None
|
required=True,
|
||||||
n_ctx: int = 512
|
)
|
||||||
n_gpu_layers: int = 0
|
chat_format: str = Param(
|
||||||
use_mmap: bool = True
|
help=(
|
||||||
vocab_only: bool = False
|
"Chat format to use. Please refer to llama_cpp.llama_chat_format for a "
|
||||||
|
"list of supported formats. If blank, the chat format will be auto-"
|
||||||
|
"inferred."
|
||||||
|
),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
lora_base: Optional[str] = Param(None, help="Path to the base Lora model")
|
||||||
|
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
|
||||||
|
n_gpu_layers: Optional[int] = Param(
|
||||||
|
0,
|
||||||
|
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
|
||||||
|
)
|
||||||
|
use_mmap: Optional[bool] = Param(
|
||||||
|
True,
|
||||||
|
help=(),
|
||||||
|
)
|
||||||
|
vocab_only: Optional[bool] = Param(
|
||||||
|
False,
|
||||||
|
help=("If True, only the vocabulary is loaded. This is useful for debugging."),
|
||||||
|
)
|
||||||
|
|
||||||
_role_mapper: dict[str, str] = {
|
_role_mapper: dict[str, str] = {
|
||||||
"human": "user",
|
"human": "user",
|
||||||
|
@ -60,9 +79,9 @@ class LlamaCppChat(ChatLLM):
|
||||||
vocab_only=self.vocab_only,
|
vocab_only=self.vocab_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(
|
def prepare_message(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
self, messages: str | BaseMessage | list[BaseMessage]
|
||||||
) -> LLMInterface:
|
) -> list[dict]:
|
||||||
input_: list[BaseMessage] = []
|
input_: list[BaseMessage] = []
|
||||||
|
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
|
@ -72,11 +91,19 @@ class LlamaCppChat(ChatLLM):
|
||||||
else:
|
else:
|
||||||
input_ = messages
|
input_ = messages
|
||||||
|
|
||||||
|
output_ = [
|
||||||
|
{"role": self._role_mapper[each.type], "content": each.content}
|
||||||
|
for each in input_
|
||||||
|
]
|
||||||
|
|
||||||
|
return output_
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
) -> LLMInterface:
|
||||||
|
|
||||||
pred: "CCCR" = self.client_object.create_chat_completion(
|
pred: "CCCR" = self.client_object.create_chat_completion(
|
||||||
messages=[
|
messages=self.prepare_message(messages),
|
||||||
{"role": self._role_mapper[each.type], "content": each.content}
|
|
||||||
for each in input_
|
|
||||||
], # type: ignore
|
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,3 +118,19 @@ class LlamaCppChat(ChatLLM):
|
||||||
total_tokens=pred["usage"]["total_tokens"],
|
total_tokens=pred["usage"]["total_tokens"],
|
||||||
prompt_tokens=pred["usage"]["prompt_tokens"],
|
prompt_tokens=pred["usage"]["prompt_tokens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||||
|
) -> Iterator[LLMInterface]:
|
||||||
|
pred = self.client_object.create_chat_completion(
|
||||||
|
messages=self.prepare_message(messages),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for chunk in pred:
|
||||||
|
if not chunk["choices"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "content" not in chunk["choices"][0]["delta"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield LLMInterface(content=chunk["choices"][0]["delta"]["content"])
|
||||||
|
|
356
libs/kotaemon/kotaemon/llms/chats/openai.py
Normal file
356
libs/kotaemon/kotaemon/llms/chats/openai.py
Normal file
|
@ -0,0 +1,356 @@
|
||||||
|
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
|
||||||
|
|
||||||
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
|
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
|
||||||
|
|
||||||
|
from .base import ChatLLM
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openai.types.chat.chat_completion_message_param import (
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatOpenAI(ChatLLM):
|
||||||
|
"""Base interface for OpenAI chat model, using the openai library
|
||||||
|
|
||||||
|
This class exposes the parameters in resources.Chat. To subclass this class:
|
||||||
|
|
||||||
|
- Implement the `prepare_client` method to return the OpenAI client
|
||||||
|
- Implement the `openai_response` method to return the OpenAI response
|
||||||
|
- Implement the params relate to the OpenAI client
|
||||||
|
"""
|
||||||
|
|
||||||
|
_dependencies = ["openai"]
|
||||||
|
_capabilities = ["chat", "text"] # consider as mixin
|
||||||
|
|
||||||
|
api_key: str = Param(help="API key", required=True)
|
||||||
|
timeout: Optional[float] = Param(None, help="Timeout for the API request")
|
||||||
|
max_retries: Optional[int] = Param(
|
||||||
|
None, help="Maximum number of retries for the API request"
|
||||||
|
)
|
||||||
|
|
||||||
|
temperature: Optional[float] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Number between 0 and 2 that controls the randomness of the generated "
|
||||||
|
"tokens. Lower values make the model more deterministic, while higher "
|
||||||
|
"values make the model more random."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
max_tokens: Optional[int] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Maximum number of tokens to generate. The total length of input tokens "
|
||||||
|
"and generated tokens is limited by the model's context length."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
n: int = Param(
|
||||||
|
1,
|
||||||
|
help=(
|
||||||
|
"Number of completions to generate. The API will generate n completion "
|
||||||
|
"for each prompt."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
stop: Optional[str | list[str]] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Stop sequence. If a stop sequence is detected, generation will stop "
|
||||||
|
"at that point. If not specified, generation will continue until the "
|
||||||
|
"maximum token length is reached."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
frequency_penalty: Optional[float] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||||
|
"based on their existing frequency in the text so far, decrearsing the "
|
||||||
|
"model's likelihood of repeating the same text."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
presence_penalty: Optional[float] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Number between -2.0 and 2.0. Positive values penalize new tokens "
|
||||||
|
"based on their existing presence in the text so far, decrearsing the "
|
||||||
|
"model's likelihood of repeating the same text."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tool_choice: Optional[str] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Choice of tool to use for the completion. Available choices are: "
|
||||||
|
"auto, default."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tools: Optional[list[str]] = Param(
|
||||||
|
None,
|
||||||
|
help="List of tools to use for the completion.",
|
||||||
|
)
|
||||||
|
logprobs: Optional[bool] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Include log probabilities on the logprobs most likely tokens, "
|
||||||
|
"as well as the chosen token."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logit_bias: Optional[dict] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"Dictionary of logit bias values to add to the logits of the tokens "
|
||||||
|
"in the vocabulary."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
top_logprobs: Optional[int] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"An integer between 0 and 5 specifying the number of most likely tokens "
|
||||||
|
"to return at each token position, each with an associated log "
|
||||||
|
"probability. `logprobs` must also be set to `true` if this parameter "
|
||||||
|
"is used."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
top_p: Optional[float] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"An alternative to sampling with temperature, called nucleus sampling, "
|
||||||
|
"where the model considers the results of the token with top_p "
|
||||||
|
"probability mass. So 0.1 means that only the tokens comprising the "
|
||||||
|
"top 10% probability mass are considered."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@Param.auto(depends_on=["max_retries"])
|
||||||
|
def max_retries_(self):
|
||||||
|
if self.max_retries is None:
|
||||||
|
from openai._constants import DEFAULT_MAX_RETRIES
|
||||||
|
|
||||||
|
return DEFAULT_MAX_RETRIES
|
||||||
|
return self.max_retries
|
||||||
|
|
||||||
|
def prepare_message(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage]
|
||||||
|
) -> list["ChatCompletionMessageParam"]:
|
||||||
|
"""Prepare the message into OpenAI format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: List of messages in OpenAI format
|
||||||
|
"""
|
||||||
|
input_: list[BaseMessage] = []
|
||||||
|
output_: list["ChatCompletionMessageParam"] = []
|
||||||
|
|
||||||
|
if isinstance(messages, str):
|
||||||
|
input_ = [HumanMessage(content=messages)]
|
||||||
|
elif isinstance(messages, BaseMessage):
|
||||||
|
input_ = [messages]
|
||||||
|
else:
|
||||||
|
input_ = messages
|
||||||
|
|
||||||
|
for message in input_:
|
||||||
|
output_.append(message.to_openai_format())
|
||||||
|
|
||||||
|
return output_
|
||||||
|
|
||||||
|
def prepare_client(self, async_version: bool = False):
|
||||||
|
"""Get the OpenAI client
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_version (bool): Whether to get the async version of the client
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def openai_response(self, client, **kwargs):
|
||||||
|
"""Get the openai response"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
) -> LLMInterface:
|
||||||
|
client = self.prepare_client(async_version=False)
|
||||||
|
input_messages = self.prepare_message(messages)
|
||||||
|
resp = self.openai_response(
|
||||||
|
client, messages=input_messages, stream=False, **kwargs
|
||||||
|
).dict()
|
||||||
|
|
||||||
|
output = LLMInterface(
|
||||||
|
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
||||||
|
content=resp["choices"][0]["message"]["content"],
|
||||||
|
total_tokens=resp["usage"]["total_tokens"],
|
||||||
|
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=resp["usage"]["completion_tokens"],
|
||||||
|
messages=[
|
||||||
|
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
) -> LLMInterface:
|
||||||
|
client = self.prepare_client(async_version=True)
|
||||||
|
input_messages = self.prepare_message(messages)
|
||||||
|
resp = await self.openai_response(
|
||||||
|
client, messages=input_messages, stream=False, **kwargs
|
||||||
|
).dict()
|
||||||
|
|
||||||
|
output = LLMInterface(
|
||||||
|
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
||||||
|
content=resp["choices"][0]["message"]["content"],
|
||||||
|
total_tokens=resp["usage"]["total_tokens"],
|
||||||
|
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=resp["usage"]["completion_tokens"],
|
||||||
|
messages=[
|
||||||
|
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
) -> Iterator[LLMInterface]:
|
||||||
|
client = self.prepare_client(async_version=False)
|
||||||
|
input_messages = self.prepare_message(messages)
|
||||||
|
resp = self.openai_response(
|
||||||
|
client, messages=input_messages, stream=True, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in resp:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
yield LLMInterface(content=chunk.choices[0].delta.content)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
) -> AsyncGenerator[LLMInterface, None]:
|
||||||
|
client = self.prepare_client(async_version=True)
|
||||||
|
input_messages = self.prepare_message(messages)
|
||||||
|
resp = self.openai_response(
|
||||||
|
client, messages=input_messages, stream=True, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in resp:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
yield LLMInterface(content=chunk.choices[0].delta.content)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOpenAI(BaseChatOpenAI):
|
||||||
|
"""OpenAI chat model"""
|
||||||
|
|
||||||
|
base_url: Optional[str] = Param(None, help="OpenAI base URL")
|
||||||
|
organization: Optional[str] = Param(None, help="OpenAI organization")
|
||||||
|
model: str = Param(help="OpenAI model", required=True)
|
||||||
|
|
||||||
|
def prepare_client(self, async_version: bool = False):
|
||||||
|
"""Get the OpenAI client
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_version (bool): Whether to get the async version of the client
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"organization": self.organization,
|
||||||
|
"base_url": self.base_url,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"max_retries": self.max_retries_,
|
||||||
|
}
|
||||||
|
if async_version:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
return AsyncOpenAI(**params)
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
return OpenAI(**params)
|
||||||
|
|
||||||
|
def openai_response(self, client, **kwargs):
|
||||||
|
"""Get the openai response"""
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"tool_choice": self.tool_choice,
|
||||||
|
"tools": self.tools,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"top_logprobs": self.top_logprobs,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
return client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
|
"""OpenAI chat model provided by Microsoft Azure"""
|
||||||
|
|
||||||
|
azure_endpoint: str = Param(
|
||||||
|
help=(
|
||||||
|
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
|
||||||
|
"azure_deployment, and api_version parameters are used to construct "
|
||||||
|
"the full URL for the Azure OpenAI model."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
azure_deployment: str = Param(help="Azure deployment name", required=True)
|
||||||
|
api_version: str = Param(help="Azure model version", required=True)
|
||||||
|
azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
|
||||||
|
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
|
||||||
|
|
||||||
|
@Param.auto(depends_on=["azure_ad_token_provider"])
|
||||||
|
def azure_ad_token_provider_(self):
|
||||||
|
if isinstance(self.azure_ad_token_provider, str):
|
||||||
|
return import_dotted_string(self.azure_ad_token_provider, safe=False)
|
||||||
|
|
||||||
|
def prepare_client(self, async_version: bool = False):
|
||||||
|
"""Get the OpenAI client
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_version (bool): Whether to get the async version of the client
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"azure_endpoint": self.azure_endpoint,
|
||||||
|
"api_version": self.api_version,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"azure_ad_token": self.azure_ad_token,
|
||||||
|
"azure_ad_token_provider": self.azure_ad_token_provider_,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"max_retries": self.max_retries_,
|
||||||
|
}
|
||||||
|
if async_version:
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
|
return AsyncAzureOpenAI(**params)
|
||||||
|
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
|
return AzureOpenAI(**params)
|
||||||
|
|
||||||
|
def openai_response(self, client, **kwargs):
|
||||||
|
"""Get the openai response"""
|
||||||
|
params = {
|
||||||
|
"model": self.azure_deployment,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"tool_choice": self.tool_choice,
|
||||||
|
"tools": self.tools,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"top_logprobs": self.top_logprobs,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
return client.chat.completions.create(**params)
|
|
@ -5,7 +5,7 @@ from theflow import Function, Node, Param
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent, Document
|
from kotaemon.base import BaseComponent, Document
|
||||||
|
|
||||||
from .chats import AzureChatOpenAI
|
from .chats import LCAzureChatOpenAI
|
||||||
from .completions import LLM
|
from .completions import LLM
|
||||||
from .prompts import BasePromptComponent
|
from .prompts import BasePromptComponent
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class Thought(BaseComponent):
|
||||||
>> from kotaemon.pipelines.cot import Thought
|
>> from kotaemon.pipelines.cot import Thought
|
||||||
>> thought = Thought(
|
>> thought = Thought(
|
||||||
prompt="How to {action} {object}?",
|
prompt="How to {action} {object}?",
|
||||||
llm=AzureChatOpenAI(...),
|
llm=LCAzureChatOpenAI(...),
|
||||||
post_process=lambda string: {"tutorial": string},
|
post_process=lambda string: {"tutorial": string},
|
||||||
)
|
)
|
||||||
>> output = thought(action="install", object="python")
|
>> output = thought(action="install", object="python")
|
||||||
|
@ -42,7 +42,7 @@ class Thought(BaseComponent):
|
||||||
This `Thought` allows chaining sequentially with the + operator. For example:
|
This `Thought` allows chaining sequentially with the + operator. For example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>> llm = AzureChatOpenAI(...)
|
>> llm = LCAzureChatOpenAI(...)
|
||||||
>> thought1 = Thought(
|
>> thought1 = Thought(
|
||||||
prompt="Word {word} in {language} is ",
|
prompt="Word {word} in {language} is ",
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@ -73,7 +73,7 @@ class Thought(BaseComponent):
|
||||||
" component is executed"
|
" component is executed"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
llm: LLM = Node(LCAzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||||
post_process: Function = Node(
|
post_process: Function = Node(
|
||||||
help=(
|
help=(
|
||||||
"The function post-processor that post-processes LLM output prediction ."
|
"The function post-processor that post-processes LLM output prediction ."
|
||||||
|
@ -117,7 +117,7 @@ class ManualSequentialChainOfThought(BaseComponent):
|
||||||
|
|
||||||
```pycon
|
```pycon
|
||||||
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
|
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
|
||||||
>>> llm = AzureChatOpenAI(...)
|
>>> llm = LCAzureChatOpenAI(...)
|
||||||
>>> thought1 = Thought(
|
>>> thought1 = Thought(
|
||||||
>>> prompt="Word {word} in {language} is ",
|
>>> prompt="Word {word} in {language} is ",
|
||||||
>>> post_process=lambda string: {"translated": string},
|
>>> post_process=lambda string: {"translated": string},
|
||||||
|
|
|
@ -22,12 +22,12 @@ class SimpleLinearPipeline(BaseComponent):
|
||||||
|
|
||||||
Example Usage:
|
Example Usage:
|
||||||
```python
|
```python
|
||||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
|
||||||
|
|
||||||
def identity(x):
|
def identity(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
openai_api_base="your openai api base",
|
openai_api_base="your openai api base",
|
||||||
openai_api_key="your openai api key",
|
openai_api_key="your openai api key",
|
||||||
openai_api_version="your openai api version",
|
openai_api_version="your openai api version",
|
||||||
|
@ -89,13 +89,13 @@ class GatedLinearPipeline(SimpleLinearPipeline):
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```{.py3 title="Example Usage"}
|
```{.py3 title="Example Usage"}
|
||||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
|
||||||
from kotaemon.parsers import RegexExtractor
|
from kotaemon.parsers import RegexExtractor
|
||||||
|
|
||||||
def identity(x):
|
def identity(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
openai_api_base="your openai api base",
|
openai_api_base="your openai api base",
|
||||||
openai_api_key="your openai api key",
|
openai_api_key="your openai api key",
|
||||||
openai_api_version="your openai api version",
|
openai_api_version="your openai api version",
|
||||||
|
|
|
@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
||||||
# metadata and dependencies
|
# metadata and dependencies
|
||||||
[project]
|
[project]
|
||||||
name = "kotaemon"
|
name = "kotaemon"
|
||||||
version = "0.3.8"
|
version = "0.3.9"
|
||||||
requires-python = ">= 3.10"
|
requires-python = ">= 3.10"
|
||||||
description = "Kotaemon core library for AI development."
|
description = "Kotaemon core library for AI development."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
|
@ -13,7 +13,7 @@ from kotaemon.agents import (
|
||||||
RewooAgent,
|
RewooAgent,
|
||||||
WikipediaTool,
|
WikipediaTool,
|
||||||
)
|
)
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import LCAzureChatOpenAI
|
||||||
|
|
||||||
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
||||||
REWOO_VALID_PLAN = (
|
REWOO_VALID_PLAN = (
|
||||||
|
@ -112,7 +112,7 @@ _openai_chat_completion_responses_react_langchain_tool = [
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm():
|
def llm():
|
||||||
return AzureChatOpenAI(
|
return LCAzureChatOpenAI(
|
||||||
azure_endpoint="https://dummy.openai.azure.com/",
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
openai_api_key="dummy",
|
openai_api_key="dummy",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
|
|
@ -4,10 +4,10 @@ import pytest
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.llms import (
|
from kotaemon.llms import (
|
||||||
AzureChatOpenAI,
|
|
||||||
BasePromptComponent,
|
BasePromptComponent,
|
||||||
GatedBranchingPipeline,
|
GatedBranchingPipeline,
|
||||||
GatedLinearPipeline,
|
GatedLinearPipeline,
|
||||||
|
LCAzureChatOpenAI,
|
||||||
SimpleBranchingPipeline,
|
SimpleBranchingPipeline,
|
||||||
SimpleLinearPipeline,
|
SimpleLinearPipeline,
|
||||||
)
|
)
|
||||||
|
@ -40,7 +40,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llm():
|
def mock_llm():
|
||||||
return AzureChatOpenAI(
|
return LCAzureChatOpenAI(
|
||||||
azure_endpoint="OPENAI_API_BASE",
|
azure_endpoint="OPENAI_API_BASE",
|
||||||
openai_api_key="OPENAI_API_KEY",
|
openai_api_key="OPENAI_API_KEY",
|
||||||
openai_api_version="OPENAI_API_VERSION",
|
openai_api_version="OPENAI_API_VERSION",
|
||||||
|
|
|
@ -2,7 +2,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import LCAzureChatOpenAI
|
||||||
from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought
|
from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought
|
||||||
|
|
||||||
_openai_chat_completion_response = [
|
_openai_chat_completion_response = [
|
||||||
|
@ -38,7 +38,7 @@ _openai_chat_completion_response = [
|
||||||
side_effect=_openai_chat_completion_response,
|
side_effect=_openai_chat_completion_response,
|
||||||
)
|
)
|
||||||
def test_cot_plus_operator(openai_completion):
|
def test_cot_plus_operator(openai_completion):
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
azure_endpoint="https://dummy.openai.azure.com/",
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
openai_api_key="dummy",
|
openai_api_key="dummy",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
@ -70,7 +70,7 @@ def test_cot_plus_operator(openai_completion):
|
||||||
side_effect=_openai_chat_completion_response,
|
side_effect=_openai_chat_completion_response,
|
||||||
)
|
)
|
||||||
def test_cot_manual(openai_completion):
|
def test_cot_manual(openai_completion):
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
azure_endpoint="https://dummy.openai.azure.com/",
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
openai_api_key="dummy",
|
openai_api_key="dummy",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
@ -100,7 +100,7 @@ def test_cot_manual(openai_completion):
|
||||||
side_effect=_openai_chat_completion_response,
|
side_effect=_openai_chat_completion_response,
|
||||||
)
|
)
|
||||||
def test_cot_with_termination_callback(openai_completion):
|
def test_cot_with_termination_callback(openai_completion):
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
azure_endpoint="https://dummy.openai.azure.com/",
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
openai_api_key="dummy",
|
openai_api_key="dummy",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
|
|
@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage
|
from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage
|
||||||
from kotaemon.llms import AzureChatOpenAI, LlamaCppChat
|
from kotaemon.llms import LCAzureChatOpenAI, LlamaCppChat
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC
|
from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC
|
||||||
|
@ -43,7 +43,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||||
side_effect=lambda *args, **kwargs: _openai_chat_completion_response,
|
side_effect=lambda *args, **kwargs: _openai_chat_completion_response,
|
||||||
)
|
)
|
||||||
def test_azureopenai_model(openai_completion):
|
def test_azureopenai_model(openai_completion):
|
||||||
model = AzureChatOpenAI(
|
model = LCAzureChatOpenAI(
|
||||||
azure_endpoint="https://test.openai.azure.com/",
|
azure_endpoint="https://test.openai.azure.com/",
|
||||||
openai_api_key="some-key",
|
openai_api_key="some-key",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
|
|
@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
from kotaemon.indices.rankings import LLMReranking
|
from kotaemon.indices.rankings import LLMReranking
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import LCAzureChatOpenAI
|
||||||
|
|
||||||
_openai_chat_completion_responses = [
|
_openai_chat_completion_responses = [
|
||||||
ChatCompletion.parse_obj(
|
ChatCompletion.parse_obj(
|
||||||
|
@ -41,7 +41,7 @@ _openai_chat_completion_responses = [
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm():
|
def llm():
|
||||||
return AzureChatOpenAI(
|
return LCAzureChatOpenAI(
|
||||||
azure_endpoint="https://dummy.openai.azure.com/",
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
openai_api_key="dummy",
|
openai_api_key="dummy",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
|
|
@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||||
):
|
):
|
||||||
if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""):
|
if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""):
|
||||||
KH_LLMS["azure"] = {
|
KH_LLMS["azure"] = {
|
||||||
"def": {
|
"spec": {
|
||||||
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||||
"api_version": config("OPENAI_API_VERSION", default="")
|
"api_version": config("OPENAI_API_VERSION", default="")
|
||||||
or "2024-02-15-preview",
|
or "2024-02-15-preview",
|
||||||
"deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
|
"azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
|
||||||
"request_timeout": 10,
|
"timeout": 20,
|
||||||
"stream": False,
|
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
"accuracy": 5,
|
"accuracy": 5,
|
||||||
|
@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||||
}
|
}
|
||||||
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
||||||
KH_EMBEDDINGS["azure"] = {
|
KH_EMBEDDINGS["azure"] = {
|
||||||
"def": {
|
"spec": {
|
||||||
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||||
|
@ -164,5 +163,11 @@ KH_INDICES = [
|
||||||
"name": "File",
|
"name": "File",
|
||||||
"config": {},
|
"config": {},
|
||||||
"index_type": "ktem.index.file.FileIndex",
|
"index_type": "ktem.index.file.FileIndex",
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "Sample",
|
||||||
|
"config": {},
|
||||||
|
"index_type": "ktem.index.file.FileIndex",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import logging
|
import logging
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from theflow.settings import settings
|
from theflow.settings import settings
|
||||||
from theflow.utils.modules import deserialize
|
from theflow.utils.modules import deserialize
|
||||||
|
@ -48,7 +49,7 @@ class ModelPool:
|
||||||
self._default: list[str] = []
|
self._default: list[str] = []
|
||||||
|
|
||||||
for name, model in conf.items():
|
for name, model in conf.items():
|
||||||
self._models[name] = deserialize(model["def"], safe=False)
|
self._models[name] = deserialize(model["spec"], safe=False)
|
||||||
if model.get("default", False):
|
if model.get("default", False):
|
||||||
self._default.append(name)
|
self._default.append(name)
|
||||||
|
|
||||||
|
@ -58,11 +59,27 @@ class ModelPool:
|
||||||
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
|
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> BaseComponent:
|
def __getitem__(self, key: str) -> BaseComponent:
|
||||||
|
"""Get model by name"""
|
||||||
return self._models[key]
|
return self._models[key]
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: BaseComponent):
|
def __setitem__(self, key: str, value: BaseComponent):
|
||||||
|
"""Set model by name"""
|
||||||
self._models[key] = value
|
self._models[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key: str):
|
||||||
|
"""Delete model by name"""
|
||||||
|
del self._models[key]
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if model exists"""
|
||||||
|
return key in self._models
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: str, default: Optional[BaseComponent] = None
|
||||||
|
) -> Optional[BaseComponent]:
|
||||||
|
"""Get model by name with default value"""
|
||||||
|
return self._models.get(key, default)
|
||||||
|
|
||||||
def settings(self) -> dict:
|
def settings(self) -> dict:
|
||||||
"""Present model pools option for gradio"""
|
"""Present model pools option for gradio"""
|
||||||
return {
|
return {
|
||||||
|
@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS)
|
||||||
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
|
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
|
||||||
reasonings: dict = {}
|
reasonings: dict = {}
|
||||||
tools = ModelPool("Tools", {})
|
tools = ModelPool("Tools", {})
|
||||||
indices = ModelPool("Indices", {})
|
|
||||||
|
|
|
@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
from ktem.components import llms
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reranking_llm = llms.get_lowest_cost_name()
|
reranking_llm = llms.get_default_name()
|
||||||
reranking_llm_choices = list(llms.options().keys())
|
reranking_llm_choices = list(llms.options().keys())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
0
libs/ktem/ktem/llms/__init__.py
Normal file
0
libs/ktem/ktem/llms/__init__.py
Normal file
36
libs/ktem/ktem/llms/db.py
Normal file
36
libs/ktem/ktem/llms/db.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from ktem.db.engine import engine
|
||||||
|
from sqlalchemy import JSON, Boolean, Column, String
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMTable(Base):
|
||||||
|
"""Base table to store language model"""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
name = Column(String, primary_key=True, unique=True)
|
||||||
|
spec = Column(JSON, default={})
|
||||||
|
default = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
_base_llm: Type[BaseLLMTable] = (
|
||||||
|
import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False)
|
||||||
|
if hasattr(flowsettings, "KH_TABLE_LLM")
|
||||||
|
else BaseLLMTable
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTable(_base_llm): # type: ignore
|
||||||
|
__tablename__ = "llm_table"
|
||||||
|
|
||||||
|
|
||||||
|
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
|
||||||
|
LLMTable.metadata.create_all(engine)
|
191
libs/ktem/ktem/llms/manager.py
Normal file
191
libs/ktem/ktem/llms/manager.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
from theflow.utils.modules import deserialize
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
from .db import LLMTable, engine
|
||||||
|
|
||||||
|
|
||||||
|
class LLMManager:
|
||||||
|
"""Represent a pool of models"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._models: dict[str, BaseComponent] = {}
|
||||||
|
self._info: dict[str, dict] = {}
|
||||||
|
self._default: str = ""
|
||||||
|
self._vendors: list[Type] = []
|
||||||
|
|
||||||
|
if hasattr(flowsettings, "KH_LLMS"):
|
||||||
|
for name, model in flowsettings.KH_LLMS.items():
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = select(LLMTable).where(LLMTable.name == name)
|
||||||
|
result = session.execute(stmt)
|
||||||
|
if not result.first():
|
||||||
|
item = LLMTable(
|
||||||
|
name=name,
|
||||||
|
spec=model["spec"],
|
||||||
|
default=model.get("default", False),
|
||||||
|
)
|
||||||
|
session.add(item)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
self.load_vendors()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
"""Load the model pool from database"""
|
||||||
|
self._models, self._info, self._defaut = {}, {}, ""
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = select(LLMTable)
|
||||||
|
items = session.execute(stmt)
|
||||||
|
|
||||||
|
for (item,) in items:
|
||||||
|
self._models[item.name] = deserialize(item.spec, safe=False)
|
||||||
|
self._info[item.name] = {
|
||||||
|
"name": item.name,
|
||||||
|
"spec": item.spec,
|
||||||
|
"default": item.default,
|
||||||
|
}
|
||||||
|
if item.default:
|
||||||
|
self._default = item.name
|
||||||
|
|
||||||
|
def load_vendors(self):
|
||||||
|
from kotaemon.llms import (
|
||||||
|
AzureChatOpenAI,
|
||||||
|
ChatOpenAI,
|
||||||
|
EndpointChatLLM,
|
||||||
|
LlamaCppChat,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> BaseComponent:
|
||||||
|
"""Get model by name"""
|
||||||
|
return self._models[key]
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if model exists"""
|
||||||
|
return key in self._models
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: str, default: Optional[BaseComponent] = None
|
||||||
|
) -> Optional[BaseComponent]:
|
||||||
|
"""Get model by name with default value"""
|
||||||
|
return self._models.get(key, default)
|
||||||
|
|
||||||
|
def settings(self) -> dict:
|
||||||
|
"""Present model pools option for gradio"""
|
||||||
|
return {
|
||||||
|
"label": "LLM",
|
||||||
|
"choices": list(self._models.keys()),
|
||||||
|
"value": self.get_default_name(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def options(self) -> dict:
|
||||||
|
"""Present a dict of models"""
|
||||||
|
return self._models
|
||||||
|
|
||||||
|
def get_random_name(self) -> str:
|
||||||
|
"""Get the name of random model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: random model name in the pool
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
if not self._models:
|
||||||
|
raise ValueError("No models in pool")
|
||||||
|
|
||||||
|
return random.choice(list(self._models.keys()))
|
||||||
|
|
||||||
|
def get_default_name(self) -> str:
|
||||||
|
"""Get the name of default model
|
||||||
|
|
||||||
|
In case there is no default model, choose random model from pool. In
|
||||||
|
case there are multiple default models, choose random from them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: model name
|
||||||
|
"""
|
||||||
|
if not self._models:
|
||||||
|
raise ValueError("No models in pool")
|
||||||
|
|
||||||
|
if not self._default:
|
||||||
|
return self.get_random_name()
|
||||||
|
|
||||||
|
return self._default
|
||||||
|
|
||||||
|
def get_random(self) -> BaseComponent:
|
||||||
|
"""Get random model"""
|
||||||
|
return self._models[self.get_random_name()]
|
||||||
|
|
||||||
|
def get_default(self) -> BaseComponent:
|
||||||
|
"""Get default model
|
||||||
|
|
||||||
|
In case there is no default model, choose random model from pool. In
|
||||||
|
case there are multiple default models, choose random from them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseComponent: model
|
||||||
|
"""
|
||||||
|
return self._models[self.get_default_name()]
|
||||||
|
|
||||||
|
def info(self) -> dict:
|
||||||
|
"""List all models"""
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
def add(self, name: str, spec: dict, default: bool):
|
||||||
|
"""Add a new model to the pool"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
item = LLMTable(name=name, spec=spec, default=default)
|
||||||
|
session.add(item)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to add model {name}: {e}")
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def delete(self, name: str):
|
||||||
|
"""Delete a model from the pool"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
item = session.query(LLMTable).filter_by(name=name).first()
|
||||||
|
session.delete(item)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to delete model {name}: {e}")
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def update(self, name: str, spec: dict, default: bool):
|
||||||
|
"""Update a model in the pool"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
|
||||||
|
if default:
|
||||||
|
# turn all models to non-default
|
||||||
|
session.query(LLMTable).update({"default": False})
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
item = session.query(LLMTable).filter_by(name=name).first()
|
||||||
|
if not item:
|
||||||
|
raise ValueError(f"Model {name} not found")
|
||||||
|
item.spec = spec
|
||||||
|
item.default = default
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to update model {name}: {e}")
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def vendors(self) -> dict:
|
||||||
|
"""Return list of vendors"""
|
||||||
|
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
||||||
|
|
||||||
|
|
||||||
|
llms = LLMManager()
|
318
libs/ktem/ktem/llms/ui.py
Normal file
318
libs/ktem/ktem/llms/ui.py
Normal file
|
@ -0,0 +1,318 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
from ktem.app import BasePage
|
||||||
|
|
||||||
|
from .manager import llms
|
||||||
|
|
||||||
|
|
||||||
|
def format_description(cls):
|
||||||
|
params = cls.describe()["params"]
|
||||||
|
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
|
||||||
|
for key, value in params.items():
|
||||||
|
if isinstance(value["auto_callback"], str):
|
||||||
|
continue
|
||||||
|
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
|
||||||
|
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMManagement(BasePage):
|
||||||
|
def __init__(self, app):
|
||||||
|
self._app = app
|
||||||
|
self.spec_desc_default = (
|
||||||
|
"# Spec description\n\nSelect an LLM to view the spec description."
|
||||||
|
)
|
||||||
|
self.on_building_ui()
|
||||||
|
|
||||||
|
def on_building_ui(self):
|
||||||
|
with gr.Tab(label="View"):
|
||||||
|
self.llm_list = gr.DataFrame(
|
||||||
|
headers=["name", "vendor", "default"],
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(visible=False) as self._selected_panel:
|
||||||
|
self.selected_llm_name = gr.Textbox(value="", visible=False)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
self.edit_default = gr.Checkbox(
|
||||||
|
label="Set default",
|
||||||
|
info=(
|
||||||
|
"Set this LLM as default. If no default is set, a "
|
||||||
|
"random LLM will be used."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.edit_spec = gr.Textbox(
|
||||||
|
label="Specification",
|
||||||
|
info="Specification of the LLM in YAML format",
|
||||||
|
lines=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||||
|
with gr.Column():
|
||||||
|
self.btn_edit_save = gr.Button("Save", min_width=10)
|
||||||
|
with gr.Column():
|
||||||
|
self.btn_delete = gr.Button("Delete", min_width=10)
|
||||||
|
with gr.Row():
|
||||||
|
self.btn_delete_yes = gr.Button(
|
||||||
|
"Confirm delete",
|
||||||
|
variant="primary",
|
||||||
|
visible=False,
|
||||||
|
min_width=10,
|
||||||
|
)
|
||||||
|
self.btn_delete_no = gr.Button(
|
||||||
|
"Cancel", visible=False, min_width=10
|
||||||
|
)
|
||||||
|
with gr.Column():
|
||||||
|
self.btn_close = gr.Button("Close", min_width=10)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
self.edit_spec_desc = gr.Markdown("# Spec description")
|
||||||
|
|
||||||
|
with gr.Tab(label="Add"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
self.name = gr.Textbox(
|
||||||
|
label="LLM name",
|
||||||
|
info=(
|
||||||
|
"Must be unique. The name will be used to identify the LLM."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.llm_choices = gr.Dropdown(
|
||||||
|
label="LLM vendors",
|
||||||
|
info=(
|
||||||
|
"Choose the vendor for the LLM. Each vendor has different "
|
||||||
|
"specification."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.spec = gr.Textbox(
|
||||||
|
label="Specification",
|
||||||
|
info="Specification of the LLM in YAML format",
|
||||||
|
)
|
||||||
|
self.default = gr.Checkbox(
|
||||||
|
label="Set default",
|
||||||
|
info=(
|
||||||
|
"Set this LLM as default. This default LLM will be used "
|
||||||
|
"by default across the application."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.btn_new = gr.Button("Create LLM")
|
||||||
|
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
self.spec_desc = gr.Markdown(self.spec_desc_default)
|
||||||
|
|
||||||
|
def _on_app_created(self):
|
||||||
|
"""Called when the app is created"""
|
||||||
|
self._app.app.load(
|
||||||
|
self.list_llms,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[self.llm_list],
|
||||||
|
)
|
||||||
|
self._app.app.load(
|
||||||
|
lambda: gr.update(choices=list(llms.vendors().keys())),
|
||||||
|
outputs=[self.llm_choices],
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_llm_vendor_change(self, vendor):
|
||||||
|
vendor = llms.vendors()[vendor]
|
||||||
|
|
||||||
|
required: dict = {}
|
||||||
|
desc = vendor.describe()
|
||||||
|
for key, value in desc["params"].items():
|
||||||
|
if value.get("required", False):
|
||||||
|
required[key] = None
|
||||||
|
|
||||||
|
return yaml.dump(required), format_description(vendor)
|
||||||
|
|
||||||
|
def on_register_events(self):
|
||||||
|
self.llm_choices.select(
|
||||||
|
self.on_llm_vendor_change,
|
||||||
|
inputs=[self.llm_choices],
|
||||||
|
outputs=[self.spec, self.spec_desc],
|
||||||
|
)
|
||||||
|
self.btn_new.click(
|
||||||
|
self.create_llm,
|
||||||
|
inputs=[self.name, self.llm_choices, self.spec, self.default],
|
||||||
|
outputs=None,
|
||||||
|
).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then(
|
||||||
|
lambda: ("", None, "", False, self.spec_desc_default),
|
||||||
|
outputs=[
|
||||||
|
self.name,
|
||||||
|
self.llm_choices,
|
||||||
|
self.spec,
|
||||||
|
self.default,
|
||||||
|
self.spec_desc,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.llm_list.select(
|
||||||
|
self.select_llm,
|
||||||
|
inputs=self.llm_list,
|
||||||
|
outputs=[self.selected_llm_name],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.selected_llm_name.change(
|
||||||
|
self.on_selected_llm_change,
|
||||||
|
inputs=[self.selected_llm_name],
|
||||||
|
outputs=[
|
||||||
|
self._selected_panel,
|
||||||
|
self._selected_panel_btn,
|
||||||
|
# delete section
|
||||||
|
self.btn_delete,
|
||||||
|
self.btn_delete_yes,
|
||||||
|
self.btn_delete_no,
|
||||||
|
# edit section
|
||||||
|
self.edit_spec,
|
||||||
|
self.edit_spec_desc,
|
||||||
|
self.edit_default,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.btn_delete.click(
|
||||||
|
self.on_btn_delete_click,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.btn_delete_yes.click(
|
||||||
|
self.delete_llm,
|
||||||
|
inputs=[self.selected_llm_name],
|
||||||
|
outputs=[self.selected_llm_name],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
self.list_llms,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[self.llm_list],
|
||||||
|
)
|
||||||
|
self.btn_delete_no.click(
|
||||||
|
lambda: (
|
||||||
|
gr.update(visible=True),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(visible=False),
|
||||||
|
),
|
||||||
|
inputs=None,
|
||||||
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.btn_edit_save.click(
|
||||||
|
self.save_llm,
|
||||||
|
inputs=[
|
||||||
|
self.selected_llm_name,
|
||||||
|
self.edit_default,
|
||||||
|
self.edit_spec,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
self.list_llms,
|
||||||
|
inputs=None,
|
||||||
|
outputs=[self.llm_list],
|
||||||
|
)
|
||||||
|
self.btn_close.click(
|
||||||
|
lambda: "",
|
||||||
|
outputs=[self.selected_llm_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_llm(self, name, choices, spec, default):
|
||||||
|
try:
|
||||||
|
spec = yaml.safe_load(spec)
|
||||||
|
spec["__type__"] = (
|
||||||
|
llms.vendors()[choices].__module__
|
||||||
|
+ "."
|
||||||
|
+ llms.vendors()[choices].__qualname__
|
||||||
|
)
|
||||||
|
|
||||||
|
llms.add(name, spec=spec, default=default)
|
||||||
|
gr.Info(f"LLM {name} created successfully")
|
||||||
|
except Exception as e:
|
||||||
|
gr.Error(f"Failed to create LLM {name}: {e}")
|
||||||
|
|
||||||
|
def list_llms(self):
|
||||||
|
"""List the LLMs"""
|
||||||
|
items = []
|
||||||
|
for item in llms.info().values():
|
||||||
|
record = {}
|
||||||
|
record["name"] = item["name"]
|
||||||
|
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
|
||||||
|
record["default"] = item["default"]
|
||||||
|
items.append(record)
|
||||||
|
|
||||||
|
if items:
|
||||||
|
llm_list = pd.DataFrame.from_records(items)
|
||||||
|
else:
|
||||||
|
llm_list = pd.DataFrame.from_records(
|
||||||
|
[{"name": "-", "vendor": "-", "default": "-"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm_list
|
||||||
|
|
||||||
|
def select_llm(self, llm_list, ev: gr.SelectData):
|
||||||
|
if ev.value == "-" and ev.index[0] == 0:
|
||||||
|
gr.Info("No LLM is loaded. Please add LLM first")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if not ev.selected:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return llm_list["name"][ev.index[0]]
|
||||||
|
|
||||||
|
def on_selected_llm_change(self, selected_llm_name):
|
||||||
|
if selected_llm_name == "":
|
||||||
|
_selected_panel = gr.update(visible=False)
|
||||||
|
_selected_panel_btn = gr.update(visible=False)
|
||||||
|
btn_delete = gr.update(visible=True)
|
||||||
|
btn_delete_yes = gr.update(visible=False)
|
||||||
|
btn_delete_no = gr.update(visible=False)
|
||||||
|
edit_spec = gr.update(value="")
|
||||||
|
edit_spec_desc = gr.update(value="")
|
||||||
|
edit_default = gr.update(value=False)
|
||||||
|
else:
|
||||||
|
_selected_panel = gr.update(visible=True)
|
||||||
|
_selected_panel_btn = gr.update(visible=True)
|
||||||
|
btn_delete = gr.update(visible=True)
|
||||||
|
btn_delete_yes = gr.update(visible=False)
|
||||||
|
btn_delete_no = gr.update(visible=False)
|
||||||
|
|
||||||
|
info = deepcopy(llms.info()[selected_llm_name])
|
||||||
|
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
|
||||||
|
vendor = llms.vendors()[vendor_str]
|
||||||
|
|
||||||
|
edit_spec = yaml.dump(info["spec"])
|
||||||
|
edit_spec_desc = format_description(vendor)
|
||||||
|
edit_default = info["default"]
|
||||||
|
|
||||||
|
return (
|
||||||
|
_selected_panel,
|
||||||
|
_selected_panel_btn,
|
||||||
|
btn_delete,
|
||||||
|
btn_delete_yes,
|
||||||
|
btn_delete_no,
|
||||||
|
edit_spec,
|
||||||
|
edit_spec_desc,
|
||||||
|
edit_default,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_btn_delete_click(self):
|
||||||
|
btn_delete = gr.update(visible=False)
|
||||||
|
btn_delete_yes = gr.update(visible=True)
|
||||||
|
btn_delete_no = gr.update(visible=True)
|
||||||
|
|
||||||
|
return btn_delete, btn_delete_yes, btn_delete_no
|
||||||
|
|
||||||
|
def save_llm(self, selected_llm_name, default, spec):
|
||||||
|
try:
|
||||||
|
spec = yaml.safe_load(spec)
|
||||||
|
spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"]
|
||||||
|
llms.update(selected_llm_name, spec=spec, default=default)
|
||||||
|
gr.Info(f"LLM {selected_llm_name} saved successfully")
|
||||||
|
except Exception as e:
|
||||||
|
gr.Error(f"Failed to save LLM {selected_llm_name}: {e}")
|
||||||
|
|
||||||
|
def delete_llm(self, selected_llm_name):
|
||||||
|
try:
|
||||||
|
llms.delete(selected_llm_name)
|
||||||
|
except Exception as e:
|
||||||
|
gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}")
|
||||||
|
return selected_llm_name
|
||||||
|
|
||||||
|
return ""
|
|
@ -1,6 +1,7 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
from ktem.db.models import User, engine
|
from ktem.db.models import User, engine
|
||||||
|
from ktem.llms.ui import LLMManagement
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .user import UserManagement
|
from .user import UserManagement
|
||||||
|
@ -16,6 +17,9 @@ class AdminPage(BasePage):
|
||||||
with gr.Tab("User Management", visible=False) as self.user_management_tab:
|
with gr.Tab("User Management", visible=False) as self.user_management_tab:
|
||||||
self.user_management = UserManagement(self._app)
|
self.user_management = UserManagement(self._app)
|
||||||
|
|
||||||
|
with gr.Tab("LLM Management") as self.llm_management_tab:
|
||||||
|
self.llm_management = LLMManagement(self._app)
|
||||||
|
|
||||||
def on_subscribe_public_events(self):
|
def on_subscribe_public_events(self):
|
||||||
if self._app.f_user_management:
|
if self._app.f_user_management:
|
||||||
self._app.subscribe_event(
|
self._app.subscribe_event(
|
||||||
|
|
|
@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
|
||||||
from .chat_panel import ChatPanel
|
from .chat_panel import ChatPanel
|
||||||
from .chat_suggestion import ChatSuggestion
|
from .chat_suggestion import ChatSuggestion
|
||||||
from .common import STATE
|
from .common import STATE
|
||||||
|
@ -189,6 +191,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
|
self.chat_state,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
@ -220,6 +223,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
|
self.chat_state,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
@ -392,7 +396,7 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
return pipeline, reasoning_state
|
return pipeline, reasoning_state
|
||||||
|
|
||||||
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||||
"""Chat function"""
|
"""Chat function"""
|
||||||
chat_input = chat_history[-1][0]
|
chat_input = chat_history[-1][0]
|
||||||
chat_history = chat_history[:-1]
|
chat_history = chat_history[:-1]
|
||||||
|
@ -403,52 +407,43 @@ class ChatPage(BasePage):
|
||||||
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
||||||
pipeline.set_output_queue(queue)
|
pipeline.set_output_queue(queue)
|
||||||
|
|
||||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
|
||||||
text, refs = "", ""
|
text, refs = "", ""
|
||||||
|
|
||||||
len_ref = -1 # for logging purpose
|
|
||||||
msg_placeholder = getattr(
|
msg_placeholder = getattr(
|
||||||
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
|
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
|
||||||
)
|
)
|
||||||
|
|
||||||
print(msg_placeholder)
|
print(msg_placeholder)
|
||||||
while True:
|
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
|
||||||
try:
|
|
||||||
response = queue.get_nowait()
|
len_ref = -1 # for logging purpose
|
||||||
except Exception:
|
|
||||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
||||||
yield chat_history + [
|
|
||||||
(chat_input, text or msg_placeholder)
|
if not isinstance(response, Document):
|
||||||
], refs, state
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if response is None:
|
if response.channel is None:
|
||||||
queue.task_done()
|
continue
|
||||||
print("Chat completed")
|
|
||||||
break
|
|
||||||
|
|
||||||
if "output" in response:
|
if response.channel == "chat":
|
||||||
if response["output"] is None:
|
if response.content is None:
|
||||||
text = ""
|
text = ""
|
||||||
else:
|
else:
|
||||||
text += response["output"]
|
text += response.content
|
||||||
|
|
||||||
if "evidence" in response:
|
if response.channel == "info":
|
||||||
if response["evidence"] is None:
|
if response.content is None:
|
||||||
refs = ""
|
refs = ""
|
||||||
else:
|
else:
|
||||||
refs += response["evidence"]
|
refs += response.content
|
||||||
|
|
||||||
if len(refs) > len_ref:
|
if len(refs) > len_ref:
|
||||||
print(f"Len refs: {len(refs)}")
|
print(f"Len refs: {len(refs)}")
|
||||||
len_ref = len(refs)
|
len_ref = len(refs)
|
||||||
|
|
||||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||||
yield chat_history + [(chat_input, text)], refs, state
|
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
|
||||||
|
|
||||||
async def regen_fn(
|
def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||||
self, conversation_id, chat_history, settings, state, *selecteds
|
|
||||||
):
|
|
||||||
"""Regen function"""
|
"""Regen function"""
|
||||||
if not chat_history:
|
if not chat_history:
|
||||||
gr.Warning("Empty chat")
|
gr.Warning("Empty chat")
|
||||||
|
@ -456,12 +451,11 @@ class ChatPage(BasePage):
|
||||||
return
|
return
|
||||||
|
|
||||||
state["app"]["regen"] = True
|
state["app"]["regen"] = True
|
||||||
async for chat, refs, state in self.chat_fn(
|
for chat, refs, state in self.chat_fn(
|
||||||
conversation_id, chat_history, settings, state, *selecteds
|
conversation_id, chat_history, settings, state, *selecteds
|
||||||
):
|
):
|
||||||
new_state = deepcopy(state)
|
new_state = deepcopy(state)
|
||||||
new_state["app"]["regen"] = False
|
new_state["app"]["regen"] = False
|
||||||
yield chat, refs, new_state
|
yield chat, refs, new_state
|
||||||
else:
|
|
||||||
state["app"]["regen"] = False
|
state["app"]["regen"] = False
|
||||||
yield chat_history, "", state
|
|
||||||
|
|
|
@ -4,10 +4,10 @@ import logging
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from ktem.components import llms
|
from ktem.llms.manager import llms
|
||||||
from theflow.settings import settings as flowsettings
|
|
||||||
|
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
BaseComponent,
|
BaseComponent,
|
||||||
|
@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
lang: the language of the answer. Currently support English and Japanese
|
lang: the language of the answer. Currently support English and Japanese
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
|
vlm_endpoint: str = ""
|
||||||
citation_pipeline: CitationPipeline = Node(
|
citation_pipeline: CitationPipeline = Node(
|
||||||
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
|
||||||
)
|
)
|
||||||
|
|
||||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||||
|
@ -297,13 +297,95 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
def stream( # type: ignore
|
||||||
|
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||||
|
) -> Generator[Document, None, Document]:
|
||||||
|
"""Answer the question based on the evidence
|
||||||
|
|
||||||
def extract_evidence_images(self, evidence: str):
|
In addition to the question and the evidence, this method also take into
|
||||||
"""Util function to extract and isolate images from context/evidence"""
|
account evidence_mode. The evidence_mode tells which kind of evidence is.
|
||||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
The kind of evidence affects:
|
||||||
matches = re.findall(image_pattern, evidence)
|
1. How the evidence is represented.
|
||||||
context = re.sub(image_pattern, "", evidence)
|
2. The prompt to generate the answer.
|
||||||
return context, matches
|
|
||||||
|
By default, the evidence_mode is 0, which means the evidence is plain text with
|
||||||
|
no particular semantic representation. The evidence_mode can be:
|
||||||
|
1. "table": There will be HTML markup telling that there is a table
|
||||||
|
within the evidence.
|
||||||
|
2. "chatbot": There will be HTML markup telling that there is a chatbot.
|
||||||
|
This chatbot is a scenario, extracted from an Excel file, where each
|
||||||
|
row corresponds to an interaction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: the original question posed by user
|
||||||
|
evidence: the text that contain relevant information to answer the question
|
||||||
|
(determined by retrieval pipeline)
|
||||||
|
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||||
|
"""
|
||||||
|
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||||
|
prompt_template = PromptTemplate(self.qa_template)
|
||||||
|
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||||
|
prompt_template = PromptTemplate(self.qa_table_template)
|
||||||
|
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||||
|
else:
|
||||||
|
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
# isolate image from evidence
|
||||||
|
evidence, images = self.extract_evidence_images(evidence)
|
||||||
|
prompt = prompt_template.populate(
|
||||||
|
context=evidence,
|
||||||
|
question=question,
|
||||||
|
lang=self.lang,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_template.populate(
|
||||||
|
context=evidence,
|
||||||
|
question=question,
|
||||||
|
lang=self.lang,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
||||||
|
output += text
|
||||||
|
yield Document(channel="chat", content=text)
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
if self.system_prompt:
|
||||||
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
messages.append(HumanMessage(content=prompt))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# try streaming first
|
||||||
|
print("Trying LLM streaming")
|
||||||
|
for text in self.llm.stream(messages):
|
||||||
|
output += text.text
|
||||||
|
yield Document(channel="chat", content=text.text)
|
||||||
|
except NotImplementedError:
|
||||||
|
print("Streaming is not supported, falling back to normal processing")
|
||||||
|
output = self.llm(messages).text
|
||||||
|
yield Document(channel="chat", content=output)
|
||||||
|
|
||||||
|
# retrieve the citation
|
||||||
|
citation = None
|
||||||
|
if evidence and self.enable_citation:
|
||||||
|
citation = self.citation_pipeline.invoke(
|
||||||
|
context=evidence, question=question
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = Document(text=output, metadata={"citation": citation})
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def extract_evidence_images(self, evidence: str):
|
||||||
|
"""Util function to extract and isolate images from context/evidence"""
|
||||||
|
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||||
|
matches = re.findall(image_pattern, evidence)
|
||||||
|
context = re.sub(image_pattern, "", evidence)
|
||||||
|
return context, matches
|
||||||
|
|
||||||
|
|
||||||
class RewriteQuestionPipeline(BaseComponent):
|
class RewriteQuestionPipeline(BaseComponent):
|
||||||
|
@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent):
|
||||||
lang: the language of the answer. Currently support English and Japanese
|
lang: the language of the answer. Currently support English and Japanese
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
rewrite_template: str = DEFAULT_REWRITE_PROMPT
|
rewrite_template: str = DEFAULT_REWRITE_PROMPT
|
||||||
|
|
||||||
lang: str = "English"
|
lang: str = "English"
|
||||||
|
|
||||||
async def run(self, question: str) -> Document: # type: ignore
|
def run(self, question: str) -> Document: # type: ignore
|
||||||
prompt_template = PromptTemplate(self.rewrite_template)
|
prompt_template = PromptTemplate(self.rewrite_template)
|
||||||
prompt = prompt_template.populate(question=question, lang=self.lang)
|
prompt = prompt_template.populate(question=question, lang=self.lang)
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content="You are a helpful assistant"),
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
HumanMessage(content=prompt),
|
HumanMessage(content=prompt),
|
||||||
]
|
]
|
||||||
output = ""
|
return self.llm(messages)
|
||||||
for text in self.llm(messages):
|
|
||||||
if "content" in text:
|
|
||||||
output += text[1]
|
|
||||||
self.report_output({"chat_input": text[1]})
|
|
||||||
break
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
return Document(text=output)
|
|
||||||
|
|
||||||
|
|
||||||
class FullQAPipeline(BaseReasoning):
|
class FullQAPipeline(BaseReasoning):
|
||||||
|
@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||||
use_rewrite: bool = False
|
use_rewrite: bool = False
|
||||||
|
|
||||||
async def run( # type: ignore
|
async def ainvoke( # type: ignore
|
||||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
) -> Document: # type: ignore
|
) -> Document: # type: ignore
|
||||||
import markdown
|
import markdown
|
||||||
|
@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning):
|
||||||
self.report_output(None)
|
self.report_output(None)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
def stream( # type: ignore
|
||||||
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
|
) -> Generator[Document, None, Document]:
|
||||||
|
import markdown
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
doc_ids = []
|
||||||
|
if self.use_rewrite:
|
||||||
|
message = self.rewrite_pipeline(question=message).text
|
||||||
|
|
||||||
|
for retriever in self.retrievers:
|
||||||
|
for doc in retriever(text=message):
|
||||||
|
if doc.doc_id not in doc_ids:
|
||||||
|
docs.append(doc)
|
||||||
|
doc_ids.append(doc.doc_id)
|
||||||
|
for doc in docs:
|
||||||
|
# TODO: a better approach to show the information
|
||||||
|
text = markdown.markdown(
|
||||||
|
doc.text, extensions=["markdown.extensions.tables"]
|
||||||
|
)
|
||||||
|
yield Document(
|
||||||
|
content=(
|
||||||
|
"<details open>"
|
||||||
|
f"<summary>{doc.metadata['file_name']}</summary>"
|
||||||
|
f"{text}"
|
||||||
|
"</details><br>"
|
||||||
|
),
|
||||||
|
channel="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||||
|
answer = yield from self.answering_pipeline.stream(
|
||||||
|
question=message,
|
||||||
|
history=history,
|
||||||
|
evidence=evidence,
|
||||||
|
evidence_mode=evidence_mode,
|
||||||
|
conv_id=conv_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare citation
|
||||||
|
spans = defaultdict(list)
|
||||||
|
if answer.metadata["citation"] is not None:
|
||||||
|
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||||
|
for quote in fact_with_evidence.substring_quote:
|
||||||
|
for doc in docs:
|
||||||
|
start_idx = doc.text.find(quote)
|
||||||
|
if start_idx == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
end_idx = start_idx + len(quote)
|
||||||
|
|
||||||
|
current_idx = start_idx
|
||||||
|
if "|" not in doc.text[start_idx:end_idx]:
|
||||||
|
spans[doc.doc_id].append(
|
||||||
|
{"start": start_idx, "end": end_idx}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
while doc.text[current_idx:end_idx].find("|") != -1:
|
||||||
|
match_idx = doc.text[current_idx:end_idx].find("|")
|
||||||
|
spans[doc.doc_id].append(
|
||||||
|
{
|
||||||
|
"start": current_idx,
|
||||||
|
"end": current_idx + match_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_idx += match_idx + 2
|
||||||
|
if current_idx > end_idx:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
id2docs = {doc.doc_id: doc for doc in docs}
|
||||||
|
lack_evidence = True
|
||||||
|
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||||
|
yield Document(channel="info", content=None)
|
||||||
|
for id, ss in spans.items():
|
||||||
|
if not ss:
|
||||||
|
not_detected.add(id)
|
||||||
|
continue
|
||||||
|
ss = sorted(ss, key=lambda x: x["start"])
|
||||||
|
text = id2docs[id].text[: ss[0]["start"]]
|
||||||
|
for idx, span in enumerate(ss):
|
||||||
|
text += (
|
||||||
|
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
|
||||||
|
)
|
||||||
|
if idx < len(ss) - 1:
|
||||||
|
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||||
|
text += id2docs[id].text[ss[-1]["end"] :]
|
||||||
|
text_out = markdown.markdown(
|
||||||
|
text, extensions=["markdown.extensions.tables"]
|
||||||
|
)
|
||||||
|
yield Document(
|
||||||
|
content=(
|
||||||
|
"<details open>"
|
||||||
|
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||||
|
f"{text_out}"
|
||||||
|
"</details><br>"
|
||||||
|
),
|
||||||
|
channel="info",
|
||||||
|
)
|
||||||
|
lack_evidence = False
|
||||||
|
|
||||||
|
if lack_evidence:
|
||||||
|
yield Document(channel="info", content="No evidence found.\n")
|
||||||
|
|
||||||
|
if not_detected:
|
||||||
|
yield Document(
|
||||||
|
channel="info",
|
||||||
|
content="Retrieved segments without matching evidence:\n",
|
||||||
|
)
|
||||||
|
for id in list(not_detected):
|
||||||
|
text_out = markdown.markdown(
|
||||||
|
id2docs[id].text, extensions=["markdown.extensions.tables"]
|
||||||
|
)
|
||||||
|
yield Document(
|
||||||
|
content=(
|
||||||
|
"<details>"
|
||||||
|
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||||
|
f"{text_out}"
|
||||||
|
"</details><br>"
|
||||||
|
),
|
||||||
|
channel="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, settings, states, retrievers):
|
def get_pipeline(cls, settings, states, retrievers):
|
||||||
"""Get the reasoning pipeline
|
"""Get the reasoning pipeline
|
||||||
|
@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning):
|
||||||
_id = cls.get_info()["id"]
|
_id = cls.get_info()["id"]
|
||||||
|
|
||||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||||
pipeline.answering_pipeline.llm = llms[
|
pipeline.answering_pipeline.llm = llms.get_default()
|
||||||
settings[f"reasoning.options.{_id}.main_llm"]
|
pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
|
||||||
]
|
|
||||||
pipeline.answering_pipeline.citation_pipeline.llm = llms[
|
|
||||||
settings[f"reasoning.options.{_id}.citation_llm"]
|
|
||||||
]
|
|
||||||
pipeline.answering_pipeline.enable_citation = settings[
|
pipeline.answering_pipeline.enable_citation = settings[
|
||||||
f"reasoning.options.{_id}.highlight_citation"
|
f"reasoning.options.{_id}.highlight_citation"
|
||||||
]
|
]
|
||||||
|
@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
f"reasoning.options.{_id}.qa_prompt"
|
f"reasoning.options.{_id}.qa_prompt"
|
||||||
]
|
]
|
||||||
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||||
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
|
pipeline.rewrite_pipeline.llm = llms.get_default()
|
||||||
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
settings["reasoning.lang"], "English"
|
settings["reasoning.lang"], "English"
|
||||||
)
|
)
|
||||||
|
@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
from ktem.components import llms
|
|
||||||
|
|
||||||
try:
|
|
||||||
citation_llm = llms.get_lowest_cost_name()
|
|
||||||
citation_llm_choices = list(llms.options().keys())
|
|
||||||
main_llm = llms.get_highest_accuracy_name()
|
|
||||||
main_llm_choices = list(llms.options().keys())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
citation_llm = None
|
|
||||||
citation_llm_choices = []
|
|
||||||
main_llm = None
|
|
||||||
main_llm_choices = []
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"highlight_citation": {
|
"highlight_citation": {
|
||||||
"name": "Highlight Citation",
|
"name": "Highlight Citation",
|
||||||
"value": False,
|
"value": False,
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
"citation_llm": {
|
|
||||||
"name": "LLM for citation",
|
|
||||||
"value": citation_llm,
|
|
||||||
"component": "dropdown",
|
|
||||||
"choices": citation_llm_choices,
|
|
||||||
},
|
|
||||||
"main_llm": {
|
|
||||||
"name": "LLM for main generation",
|
|
||||||
"value": main_llm,
|
|
||||||
"component": "dropdown",
|
|
||||||
"choices": main_llm_choices,
|
|
||||||
},
|
|
||||||
"system_prompt": {
|
"system_prompt": {
|
||||||
"name": "System Prompt",
|
"name": "System Prompt",
|
||||||
"value": "This is a question answering system",
|
"value": "This is a question answering system",
|
||||||
|
|
|
@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline
|
||||||
from openai.resources.embeddings import Embeddings
|
from openai.resources.embeddings import Embeddings
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import LCAzureChatOpenAI
|
||||||
|
|
||||||
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||||
openai_embedding = json.load(f)
|
openai_embedding = json.load(f)
|
||||||
|
@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
|
|
||||||
# create llm
|
# create llm
|
||||||
llm = AzureChatOpenAI(
|
llm = LCAzureChatOpenAI(
|
||||||
openai_api_base="https://test.openai.azure.com/",
|
openai_api_base="https://test.openai.azure.com/",
|
||||||
openai_api_key="some-key",
|
openai_api_key="some-key",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
|
|
@ -2,4 +2,4 @@ from ktem.main import App
|
||||||
|
|
||||||
app = App()
|
app = App()
|
||||||
demo = app.make()
|
demo = app.make()
|
||||||
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
|
demo.queue().launch(favicon_path=app._favicon)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, la
|
||||||
from kotaemon.contribs.promptui.logs import ResultLog
|
from kotaemon.contribs.promptui.logs import ResultLog
|
||||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import LCAzureChatOpenAI
|
||||||
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
|
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
]
|
]
|
||||||
|
|
||||||
retrieval_top_k: int = 1
|
retrieval_top_k: int = 1
|
||||||
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
llm: LCAzureChatOpenAI = LCAzureChatOpenAI.withx(
|
||||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user