feat: add structured output to openai (#603) #none
* add structured output to openai * remove notebook, modify prepare output method * fix: comfort precommit --------- Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
6f4acc979c
commit
9b05693e4f
|
@ -8,6 +8,7 @@ from .schema import (
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
RetrievedDocument,
|
RetrievedDocument,
|
||||||
|
StructuredOutputLLMInterface,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +22,7 @@ __all__ = [
|
||||||
"HumanMessage",
|
"HumanMessage",
|
||||||
"RetrievedDocument",
|
"RetrievedDocument",
|
||||||
"LLMInterface",
|
"LLMInterface",
|
||||||
|
"StructuredOutputLLMInterface",
|
||||||
"ExtractorOutput",
|
"ExtractorOutput",
|
||||||
"Param",
|
"Param",
|
||||||
"Node",
|
"Node",
|
||||||
|
|
|
@ -143,6 +143,11 @@ class LLMInterface(AIMessage):
|
||||||
logprobs: list[float] = []
|
logprobs: list[float] = []
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputLLMInterface(LLMInterface):
|
||||||
|
parsed: Any
|
||||||
|
refusal: str = ""
|
||||||
|
|
||||||
|
|
||||||
class ExtractorOutput(Document):
|
class ExtractorOutput(Document):
|
||||||
"""
|
"""
|
||||||
Represents the output of an extractor.
|
Represents the output of an extractor.
|
||||||
|
|
|
@ -14,6 +14,7 @@ from .chats import (
|
||||||
LCGeminiChat,
|
LCGeminiChat,
|
||||||
LCOllamaChat,
|
LCOllamaChat,
|
||||||
LlamaCppChat,
|
LlamaCppChat,
|
||||||
|
StructuredOutputChatOpenAI,
|
||||||
)
|
)
|
||||||
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
|
||||||
from .cot import ManualSequentialChainOfThought, Thought
|
from .cot import ManualSequentialChainOfThought, Thought
|
||||||
|
@ -31,6 +32,7 @@ __all__ = [
|
||||||
"SystemMessage",
|
"SystemMessage",
|
||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
"StructuredOutputChatOpenAI",
|
||||||
"LCAnthropicChat",
|
"LCAnthropicChat",
|
||||||
"LCGeminiChat",
|
"LCGeminiChat",
|
||||||
"LCCohereChat",
|
"LCCohereChat",
|
||||||
|
|
|
@ -10,7 +10,7 @@ from .langchain_based import (
|
||||||
LCOllamaChat,
|
LCOllamaChat,
|
||||||
)
|
)
|
||||||
from .llamacpp import LlamaCppChat
|
from .llamacpp import LlamaCppChat
|
||||||
from .openai import AzureChatOpenAI, ChatOpenAI
|
from .openai import AzureChatOpenAI, ChatOpenAI, StructuredOutputChatOpenAI
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
@ -18,6 +18,7 @@ __all__ = [
|
||||||
"ChatLLM",
|
"ChatLLM",
|
||||||
"EndpointChatLLM",
|
"EndpointChatLLM",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
"StructuredOutputChatOpenAI",
|
||||||
"LCAnthropicChat",
|
"LCAnthropicChat",
|
||||||
"LCGeminiChat",
|
"LCGeminiChat",
|
||||||
"LCCohereChat",
|
"LCCohereChat",
|
||||||
|
|
|
@ -1,8 +1,16 @@
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
|
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from theflow.utils.modules import import_dotted_string
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
|
from kotaemon.base import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
LLMInterface,
|
||||||
|
Param,
|
||||||
|
StructuredOutputLLMInterface,
|
||||||
|
)
|
||||||
|
|
||||||
from .base import ChatLLM
|
from .base import ChatLLM
|
||||||
|
|
||||||
|
@ -330,6 +338,88 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||||
return await client.chat.completions.create(**params)
|
return await client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputChatOpenAI(ChatOpenAI):
|
||||||
|
"""OpenAI chat model that returns structured output"""
|
||||||
|
|
||||||
|
response_schema: Type[BaseModel] = Param(
|
||||||
|
help="class that subclasses pydantics BaseModel", required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
|
||||||
|
"""Convert the OpenAI response into StructuredOutputLLMInterface"""
|
||||||
|
additional_kwargs = {}
|
||||||
|
|
||||||
|
if "tool_calls" in resp["choices"][0]["message"]:
|
||||||
|
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
||||||
|
"tool_calls"
|
||||||
|
]
|
||||||
|
|
||||||
|
if resp["choices"][0].get("logprobs") is None:
|
||||||
|
logprobs = []
|
||||||
|
else:
|
||||||
|
all_logprobs = resp["choices"][0]["logprobs"].get("content")
|
||||||
|
logprobs = (
|
||||||
|
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
|
||||||
|
)
|
||||||
|
|
||||||
|
output = StructuredOutputLLMInterface(
|
||||||
|
parsed=resp["choices"][0]["message"]["parsed"],
|
||||||
|
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
||||||
|
content=resp["choices"][0]["message"]["content"] or "",
|
||||||
|
total_tokens=resp["usage"]["total_tokens"],
|
||||||
|
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=resp["usage"]["completion_tokens"],
|
||||||
|
messages=[
|
||||||
|
AIMessage(content=(_["message"]["content"]) or "")
|
||||||
|
for _ in resp["choices"]
|
||||||
|
],
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def prepare_params(self, **kwargs):
|
||||||
|
if "tools_pydantic" in kwargs:
|
||||||
|
kwargs.pop("tools_pydantic")
|
||||||
|
|
||||||
|
params_ = {
|
||||||
|
"model": self.model,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"tool_choice": self.tool_choice,
|
||||||
|
"tools": self.tools,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"top_logprobs": self.top_logprobs,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"response_format": self.response_schema,
|
||||||
|
}
|
||||||
|
params = {k: v for k, v in params_.items() if v is not None}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
# doesn't do streaming
|
||||||
|
params.pop("stream")
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def openai_response(self, client, **kwargs):
|
||||||
|
"""Get the openai response"""
|
||||||
|
params = self.prepare_params(**kwargs)
|
||||||
|
|
||||||
|
return client.beta.chat.completions.parse(**params)
|
||||||
|
|
||||||
|
async def aopenai_response(self, client, **kwargs):
|
||||||
|
"""Get the openai response"""
|
||||||
|
params = self.prepare_params(**kwargs)
|
||||||
|
|
||||||
|
return await client.beta.chat.completions.parse(**params)
|
||||||
|
|
||||||
|
|
||||||
class AzureChatOpenAI(BaseChatOpenAI):
|
class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
"""OpenAI chat model provided by Microsoft Azure"""
|
"""OpenAI chat model provided by Microsoft Azure"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user