[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)
* add example1/injury pipeline example * add dotenv * update various api
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
@@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
|
||||
from kotaemon.prompt.template import PromptTemplate
|
||||
|
||||
|
||||
class BasePrompt(BaseComponent):
|
||||
class BasePromptComponent(BaseComponent):
|
||||
"""
|
||||
Base class for prompt components.
|
||||
|
||||
@@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
|
||||
given template.
|
||||
"""
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def __check_redundant_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check for redundant keyword arguments.
|
||||
@@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
|
||||
|
||||
redundant_keys = provided_keys - expected_keys
|
||||
if redundant_keys:
|
||||
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
|
||||
warnings.warn(
|
||||
f"Keys provided but not in template: {redundant_keys}", UserWarning
|
||||
)
|
||||
|
||||
def __check_unset_placeholders(self):
|
||||
"""
|
||||
@@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
|
||||
Returns:
|
||||
dict: A dictionary of keyword arguments.
|
||||
"""
|
||||
|
||||
def __prepare(key, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (int, Document)):
|
||||
return str(value)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(value)} for template value of key {key}"
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
for k in self.template.placeholders:
|
||||
v = getattr(self, k)
|
||||
if isinstance(v, (int, Document)):
|
||||
v = str(v)
|
||||
elif isinstance(v, BaseComponent):
|
||||
v = str(v())
|
||||
if isinstance(v, BaseComponent):
|
||||
v = v()
|
||||
if isinstance(v, list):
|
||||
v = str([__prepare(k, each) for each in v])
|
||||
elif isinstance(v, (str, int, Document)):
|
||||
v = __prepare(k, v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(v)} for template value of key {k}"
|
||||
)
|
||||
kwargs[k] = v
|
||||
|
||||
return kwargs
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def set(self, **kwargs):
|
||||
"""
|
||||
Similar to `__set` but for external use.
|
||||
@@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
|
||||
self.__check_unset_placeholders()
|
||||
prepared_kwargs = self.__prepare_value()
|
||||
|
||||
return self.template.populate(**prepared_kwargs)
|
||||
text = self.template.populate(**prepared_kwargs)
|
||||
return Document(text=text, metadata={"origin": "PromptComponent"})
|
||||
|
||||
def run_raw(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
|
||||
|
||||
def is_batch(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def flow(self):
|
||||
return self.__call__()
|
||||
|
Reference in New Issue
Block a user