Feat/local endpoint llm (#148)
* serve local model in a different process from the app --------- Co-authored-by: albert <albert@cinnamon.is> Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from kotaemon.base import LLMInterface
|
||||
|
||||
@@ -238,7 +238,7 @@ class AgentFinish(NamedTuple):
|
||||
log: str
|
||||
|
||||
|
||||
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
class AgentOutput(LLMInterface):
|
||||
"""Output from an agent.
|
||||
|
||||
Args:
|
||||
@@ -248,6 +248,8 @@ class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
error: The error message if any.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
text: str
|
||||
type: str = "agent"
|
||||
agent_type: AgentType
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from .base import BaseEmbeddings
|
||||
from .endpoint_based import EndpointEmbeddings
|
||||
from .langchain_based import (
|
||||
AzureOpenAIEmbeddings,
|
||||
CohereEmbdeddings,
|
||||
@@ -8,6 +9,7 @@ from .langchain_based import (
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
"EndpointEmbeddings",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"CohereEmbdeddings",
|
||||
|
46
libs/kotaemon/kotaemon/embeddings/endpoint_based.py
Normal file
46
libs/kotaemon/kotaemon/embeddings/endpoint_based.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import requests
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
|
||||
class EndpointEmbeddings(BaseEmbeddings):
|
||||
"""
|
||||
An Embeddings component that uses an OpenAI API compatible endpoint.
|
||||
|
||||
Attributes:
|
||||
endpoint_url (str): The url of an OpenAI API compatible endpoint.
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
|
||||
def run(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
"""
|
||||
Generate embeddings from text Args:
|
||||
text (str | list[str] | Document | list[Document]): text to generate
|
||||
embeddings from
|
||||
Returns:
|
||||
list[DocumentWithEmbedding]: embeddings
|
||||
"""
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
outputs = []
|
||||
|
||||
for item in text:
|
||||
response = requests.post(
|
||||
self.endpoint_url, json={"input": str(item)}
|
||||
).json()
|
||||
outputs.append(
|
||||
DocumentWithEmbedding(
|
||||
text=str(item),
|
||||
embedding=response["data"][0]["embedding"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
)
|
||||
)
|
||||
|
||||
return outputs
|
@@ -108,6 +108,9 @@ class CitationPipeline(BaseComponent):
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
@@ -126,6 +129,9 @@ class CitationPipeline(BaseComponent):
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
|
@@ -2,7 +2,7 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
|
||||
|
||||
from .base import BaseLLM
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, ChatLLM, LlamaCppChat
|
||||
from .chats import AzureChatOpenAI, ChatLLM, EndpointChatLLM, LlamaCppChat
|
||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||
from .cot import ManualSequentialChainOfThought, Thought
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
@@ -12,6 +12,7 @@ __all__ = [
|
||||
"BaseLLM",
|
||||
# chat-specific components
|
||||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"BaseMessage",
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
|
@@ -1,5 +1,12 @@
|
||||
from .base import ChatLLM
|
||||
from .endpoint_based import EndpointChatLLM
|
||||
from .langchain_based import AzureChatOpenAI, LCChatMixin
|
||||
from .llamacpp import LlamaCppChat
|
||||
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin", "LlamaCppChat"]
|
||||
__all__ = [
|
||||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"AzureChatOpenAI",
|
||||
"LCChatMixin",
|
||||
"LlamaCppChat",
|
||||
]
|
||||
|
85
libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
Normal file
85
libs/kotaemon/kotaemon/llms/chats/endpoint_based.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import requests
|
||||
|
||||
from kotaemon.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
from .base import ChatLLM
|
||||
|
||||
|
||||
class EndpointChatLLM(ChatLLM):
|
||||
"""
|
||||
A ChatLLM that uses an endpoint to generate responses. This expects an OpenAI API
|
||||
compatible endpoint.
|
||||
|
||||
Attributes:
|
||||
endpoint_url (str): The url of a OpenAI API compatible endpoint.
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""
|
||||
Generate response from messages
|
||||
Args:
|
||||
messages (str | BaseMessage | list[BaseMessage]): history of messages to
|
||||
generate response from
|
||||
**kwargs: additional arguments to pass to the OpenAI API
|
||||
Returns:
|
||||
LLMInterface: generated response
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
input_ = [HumanMessage(content=messages)]
|
||||
elif isinstance(messages, BaseMessage):
|
||||
input_ = [messages]
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
def decide_role(message: BaseMessage):
|
||||
if isinstance(message, SystemMessage):
|
||||
return "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "assistant"
|
||||
else:
|
||||
return "user"
|
||||
|
||||
request_json = {
|
||||
"messages": [{"content": m.text, "role": decide_role(m)} for m in input_]
|
||||
}
|
||||
|
||||
response = requests.post(self.endpoint_url, json=request_json).json()
|
||||
|
||||
content = ""
|
||||
candidates = []
|
||||
if response["choices"]:
|
||||
candidates = [
|
||||
each["message"]["content"]
|
||||
for each in response["choices"]
|
||||
if each["message"]["content"]
|
||||
]
|
||||
content = candidates[0]
|
||||
|
||||
return LLMInterface(
|
||||
content=content,
|
||||
candidates=candidates,
|
||||
completion_tokens=response["usage"]["completion_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Same as run"""
|
||||
return self.run(messages, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
return self.invoke(messages, **kwargs)
|
Reference in New Issue
Block a user