[AUR-361] Setup pre-commit, pytest, GitHub actions, ssh-secret (#3)
Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from ..components import BaseComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInterface:
|
||||
text: list[str]
|
||||
text: List[str]
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
logits: list[list[float]] = field(default_factory=list)
|
||||
logits: List[List[float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class PromptTemplate(BaseComponent):
|
||||
|
@@ -1,17 +1,12 @@
|
||||
from typing import Type, TypeVar
|
||||
from typing import List, Type, TypeVar
|
||||
|
||||
from theflow.base import Param
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from langchain.schema.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from theflow.base import Param
|
||||
|
||||
from ...components import BaseComponent
|
||||
from ..base import LLMInterface
|
||||
|
||||
|
||||
Message = TypeVar("Message", bound=BaseMessage)
|
||||
|
||||
|
||||
@@ -43,11 +38,11 @@ class LangchainChatLLM(ChatLLM):
|
||||
message = HumanMessage(content=text)
|
||||
return self.run_document([message])
|
||||
|
||||
def run_batch_raw(self, text: list[str]) -> list[LLMInterface]:
|
||||
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
|
||||
inputs = [[HumanMessage(content=each)] for each in text]
|
||||
return self.run_batch_document(inputs)
|
||||
|
||||
def run_document(self, text: list[Message]) -> LLMInterface:
|
||||
def run_document(self, text: List[Message]) -> LLMInterface:
|
||||
pred = self.agent.generate([text])
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
@@ -57,7 +52,7 @@ class LangchainChatLLM(ChatLLM):
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def run_batch_document(self, text: list[list[Message]]) -> list[LLMInterface]:
|
||||
def run_batch_document(self, text: List[List[Message]]) -> List[LLMInterface]:
|
||||
outputs = []
|
||||
for each_text in text:
|
||||
outputs.append(self.run_document(each_text))
|
||||
@@ -66,14 +61,14 @@ class LangchainChatLLM(ChatLLM):
|
||||
def is_document(self, text) -> bool:
|
||||
if isinstance(text, str):
|
||||
return False
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
elif isinstance(text, List) and isinstance(text[0], str):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_batch(self, text) -> bool:
|
||||
if isinstance(text, str):
|
||||
return False
|
||||
elif isinstance(text, list):
|
||||
elif isinstance(text, List):
|
||||
if isinstance(text[0], BaseMessage):
|
||||
return False
|
||||
return True
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Type
|
||||
from typing import List, Type
|
||||
|
||||
from theflow.base import Param
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from theflow.base import Param
|
||||
|
||||
from ...components import BaseComponent
|
||||
from ..base import LLMInterface
|
||||
@@ -41,7 +41,7 @@ class LangchainLLM(LLM):
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def run_batch_raw(self, text: list[str]) -> list[LLMInterface]:
|
||||
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
|
||||
outputs = []
|
||||
for each_text in text:
|
||||
outputs.append(self.run_raw(each_text))
|
||||
@@ -50,7 +50,7 @@ class LangchainLLM(LLM):
|
||||
def run_document(self, text: str) -> LLMInterface:
|
||||
return self.run_raw(text)
|
||||
|
||||
def run_batch_document(self, text: list[str]) -> list[LLMInterface]:
|
||||
def run_batch_document(self, text: List[str]) -> List[LLMInterface]:
|
||||
return self.run_batch_raw(text)
|
||||
|
||||
def is_document(self, text) -> bool:
|
||||
|
@@ -5,9 +5,11 @@ from .base import LangchainLLM
|
||||
|
||||
class OpenAI(LangchainLLM):
|
||||
"""Wrapper around Langchain's OpenAI class"""
|
||||
|
||||
_lc_class = langchain_llms.OpenAI
|
||||
|
||||
|
||||
class AzureOpenAI(LangchainLLM):
|
||||
"""Wrapper around Langchain's AzureOpenAI class"""
|
||||
|
||||
_lc_class = langchain_llms.AzureOpenAI
|
||||
|
@@ -1,13 +1,10 @@
|
||||
class DocumentLoader:
|
||||
"""Document loader"""
|
||||
pass
|
||||
|
||||
|
||||
class TextManipulator:
|
||||
"""Text manipulation"""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentManipulator:
|
||||
"""Document manipulation"""
|
||||
pass
|
||||
|
Reference in New Issue
Block a user