Subclass chat messages from Document (#86)

This commit is contained in:
ian_Cin 2023-11-27 10:38:19 +07:00 committed by GitHub
parent 3ac277cc0b
commit 0dede9c82d
7 changed files with 36 additions and 10 deletions

View File

@ -31,5 +31,6 @@ class BaseComponent(Function):
@abstractmethod
def run(self, *args, **kwargs):
# enforce output type to be compatible with Document
"""Run the component."""
...

View File

@ -2,7 +2,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, TypeVar
from langchain.schema.messages import AIMessage
from langchain.schema.messages import AIMessage as LCAIMessage
from langchain.schema.messages import HumanMessage as LCHumanMessage
from langchain.schema.messages import SystemMessage as LCSystemMessage
from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument
@ -63,6 +65,23 @@ class Document(BaseDocument):
return str(self.content)
class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError
class SystemMessage(BaseMessage, LCSystemMessage):
pass
class AIMessage(BaseMessage, LCAIMessage):
pass
class HumanMessage(BaseMessage, LCHumanMessage):
pass
class RetrievedDocument(Document):
"""Subclass of Document with retrieval-related information
@ -77,7 +96,7 @@ class RetrievedDocument(Document):
retrieval_metadata: dict = Field(default={})
class LLMInterface(Document):
class LLMInterface(AIMessage):
candidates: list[str] = Field(default_factory=list)
completion_tokens: int = -1
total_tokens: int = -1

View File

@ -1,9 +1,10 @@
from abc import abstractmethod
from typing import List, Optional
from langchain.schema.messages import AIMessage, SystemMessage
from theflow import SessionFunction
from kotaemon.base.schema import AIMessage, SystemMessage
from ..base import BaseComponent
from ..base.schema import LLMInterface
from ..llms.chats.base import BaseMessage, HumanMessage

View File

@ -1,7 +1,7 @@
from langchain.schema.messages import AIMessage, SystemMessage
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, BaseMessage, ChatLLM, HumanMessage
from .chats import AzureChatOpenAI, ChatLLM
from .completions import LLM, AzureOpenAI, OpenAI
from .linear import GatedLinearPipeline, SimpleLinearPipeline
from .prompts import BasePromptComponent, PromptTemplate

View File

@ -4,9 +4,10 @@ import logging
from typing import Type
from langchain.chat_models.base import BaseChatModel
from langchain.schema.messages import BaseMessage, HumanMessage
from theflow.base import Param
from kotaemon.base.schema import BaseMessage, HumanMessage
from ...base import BaseComponent
from ...base.schema import LLMInterface

View File

@ -1,9 +1,9 @@
from typing import Iterator, List, Union
from langchain.schema.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from kotaemon.base import BaseComponent
from kotaemon.base.schema import HumanMessage, SystemMessage
from ..llms.chats.base import ChatLLM
from ..llms.completions.base import LLM

View File

@ -1,10 +1,14 @@
from unittest.mock import patch
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.base.schema import LLMInterface
from kotaemon.base.schema import (
AIMessage,
HumanMessage,
LLMInterface,
SystemMessage,
)
from kotaemon.llms.chats.openai import AzureChatOpenAI
_openai_chat_completion_response = ChatCompletion.parse_obj(