[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)

* add example1/injury pipeline example
* add dotenv
* update various api
This commit is contained in:
ian_Cin 2023-10-02 16:24:56 +07:00 committed by GitHub
parent 3cceec63ef
commit d83c22aa4e
16 changed files with 114 additions and 69 deletions

BIN
.env.secret Normal file

Binary file not shown.

2
.gitignore vendored
View File

@ -453,7 +453,7 @@ $RECYCLE.BIN/
logs/
.gitsecret/keys/random_seed
!*.secret
credentials.txt
.env
S.gpg-agent*
.vscode/settings.json

View File

@ -1 +1 @@
credentials.txt:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5
.env:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5

View File

@ -47,7 +47,12 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
### Credential sharing
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which internally uses `gpg` to encrypt and decrypt secret files.
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
internally uses `gpg` to encrypt and decrypt secret files.
This repo uses `python-dotenv` to manage credentials stored as enviroment variable.
Please note that the use of `python-dotenv` and credentials are for development
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
#### Install git-secret
@ -63,13 +68,13 @@ In order to gain access to the secret files, you must provide your gpg public fi
#### Decrypt the secret file
The credentials are encrypted in the `credentials.txt.secret` file. To print the decrypted content to stdout, run
The credentials are encrypted in the `.env.secret` file. To print the decrypted content to stdout, run
```shell
git-secret cat [filename]
```
Or to get the decrypted `credentials.txt` file, run
Or to get the decrypted `.env` file, run
```shell
git-secret reveal [filename]

Binary file not shown.

View File

@ -16,6 +16,19 @@ class BaseComponent(Compose):
- is_batch: check if input is batch
"""
inflow = None
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
return self.__call__(self.inflow.flow())
@abstractmethod
def run_raw(self, *args, **kwargs):
...

View File

@ -1,25 +1,13 @@
from dataclasses import dataclass, field
from typing import List
from ..base import BaseComponent
from pydantic import Field
from kotaemon.documents.base import Document
@dataclass
class LLMInterface:
text: List[str]
class LLMInterface(Document):
candidates: List[str]
completion_tokens: int = -1
total_tokens: int = -1
prompt_tokens: int = -1
logits: List[List[float]] = field(default_factory=list)
class PromptTemplate(BaseComponent):
pass
class Extract(BaseComponent):
pass
class PromptNode(BaseComponent):
pass
logits: List[List[float]] = Field(default_factory=list)

View File

@ -11,7 +11,17 @@ Message = TypeVar("Message", bound=BaseMessage)
class ChatLLM(BaseComponent):
...
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
text = self.inflow.flow().text
return self.__call__(text)
class LangchainChatLLM(ChatLLM):
@ -44,8 +54,10 @@ class LangchainChatLLM(ChatLLM):
def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
pred = self.agent.generate([text], **kwargs) # type: ignore
all_text = [each.text for each in pred.generations[0]]
return LLMInterface(
text=[each.text for each in pred.generations[0]],
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],

View File

@ -33,8 +33,10 @@ class LangchainLLM(LLM):
def run_raw(self, text: str) -> LLMInterface:
pred = self.agent.generate([text])
all_text = [each.text for each in pred.generations[0]]
return LLMInterface(
text=[each.text for each in pred.generations[0]],
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
total_tokens=pred.llm_output["token_usage"]["total_tokens"],
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],

View File

@ -162,7 +162,7 @@ class ReactAgent(BaseAgent):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
response_text = response.text[0]
response_text = response.text
logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text)
if action_step is None:

View File

@ -245,7 +245,7 @@ class RewooAgent(BaseAgent):
# Plan
planner_output = planner(instruction)
plannner_text_output = planner_output.text[0]
plannner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output
@ -263,7 +263,7 @@ class RewooAgent(BaseAgent):
# Solve
solver_output = solver(instruction, worker_log)
solver_output_text = solver_output.text[0]
solver_output_text = solver_output.text
return AgentOutput(
output=solver_output_text, cost=total_cost, token_usage=total_token

View File

@ -50,9 +50,9 @@ class RegexExtractor(BaseComponent):
if not output_map:
return text
return output_map.get(text, text)
return str(output_map.get(text, text))
def run_raw(self, text: str) -> List[str]:
def run_raw(self, text: str) -> List[Document]:
"""
Runs the raw text through the static pattern and output mapping, returning a
list of strings.
@ -66,9 +66,12 @@ class RegexExtractor(BaseComponent):
output = self.run_raw_static(self.pattern, text)
output = [self.map_output(text, self.output_map) for text in output]
return output
return [
Document(text=text, metadata={"origin": "RegexExtractor"})
for text in output
]
def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]:
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
"""
Runs a batch of raw text inputs through the `run_raw()` method and returns the
output for each input.
@ -95,13 +98,7 @@ class RegexExtractor(BaseComponent):
Returns:
List[Document]: A list of extracted documents.
"""
texts = self.run_raw(document.text)
output = [
Document(text=text, metadata={**document.metadata, "RegexExtractor": True})
for text in texts
]
return output
return self.run_raw(document.text)
def run_batch_document(
self, document_batch: List[Document]

View File

@ -1,3 +1,4 @@
import warnings
from typing import Union
from kotaemon.base import BaseComponent
@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
from kotaemon.prompt.template import PromptTemplate
class BasePrompt(BaseComponent):
class BasePromptComponent(BaseComponent):
"""
Base class for prompt components.
@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
given template.
"""
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def __check_redundant_kwargs(self, **kwargs):
"""
Check for redundant keyword arguments.
@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
redundant_keys = provided_keys - expected_keys
if redundant_keys:
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}")
warnings.warn(
f"Keys provided but not in template: {redundant_keys}", UserWarning
)
def __check_unset_placeholders(self):
"""
@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
Returns:
dict: A dictionary of keyword arguments.
"""
def __prepare(key, value):
if isinstance(value, str):
return value
if isinstance(value, (int, Document)):
return str(value)
raise ValueError(
f"Unsupported type {type(value)} for template value of key {key}"
)
kwargs = {}
for k in self.template.placeholders:
v = getattr(self, k)
if isinstance(v, (int, Document)):
v = str(v)
elif isinstance(v, BaseComponent):
v = str(v())
if isinstance(v, BaseComponent):
v = v()
if isinstance(v, list):
v = str([__prepare(k, each) for each in v])
elif isinstance(v, (str, int, Document)):
v = __prepare(k, v)
else:
raise ValueError(
f"Unsupported type {type(v)} for template value of key {k}"
)
kwargs[k] = v
return kwargs
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def set(self, **kwargs):
"""
Similar to `__set` but for external use.
@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
self.__check_unset_placeholders()
prepared_kwargs = self.__prepare_value()
return self.template.populate(**prepared_kwargs)
text = self.template.populate(**prepared_kwargs)
return Document(text=text, metadata={"origin": "PromptComponent"})
def run_raw(self, *args, **kwargs):
pass
@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
def is_batch(self, *args, **kwargs):
pass
def flow(self):
return self.__call__()

View File

@ -56,6 +56,8 @@ setuptools.setup(
"chromadb",
"wikipedia",
"googlesearch-python",
"python-dotenv",
"pytest-mock",
],
},
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},

View File

@ -30,9 +30,11 @@ def test_is_batch(regex_extractor):
def test_run_raw(regex_extractor):
output = regex_extractor("This is a test. 123")
output = [each.text for each in output]
assert output == ["123"]
def test_run_batch_raw(regex_extractor):
output = regex_extractor(["This is a test. 123", "456"])
output = [[each.text for each in batch] for batch in output]
assert output == [["123"], ["456"]]

View File

@ -2,7 +2,7 @@ import pytest
from kotaemon.documents.base import Document
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePrompt
from kotaemon.prompt.base import BasePromptComponent
from kotaemon.prompt.template import PromptTemplate
@ -14,7 +14,7 @@ def test_set_attributes():
)
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp)
prompt = BasePromptComponent(template=template, s="Alice", i=30, doc=doc, comp=comp)
assert prompt.s == "Alice"
assert prompt.i == 30
assert prompt.doc == doc
@ -23,23 +23,23 @@ def test_set_attributes():
def test_check_redundant_kwargs():
template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template, name="Alice")
with pytest.raises(ValueError):
prompt._BasePrompt__check_redundant_kwargs(name="Alice", age=30)
prompt = BasePromptComponent(template, name="Alice")
with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"):
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
def test_check_unset_placeholders():
template = PromptTemplate("Hello, {name}! I'm {age} years old.")
prompt = BasePrompt(template, name="Alice")
prompt = BasePromptComponent(template, name="Alice")
with pytest.raises(ValueError):
prompt._BasePrompt__check_unset_placeholders()
prompt._BasePromptComponent__check_unset_placeholders()
def test_validate_value_type():
template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template)
prompt = BasePromptComponent(template)
with pytest.raises(ValueError):
prompt._BasePrompt__validate_value_type(name={})
prompt._BasePromptComponent__validate_value_type(name={})
def test_run():
@ -50,18 +50,18 @@ def test_run():
)
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp)
prompt = BasePromptComponent(template=template, s="Alice", i=30, doc=doc, comp=comp)
result = prompt()
assert (
result
result.text
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
)
def test_set_method():
template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template)
prompt = BasePromptComponent(template)
prompt.set(name="Alice")
assert prompt.name == "Alice"