Update the Citation pipeline according to new OpenAI function call interface (#40)
This commit is contained in:
parent
1b2082a140
commit
c6045bcb9f
|
@ -75,8 +75,8 @@ class CitationPipeline(BaseComponent):
|
|||
"parameters": schema,
|
||||
}
|
||||
llm_kwargs = {
|
||||
"functions": [function],
|
||||
"function_call": {"name": function["name"]},
|
||||
"tools": [{"type": "function", "function": function}],
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
|
@ -99,14 +99,13 @@ class CitationPipeline(BaseComponent):
|
|||
|
||||
def invoke(self, context: str, question: str):
|
||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||
|
||||
try:
|
||||
print("CitationPipeline: invoking LLM")
|
||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish invoking LLM")
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
@ -123,16 +122,12 @@ class CitationPipeline(BaseComponent):
|
|||
print("CitationPipeline: async invoking LLM")
|
||||
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish async invoking LLM")
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
return output
|
||||
|
|
|
@ -152,6 +152,28 @@ class BaseChatOpenAI(ChatLLM):
|
|||
|
||||
return output_
|
||||
|
||||
def prepare_output(self, resp: dict) -> LLMInterface:
|
||||
"""Convert the OpenAI response into LLMInterface"""
|
||||
additional_kwargs = {}
|
||||
if "tool_calls" in resp["choices"][0]["message"]:
|
||||
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
||||
"tool_calls"
|
||||
]
|
||||
output = LLMInterface(
|
||||
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"] or "",
|
||||
total_tokens=resp["usage"]["total_tokens"],
|
||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||
completion_tokens=resp["usage"]["completion_tokens"],
|
||||
additional_kwargs=additional_kwargs,
|
||||
messages=[
|
||||
AIMessage(content=(_["message"]["content"]) or "")
|
||||
for _ in resp["choices"]
|
||||
],
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def prepare_client(self, async_version: bool = False):
|
||||
"""Get the OpenAI client
|
||||
|
||||
|
@ -172,19 +194,7 @@ class BaseChatOpenAI(ChatLLM):
|
|||
resp = self.openai_response(
|
||||
client, messages=input_messages, stream=False, **kwargs
|
||||
).dict()
|
||||
|
||||
output = LLMInterface(
|
||||
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"],
|
||||
total_tokens=resp["usage"]["total_tokens"],
|
||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||
completion_tokens=resp["usage"]["completion_tokens"],
|
||||
messages=[
|
||||
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
||||
],
|
||||
)
|
||||
|
||||
return output
|
||||
return self.prepare_output(resp)
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
|
@ -195,18 +205,7 @@ class BaseChatOpenAI(ChatLLM):
|
|||
client, messages=input_messages, stream=False, **kwargs
|
||||
).dict()
|
||||
|
||||
output = LLMInterface(
|
||||
candidates=[_["message"]["content"] for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"],
|
||||
total_tokens=resp["usage"]["total_tokens"],
|
||||
prompt_tokens=resp["usage"]["prompt_tokens"],
|
||||
completion_tokens=resp["usage"]["completion_tokens"],
|
||||
messages=[
|
||||
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
|
||||
],
|
||||
)
|
||||
|
||||
return output
|
||||
return self.prepare_output(resp)
|
||||
|
||||
def stream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
|
@ -338,7 +337,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params = {
|
||||
params_ = {
|
||||
"model": self.azure_deployment,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
|
@ -353,6 +352,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||
"top_logprobs": self.top_logprobs,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
params = {k: v for k, v in params_.items() if v is not None}
|
||||
params.update(kwargs)
|
||||
|
||||
return client.chat.completions.create(**params)
|
||||
|
|
Loading…
Reference in New Issue
Block a user