Improve behavior of simple reasoning (#157)
* Add base reasoning implementation * Provide explicit async and streaming capability * Allow refreshing the information panel
This commit is contained in:
committed by
GitHub
parent
cb01d27d19
commit
2950e6ed02
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Iterator, Optional
|
||||
from typing import AsyncGenerator, Iterator, Optional
|
||||
|
||||
from theflow import Function, Node, Param, lazy
|
||||
|
||||
@@ -43,6 +43,18 @@ class BaseComponent(Function):
|
||||
if self._queue is not None:
|
||||
self._queue.put_nowait(output)
|
||||
|
||||
def invoke(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||
...
|
||||
|
||||
async def ainvoke(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||
...
|
||||
|
||||
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
|
||||
...
|
||||
|
||||
async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, *args, **kwargs
|
||||
|
@@ -65,6 +65,9 @@ class CitationPipeline(BaseComponent):
|
||||
llm: BaseLLM
|
||||
|
||||
def run(self, context: str, question: str):
|
||||
return self.invoke(context, question)
|
||||
|
||||
def prepare_llm(self, context: str, question: str):
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
@@ -92,8 +95,37 @@ class CitationPipeline(BaseComponent):
|
||||
)
|
||||
),
|
||||
]
|
||||
return messages, llm_kwargs
|
||||
|
||||
def invoke(self, context: str, question: str):
|
||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||
|
||||
try:
|
||||
print("CitationPipeline: invoking LLM")
|
||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish invoking LLM")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
return output
|
||||
|
||||
async def ainvoke(self, context: str, question: str):
|
||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||
|
||||
try:
|
||||
print("CitationPipeline: async invoking LLM")
|
||||
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish async invoking LLM")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
llm_output = self.llm(messages, **llm_kwargs)
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
|
@@ -1,8 +1,22 @@
|
||||
from typing import AsyncGenerator, Iterator
|
||||
|
||||
from langchain_core.language_models.base import BaseLanguageModel
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base import BaseComponent, LLMInterface
|
||||
|
||||
|
||||
class BaseLLM(BaseComponent):
|
||||
def to_langchain_format(self) -> BaseLanguageModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def invoke(self, *args, **kwargs) -> LLMInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
async def ainvoke(self, *args, **kwargs) -> LLMInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, Iterator
|
||||
|
||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
|
||||
|
||||
@@ -10,6 +11,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LCChatMixin:
|
||||
"""Mixin for langchain based chat models"""
|
||||
|
||||
def _get_lc_class(self):
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
@@ -30,18 +33,7 @@ class LCChatMixin:
|
||||
return self.stream(messages, **kwargs) # type: ignore
|
||||
return self.invoke(messages, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Generate response from messages
|
||||
|
||||
Args:
|
||||
messages: history of messages to generate response from
|
||||
**kwargs: additional arguments to pass to the langchain chat model
|
||||
|
||||
Returns:
|
||||
LLMInterface: generated response
|
||||
"""
|
||||
def prepare_message(self, messages: str | BaseMessage | list[BaseMessage]):
|
||||
input_: list[BaseMessage] = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
@@ -51,7 +43,9 @@ class LCChatMixin:
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||
return input_
|
||||
|
||||
def prepare_response(self, pred):
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
all_messages = [each.message for each in pred.generations[0]]
|
||||
|
||||
@@ -76,10 +70,41 @@ class LCChatMixin:
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Generate response from messages
|
||||
|
||||
Args:
|
||||
messages: history of messages to generate response from
|
||||
**kwargs: additional arguments to pass to the langchain chat model
|
||||
|
||||
Returns:
|
||||
LLMInterface: generated response
|
||||
"""
|
||||
input_ = self.prepare_message(messages)
|
||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||
return self.prepare_response(pred)
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
input_ = self.prepare_message(messages)
|
||||
pred = await self._obj.agenerate(messages=[input_], **kwargs)
|
||||
return self.prepare_response(pred)
|
||||
|
||||
def stream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> Iterator[LLMInterface]:
|
||||
for response in self._obj.stream(input=messages, **kwargs):
|
||||
yield LLMInterface(content=response.content)
|
||||
|
||||
async def astream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> AsyncGenerator[LLMInterface, None]:
|
||||
async for response in self._obj.astream(input=messages, **kwargs):
|
||||
yield LLMInterface(content=response.content)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
@@ -140,7 +165,7 @@ class LCChatMixin:
|
||||
raise ValueError(f"Invalid param {path}")
|
||||
|
||||
|
||||
class AzureChatOpenAI(LCChatMixin, ChatLLM):
|
||||
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: str | None = None,
|
||||
|
Reference in New Issue
Block a user