[AUR-395, AUR-415] Adopt Example1 Injury pipeline; add .flow() for enabling bottom-up pipeline execution (#32)

* add example1/injury pipeline example
* add dotenv
* update various api
This commit is contained in:
ian_Cin 2023-10-02 16:24:56 +07:00 committed by GitHub
parent 3cceec63ef
commit d83c22aa4e
16 changed files with 114 additions and 69 deletions

BIN
.env.secret Normal file

Binary file not shown.

2
.gitignore vendored
View File

@ -453,7 +453,7 @@ $RECYCLE.BIN/
logs/ logs/
.gitsecret/keys/random_seed .gitsecret/keys/random_seed
!*.secret !*.secret
credentials.txt .env
S.gpg-agent* S.gpg-agent*
.vscode/settings.json .vscode/settings.json

View File

@ -1 +1 @@
credentials.txt:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5 .env:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5

View File

@ -47,7 +47,12 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
### Credential sharing ### Credential sharing
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which internally uses `gpg` to encrypt and decrypt secret files. This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
internally uses `gpg` to encrypt and decrypt secret files.
This repo uses `python-dotenv` to manage credentials stored as enviroment variable.
Please note that the use of `python-dotenv` and credentials are for development
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
#### Install git-secret #### Install git-secret
@ -63,13 +68,13 @@ In order to gain access to the secret files, you must provide your gpg public fi
#### Decrypt the secret file #### Decrypt the secret file
The credentials are encrypted in the `credentials.txt.secret` file. To print the decrypted content to stdout, run The credentials are encrypted in the `.env.secret` file. To print the decrypted content to stdout, run
```shell ```shell
git-secret cat [filename] git-secret cat [filename]
``` ```
Or to get the decrypted `credentials.txt` file, run Or to get the decrypted `.env` file, run
```shell ```shell
git-secret reveal [filename] git-secret reveal [filename]

Binary file not shown.

View File

@ -16,6 +16,19 @@ class BaseComponent(Compose):
- is_batch: check if input is batch - is_batch: check if input is batch
""" """
inflow = None
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
return self.__call__(self.inflow.flow())
@abstractmethod @abstractmethod
def run_raw(self, *args, **kwargs): def run_raw(self, *args, **kwargs):
... ...

View File

@ -1,25 +1,13 @@
from dataclasses import dataclass, field
from typing import List from typing import List
from ..base import BaseComponent from pydantic import Field
from kotaemon.documents.base import Document
@dataclass class LLMInterface(Document):
class LLMInterface: candidates: List[str]
text: List[str]
completion_tokens: int = -1 completion_tokens: int = -1
total_tokens: int = -1 total_tokens: int = -1
prompt_tokens: int = -1 prompt_tokens: int = -1
logits: List[List[float]] = field(default_factory=list) logits: List[List[float]] = Field(default_factory=list)
class PromptTemplate(BaseComponent):
pass
class Extract(BaseComponent):
pass
class PromptNode(BaseComponent):
pass

View File

@ -11,7 +11,17 @@ Message = TypeVar("Message", bound=BaseMessage)
class ChatLLM(BaseComponent): class ChatLLM(BaseComponent):
... def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
text = self.inflow.flow().text
return self.__call__(text)
class LangchainChatLLM(ChatLLM): class LangchainChatLLM(ChatLLM):
@ -44,8 +54,10 @@ class LangchainChatLLM(ChatLLM):
def run_document(self, text: List[Message], **kwargs) -> LLMInterface: def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
pred = self.agent.generate([text], **kwargs) # type: ignore pred = self.agent.generate([text], **kwargs) # type: ignore
all_text = [each.text for each in pred.generations[0]]
return LLMInterface( return LLMInterface(
text=[each.text for each in pred.generations[0]], text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
total_tokens=pred.llm_output["token_usage"]["total_tokens"], total_tokens=pred.llm_output["token_usage"]["total_tokens"],
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"], prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],

View File

@ -33,8 +33,10 @@ class LangchainLLM(LLM):
def run_raw(self, text: str) -> LLMInterface: def run_raw(self, text: str) -> LLMInterface:
pred = self.agent.generate([text]) pred = self.agent.generate([text])
all_text = [each.text for each in pred.generations[0]]
return LLMInterface( return LLMInterface(
text=[each.text for each in pred.generations[0]], text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
total_tokens=pred.llm_output["token_usage"]["total_tokens"], total_tokens=pred.llm_output["token_usage"]["total_tokens"],
prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"], prompt_tokens=pred.llm_output["token_usage"]["prompt_tokens"],

View File

@ -162,7 +162,7 @@ class ReactAgent(BaseAgent):
prompt = self._compose_prompt(instruction) prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}") logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore response = self.llm(prompt, stop=["Observation:"]) # type: ignore
response_text = response.text[0] response_text = response.text
logging.info(f"Response: {response_text}") logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text) action_step = self._parse_output(response_text)
if action_step is None: if action_step is None:

View File

@ -245,7 +245,7 @@ class RewooAgent(BaseAgent):
# Plan # Plan
planner_output = planner(instruction) planner_output = planner(instruction)
plannner_text_output = planner_output.text[0] plannner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(plannner_text_output) plan_to_es, plans = self._parse_plan_map(plannner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences( planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output plannner_text_output
@ -263,7 +263,7 @@ class RewooAgent(BaseAgent):
# Solve # Solve
solver_output = solver(instruction, worker_log) solver_output = solver(instruction, worker_log)
solver_output_text = solver_output.text[0] solver_output_text = solver_output.text
return AgentOutput( return AgentOutput(
output=solver_output_text, cost=total_cost, token_usage=total_token output=solver_output_text, cost=total_cost, token_usage=total_token

View File

@ -50,9 +50,9 @@ class RegexExtractor(BaseComponent):
if not output_map: if not output_map:
return text return text
return output_map.get(text, text) return str(output_map.get(text, text))
def run_raw(self, text: str) -> List[str]: def run_raw(self, text: str) -> List[Document]:
""" """
Runs the raw text through the static pattern and output mapping, returning a Runs the raw text through the static pattern and output mapping, returning a
list of strings. list of strings.
@ -66,9 +66,12 @@ class RegexExtractor(BaseComponent):
output = self.run_raw_static(self.pattern, text) output = self.run_raw_static(self.pattern, text)
output = [self.map_output(text, self.output_map) for text in output] output = [self.map_output(text, self.output_map) for text in output]
return output return [
Document(text=text, metadata={"origin": "RegexExtractor"})
for text in output
]
def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]: def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
""" """
Runs a batch of raw text inputs through the `run_raw()` method and returns the Runs a batch of raw text inputs through the `run_raw()` method and returns the
output for each input. output for each input.
@ -95,13 +98,7 @@ class RegexExtractor(BaseComponent):
Returns: Returns:
List[Document]: A list of extracted documents. List[Document]: A list of extracted documents.
""" """
texts = self.run_raw(document.text) return self.run_raw(document.text)
output = [
Document(text=text, metadata={**document.metadata, "RegexExtractor": True})
for text in texts
]
return output
def run_batch_document( def run_batch_document(
self, document_batch: List[Document] self, document_batch: List[Document]

View File

@ -1,3 +1,4 @@
import warnings
from typing import Union from typing import Union
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
@ -5,7 +6,7 @@ from kotaemon.documents.base import Document
from kotaemon.prompt.template import PromptTemplate from kotaemon.prompt.template import PromptTemplate
class BasePrompt(BaseComponent): class BasePromptComponent(BaseComponent):
""" """
Base class for prompt components. Base class for prompt components.
@ -15,6 +16,16 @@ class BasePrompt(BaseComponent):
given template. given template.
""" """
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def __check_redundant_kwargs(self, **kwargs): def __check_redundant_kwargs(self, **kwargs):
""" """
Check for redundant keyword arguments. Check for redundant keyword arguments.
@ -33,7 +44,9 @@ class BasePrompt(BaseComponent):
redundant_keys = provided_keys - expected_keys redundant_keys = provided_keys - expected_keys
if redundant_keys: if redundant_keys:
raise ValueError(f"\nKeys provided but not in template: {redundant_keys}") warnings.warn(
f"Keys provided but not in template: {redundant_keys}", UserWarning
)
def __check_unset_placeholders(self): def __check_unset_placeholders(self):
""" """
@ -111,27 +124,34 @@ class BasePrompt(BaseComponent):
Returns: Returns:
dict: A dictionary of keyword arguments. dict: A dictionary of keyword arguments.
""" """
def __prepare(key, value):
if isinstance(value, str):
return value
if isinstance(value, (int, Document)):
return str(value)
raise ValueError(
f"Unsupported type {type(value)} for template value of key {key}"
)
kwargs = {} kwargs = {}
for k in self.template.placeholders: for k in self.template.placeholders:
v = getattr(self, k) v = getattr(self, k)
if isinstance(v, (int, Document)): if isinstance(v, BaseComponent):
v = str(v) v = v()
elif isinstance(v, BaseComponent): if isinstance(v, list):
v = str(v()) v = str([__prepare(k, each) for each in v])
elif isinstance(v, (str, int, Document)):
v = __prepare(k, v)
else:
raise ValueError(
f"Unsupported type {type(v)} for template value of key {k}"
)
kwargs[k] = v kwargs[k] = v
return kwargs return kwargs
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def set(self, **kwargs): def set(self, **kwargs):
""" """
Similar to `__set` but for external use. Similar to `__set` but for external use.
@ -163,7 +183,8 @@ class BasePrompt(BaseComponent):
self.__check_unset_placeholders() self.__check_unset_placeholders()
prepared_kwargs = self.__prepare_value() prepared_kwargs = self.__prepare_value()
return self.template.populate(**prepared_kwargs) text = self.template.populate(**prepared_kwargs)
return Document(text=text, metadata={"origin": "PromptComponent"})
def run_raw(self, *args, **kwargs): def run_raw(self, *args, **kwargs):
pass pass
@ -182,3 +203,6 @@ class BasePrompt(BaseComponent):
def is_batch(self, *args, **kwargs): def is_batch(self, *args, **kwargs):
pass pass
def flow(self):
return self.__call__()

View File

@ -56,6 +56,8 @@ setuptools.setup(
"chromadb", "chromadb",
"wikipedia", "wikipedia",
"googlesearch-python", "googlesearch-python",
"python-dotenv",
"pytest-mock",
], ],
}, },
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]}, entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},

View File

@ -30,9 +30,11 @@ def test_is_batch(regex_extractor):
def test_run_raw(regex_extractor): def test_run_raw(regex_extractor):
output = regex_extractor("This is a test. 123") output = regex_extractor("This is a test. 123")
output = [each.text for each in output]
assert output == ["123"] assert output == ["123"]
def test_run_batch_raw(regex_extractor): def test_run_batch_raw(regex_extractor):
output = regex_extractor(["This is a test. 123", "456"]) output = regex_extractor(["This is a test. 123", "456"])
output = [[each.text for each in batch] for batch in output]
assert output == [["123"], ["456"]] assert output == [["123"], ["456"]]

View File

@ -2,7 +2,7 @@ import pytest
from kotaemon.documents.base import Document from kotaemon.documents.base import Document
from kotaemon.post_processing.extractor import RegexExtractor from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePrompt from kotaemon.prompt.base import BasePromptComponent
from kotaemon.prompt.template import PromptTemplate from kotaemon.prompt.template import PromptTemplate
@ -14,7 +14,7 @@ def test_set_attributes():
) )
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True) comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp) prompt = BasePromptComponent(template=template, s="Alice", i=30, doc=doc, comp=comp)
assert prompt.s == "Alice" assert prompt.s == "Alice"
assert prompt.i == 30 assert prompt.i == 30
assert prompt.doc == doc assert prompt.doc == doc
@ -23,23 +23,23 @@ def test_set_attributes():
def test_check_redundant_kwargs(): def test_check_redundant_kwargs():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template, name="Alice") prompt = BasePromptComponent(template, name="Alice")
with pytest.raises(ValueError): with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"):
prompt._BasePrompt__check_redundant_kwargs(name="Alice", age=30) prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
def test_check_unset_placeholders(): def test_check_unset_placeholders():
template = PromptTemplate("Hello, {name}! I'm {age} years old.") template = PromptTemplate("Hello, {name}! I'm {age} years old.")
prompt = BasePrompt(template, name="Alice") prompt = BasePromptComponent(template, name="Alice")
with pytest.raises(ValueError): with pytest.raises(ValueError):
prompt._BasePrompt__check_unset_placeholders() prompt._BasePromptComponent__check_unset_placeholders()
def test_validate_value_type(): def test_validate_value_type():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template) prompt = BasePromptComponent(template)
with pytest.raises(ValueError): with pytest.raises(ValueError):
prompt._BasePrompt__validate_value_type(name={}) prompt._BasePromptComponent__validate_value_type(name={})
def test_run(): def test_run():
@ -50,18 +50,18 @@ def test_run():
) )
comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True) comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True)
prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp) prompt = BasePromptComponent(template=template, s="Alice", i=30, doc=doc, comp=comp)
result = prompt() result = prompt()
assert ( assert (
result result.text
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']" == "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
) )
def test_set_method(): def test_set_method():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePrompt(template) prompt = BasePromptComponent(template)
prompt.set(name="Alice") prompt.set(name="Alice")
assert prompt.name == "Alice" assert prompt.name == "Alice"