Subclass chat messages from Document (#86)

This commit is contained in:
ian_Cin 2023-11-27 10:38:19 +07:00 committed by GitHub
parent 3ac277cc0b
commit 0dede9c82d
7 changed files with 36 additions and 10 deletions

View File

@ -31,5 +31,6 @@ class BaseComponent(Function):
@abstractmethod @abstractmethod
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
# enforce output type to be compatible with Document
"""Run the component.""" """Run the component."""
... ...

View File

@ -2,7 +2,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, TypeVar from typing import TYPE_CHECKING, Any, Optional, TypeVar
from langchain.schema.messages import AIMessage from langchain.schema.messages import AIMessage as LCAIMessage
from langchain.schema.messages import HumanMessage as LCHumanMessage
from langchain.schema.messages import SystemMessage as LCSystemMessage
from llama_index.bridge.pydantic import Field from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument from llama_index.schema import Document as BaseDocument
@ -63,11 +65,28 @@ class Document(BaseDocument):
return str(self.content) return str(self.content)
class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError
class SystemMessage(BaseMessage, LCSystemMessage):
pass
class AIMessage(BaseMessage, LCAIMessage):
pass
class HumanMessage(BaseMessage, LCHumanMessage):
pass
class RetrievedDocument(Document): class RetrievedDocument(Document):
"""Subclass of Document with retrieval-related information """Subclass of Document with retrieval-related information
Attributes: Attributes:
score (float): score of the document (from 0.0 to 1.0) score (float): score of the document (from 0.0 to 1.0)
retrieval_metadata (dict): metadata from the retrieval process, can be used retrieval_metadata (dict): metadata from the retrieval process, can be used
by different components in a retrieved pipeline to communicate with each by different components in a retrieved pipeline to communicate with each
other other
@ -77,7 +96,7 @@ class RetrievedDocument(Document):
retrieval_metadata: dict = Field(default={}) retrieval_metadata: dict = Field(default={})
class LLMInterface(Document): class LLMInterface(AIMessage):
candidates: list[str] = Field(default_factory=list) candidates: list[str] = Field(default_factory=list)
completion_tokens: int = -1 completion_tokens: int = -1
total_tokens: int = -1 total_tokens: int = -1

View File

@ -1,9 +1,10 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional from typing import List, Optional
from langchain.schema.messages import AIMessage, SystemMessage
from theflow import SessionFunction from theflow import SessionFunction
from kotaemon.base.schema import AIMessage, SystemMessage
from ..base import BaseComponent from ..base import BaseComponent
from ..base.schema import LLMInterface from ..base.schema import LLMInterface
from ..llms.chats.base import BaseMessage, HumanMessage from ..llms.chats.base import BaseMessage, HumanMessage

View File

@ -1,7 +1,7 @@
from langchain.schema.messages import AIMessage, SystemMessage from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, BaseMessage, ChatLLM, HumanMessage from .chats import AzureChatOpenAI, ChatLLM
from .completions import LLM, AzureOpenAI, OpenAI from .completions import LLM, AzureOpenAI, OpenAI
from .linear import GatedLinearPipeline, SimpleLinearPipeline from .linear import GatedLinearPipeline, SimpleLinearPipeline
from .prompts import BasePromptComponent, PromptTemplate from .prompts import BasePromptComponent, PromptTemplate

View File

@ -4,9 +4,10 @@ import logging
from typing import Type from typing import Type
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema.messages import BaseMessage, HumanMessage
from theflow.base import Param from theflow.base import Param
from kotaemon.base.schema import BaseMessage, HumanMessage
from ...base import BaseComponent from ...base import BaseComponent
from ...base.schema import LLMInterface from ...base.schema import LLMInterface

View File

@ -1,9 +1,9 @@
from typing import Iterator, List, Union from typing import Iterator, List, Union
from langchain.schema.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.base.schema import HumanMessage, SystemMessage
from ..llms.chats.base import ChatLLM from ..llms.chats.base import ChatLLM
from ..llms.completions.base import LLM from ..llms.completions.base import LLM

View File

@ -1,10 +1,14 @@
from unittest.mock import patch from unittest.mock import patch
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.base.schema import LLMInterface from kotaemon.base.schema import (
AIMessage,
HumanMessage,
LLMInterface,
SystemMessage,
)
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms.chats.openai import AzureChatOpenAI
_openai_chat_completion_response = ChatCompletion.parse_obj( _openai_chat_completion_response = ChatCompletion.parse_obj(