Feat/local endpoint llm (#148)

* serve local model in a different process from the app
---------

Co-authored-by: albert <albert@cinnamon.is>
Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2024-03-15 16:17:33 +07:00
committed by GitHub
parent 2950e6ed02
commit df12dec732
20 changed files with 675 additions and 79 deletions

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
from pydantic import Extra
from pydantic import ConfigDict
from kotaemon.base import LLMInterface
@@ -238,7 +238,7 @@ class AgentFinish(NamedTuple):
log: str
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
class AgentOutput(LLMInterface):
"""Output from an agent.
Args:
@@ -248,6 +248,8 @@ class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
error: The error message if any.
"""
model_config = ConfigDict(extra="allow")
text: str
type: str = "agent"
agent_type: AgentType

View File

@@ -1,4 +1,5 @@
from .base import BaseEmbeddings
from .endpoint_based import EndpointEmbeddings
from .langchain_based import (
AzureOpenAIEmbeddings,
CohereEmbdeddings,
@@ -8,6 +9,7 @@ from .langchain_based import (
__all__ = [
"BaseEmbeddings",
"EndpointEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CohereEmbdeddings",

View File

@@ -0,0 +1,46 @@
import requests
from kotaemon.base import Document, DocumentWithEmbedding
from .base import BaseEmbeddings
class EndpointEmbeddings(BaseEmbeddings):
"""
An Embeddings component that uses an OpenAI API compatible endpoint.
Attributes:
endpoint_url (str): The url of an OpenAI API compatible endpoint.
"""
endpoint_url: str
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[DocumentWithEmbedding]:
"""
Generate embeddings from text Args:
text (str | list[str] | Document | list[Document]): text to generate
embeddings from
Returns:
list[DocumentWithEmbedding]: embeddings
"""
if not isinstance(text, list):
text = [text]
outputs = []
for item in text:
response = requests.post(
self.endpoint_url, json={"input": str(item)}
).json()
outputs.append(
DocumentWithEmbedding(
text=str(item),
embedding=response["data"][0]["embedding"],
total_tokens=response["usage"]["total_tokens"],
prompt_tokens=response["usage"]["prompt_tokens"],
)
)
return outputs

View File

@@ -108,6 +108,9 @@ class CitationPipeline(BaseComponent):
print(e)
return None
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
@@ -126,6 +129,9 @@ class CitationPipeline(BaseComponent):
print(e)
return None
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]

View File

@@ -2,7 +2,7 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, ChatLLM, LlamaCppChat
from .chats import AzureChatOpenAI, ChatLLM, EndpointChatLLM, LlamaCppChat
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
from .cot import ManualSequentialChainOfThought, Thought
from .linear import GatedLinearPipeline, SimpleLinearPipeline
@@ -12,6 +12,7 @@ __all__ = [
"BaseLLM",
# chat-specific components
"ChatLLM",
"EndpointChatLLM",
"BaseMessage",
"HumanMessage",
"AIMessage",

View File

@@ -1,5 +1,12 @@
from .base import ChatLLM
from .endpoint_based import EndpointChatLLM
from .langchain_based import AzureChatOpenAI, LCChatMixin
from .llamacpp import LlamaCppChat
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin", "LlamaCppChat"]
__all__ = [
"ChatLLM",
"EndpointChatLLM",
"AzureChatOpenAI",
"LCChatMixin",
"LlamaCppChat",
]

View File

@@ -0,0 +1,85 @@
import requests
from kotaemon.base import (
AIMessage,
BaseMessage,
HumanMessage,
LLMInterface,
SystemMessage,
)
from .base import ChatLLM
class EndpointChatLLM(ChatLLM):
"""
A ChatLLM that uses an endpoint to generate responses. This expects an OpenAI API
compatible endpoint.
Attributes:
endpoint_url (str): The url of a OpenAI API compatible endpoint.
"""
endpoint_url: str
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""
Generate response from messages
Args:
messages (str | BaseMessage | list[BaseMessage]): history of messages to
generate response from
**kwargs: additional arguments to pass to the OpenAI API
Returns:
LLMInterface: generated response
"""
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
def decide_role(message: BaseMessage):
if isinstance(message, SystemMessage):
return "system"
elif isinstance(message, AIMessage):
return "assistant"
else:
return "user"
request_json = {
"messages": [{"content": m.text, "role": decide_role(m)} for m in input_]
}
response = requests.post(self.endpoint_url, json=request_json).json()
content = ""
candidates = []
if response["choices"]:
candidates = [
each["message"]["content"]
for each in response["choices"]
if each["message"]["content"]
]
content = candidates[0]
return LLMInterface(
content=content,
candidates=candidates,
completion_tokens=response["usage"]["completion_tokens"],
total_tokens=response["usage"]["total_tokens"],
prompt_tokens=response["usage"]["prompt_tokens"],
)
def invoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Same as run"""
return self.run(messages, **kwargs)
async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
return self.invoke(messages, **kwargs)