feat: Add chain-of-thought (#37)
* Add chain-of-thought * Use BasePromptComponent * Add terminate callback for the chain-of-thought
This commit is contained in:
parent
f80a4ea883
commit
6ab1854532
BIN
.env.secret
BIN
.env.secret
Binary file not shown.
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -454,6 +454,7 @@ logs/
|
||||||
.gitsecret/keys/random_seed
|
.gitsecret/keys/random_seed
|
||||||
!*.secret
|
!*.secret
|
||||||
.env
|
.env
|
||||||
|
.envrc
|
||||||
|
|
||||||
S.gpg-agent*
|
S.gpg-agent*
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
.env:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5
|
.env:555d804179d7207ad6784a84afb88d2ec44f90ea3b7a061d0e38f9dd53fe7211
|
||||||
|
|
169
knowledgehub/pipelines/cot.py
Normal file
169
knowledgehub/pipelines/cot.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from theflow import Compose, Node, Param
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
|
||||||
|
class Thought(BaseComponent):
|
||||||
|
"""A thought in the chain of thought
|
||||||
|
|
||||||
|
- Input: `**kwargs` pairs, where key is the placeholder in the prompt, and
|
||||||
|
value is the value.
|
||||||
|
- Output: an output dictionary
|
||||||
|
|
||||||
|
##### Usage:
|
||||||
|
|
||||||
|
Create and run a thought:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> from kotaemon.pipelines.cot import Thought
|
||||||
|
>> thought = Thought(
|
||||||
|
prompt="How to {action} {object}?",
|
||||||
|
llm=AzureChatOpenAI(...),
|
||||||
|
post_process=lambda string: {"tutorial": string},
|
||||||
|
)
|
||||||
|
>> output = thought(action="install", object="python")
|
||||||
|
>> print(output)
|
||||||
|
{'tutorial': 'As an AI language model,...'}
|
||||||
|
```
|
||||||
|
|
||||||
|
Basically, when a thought is run, it will:
|
||||||
|
|
||||||
|
1. Populate the prompt template with the input `**kwargs`.
|
||||||
|
2. Run the LLM model with the populated prompt.
|
||||||
|
3. Post-process the LLM output with the post-processor.
|
||||||
|
|
||||||
|
This `Thought` allows chaining sequentially with the + operator. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> llm = AzureChatOpenAI(...)
|
||||||
|
>> thought1 = Thought(
|
||||||
|
prompt="Word {word} in {language} is ",
|
||||||
|
llm=llm,
|
||||||
|
post_process=lambda string: {"translated": string},
|
||||||
|
)
|
||||||
|
>> thought2 = Thought(
|
||||||
|
prompt="Translate {translated} to Japanese",
|
||||||
|
llm=llm,
|
||||||
|
post_process=lambda string: {"output": string},
|
||||||
|
)
|
||||||
|
|
||||||
|
>> thought = thought1 + thought2
|
||||||
|
>> thought(word="hello", language="French")
|
||||||
|
{'word': 'hello',
|
||||||
|
'language': 'French',
|
||||||
|
'translated': '"Bonjour"',
|
||||||
|
'output': 'こんにちは (Konnichiwa)'}
|
||||||
|
```
|
||||||
|
|
||||||
|
Under the hood, when the `+` operator is used, a `ManualSequentialChainOfThought`
|
||||||
|
is created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: Param[str] = Param(
|
||||||
|
help="The prompt template string. This prompt template has Python-like "
|
||||||
|
"variable placeholders, that then will be subsituted with real values when "
|
||||||
|
"this component is executed"
|
||||||
|
)
|
||||||
|
llm = Node(
|
||||||
|
default=AzureChatOpenAI, help="The LLM model to execute the input prompt"
|
||||||
|
)
|
||||||
|
post_process: Node[Compose] = Node(
|
||||||
|
help="The function post-processor that post-processes LLM output prediction ."
|
||||||
|
"It should take a string as input (this is the LLM output text) and return "
|
||||||
|
"a dictionary, where the key should"
|
||||||
|
)
|
||||||
|
|
||||||
|
@Node.decorate(depends_on="prompt")
|
||||||
|
def prompt_template(self):
|
||||||
|
return BasePromptComponent(self.prompt)
|
||||||
|
|
||||||
|
def run(self, **kwargs) -> dict:
|
||||||
|
"""Run the chain of thought"""
|
||||||
|
prompt = self.prompt_template(**kwargs).text
|
||||||
|
response = self.llm(prompt).text
|
||||||
|
return self.post_process(response)
|
||||||
|
|
||||||
|
def get_variables(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def __add__(self, next_thought: "Thought") -> "ManualSequentialChainOfThought":
|
||||||
|
return ManualSequentialChainOfThought(
|
||||||
|
thoughts=[self, next_thought], llm=self.llm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ManualSequentialChainOfThought(BaseComponent):
|
||||||
|
"""Perform sequential chain-of-thought with manual pre-defined prompts
|
||||||
|
|
||||||
|
This method supports variable number of steps. Each step corresponds to a
|
||||||
|
`kotaemon.pipelines.cot.Thought`. Please refer that section for
|
||||||
|
Thought's detail. This section is about chaining thought together.
|
||||||
|
|
||||||
|
##### Usage:
|
||||||
|
|
||||||
|
**Create and run a chain of thought without "+" operator:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
|
||||||
|
|
||||||
|
>> llm = AzureChatOpenAI(...)
|
||||||
|
>> thought1 = Thought(
|
||||||
|
prompt="Word {word} in {language} is ",
|
||||||
|
post_process=lambda string: {"translated": string},
|
||||||
|
)
|
||||||
|
>> thought2 = Thought(
|
||||||
|
prompt="Translate {translated} to Japanese",
|
||||||
|
post_process=lambda string: {"output": string},
|
||||||
|
)
|
||||||
|
>> thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
|
||||||
|
>> thought(word="hello", language="French")
|
||||||
|
{'word': 'hello',
|
||||||
|
'language': 'French',
|
||||||
|
'translated': '"Bonjour"',
|
||||||
|
'output': 'こんにちは (Konnichiwa)'}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Create and run a chain of thought without "+" operator:** Please refer the
|
||||||
|
`kotaemon.pipelines.cot.Thought` section for examples.
|
||||||
|
|
||||||
|
This chain-of-thought optionally takes a termination check callback function.
|
||||||
|
This function will be called after each thought is executed. It takes in a
|
||||||
|
dictionary of all thought outputs so far, and it returns True or False. If
|
||||||
|
True, the chain-of-thought will terminate. If unset, the default callback always
|
||||||
|
returns False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
thoughts: Param[List[Thought]] = Param(
|
||||||
|
default_callback=lambda *_: [], help="List of Thought"
|
||||||
|
)
|
||||||
|
llm: Param = Param(help="The LLM model to use (base of kotaemon.llms.LLM)")
|
||||||
|
terminate: Param = Param(
|
||||||
|
default=lambda _: False,
|
||||||
|
help="Callback on terminate condition. Default to always return False",
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, **kwargs) -> dict:
|
||||||
|
"""Run the manual chain of thought"""
|
||||||
|
|
||||||
|
inputs = deepcopy(kwargs)
|
||||||
|
for idx, thought in enumerate(self.thoughts):
|
||||||
|
if self.llm:
|
||||||
|
thought.llm = self.llm
|
||||||
|
self._prepare_child(thought, f"thought{idx}")
|
||||||
|
|
||||||
|
output = thought(**inputs)
|
||||||
|
inputs.update(output)
|
||||||
|
if self.terminate(inputs):
|
||||||
|
break
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
|
||||||
|
return ManualSequentialChainOfThought(
|
||||||
|
thoughts=self.thoughts + [next_thought], llm=self.llm
|
||||||
|
)
|
120
tests/test_cot.py
Normal file
120
tests/test_cot.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.pipelines.cot import ManualSequentialChainOfThought, Thought
|
||||||
|
|
||||||
|
_openai_chat_completion_response = [
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1692338378,
|
||||||
|
"model": "gpt-35-turbo",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
|
}
|
||||||
|
for text in ["Bonjour", "こんにちは (Konnichiwa)"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_response,
|
||||||
|
)
|
||||||
|
def test_cot_plus_operator(openai_completion):
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="https://dummy.openai.azure.com/",
|
||||||
|
openai_api_key="dummy",
|
||||||
|
openai_api_version="2023-03-15-preview",
|
||||||
|
deployment_name="dummy-q2",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
thought1 = Thought(
|
||||||
|
prompt="Word {word} in {language} is ",
|
||||||
|
llm=llm,
|
||||||
|
post_process=lambda string: {"translated": string},
|
||||||
|
)
|
||||||
|
thought2 = Thought(
|
||||||
|
prompt="Translate {translated} to Japanese",
|
||||||
|
llm=llm,
|
||||||
|
post_process=lambda string: {"output": string},
|
||||||
|
)
|
||||||
|
thought = thought1 + thought2
|
||||||
|
output = thought(word="hello", language="French")
|
||||||
|
assert output == {
|
||||||
|
"word": "hello",
|
||||||
|
"language": "French",
|
||||||
|
"translated": "Bonjour",
|
||||||
|
"output": "こんにちは (Konnichiwa)",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_response,
|
||||||
|
)
|
||||||
|
def test_cot_manual(openai_completion):
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="https://dummy.openai.azure.com/",
|
||||||
|
openai_api_key="dummy",
|
||||||
|
openai_api_version="2023-03-15-preview",
|
||||||
|
deployment_name="dummy-q2",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
thought1 = Thought(
|
||||||
|
prompt="Word {word} in {language} is ",
|
||||||
|
post_process=lambda string: {"translated": string},
|
||||||
|
)
|
||||||
|
thought2 = Thought(
|
||||||
|
prompt="Translate {translated} to Japanese",
|
||||||
|
post_process=lambda string: {"output": string},
|
||||||
|
)
|
||||||
|
thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
|
||||||
|
output = thought(word="hello", language="French")
|
||||||
|
assert output == {
|
||||||
|
"word": "hello",
|
||||||
|
"language": "French",
|
||||||
|
"translated": "Bonjour",
|
||||||
|
"output": "こんにちは (Konnichiwa)",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_response,
|
||||||
|
)
|
||||||
|
def test_cot_with_termination_callback(openai_completion):
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="https://dummy.openai.azure.com/",
|
||||||
|
openai_api_key="dummy",
|
||||||
|
openai_api_version="2023-03-15-preview",
|
||||||
|
deployment_name="dummy-q2",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
thought1 = Thought(
|
||||||
|
prompt="Word {word} in {language} is ",
|
||||||
|
post_process=lambda string: {"translated": string},
|
||||||
|
)
|
||||||
|
thought2 = Thought(
|
||||||
|
prompt="Translate {translated} to Japanese",
|
||||||
|
post_process=lambda string: {"output": string},
|
||||||
|
)
|
||||||
|
thought = ManualSequentialChainOfThought(
|
||||||
|
thoughts=[thought1, thought2],
|
||||||
|
llm=llm,
|
||||||
|
terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False,
|
||||||
|
)
|
||||||
|
output = thought(word="hallo", language="French")
|
||||||
|
assert output == {
|
||||||
|
"word": "hallo",
|
||||||
|
"language": "French",
|
||||||
|
"translated": "Bonjour",
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user