[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)

* add example1/injury pipeline example
* add dotenv
* update various api
This commit is contained in:
ian_Cin
2023-10-02 16:24:56 +07:00
committed by GitHub
parent 3cceec63ef
commit d83c22aa4e
16 changed files with 114 additions and 69 deletions

View File

@@ -1,3 +1,4 @@
import warnings
from typing import Union
from kotaemon.base import BaseComponent
@@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
from kotaemon.prompt.template import PromptTemplate
class BasePrompt(BaseComponent):
class BasePromptComponent(BaseComponent):
"""
Base class for prompt components.
@@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
given template.
"""
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def __check_redundant_kwargs(self, **kwargs):
"""
Check for redundant keyword arguments.
@@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
redundant_keys = provided_keys - expected_keys
if redundant_keys:
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
warnings.warn(
f"Keys provided but not in template: {redundant_keys}", UserWarning
)
def __check_unset_placeholders(self):
"""
@@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
Returns:
dict: A dictionary of keyword arguments.
"""
def __prepare(key, value):
if isinstance(value, str):
return value
if isinstance(value, (int, Document)):
return str(value)
raise ValueError(
f"Unsupported type {type(value)} for template value of key {key}"
)
kwargs = {}
for k in self.template.placeholders:
v = getattr(self, k)
if isinstance(v, (int, Document)):
v = str(v)
elif isinstance(v, BaseComponent):
v = str(v())
if isinstance(v, BaseComponent):
v = v()
if isinstance(v, list):
v = str([__prepare(k, each) for each in v])
elif isinstance(v, (str, int, Document)):
v = __prepare(k, v)
else:
raise ValueError(
f"Unsupported type {type(v)} for template value of key {k}"
)
kwargs[k] = v
return kwargs
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def set(self, **kwargs):
"""
Similar to `__set` but for external use.
@@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
self.__check_unset_placeholders()
prepared_kwargs = self.__prepare_value()
return self.template.populate(**prepared_kwargs)
text = self.template.populate(**prepared_kwargs)
return Document(text=text, metadata={"origin": "PromptComponent"})
def run_raw(self, *args, **kwargs):
pass
@@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
def is_batch(self, *args, **kwargs):
pass
def flow(self):
return self.__call__()