[AUR-361] Setup pre-commit, pytest, GitHub actions, ssh-secret (#3)

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-08-30 07:22:01 +07:00
committed by GitHub
parent c3c25db48c
commit 5241edbc46
19 changed files with 268 additions and 54 deletions

View File

@@ -1,15 +1,16 @@
from dataclasses import dataclass, field
from typing import List
from ..components import BaseComponent
@dataclass
class LLMInterface:
text: list[str]
text: List[str]
completion_tokens: int = -1
total_tokens: int = -1
prompt_tokens: int = -1
logits: list[list[float]] = field(default_factory=list)
logits: List[List[float]] = field(default_factory=list)
class PromptTemplate(BaseComponent):

View File

@@ -1,17 +1,12 @@
from typing import Type, TypeVar
from typing import List, Type, TypeVar
from theflow.base import Param
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import (
BaseMessage,
HumanMessage,
)
from langchain.schema.messages import BaseMessage, HumanMessage
from theflow.base import Param
from ...components import BaseComponent
from ..base import LLMInterface
Message = TypeVar("Message", bound=BaseMessage)
@@ -43,11 +38,11 @@ class LangchainChatLLM(ChatLLM):
message = HumanMessage(content=text)
return self.run_document([message])
def run_batch_raw(self, text: list[str]) -> list[LLMInterface]:
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
inputs = [[HumanMessage(content=each)] for each in text]
return self.run_batch_document(inputs)
def run_document(self, text: list[Message]) -> LLMInterface:
def run_document(self, text: List[Message]) -> LLMInterface:
pred = self.agent.generate([text])
return LLMInterface(
text=[each.text for each in pred.generations[0]],
@@ -57,7 +52,7 @@ class LangchainChatLLM(ChatLLM):
logits=[],
)
def run_batch_document(self, text: list[list[Message]]) -> list[LLMInterface]:
def run_batch_document(self, text: List[List[Message]]) -> List[LLMInterface]:
outputs = []
for each_text in text:
outputs.append(self.run_document(each_text))
@@ -66,14 +61,14 @@ class LangchainChatLLM(ChatLLM):
def is_document(self, text) -> bool:
if isinstance(text, str):
return False
elif isinstance(text, list) and isinstance(text[0], str):
elif isinstance(text, List) and isinstance(text[0], str):
return False
return True
def is_batch(self, text) -> bool:
if isinstance(text, str):
return False
elif isinstance(text, list):
elif isinstance(text, List):
if isinstance(text[0], BaseMessage):
return False
return True

View File

@@ -1,7 +1,7 @@
from typing import Type
from typing import List, Type
from theflow.base import Param
from langchain.schema.language_model import BaseLanguageModel
from theflow.base import Param
from ...components import BaseComponent
from ..base import LLMInterface
@@ -41,7 +41,7 @@ class LangchainLLM(LLM):
logits=[],
)
def run_batch_raw(self, text: list[str]) -> list[LLMInterface]:
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]:
outputs = []
for each_text in text:
outputs.append(self.run_raw(each_text))
@@ -50,7 +50,7 @@ class LangchainLLM(LLM):
def run_document(self, text: str) -> LLMInterface:
return self.run_raw(text)
def run_batch_document(self, text: list[str]) -> list[LLMInterface]:
def run_batch_document(self, text: List[str]) -> List[LLMInterface]:
return self.run_batch_raw(text)
def is_document(self, text) -> bool:

View File

@@ -5,9 +5,11 @@ from .base import LangchainLLM
class OpenAI(LangchainLLM):
"""Wrapper around Langchain's OpenAI class"""
_lc_class = langchain_llms.OpenAI
class AzureOpenAI(LangchainLLM):
"""Wrapper around Langchain's AzureOpenAI class"""
_lc_class = langchain_llms.AzureOpenAI

View File

@@ -1,13 +1,10 @@
class DocumentLoader:
"""Document loader"""
pass
class TextManipulator:
"""Text manipulation"""
pass
class DocumentManipulator:
"""Document manipulation"""
pass