Subclass chat messages from Document (#86)
This commit is contained in:
parent
3ac277cc0b
commit
0dede9c82d
|
@ -31,5 +31,6 @@ class BaseComponent(Function):
|
|||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
# enforce output type to be compatible with Document
|
||||
"""Run the component."""
|
||||
...
|
||||
|
|
|
@ -2,7 +2,9 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||
from langchain.schema.messages import SystemMessage as LCSystemMessage
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.schema import Document as BaseDocument
|
||||
|
||||
|
@ -63,6 +65,23 @@ class Document(BaseDocument):
|
|||
return str(self.content)
|
||||
|
||||
|
||||
class BaseMessage(Document):
|
||||
def __add__(self, other: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage, LCSystemMessage):
|
||||
pass
|
||||
|
||||
|
||||
class AIMessage(BaseMessage, LCAIMessage):
|
||||
pass
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage, LCHumanMessage):
|
||||
pass
|
||||
|
||||
|
||||
class RetrievedDocument(Document):
|
||||
"""Subclass of Document with retrieval-related information
|
||||
|
||||
|
@ -77,7 +96,7 @@ class RetrievedDocument(Document):
|
|||
retrieval_metadata: dict = Field(default={})
|
||||
|
||||
|
||||
class LLMInterface(Document):
|
||||
class LLMInterface(AIMessage):
|
||||
candidates: list[str] = Field(default_factory=list)
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema.messages import AIMessage, SystemMessage
|
||||
from theflow import SessionFunction
|
||||
|
||||
from kotaemon.base.schema import AIMessage, SystemMessage
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import LLMInterface
|
||||
from ..llms.chats.base import BaseMessage, HumanMessage
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from langchain.schema.messages import AIMessage, SystemMessage
|
||||
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, BaseMessage, ChatLLM, HumanMessage
|
||||
from .chats import AzureChatOpenAI, ChatLLM
|
||||
from .completions import LLM, AzureOpenAI, OpenAI
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
from .prompts import BasePromptComponent, PromptTemplate
|
||||
|
|
|
@ -4,9 +4,10 @@ import logging
|
|||
from typing import Type
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from theflow.base import Param
|
||||
|
||||
from kotaemon.base.schema import BaseMessage, HumanMessage
|
||||
|
||||
from ...base import BaseComponent
|
||||
from ...base.schema import LLMInterface
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain.schema.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||
|
||||
from ..llms.chats.base import ChatLLM
|
||||
from ..llms.completions.base import LLM
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.base.schema import LLMInterface
|
||||
from kotaemon.base.schema import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
SystemMessage,
|
||||
)
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
|
||||
_openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||
|
|
Loading…
Reference in New Issue
Block a user