[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)
* add example1/injury pipeline example * add dotenv * update various api
This commit is contained in:
parent
3cceec63ef
commit
d83c22aa4e
BIN
.env.secret
Normal file
BIN
.env.secret
Normal file
Binary file not shown.
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -453,7 +453,7 @@ $RECYCLE.BIN/
|
|||
logs/
|
||||
.gitsecret/keys/random_seed
|
||||
!*.secret
|
||||
credentials.txt
|
||||
.env
|
||||
|
||||
S.gpg-agent*
|
||||
.vscode/settings.json
|
||||
|
|
|
@ -1 +1 @@
|
|||
credentials.txt:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5
|
||||
.env:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5
|
||||
|
|
11
README.md
11
README.md
|
@ -47,7 +47,12 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
|
|||
|
||||
### Credential sharing
|
||||
|
||||
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which internally uses `gpg` to encrypt and decrypt secret files.
|
||||
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
|
||||
internally uses `gpg` to encrypt and decrypt secret files.
|
||||
|
||||
This repo uses `python-dotenv` to manage credentials stored as enviroment variable.
|
||||
Please note that the use of `python-dotenv` and credentials are for development
|
||||
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
|
||||
|
||||
#### Install git-secret
|
||||
|
||||
|
@ -63,13 +68,13 @@ In order to gain access to the secret files, you must provide your gpg public fi
|
|||
|
||||
#### Decrypt the secret file
|
||||
|
||||
The credentials are encrypted in the `credentials.txt.secret` file. To print the decrypted content to stdout, run
|
||||
The credentials are encrypted in the `.env.secret` file. To print the decrypted content to stdout, run
|
||||
|
||||
```shell
|
||||
git-secret cat [filename]
|
||||
```
|
||||
|
||||
Or to get the decrypted `credentials.txt` file, run
|
||||
Or to get the decrypted `.env` file, run
|
||||
|
||||
```shell
|
||||
git-secret reveal [filename]
|
||||
|
|
Binary file not shown.
|
@ -16,6 +16,19 @@ class BaseComponent(Compose):
|
|||
- is_batch: check if input is batch
|
||||
"""
|
||||
|
||||
inflow = None
|
||||
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
return self.__call__(self.inflow.flow())
|
||||
|
||||
@abstractmethod
|
||||
def run_raw(self, *args, **kwargs):
|
||||
...
|
||||
|
|
|
@ -1,25 +1,13 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from ..base import BaseComponent
|
||||
from pydantic import Field
|
||||
|
||||
from kotaemon.documents.base import Document
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInterface:
|
||||
text: List[str]
|
||||
class LLMInterface(Document):
|
||||
candidates: List[str]
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
logits: List[List[float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class PromptTemplate(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class Extract(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class PromptNode(BaseComponent):
|
||||
pass
|
||||
logits: List[List[float]] = Field(default_factory=list)
|
||||
|
|
|
@ -11,7 +11,17 @@ Message = TypeVar("Message", bound=BaseMessage)
|
|||
|
||||
|
||||
class ChatLLM(BaseComponent):
|
||||
...
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
text = self.inflow.flow().text
|
||||
return self.__call__(text)
|
||||
|
||||
|
||||
class LangchainChatLLM(ChatLLM):
|
||||
|
@ -44,8 +54,10 @@ class LangchainChatLLM(ChatLLM):
|
|||
|
||||
def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
|
||||
pred = self.agent.generate([text], **kwargs) # type: ignore
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
|
||||
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
|
||||
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],
|
||||
|
|
|
@ -33,8 +33,10 @@ class LangchainLLM(LLM):
|
|||
|
||||
def run_raw(self, text: str) -> LLMInterface:
|
||||
pred = self.agent.generate([text])
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
return LLMInterface(
|
||||
text=[each.text for each in pred.generations[0]],
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
|
||||
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
|
||||
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],
|
||||
|
|
|
@ -162,7 +162,7 @@ class ReactAgent(BaseAgent):
|
|||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
|
||||
response_text = response.text[0]
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
if action_step is None:
|
||||
|
|
|
@ -245,7 +245,7 @@ class RewooAgent(BaseAgent):
|
|||
|
||||
# Plan
|
||||
planner_output = planner(instruction)
|
||||
plannner_text_output = planner_output.text[0]
|
||||
plannner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
plannner_text_output
|
||||
|
@ -263,7 +263,7 @@ class RewooAgent(BaseAgent):
|
|||
|
||||
# Solve
|
||||
solver_output = solver(instruction, worker_log)
|
||||
solver_output_text = solver_output.text[0]
|
||||
solver_output_text = solver_output.text
|
||||
|
||||
return AgentOutput(
|
||||
output=solver_output_text, cost=total_cost, token_usage=total_token
|
||||
|
|
|
@ -50,9 +50,9 @@ class RegexExtractor(BaseComponent):
|
|||
if not output_map:
|
||||
return text
|
||||
|
||||
return output_map.get(text, text)
|
||||
return str(output_map.get(text, text))
|
||||
|
||||
def run_raw(self, text: str) -> List[str]:
|
||||
def run_raw(self, text: str) -> List[Document]:
|
||||
"""
|
||||
Runs the raw text through the static pattern and output mapping, returning a
|
||||
list of strings.
|
||||
|
@ -66,9 +66,12 @@ class RegexExtractor(BaseComponent):
|
|||
output = self.run_raw_static(self.pattern, text)
|
||||
output = [self.map_output(text, self.output_map) for text in output]
|
||||
|
||||
return output
|
||||
return [
|
||||
Document(text=text, metadata={"origin": "RegexExtractor"})
|
||||
for text in output
|
||||
]
|
||||
|
||||
def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]:
|
||||
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
|
||||
"""
|
||||
Runs a batch of raw text inputs through the `run_raw()` method and returns the
|
||||
output for each input.
|
||||
|
@ -95,13 +98,7 @@ class RegexExtractor(BaseComponent):
|
|||
Returns:
|
||||
List[Document]: A list of extracted documents.
|
||||
"""
|
||||
texts = self.run_raw(document.text)
|
||||
output = [
|
||||
Document(text=text, metadata={**document.metadata, "RegexExtractor": True})
|
||||
for text in texts
|
||||
]
|
||||
|
||||
return output
|
||||
return self.run_raw(document.text)
|
||||
|
||||
def run_batch_document(
|
||||
self, document_batch: List[Document]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
|
@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
|
|||
from kotaemon.prompt.template import PromptTemplate
|
||||
|
||||
|
||||
class BasePrompt(BaseComponent):
|
||||
class BasePromptComponent(BaseComponent):
|
||||
"""
|
||||
Base class for prompt components.
|
||||
|
||||
|
@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
|
|||
given template.
|
||||
"""
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def __check_redundant_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check for redundant keyword arguments.
|
||||
|
@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
|
|||
|
||||
redundant_keys = provided_keys - expected_keys
|
||||
if redundant_keys:
|
||||
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
|
||||
warnings.warn(
|
||||
f"Keys provided but not in template: {redundant_keys}", UserWarning
|
||||
)
|
||||
|
||||
def __check_unset_placeholders(self):
|
||||
"""
|
||||
|
@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
|
|||
Returns:
|
||||
dict: A dictionary of keyword arguments.
|
||||
"""
|
||||
|
||||
def __prepare(key, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (int, Document)):
|
||||
return str(value)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(value)} for template value of key {key}"
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
for k in self.template.placeholders:
|
||||
v = getattr(self, k)
|
||||
if isinstance(v, (int, Document)):
|
||||
v = str(v)
|
||||
elif isinstance(v, BaseComponent):
|
||||
v = str(v())
|
||||
if isinstance(v, BaseComponent):
|
||||
v = v()
|
||||
if isinstance(v, list):
|
||||
v = str([__prepare(k, each) for each in v])
|
||||
elif isinstance(v, (str, int, Document)):
|
||||
v = __prepare(k, v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(v)} for template value of key {k}"
|
||||
)
|
||||
kwargs[k] = v
|
||||
|
||||
return kwargs
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def set(self, **kwargs):
|
||||
"""
|
||||
Similar to `__set` but for external use.
|
||||
|
@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
|
|||
self.__check_unset_placeholders()
|
||||
prepared_kwargs = self.__prepare_value()
|
||||
|
||||
return self.template.populate(**prepared_kwargs)
|
||||
text = self.template.populate(**prepared_kwargs)
|
||||
return Document(text=text, metadata={"origin": "PromptComponent"})
|
||||
|
||||
def run_raw(self, *args, **kwargs):
|
||||
pass
|
||||
|
@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
|
|||
|
||||
def is_batch(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def flow(self):
|
||||
return self.__call__()
|
||||
|
|
2
setup.py
2
setup.py
|
@ -56,6 +56,8 @@ setuptools.setup(
|
|||
"chromadb",
|
||||
"wikipedia",
|
||||
"googlesearch-python",
|
||||
"python-dotenv",
|
||||
"pytest-mock",
|
||||
],
|
||||
},
|
||||
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
|
||||
|
|
|
@ -30,9 +30,11 @@ def test_is_batch(regex_extractor):
|
|||
|
||||
def test_run_raw(regex_extractor):
|
||||
output = regex_extractor("This is a test. 123")
|
||||
output = [each.text for each in output]
|
||||
assert output == ["123"]
|
||||
|
||||
|
||||
def test_run_batch_raw(regex_extractor):
|
||||
output = regex_extractor(["This is a test. 123", "456"])
|
||||
output = [[each.text for each in batch] for batch in output]
|
||||
assert output == [["123"], ["456"]]
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from kotaemon.documents.base import Document
|
||||
from kotaemon.post_processing.extractor import RegexExtractor
|
||||
from kotaemon.prompt.base import BasePrompt
|
||||
from kotaemon.prompt.base import BasePromptComponent
|
||||
from kotaemon.prompt.template import PromptTemplate
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ def test_set_attributes():
|
|||
)
|
||||
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
|
||||
|
||||
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp)
|
||||
prompt = BasePromptComponent(template=template, s="Alice", i=30, doc=doc, comp=comp)
|
||||
assert prompt.s == "Alice"
|
||||
assert prompt.i == 30
|
||||
assert prompt.doc == doc
|
||||
|
@ -23,23 +23,23 @@ def test_set_attributes():
|
|||
|
||||
def test_check_redundant_kwargs():
|
||||
template = PromptTemplate("Hello, {name}!")
|
||||
prompt = BasePrompt(template, name="Alice")
|
||||
with pytest.raises(ValueError):
|
||||
prompt._BasePrompt__check_redundant_kwargs(name="Alice", age=30)
|
||||
prompt = BasePromptComponent(template, name="Alice")
|
||||
with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"):
|
||||
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
|
||||
|
||||
|
||||
def test_check_unset_placeholders():
|
||||
template = PromptTemplate("Hello, {name}! I'm {age} years old.")
|
||||
prompt = BasePrompt(template, name="Alice")
|
||||
prompt = BasePromptComponent(template, name="Alice")
|
||||
with pytest.raises(ValueError):
|
||||
prompt._BasePrompt__check_unset_placeholders()
|
||||
prompt._BasePromptComponent__check_unset_placeholders()
|
||||
|
||||
|
||||
def test_validate_value_type():
|
||||
template = PromptTemplate("Hello, {name}!")
|
||||
prompt = BasePrompt(template)
|
||||
prompt = BasePromptComponent(template)
|
||||
with pytest.raises(ValueError):
|
||||
prompt._BasePrompt__validate_value_type(name={})
|
||||
prompt._BasePromptComponent__validate_value_type(name={})
|
||||
|
||||
|
||||
def test_run():
|
||||
|
@ -50,18 +50,18 @@ def test_run():
|
|||
)
|
||||
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
|
||||
|
||||
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp)
|
||||
prompt = BasePromptComponent(template=template, s="Alice", i=30, doc=doc, comp=comp)
|
||||
|
||||
result = prompt()
|
||||
|
||||
assert (
|
||||
result
|
||||
result.text
|
||||
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
|
||||
)
|
||||
|
||||
|
||||
def test_set_method():
|
||||
template = PromptTemplate("Hello, {name}!")
|
||||
prompt = BasePrompt(template)
|
||||
prompt = BasePromptComponent(template)
|
||||
prompt.set(name="Alice")
|
||||
assert prompt.name == "Alice"
|
||||
|
|
Loading…
Reference in New Issue
Block a user