Subclass chat messages from Document (#86)
This commit is contained in:
parent
3ac277cc0b
commit
0dede9c82d
|
@ -31,5 +31,6 @@ class BaseComponent(Function):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
|
# enforce output type to be compatible with Document
|
||||||
"""Run the component."""
|
"""Run the component."""
|
||||||
...
|
...
|
||||||
|
|
|
@ -2,7 +2,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||||
|
|
||||||
from langchain.schema.messages import AIMessage
|
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||||
|
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||||
|
from langchain.schema.messages import SystemMessage as LCSystemMessage
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.bridge.pydantic import Field
|
||||||
from llama_index.schema import Document as BaseDocument
|
from llama_index.schema import Document as BaseDocument
|
||||||
|
|
||||||
|
@ -63,11 +65,28 @@ class Document(BaseDocument):
|
||||||
return str(self.content)
|
return str(self.content)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage(Document):
|
||||||
|
def __add__(self, other: Any):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessage(BaseMessage, LCSystemMessage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessage(BaseMessage, LCAIMessage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessage(BaseMessage, LCHumanMessage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RetrievedDocument(Document):
|
class RetrievedDocument(Document):
|
||||||
"""Subclass of Document with retrieval-related information
|
"""Subclass of Document with retrieval-related information
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
score (float): score of the document (from 0.0 to 1.0)
|
score (float): score of the document (from 0.0 to 1.0)
|
||||||
retrieval_metadata (dict): metadata from the retrieval process, can be used
|
retrieval_metadata (dict): metadata from the retrieval process, can be used
|
||||||
by different components in a retrieved pipeline to communicate with each
|
by different components in a retrieved pipeline to communicate with each
|
||||||
other
|
other
|
||||||
|
@ -77,7 +96,7 @@ class RetrievedDocument(Document):
|
||||||
retrieval_metadata: dict = Field(default={})
|
retrieval_metadata: dict = Field(default={})
|
||||||
|
|
||||||
|
|
||||||
class LLMInterface(Document):
|
class LLMInterface(AIMessage):
|
||||||
candidates: list[str] = Field(default_factory=list)
|
candidates: list[str] = Field(default_factory=list)
|
||||||
completion_tokens: int = -1
|
completion_tokens: int = -1
|
||||||
total_tokens: int = -1
|
total_tokens: int = -1
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain.schema.messages import AIMessage, SystemMessage
|
|
||||||
from theflow import SessionFunction
|
from theflow import SessionFunction
|
||||||
|
|
||||||
|
from kotaemon.base.schema import AIMessage, SystemMessage
|
||||||
|
|
||||||
from ..base import BaseComponent
|
from ..base import BaseComponent
|
||||||
from ..base.schema import LLMInterface
|
from ..base.schema import LLMInterface
|
||||||
from ..llms.chats.base import BaseMessage, HumanMessage
|
from ..llms.chats.base import BaseMessage, HumanMessage
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from langchain.schema.messages import AIMessage, SystemMessage
|
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||||
from .chats import AzureChatOpenAI, BaseMessage, ChatLLM, HumanMessage
|
from .chats import AzureChatOpenAI, ChatLLM
|
||||||
from .completions import LLM, AzureOpenAI, OpenAI
|
from .completions import LLM, AzureOpenAI, OpenAI
|
||||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||||
from .prompts import BasePromptComponent, PromptTemplate
|
from .prompts import BasePromptComponent, PromptTemplate
|
||||||
|
|
|
@ -4,9 +4,10 @@ import logging
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
|
||||||
from theflow.base import Param
|
from theflow.base import Param
|
||||||
|
|
||||||
|
from kotaemon.base.schema import BaseMessage, HumanMessage
|
||||||
|
|
||||||
from ...base import BaseComponent
|
from ...base import BaseComponent
|
||||||
from ...base.schema import LLMInterface
|
from ...base.schema import LLMInterface
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from typing import Iterator, List, Union
|
from typing import Iterator, List, Union
|
||||||
|
|
||||||
from langchain.schema.messages import HumanMessage, SystemMessage
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||||
|
|
||||||
from ..llms.chats.base import ChatLLM
|
from ..llms.chats.base import ChatLLM
|
||||||
from ..llms.completions.base import LLM
|
from ..llms.completions.base import LLM
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
|
from langchain.chat_models import AzureChatOpenAI as AzureChatOpenAILC
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.base.schema import LLMInterface
|
from kotaemon.base.schema import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
LLMInterface,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
|
||||||
_openai_chat_completion_response = ChatCompletion.parse_obj(
|
_openai_chat_completion_response = ChatCompletion.parse_obj(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user