Update the Citation pipeline according to new OpenAI function call interface (#40)

This commit is contained in:
Duc Nguyen (john) 2024-04-20 01:12:23 +07:00 committed by GitHub
parent 1b2082a140
commit c6045bcb9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 38 deletions

View File

@ -75,8 +75,8 @@ class CitationPipeline(BaseComponent):
"parameters": schema, "parameters": schema,
} }
llm_kwargs = { llm_kwargs = {
"functions": [function], "tools": [{"type": "function", "function": function}],
"function_call": {"name": function["name"]}, "tool_choice": "auto",
} }
messages = [ messages = [
SystemMessage( SystemMessage(
@ -99,14 +99,13 @@ class CitationPipeline(BaseComponent):
def invoke(self, context: str, question: str): def invoke(self, context: str, question: str):
messages, llm_kwargs = self.prepare_llm(context, question) messages, llm_kwargs = self.prepare_llm(context, question)
try: try:
print("CitationPipeline: invoking LLM") print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM") print("CitationPipeline: finish invoking LLM")
if not llm_output.messages: if not llm_output.messages:
return None return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][ function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments" "arguments"
] ]
output = QuestionAnswer.parse_raw(function_output) output = QuestionAnswer.parse_raw(function_output)
@ -123,16 +122,12 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: async invoking LLM") print("CitationPipeline: async invoking LLM")
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs) llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
print("CitationPipeline: finish async invoking LLM") print("CitationPipeline: finish async invoking LLM")
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
return output return output

View File

@ -152,6 +152,28 @@ class BaseChatOpenAI(ChatLLM):
return output_ return output_
def prepare_output(self, resp: dict) -> LLMInterface:
"""Convert the OpenAI response into LLMInterface"""
additional_kwargs = {}
if "tool_calls" in resp["choices"][0]["message"]:
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
"tool_calls"
]
output = LLMInterface(
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"] or "",
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
additional_kwargs=additional_kwargs,
messages=[
AIMessage(content=(_["message"]["content"]) or "")
for _ in resp["choices"]
],
)
return output
def prepare_client(self, async_version: bool = False): def prepare_client(self, async_version: bool = False):
"""Get the OpenAI client """Get the OpenAI client
@ -172,19 +194,7 @@ class BaseChatOpenAI(ChatLLM):
resp = self.openai_response( resp = self.openai_response(
client, messages=input_messages, stream=False, **kwargs client, messages=input_messages, stream=False, **kwargs
).dict() ).dict()
return self.prepare_output(resp)
output = LLMInterface(
candidates=[_["message"]["content"] for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"],
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
],
)
return output
async def ainvoke( async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
@ -195,18 +205,7 @@ class BaseChatOpenAI(ChatLLM):
client, messages=input_messages, stream=False, **kwargs client, messages=input_messages, stream=False, **kwargs
).dict() ).dict()
output = LLMInterface( return self.prepare_output(resp)
candidates=[_["message"]["content"] for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"],
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
],
)
return output
def stream( def stream(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
@ -338,7 +337,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
def openai_response(self, client, **kwargs): def openai_response(self, client, **kwargs):
"""Get the openai response""" """Get the openai response"""
params = { params_ = {
"model": self.azure_deployment, "model": self.azure_deployment,
"temperature": self.temperature, "temperature": self.temperature,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
@ -353,6 +352,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
"top_logprobs": self.top_logprobs, "top_logprobs": self.top_logprobs,
"top_p": self.top_p, "top_p": self.top_p,
} }
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs) params.update(kwargs)
return client.chat.completions.create(**params) return client.chat.completions.create(**params)