Make ktem official (#134)

* Move kotaemon and ktem into same folder

* Update docs

* Update CI

* Resolve mypy, isorts

* Re-allow test pdf files
This commit is contained in:
Duc Nguyen (john)
2024-01-23 10:54:18 +07:00
committed by GitHub
parent 9c5b707010
commit 2dd531114f
180 changed files with 4638 additions and 235 deletions

130
libs/kotaemon/README.md Normal file
View File

@@ -0,0 +1,130 @@
# kotaemon
Quick and easy AI components to build Kotaemon - applicable in client
project.
## Documentation
https://docs.promptui.dm.cinnamon.is
## Install
```shell
pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
```
## Contribute
### Setup
- Create conda environment (suggest 3.10)
```shell
conda create -n kotaemon python=3.10
conda activate kotaemon
```
- Clone the repo
```shell
git clone git@github.com:Cinnamon/kotaemon.git
cd kotaemon
```
- Install all
```shell
pip install -e ".[dev]"
```
- Pre-commit
```shell
pre-commit install
```
- Test
```shell
pytest tests
```
### Credential sharing
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
internally uses `gpg` to encrypt and decrypt secret files.
This repo uses `python-dotenv` to manage credentials stored as environment variable.
Please note that the use of `python-dotenv` and credentials are for development
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
#### Install git-secret
Please follow the [official guide](https://sobolevn.me/git-secret/installation) to install git-secret.
For Windows users, see [For Windows users](#for-windows-users).
For users who don't have sudo privilege to install packages, follow the `Manual Installation` in the [official guide](https://sobolevn.me/git-secret/installation) and set `PREFIX` to a path that you have access to. And please don't forget to add `PREFIX` to your `PATH`.
#### Gaining access
In order to gain access to the secret files, you must provide your gpg public file to anyone who has access and ask them to ask your key to the keyring. For a quick tutorial on generating your gpg key pair, you can refer to the `Using gpg` section from the [git-secret main page](https://sobolevn.me/git-secret/).
#### Decrypt the secret file
The credentials are encrypted in the `.env.secret` file. To print the decrypted content to stdout, run
```shell
git-secret cat [filename]
```
Or to get the decrypted `.env` file, run
```shell
git-secret reveal [filename]
```
#### For Windows users
git-secret is currently not available for Windows, thus the easiest way is to use it in WSL (please use the latest version of WSL2). From there you have 2 options:
1. Using the gpg of WSL.
This is the most straight-forward option since you would use WSL just like any other unix environment. However, the downside is that you have to make WSL your main environment, which means WSL must have write permission on your repo. To achieve this, you must either:
- Clone and store your repo inside WSL's file system.
- Provide WSL with necessary permission on your Windows file system. This can be achieve by setting `automount` options for WSL. To do that, add these content to `/etc/wsl.conf` and then restart your sub-system.
```shell
[automount]
options = "metadata,umask=022,fmask=011"
```
This enables all permissions for user owner.
2. Using the gpg of Windows but with git-secret from WSL.
For those who use Windows as the main environment, having to switch back and forth between Windows and WSL will be inconvenient. You can instead stay within your Windows environment and apply some tricks to use `git-secret` from WSL.
- Install and setup `gpg` on Windows.
- Install `git-secret` on WSL. Now in Windows, you can invoke `git-secret` using `wsl git-secret`.
- Alternatively you can setup alias in CMD to shorten the syntax. Please refer to [this SO answer](https://stackoverflow.com/a/65823225) for the instruction. Some recommended aliases are:
```bat
@echo off
:: Commands
DOSKEY ls=dir /B $*
DOSKEY ll=dir /a $*
DOSKEY git-secret=wsl git-secret $*
DOSKEY gs=wsl git-secret $*
```
Now you can invoke `git-secret` in CMD using `git-secret` or `gs`.
- For Powershell users, similar behaviours can be achieved using `Set-Alias` and `profile.ps1`. Please refer this [SO thread](https://stackoverflow.com/questions/61081434/how-do-i-create-a-permanent-alias-file-in-powershell-core) as an example.
### Code base structure
- documents: define document
- loaders

View File

@@ -0,0 +1,23 @@
# Disable telemetry with monkey patching
import logging
logger = logging.getLogger(__name__)
try:
import posthog
def capture(*args, **kwargs):
logger.info("posthog.capture called with args: %s, kwargs: %s", args, kwargs)
posthog.capture = capture
except ImportError:
pass
try:
import os
os.environ["HAYSTACK_TELEMETRY_ENABLED"] = "False"
import haystack.telemetry
haystack.telemetry.telemetry = None
except ImportError:
pass

View File

@@ -0,0 +1,25 @@
from .base import BaseAgent
from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad
from .langchain_based import LangchainAgent
from .react.agent import ReactAgent
from .rewoo.agent import RewooAgent
from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool
__all__ = [
# agent
"BaseAgent",
"ReactAgent",
"RewooAgent",
"LangchainAgent",
# tool
"BaseTool",
"ComponentTool",
"GoogleSearchTool",
"WikipediaTool",
"LLMTool",
# io
"AgentType",
"AgentOutput",
"AgentFinish",
"BaseScratchPad",
]

View File

@@ -0,0 +1,57 @@
from typing import Optional, Union
from kotaemon.base import BaseComponent, Node, Param
from kotaemon.llms import BaseLLM, PromptTemplate
from .io import AgentOutput, AgentType
from .tools import BaseTool
class BaseAgent(BaseComponent):
"""Define base agent interface"""
name: str = Param(help="Name of the agent.")
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
description: str = Param(
help=(
"Description used to tell the model how/when/why to use the agent. You can"
" provide few-shot examples as a part of the description. This will be"
" input to the prompt of LLM."
)
)
llm: Optional[BaseLLM] = Node(
help=(
"LLM to be used for the agent (optional). LLM must implement BaseLLM"
" interface."
)
)
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
help="A prompt template or a dict to supply different prompt to the agent"
)
plugins: list[BaseTool] = Param(
default_callback=lambda _: [],
help="List of plugins / tools to be used in the agent",
)
@staticmethod
def safeguard_run(run_func, *args, **kwargs):
def wrapper(self, *args, **kwargs):
try:
return run_func(self, *args, **kwargs)
except Exception as e:
return AgentOutput(
text="",
agent_type=self.agent_type,
status="failed",
error=str(e),
)
return wrapper
def add_tools(self, tools: list[BaseTool]) -> None:
"""Helper method to add tools and update agent state if needed"""
self.plugins.extend(tools)
def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]:
"""Run the component."""
raise NotImplementedError()

View File

@@ -0,0 +1,3 @@
from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad
__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"]

View File

@@ -0,0 +1,254 @@
import json
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
from kotaemon.base import LLMInterface
from pydantic import Extra
def check_log():
"""
Checks if logging has been enabled.
:return: True if logging has been enabled, False otherwise.
:rtype: bool
"""
return os.environ.get("LOG_PATH", None) is not None
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
class BaseScratchPad:
"""
Base class for output handlers.
Attributes:
-----------
logger : logging.Logger
The logger object to log messages.
Methods:
--------
stop():
Stop the output.
update_status(output: str, **kwargs):
Update the status of the output.
thinking(name: str):
Log that a process is thinking.
done(_all=False):
Log that the process is done.
stream_print(item: str):
Not implemented.
json_print(item: Dict[str, Any]):
Log a JSON object.
panel_print(item: Any, title: str = "Output", stream: bool = False):
Log a panel output.
clear():
Not implemented.
print(content: str, **kwargs):
Log arbitrary content.
format_json(json_obj: str):
Format a JSON object.
debug(content: str, **kwargs):
Log a debug message.
info(content: str, **kwargs):
Log an informational message.
warning(content: str, **kwargs):
Log a warning message.
error(content: str, **kwargs):
Log an error message.
critical(content: str, **kwargs):
Log a critical message.
"""
def __init__(self):
"""
Initialize the BaseOutput object.
"""
self.logger = logging
self.log = []
def stop(self):
"""
Stop the output.
"""
def update_status(self, output: str, **kwargs):
"""
Update the status of the output.
"""
if check_log():
self.logger.info(output)
def thinking(self, name: str):
"""
Log that a process is thinking.
"""
if check_log():
self.logger.info(f"{name} is thinking...")
def done(self, _all=False):
"""
Log that the process is done.
"""
if check_log():
self.logger.info("Done")
def stream_print(self, item: str):
"""
Stream print.
"""
def json_print(self, item: Dict[str, Any]):
"""
Log a JSON object.
"""
if check_log():
self.logger.info(json.dumps(item, indent=2))
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
"""
Log a panel output.
Args:
item : Any
The item to log.
title : str, optional
The title of the panel, defaults to "Output".
stream : bool, optional
"""
if not stream:
self.log.append(item)
if check_log():
self.logger.info("-" * 20)
self.logger.info(item)
self.logger.info("-" * 20)
def clear(self):
"""
Not implemented.
"""
def print(self, content: str, **kwargs):
"""
Log arbitrary content.
"""
self.log.append(content)
if check_log():
self.logger.info(content)
def format_json(self, json_obj: str):
"""
Format a JSON object.
"""
formatted_json = json.dumps(json_obj, indent=2)
return formatted_json
def debug(self, content: str, **kwargs):
"""
Log a debug message.
"""
if check_log():
self.logger.debug(content, **kwargs)
def info(self, content: str, **kwargs):
"""
Log an informational message.
"""
if check_log():
self.logger.info(content, **kwargs)
def warning(self, content: str, **kwargs):
"""
Log a warning message.
"""
if check_log():
self.logger.warning(content, **kwargs)
def error(self, content: str, **kwargs):
"""
Log an error message.
"""
if check_log():
self.logger.error(content, **kwargs)
def critical(self, content: str, **kwargs):
"""
Log a critical message.
"""
if check_log():
self.logger.critical(content, **kwargs)
@dataclass
class AgentAction:
"""Agent's action to take.
Args:
tool: The tool to invoke.
tool_input: The input to the tool.
log: The log message.
"""
tool: str
tool_input: Union[str, dict]
log: str
class AgentFinish(NamedTuple):
"""Agent's return value when finishing execution.
Args:
return_values: The return values of the agent.
log: The log message.
"""
return_values: dict
log: str
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
"""Output from an agent.
Args:
text: The text output from the agent.
agent_type: The type of agent.
status: The status after executing the agent.
error: The error message if any.
"""
text: str
type: str = "agent"
agent_type: AgentType
status: Literal["finished", "stopped", "failed"]
error: Optional[str] = None

View File

@@ -0,0 +1,78 @@
from typing import List, Optional
from kotaemon.llms import LLM, ChatLLM
from langchain.agents import AgentType as LCAgentType
from langchain.agents import initialize_agent
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
from .base import BaseAgent
from .io import AgentOutput, AgentType
from .tools import BaseTool
class LangchainAgent(BaseAgent):
"""Wrapper for Langchain Agent"""
name: str = "LangchainAgent"
agent_type: AgentType
description: str = "LangchainAgent for answering multi-step reasoning questions"
AGENT_TYPE_MAP = {
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
}
agent: Optional[LCAgentExecutor] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.agent_type not in self.AGENT_TYPE_MAP:
raise NotImplementedError(
f"AgentType {self.agent_type } not supported by Langchain wrapper"
)
self.update_agent_tools()
def update_agent_tools(self):
assert isinstance(self.llm, (ChatLLM, LLM))
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
# a fix for search_doc tool name:
# use "Intermediate Answer" for self-ask agent
found_search_tool = False
if self.agent_type == AgentType.self_ask:
for plugin in langchain_plugins:
if plugin.name == "search_doc":
plugin.name = "Intermediate Answer"
langchain_plugins = [plugin]
found_search_tool = True
break
if self.agent_type != AgentType.self_ask or found_search_tool:
# reinit Langchain AgentExecutor
self.agent = initialize_agent(
langchain_plugins,
self.llm.to_langchain_format(),
agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True,
verbose=True,
)
def add_tools(self, tools: List[BaseTool]) -> None:
super().add_tools(tools)
self.update_agent_tools()
return
def run(self, instruction: str) -> AgentOutput:
assert (
self.agent is not None
), "Lanchain AgentExecutor is not correctly initialized"
# Langchain AgentExecutor call
output = self.agent(instruction)["output"]
return AgentOutput(
text=output,
agent_type=self.agent_type,
status="finished",
)

View File

@@ -0,0 +1,3 @@
from .agent import ReactAgent
__all__ = ["ReactAgent"]

View File

@@ -0,0 +1,204 @@
import logging
import re
from typing import Optional
from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool
from kotaemon.base import Param
from kotaemon.llms import PromptTemplate
FINAL_ANSWER_ACTION = "Final Answer:"
class ReactAgent(BaseAgent):
"""
Sequential ReactAgent class inherited from BaseAgent.
Implementing ReAct agent paradigm https://arxiv.org/pdf/2210.03629.pdf
"""
name: str = "ReactAgent"
agent_type: AgentType = AgentType.react
description: str = "ReactAgent for answering multi-step reasoning questions"
llm: BaseLLM
prompt_template: Optional[PromptTemplate] = None
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. "
)
examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent. "
)
intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = Param(
default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output",
)
max_iterations: int = 10
strict_decode: bool = False
def _compose_plugin_description(self) -> str:
"""
Compose the worker prompt from the workers.
Example:
toolname1[input]: tool1 description
toolname2[input]: tool2 description
"""
prompt = ""
try:
for plugin in self.plugins:
prompt += f"{plugin.name}[input]: {plugin.description}\n"
except Exception:
raise ValueError("Worker must have a name and description.")
return prompt
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = []
) -> str:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought:"
return thoughts
def _parse_output(self, text: str) -> Optional[AgentAction | AgentFinish]:
"""
Parse text output from LLM for the next Action or Final Answer
Using Regex to parse "Action:\n Action Input:\n" for the next Action
Using FINAL_ANSWER_ACTION to parse Final Answer
Args:
text[str]: input text to parse
"""
includes_answer = FINAL_ANSWER_ACTION in text
regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)
action_match = re.search(regex, text, re.DOTALL)
action_output: Optional[AgentAction | AgentFinish] = None
if action_match:
if includes_answer:
raise Exception(
"Parsing LLM output produced both a final answer "
f"and a parse-able action: {text}"
)
action = action_match.group(1).strip()
action_input = action_match.group(2)
tool_input = action_input.strip(" ")
# ensure if its a well formed SQL query we don't remove any trailing " chars
if tool_input.startswith("SELECT ") is False:
tool_input = tool_input.strip('"')
action_output = AgentAction(action, tool_input, text)
elif includes_answer:
action_output = AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
)
else:
if self.strict_decode:
raise Exception(f"Could not parse LLM output: `{text}`")
else:
action_output = AgentFinish({"output": text}, text)
return action_output
def _compose_prompt(self, instruction) -> str:
"""
Compose the prompt from template, worker description, examples and instruction.
"""
agent_scratchpad = self._construct_scratchpad(self.intermediate_steps)
tool_description = self._compose_plugin_description()
tool_names = ", ".join([plugin.name for plugin in self.plugins])
if self.prompt_template is None:
from .prompt import zero_shot_react_prompt
self.prompt_template = zero_shot_react_prompt
return self.prompt_template.populate(
instruction=instruction,
agent_scratchpad=agent_scratchpad,
tool_description=tool_description,
tool_names=tool_names,
)
def _format_function_map(self) -> dict[str, BaseTool]:
"""Format the function map for the open AI function API.
Return:
Dict[str, Callable]: The function map.
"""
# Map the function name to the real function object.
function_map = {}
for plugin in self.plugins:
function_map[plugin.name] = plugin
return function_map
def clear(self):
"""
Clear and reset the agent.
"""
self.intermediate_steps = []
def run(self, instruction, max_iterations=None) -> AgentOutput:
"""
Run the agent with the given instruction.
Args:
instruction: Instruction to run the agent with.
max_iterations: Maximum number of iterations
of reasoning steps, defaults to 10.
Return:
AgentOutput object.
"""
if not max_iterations:
max_iterations = self.max_iterations
assert max_iterations > 0
self.clear()
logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
status = "failed"
response_text = None
for step_count in range(1, max_iterations + 1):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
response = self.llm(
prompt, stop=["Observation:"]
) # could cause bugs if llm doesn't have `stop` as a parameter
response_text = response.text
logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text)
if action_step is None:
raise ValueError("Invalid action")
is_finished_chain = isinstance(action_step, AgentFinish)
if is_finished_chain:
result = ""
else:
assert isinstance(action_step, AgentAction)
action_name = action_step.tool
tool_input = action_step.tool_input
logging.info(f"Action: {action_name}")
logging.info(f"Tool Input: {tool_input}")
result = self._format_function_map()[action_name](tool_input)
logging.info(f"Result: {result}")
self.intermediate_steps.append((action_step, result))
if is_finished_chain:
logging.info(f"Finished after {step_count} steps.")
status = "finished"
break
else:
status = "stopped"
return AgentOutput(
text=response_text,
agent_type=self.agent_type,
status=status,
total_tokens=total_token,
total_cost=total_cost,
intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
)

View File

@@ -0,0 +1,28 @@
# flake8: noqa
from kotaemon.llms import PromptTemplate
zero_shot_react_prompt = PromptTemplate(
template="""Answer the following questions as best you can. You have access to the following tools:
{tool_description}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
#Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin! After each Action Input.
Question: {instruction}
Thought:{agent_scratchpad}
"""
)

View File

@@ -0,0 +1,3 @@
from .agent import RewooAgent
__all__ = ["RewooAgent"]

View File

@@ -0,0 +1,273 @@
import logging
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.base import Node, Param
from kotaemon.indices.qa import CitationPipeline
from kotaemon.llms import BaseLLM, PromptTemplate
from .planner import Planner
from .solver import Solver
class RewooAgent(BaseAgent):
"""Distributive RewooAgent class inherited from BaseAgent.
Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf"""
name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions"
planner_llm: BaseLLM
solver_llm: BaseLLM
prompt_template: dict[str, PromptTemplate] = Param(
default_callback=lambda _: {},
help="A dict to supply different prompt to the agent.",
)
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="A list of plugins to be used in the model."
)
examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent."
)
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
def planner(self):
return Planner(
model=self.planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
def solver(self):
return Solver(
model=self.solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
def _parse_plan_map(
self, planner_response: str
) -> tuple[dict[str, list[str]], dict[str, str]]:
"""
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
This is because sometimes LLM cannot follow the strict output format.
Example:
#Plan1
#E1
#E2
should result in: {"#Plan1": ["#E1", "#E2"]}
Or:
#Plan1
#Plan2
#E1
should result in: {"#Plan1": [], "#Plan2": ["#E1"]}
This function should also return a plan map.
Returns:
tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
"""
valid_chunk = [
line
for line in planner_response.splitlines()
if line.startswith("#Plan") or line.startswith("#E")
]
plan_to_es: dict[str, list[str]] = dict()
plans: dict[str, str] = dict()
prev_key = ""
for line in valid_chunk:
key, description = line.split(":", 1)
key = key.strip()
if key.startswith("#Plan"):
plans[key] = description.strip()
plan_to_es[key] = []
prev_key = key
elif key.startswith("#E"):
plan_to_es[prev_key].append(key)
return plan_to_es, plans
def _parse_planner_evidences(
self, planner_response: str
) -> tuple[dict[str, str], list[list[str]]]:
"""
Parse planner output. This should return a mapping from #E to tool call.
It should also identify the level of each #E in dependency map.
Example:
{
"#E1": "Tool1", "#E2": "Tool2",
"#E3": "Tool3", "#E4": "Tool4"
}, [[#E1, #E2], [#E3, #E4]]
Returns:
tuple[dict[str, str], List[List[str]]]:
A mapping from #E to tool call and a list of levels.
"""
evidences: dict[str, str] = dict()
dependence: dict[str, list[str]] = dict()
for line in planner_response.splitlines():
if line.startswith("#E") and line[2].isdigit():
e, tool_call = line.split(":", 1)
e, tool_call = e.strip(), tool_call.strip()
if len(e) == 3:
dependence[e] = []
evidences[e] = tool_call
for var in re.findall(r"#E\d+", tool_call):
if var in evidences:
dependence[e].append(var)
else:
evidences[e] = "No evidence found"
level = []
while dependence:
select = [i for i in dependence if not dependence[i]]
if len(select) == 0:
raise ValueError("Circular dependency detected.")
level.append(select)
for item in select:
dependence.pop(item)
for item in dependence:
for i in select:
if i in dependence[item]:
dependence[item].remove(i)
return evidences, level
def _run_plugin(
self,
e: str,
planner_evidences: dict[str, str],
worker_evidences: dict[str, str],
output=BaseScratchPad(),
):
"""
Run a plugin for a given evidence.
This function should also cumulate the cost and tokens.
"""
result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="")
tool_call = planner_evidences[e]
if "[" not in tool_call:
result["evidence"] = tool_call
else:
tool, tool_input = tool_call.split("[", 1)
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input):
if var in worker_evidences:
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
try:
selected_plugin = self._find_plugin(tool)
if selected_plugin is None:
raise ValueError("Invalid plugin detected")
tool_response = selected_plugin(tool_input)
result["evidence"] = get_plugin_response_content(tool_response)
except ValueError:
result["evidence"] = "No evidence found."
finally:
output.panel_print(
result["evidence"], f"[green] Function Response of [blue]{tool}: "
)
return result
def _get_worker_evidence(
self,
planner_evidences: dict[str, str],
evidences_level: list[list[str]],
output=BaseScratchPad(),
) -> Any:
"""
Parallel execution of plugins in DAG for speedup.
This is one of core benefits of ReWOO agents.
Args:
planner_evidences: A mapping from #E to tool call.
evidences_level: A list of levels of evidences.
Calculated from DAG of plugin calls.
output: Output object, defaults to BaseOutput().
Returns:
A mapping from #E to tool call.
"""
worker_evidences: dict[str, str] = dict()
plugin_cost, plugin_token = 0.0, 0.0
with ThreadPoolExecutor() as pool:
for level in evidences_level:
results = []
for e in level:
results.append(
pool.submit(
self._run_plugin,
e,
planner_evidences,
worker_evidences,
output,
)
)
if len(results) > 1:
output.update_status(f"Running tasks {level} in parallel.")
else:
output.update_status(f"Running task {level[0]}.")
for r in results:
resp = r.result()
plugin_cost += resp["plugin_cost"]
plugin_token += resp["plugin_token"]
worker_evidences[resp["e"]] = resp["evidence"]
output.done()
return worker_evidences, plugin_cost, plugin_token
def _find_plugin(self, name: str):
for p in self.plugins:
if p.name == name:
return p
@BaseAgent.safeguard_run
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
"""
Run the agent with a given instruction.
"""
logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
# Plan
planner_output = self.planner(instruction)
planner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(planner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
planner_text_output
)
# Work
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
planner_evidences, evidence_level
)
worker_log = ""
for plan in plan_to_es:
worker_log += f"{plan}: {plans[plan]}\n"
for e in plan_to_es[plan]:
worker_log += f"{e}: {worker_evidences[e]}\n"
# Solve
solver_output = self.solver(instruction, worker_log)
solver_output_text = solver_output.text
if use_citation:
citation_pipeline = CitationPipeline(llm=self.solver_llm)
citation = citation_pipeline(context=worker_log, question=instruction)
else:
citation = None
return AgentOutput(
text=solver_output_text,
agent_type=self.agent_type,
status="finished",
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
metadata={"citation": citation},
)

View File

@@ -0,0 +1,83 @@
from typing import Any, List, Optional, Union
from kotaemon.agents.base import BaseLLM, BaseTool
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
class Planner(BaseComponent):
model: BaseLLM
prompt_template: Optional[PromptTemplate] = None
examples: Optional[Union[str, List[str]]] = None
plugins: List[BaseTool]
def _compose_worker_description(self) -> str:
"""
Compose the worker prompt from the workers.
Example:
toolname1[input]: tool1 description
toolname2[input]: tool2 description
"""
prompt = ""
try:
for worker in self.plugins:
prompt += f"{worker.name}[input]: {worker.description}\n"
except Exception:
raise ValueError("Worker must have a name and description.")
return prompt
def _compose_fewshot_prompt(self) -> str:
if self.examples is None:
return ""
if isinstance(self.examples, str):
return self.examples
else:
return "\n\n".join([e.strip("\n") for e in self.examples])
def _compose_prompt(self, instruction) -> str:
"""
Compose the prompt from template, worker description, examples and instruction.
"""
worker_desctription = self._compose_worker_description()
fewshot = self._compose_fewshot_prompt()
if self.prompt_template is not None:
if "fewshot" in self.prompt_template.placeholders:
return self.prompt_template.populate(
tool_description=worker_desctription,
fewshot=fewshot,
task=instruction,
)
else:
return self.prompt_template.populate(
tool_description=worker_desctription, task=instruction
)
else:
if self.examples is not None:
return few_shot_planner_prompt.populate(
tool_description=worker_desctription,
fewshot=fewshot,
task=instruction,
)
else:
return zero_shot_planner_prompt.populate(
tool_description=worker_desctription, task=instruction
)
def run(self, instruction: str, output: BaseScratchPad = BaseScratchPad()) -> Any:
response = None
output.info("Running Planner")
prompt = self._compose_prompt(instruction)
output.debug(f"Prompt: {prompt}")
try:
response = self.model(prompt)
self.log_progress(".planner", response=response)
output.info("Planner run successful.")
except ValueError as e:
output.error("Planner failed to retrieve response from LLM")
raise ValueError("Planner failed to retrieve response from LLM") from e
return response

View File

@@ -0,0 +1,119 @@
# flake8: noqa
from kotaemon.llms import PromptTemplate
zero_shot_planner_prompt = PromptTemplate(
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
##Available Tools##
{tool_description}
##Output Format (Replace '<...>')##
#Plan1: <describe your plan here>
#E1: <toolname>[<input here>] (eg. Search[What is Python])
#Plan2: <describe next plan>
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
And so on...
##Your Task##
{task}
##Now Begin##
"""
)
one_shot_planner_prompt = PromptTemplate(
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
##Available Tools##
{tool_description}
##Output Format##
#Plan1: <describe your plan here>
#E1: <toolname>[<input here>]
#Plan2: <describe next plan>
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
And so on...
##Example##
Task: What is the 4th root of 64 to the power of 3?
#Plan1: Find the 4th root of 64
#E1: Calculator[64^(1/4)]
#Plan2: Raise the result from #Plan1 to the power of 3
#E2: Calculator[#E1^3]
##Your Task##
{task}
##Now Begin##
"""
)
few_shot_planner_prompt = PromptTemplate(
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
##Available Tools##
{tool_description}
##Output Format (Replace '<...>')##
#Plan1: <describe your plan here>
#E1: <toolname>[<input>]
#Plan2: <describe next plan>
#E2: <toolname>[<input, you can use #E1 to represent its expected output>]
And so on...
##Examples##
{fewshot}
##Your Task##
{task}
##Now Begin##
"""
)
zero_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task.
##My Plans and Evidences##
{plan_evidence}
##Example Output##
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
So, <your conclusion>.
##Your Task##
{task}
##Now Begin##
"""
)
few_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task.
##My Plans and Evidences##
{plan_evidence}
##Example Output##
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
So, <your conclusion>.
##Example##
{fewshot}
##Your Task##
{task}
##Now Begin##
"""
)

View File

@@ -0,0 +1,65 @@
from typing import Any, List, Optional, Union
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent
from kotaemon.llms import BaseLLM, PromptTemplate
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
class Solver(BaseComponent):
model: BaseLLM
prompt_template: Optional[PromptTemplate] = None
examples: Optional[Union[str, List[str]]] = None
def _compose_fewshot_prompt(self) -> str:
if self.examples is None:
return ""
if isinstance(self.examples, str):
return self.examples
else:
return "\n\n".join([e.strip("\n") for e in self.examples])
def _compose_prompt(self, instruction, plan_evidence) -> str:
"""
Compose the prompt from template, plan&evidence, examples and instruction.
"""
fewshot = self._compose_fewshot_prompt()
if self.prompt_template is not None:
if "fewshot" in self.prompt_template.placeholders:
return self.prompt_template.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
)
else:
return self.prompt_template.populate(
plan_evidence=plan_evidence, task=instruction
)
else:
if self.examples is not None:
return few_shot_solver_prompt.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
)
else:
return zero_shot_solver_prompt.populate(
plan_evidence=plan_evidence, task=instruction
)
def run(
self,
instruction: str,
plan_evidence: str,
output: BaseScratchPad = BaseScratchPad(),
) -> Any:
response = None
output.info("Running Solver")
output.debug(f"Instruction: {instruction}")
output.debug(f"Plan Evidence: {plan_evidence}")
prompt = self._compose_prompt(instruction, plan_evidence)
output.debug(f"Prompt: {prompt}")
try:
response = self.model(prompt)
output.info("Solver run successful.")
except ValueError:
output.error("Solver failed to retrieve response from LLM")
return response

View File

@@ -0,0 +1,6 @@
from .base import BaseTool, ComponentTool
from .google import GoogleSearchTool
from .llm import LLMTool
from .wikipedia import WikipediaTool
__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool", "LLMTool"]

View File

@@ -0,0 +1,138 @@
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from kotaemon.base import BaseComponent
from langchain.agents import Tool as LCTool
from pydantic import BaseModel
class ToolException(Exception):
"""An optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working,
but will handle the exception according to the handle_tool_error
variable of the tool, and the processing result will be returned
to the agent as observation, and printed in red on the console.
"""
class BaseTool(BaseComponent):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Description used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description. This will be
input to the prompt of LLM.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
verbose: bool = False
"""Whether to log the tool's progress."""
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown."""
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
args_schema = self.args_schema
if isinstance(tool_input, str):
if args_schema is not None:
key_ = next(iter(args_schema.model_fields.keys()))
args_schema.validate({key_: tool_input})
return tool_input
else:
if args_schema is not None:
result = args_schema.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
def _run_tool(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Call tool."""
raise NotImplementedError(f"_run_tool is not implemented for {self.name}")
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def _handle_tool_error(self, e: ToolException) -> Any:
"""Handle the content of the ToolException thrown."""
observation = None
if not self.handle_tool_error:
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
return observation
def to_langchain_format(self) -> LCTool:
"""Convert this tool to Langchain format to use with its agent"""
return LCTool(name=self.name, description=self.description, func=self.run)
def run(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
# TODO (verbose_): Add logging
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
call_kwargs = {**kwargs, **tool_kwargs}
observation = self._run_tool(*tool_args, **call_kwargs)
except ToolException as e:
observation = self._handle_tool_error(e)
return observation
else:
return observation
@classmethod
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
"""Wrapper for Langchain Tool"""
new_tool = BaseTool(
name=langchain_tool.name, description=langchain_tool.description
)
new_tool._run_tool = langchain_tool._run # type: ignore
return new_tool
class ComponentTool(BaseTool):
"""Wrapper around other BaseComponent to use it as a tool
Args:
component: BaseComponent-based component to wrap
postprocessor: Optional postprocessor for the component output
"""
component: BaseComponent
postprocessor: Optional[Callable] = None
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
output = self.component(*args, **kwargs)
if self.postprocessor:
output = self.postprocessor(output)
return output

View File

@@ -0,0 +1,51 @@
from typing import AnyStr, Optional, Type
from langchain.utilities import SerpAPIWrapper
from pydantic import BaseModel, Field
from .base import BaseTool
class GoogleSearchArgs(BaseModel):
query: str = Field(..., description="a search query")
class GoogleSearchTool(BaseTool):
name: str = "google_search"
description: str = (
"A search engine retrieving top search results as snippets from Google. "
"Input should be a search query."
)
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
def _run_tool(self, query: AnyStr) -> str:
try:
from googlesearch import search
except ImportError:
raise ImportError(
"install googlesearch using `pip3 install googlesearch-python` to "
"use this tool"
)
output = ""
search_results = search(query, advanced=True)
if search_results:
output = "\n".join(
"{} {}".format(item.title, item.description) for item in search_results
)
return output
class SerpTool(BaseTool):
name = "google_search"
description = (
"Worker that searches results from Google. Useful when you need to find short "
"and succinct answers about a specific topic. Input should be a search query."
)
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
def _run_tool(self, query: AnyStr) -> str:
tool = SerpAPIWrapper()
evidence = tool.run(query)
return evidence

View File

@@ -0,0 +1,31 @@
from typing import AnyStr, Optional, Type
from kotaemon.llms import BaseLLM
from pydantic import BaseModel, Field
from .base import BaseTool, ToolException
class LLMArgs(BaseModel):
query: str = Field(..., description="a search question or prompt")
class LLMTool(BaseTool):
name: str = "llm"
description: str = (
"A pretrained LLM like yourself. Useful when you need to act with "
"general world knowledge and common sense. Prioritize it when you "
"are confident in solving the problem "
"yourself. Input can be any instruction."
)
llm: BaseLLM
args_schema: Optional[Type[BaseModel]] = LLMArgs
def _run_tool(self, query: AnyStr) -> str:
output = None
try:
response = self.llm(query)
except ValueError:
raise ToolException("LLM Tool call failed")
output = response.text
return output

View File

@@ -0,0 +1,65 @@
from typing import Any, AnyStr, Optional, Type, Union
from kotaemon.base import Document
from pydantic import BaseModel, Field
from .base import BaseTool
class Wiki:
"""Wrapper around wikipedia API."""
def __init__(self) -> None:
"""Check that wikipedia package is installed."""
try:
import wikipedia # noqa: F401
except ImportError:
raise ValueError(
"Could not import wikipedia python package. "
"Please install it with `pip install wikipedia`."
)
def search(self, search: str) -> Union[str, Document]:
"""Try to search for wiki page.
If page exists, return the page summary, and a PageWithLookups object.
If page does not exist, return similar entries.
"""
import wikipedia
try:
page_content = wikipedia.page(search).content
url = wikipedia.page(search).url
result: Union[str, Document] = Document(
text=page_content, metadata={"page": url}
)
except wikipedia.PageError:
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
except wikipedia.DisambiguationError:
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
return result
class WikipediaArgs(BaseModel):
query: str = Field(..., description="a search query as input to wkipedia")
class WikipediaTool(BaseTool):
"""Tool that adds the capability to query the Wikipedia API."""
name: str = "wikipedia"
description: str = (
"Search engine from Wikipedia, retrieving relevant wiki page. "
"Useful when you need to get holistic knowledge about people, "
"places, companies, historical events, or other subjects. "
"Input should be a search query."
)
args_schema: Optional[Type[BaseModel]] = WikipediaArgs
doc_store: Any = None
def _run_tool(self, query: AnyStr) -> AnyStr:
if not self.doc_store:
self.doc_store = Wiki()
tool = self.doc_store
evidence = tool.search(query)
return evidence

View File

@@ -0,0 +1,22 @@
from kotaemon.base import Document
def get_plugin_response_content(output) -> str:
"""
Wrapper for AgentOutput content return
"""
if isinstance(output, Document):
return output.text
else:
return str(output)
def calculate_cost(model_name: str, prompt_token: int, completion_token: int) -> float:
"""
Calculate the cost of a prompt and completion.
Returns:
float: Cost of the provided model name with provided token information
"""
# TODO: to be implemented
return 0.0

View File

@@ -0,0 +1,28 @@
from .component import BaseComponent, Node, Param, lazy
from .schema import (
AIMessage,
BaseMessage,
Document,
DocumentWithEmbedding,
ExtractorOutput,
HumanMessage,
LLMInterface,
RetrievedDocument,
SystemMessage,
)
__all__ = [
"BaseComponent",
"Document",
"DocumentWithEmbedding",
"BaseMessage",
"SystemMessage",
"AIMessage",
"HumanMessage",
"RetrievedDocument",
"LLMInterface",
"ExtractorOutput",
"Param",
"Node",
"lazy",
]

View File

@@ -0,0 +1,42 @@
from abc import abstractmethod
from typing import Iterator
from kotaemon.base.schema import Document
from theflow import Function, Node, Param, lazy
class BaseComponent(Function):
"""A component is a class that can be used to compose a pipeline.
!!! tip "Benefits of component"
- Auto caching, logging
- Allow deployment
!!! tip "For each component, the spirit is"
- Tolerate multiple input types, e.g. str, Document, List[str], List[Document]
- Enforce single output type. Hence, the output type of a component should be
as generic as possible.
"""
inflow = None
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
return self.__call__(self.inflow.flow())
@abstractmethod
def run(
self, *args, **kwargs
) -> Document | list[Document] | Iterator[Document] | None:
"""Run the component."""
...
__all__ = ["BaseComponent", "Param", "Node", "lazy"]

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, TypeVar
from langchain.schema.messages import AIMessage as LCAIMessage
from langchain.schema.messages import HumanMessage as LCHumanMessage
from langchain.schema.messages import SystemMessage as LCSystemMessage
from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon"
class Document(BaseDocument):
"""
Base document class, mostly inherited from Document class from llama-index.
This class accept one positional argument `content` of an arbitrary type, which will
store the raw content of the document. If specified, the class will use
`content` to initialize the base llama_index class.
Attributes:
content: raw content of the document, can be anything
source: id of the source of the Document. Optional.
"""
content: Any
source: Optional[str] = None
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:
if kwargs.get("text", None) is not None:
kwargs["content"] = kwargs["text"]
elif kwargs.get("embedding", None) is not None:
kwargs["content"] = kwargs["embedding"]
# default text indicating this document only contains embedding
kwargs["text"] = "<EMBEDDING>"
elif isinstance(content, Document):
kwargs = content.dict()
else:
kwargs["content"] = content
if content:
kwargs["text"] = str(content)
else:
kwargs["text"] = ""
super().__init__(*args, **kwargs)
def __bool__(self):
return bool(self.content)
@classmethod
def example(cls) -> "Document":
document = Document(
text=SAMPLE_TEXT,
metadata={"filename": "README.md", "category": "codebase"},
)
return document
def to_haystack_format(self) -> "HaystackDocument":
"""Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
metadata = self.metadata or {}
text = self.text
return HaystackDocument(content=text, meta=metadata)
def __str__(self):
return str(self.content)
class DocumentWithEmbedding(Document):
"""Subclass of Document which must contains embedding
Use this if you want to enforce component's IOs to must contain embedding.
"""
def __init__(self, embedding: list[float], *args, **kwargs):
kwargs["embedding"] = embedding
super().__init__(*args, **kwargs)
class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError
class SystemMessage(BaseMessage, LCSystemMessage):
pass
class AIMessage(BaseMessage, LCAIMessage):
pass
class HumanMessage(BaseMessage, LCHumanMessage):
pass
class RetrievedDocument(Document):
"""Subclass of Document with retrieval-related information
Attributes:
score (float): score of the document (from 0.0 to 1.0)
retrieval_metadata (dict): metadata from the retrieval process, can be used
by different components in a retrieved pipeline to communicate with each
other
"""
score: float = Field(default=0.0)
retrieval_metadata: dict = Field(default={})
class LLMInterface(AIMessage):
candidates: list[str] = Field(default_factory=list)
completion_tokens: int = -1
total_tokens: int = -1
prompt_tokens: int = -1
total_cost: float = 0
logits: list[list[float]] = Field(default_factory=list)
messages: list[AIMessage] = Field(default_factory=list)
class ExtractorOutput(Document):
"""
Represents the output of an extractor.
"""
matches: list[str]

View File

@@ -0,0 +1,4 @@
from .base import BaseChatBot, ChatConversation
from .simple_respondent import SimpleRespondentChatbot
__all__ = ["BaseChatBot", "SimpleRespondentChatbot", "ChatConversation"]

View File

@@ -0,0 +1,114 @@
from abc import abstractmethod
from typing import List, Optional
from kotaemon.base import BaseComponent, LLMInterface
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from theflow import SessionFunction
class BaseChatBot(BaseComponent):
@abstractmethod
def run(self, messages: List[BaseMessage]) -> LLMInterface:
...
def session_chat_storage(obj):
"""Store using the bot location rather than the session location"""
return obj._store_result
class ChatConversation(SessionFunction):
"""Base implementation of a chat bot component
A chatbot component should:
- handle internal state, including history messages
- return output for a given input
"""
class Config:
store_result = session_chat_storage
system_message: str = ""
bot: BaseChatBot
def __init__(self, *args, **kwargs):
self._history: List[BaseMessage] = []
self._store_result = (
f"{self.__module__}.{self.__class__.__name__},uninitiated_bot"
)
super().__init__(*args, **kwargs)
def run(self, message: HumanMessage) -> Optional[BaseMessage]:
"""Chat, given a message, return a response
Args:
message: The message to respond to
Returns:
The response to the message. If None, no response is sent.
"""
user_message = (
HumanMessage(content=message) if isinstance(message, str) else message
)
self.history.append(user_message)
output = self.bot(self.history).text
output_message = None
if output is not None:
output_message = AIMessage(content=output)
self.history.append(output_message)
return output_message
def start_session(self):
self._store_result = self.bot.config.store_result
super().start_session()
if not self.history and self.system_message:
system_message = SystemMessage(content=self.system_message)
self.history.append(system_message)
def end_session(self):
super().end_session()
self._history = []
def check_end(
self,
history: Optional[List[BaseMessage]] = None,
user_message: Optional[HumanMessage] = None,
bot_message: Optional[AIMessage] = None,
) -> bool:
"""Check if a conversation should end"""
if user_message is not None and user_message.content == "":
return True
return False
def terminal_session(self):
"""Create a terminal session"""
self.start_session()
print(">> Start chat:")
while True:
human = HumanMessage(content=input("Human: "))
if self.check_end(history=self.history, user_message=human):
break
output = self(human)
if output is None:
print("AI: <No response>")
else:
print("AI:", output.content)
if self.check_end(history=self.history, bot_message=output):
break
self.end_session()
@property
def history(self):
return self._history
@history.setter
def history(self, value):
self._history = value
self._variablex()

View File

@@ -0,0 +1,11 @@
from ..llms import ChatLLM
from .base import BaseChatBot
class SimpleRespondentChatbot(BaseChatBot):
"""Simple text respondent chatbot that essentially wraps around a chat LLM"""
llm: ChatLLM
def _get_message(self) -> str:
return self.llm(self.history).text

View File

@@ -0,0 +1,189 @@
import os
import click
import yaml
from trogon import tui
# check if the output is not a .yml file -> raise error
def check_config_format(config):
if os.path.exists(config):
if isinstance(config, str):
with open(config) as f:
yaml.safe_load(f)
else:
raise ValueError("config must be yaml format.")
@tui(command="ui", help="Open the terminal UI") # generate the terminal UI
@click.group()
def main():
pass
@click.group()
def promptui():
pass
main.add_command(promptui)
@promptui.command()
@click.argument("export_path", nargs=1)
@click.option("--output", default="promptui.yml", show_default=True, required=False)
def export(export_path, output):
"""Export a pipeline to a config file"""
import sys
from kotaemon.contribs.promptui.config import export_pipeline_to_config
from theflow.utils.modules import import_dotted_string
sys.path.append(os.getcwd())
cls = import_dotted_string(export_path, safe=False)
export_pipeline_to_config(cls, output)
check_config_format(output)
@promptui.command()
@click.argument("run_path", required=False, default="promptui.yml")
@click.option(
"--share",
is_flag=True,
show_default=True,
default=False,
help="Share the app through Gradio. Requires --username to enable authentication.",
)
@click.option(
"--username",
required=False,
help=(
"Username for the user. If not provided, the promptui will not have "
"authentication."
),
)
@click.option(
"--password",
required=False,
help="Password for the user. If not provided, will be prompted.",
)
@click.option(
"--appname",
required=False,
help="The share app subdomain. Requires --share and --username",
)
@click.option(
"--port",
required=False,
help="Port to run the app. If not provided, will $GRADIO_SERVER_PORT (7860)",
)
def run(run_path, share, username, password, appname, port):
"""Run the UI from a config file
Examples:
\b
# Run with default config file
$ kh promptui run
\b
# Run with username and password supplied
$ kh promptui run --username admin --password password
\b
# Run with username and prompted password
$ kh promptui run --username admin
# Run and share to promptui
# kh promptui run --username admin --password password --share --appname hey \
--port 7861
"""
import sys
from kotaemon.contribs.promptui.ui import build_from_dict
sys.path.append(os.getcwd())
check_config_format(run_path)
demo = build_from_dict(run_path)
params: dict = {}
if username is not None:
if password is not None:
auth = (username, password)
else:
auth = (username, click.prompt("Password", hide_input=True))
params["auth"] = auth
port = int(port) if port else int(os.getenv("GRADIO_SERVER_PORT", "7860"))
params["server_port"] = port
if share:
if username is None:
raise ValueError(
"Username must be provided to enable authentication for sharing"
)
if appname:
from kotaemon.contribs.promptui.tunnel import Tunnel
tunnel = Tunnel(
appname=str(appname), username=str(username), local_port=port
)
url = tunnel.run()
print(f"App is shared at {url}")
else:
params["share"] = True
print("App is shared at Gradio")
demo.launch(**params)
@main.command()
@click.argument("module", required=True)
@click.option(
"--output", default="docs.md", required=False, help="The output markdown file"
)
@click.option(
"--separation-level", required=False, default=1, help="Organize markdown layout"
)
def makedoc(module, output, separation_level):
"""Make documentation for module `module`
Example:
\b
# Make component documentation for kotaemon library
$ kh makedoc kotaemon
"""
from kotaemon.contribs.docs import make_doc
make_doc(module, output, separation_level)
print(f"Documentation exported to {output}")
@main.command()
@click.option(
"--template",
default="project-default",
required=False,
help="Template name",
show_default=True,
)
def start_project(template):
"""Start a project from a template.
Important: the value for --template corresponds to the name of the template folder,
which is located at https://github.com/Cinnamon/kotaemon/tree/main/templates
The default value is "project-default", which should work when you are starting a
client project.
"""
print("Retrieving template...")
os.system(
"cookiecutter git@github.com:Cinnamon/kotaemon.git "
f"--directory='templates/{template}'"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,66 @@
import inspect
from collections import defaultdict
from theflow.utils.documentation import get_function_documentation_from_module
def from_definition_to_markdown(definition: dict) -> str:
"""From definition to markdown"""
# Handle params
params = " N/A\n"
if definition["params"]:
params = "\n| Name | Description | Type | Default |\n"
params += "| --- | --- | --- | --- |\n"
for name, p in definition["params"].items():
type_ = p["type"].__name__ if inspect.isclass(p["type"]) else p["type"]
params += f"| {name} | {p['desc']} | {type_} | {p['default']} |\n"
# Handle nodes
nodes = " N/A\n"
if definition["nodes"]:
nodes = "\n| Name | Description | Type | Input | Output |\n"
nodes += "| --- | --- | --- | --- | --- |\n"
for name, n in definition["nodes"].items():
type_ = n["type"].__name__ if inspect.isclass(n["type"]) else str(n["type"])
input_ = (
n["input"].__name__ if inspect.isclass(n["input"]) else str(n["input"])
)
output_ = (
n["output"].__name__
if inspect.isclass(n["output"])
else str(n["output"])
)
nodes += f"|{name}|{n['desc']}|{type_}|{input_}|{output_}|\n"
description = inspect.cleandoc(definition["desc"])
return f"{description}\n\n_**Params:**_{params}\n_**Nodes:**_{nodes}"
def make_doc(module: str, output: str, separation_level: int):
"""Run exporting components to markdown
Args:
module (str): module name
output_path (str): output path to save
separation_level (int): level of separation
"""
documentation = sorted(
get_function_documentation_from_module(module).items(), key=lambda x: x[0]
)
entries = defaultdict(list)
for name, definition in documentation:
section = name.split(".")[separation_level].capitalize()
cls_name = name.split(".")[-1]
markdown = from_definition_to_markdown(definition)
entries[section].append(f"### {cls_name}\n{markdown}")
final = "\n".join(
[f"## {section}\n" + "\n".join(entries[section]) for section in entries]
)
with open(output, "w") as f:
f.write(final)

View File

@@ -0,0 +1 @@
/frpc_*

View File

@@ -0,0 +1,43 @@
import gradio as gr
COMPONENTS_CLASS = {
"text": gr.components.Textbox,
"checkbox": gr.components.CheckboxGroup,
"dropdown": gr.components.Dropdown,
"file": gr.components.File,
"image": gr.components.Image,
"number": gr.components.Number,
"radio": gr.components.Radio,
"slider": gr.components.Slider,
}
SUPPORTED_COMPONENTS = set(COMPONENTS_CLASS.keys())
DEFAULT_COMPONENT_BY_TYPES = {
"str": "text",
"bool": "checkbox",
"int": "number",
"float": "number",
"list": "dropdown",
}
def get_component(component_def: dict) -> gr.components.Component:
"""Get the component based on component definition"""
component_cls = None
if "component" in component_def:
component = component_def["component"]
if component not in SUPPORTED_COMPONENTS:
raise ValueError(
f"Unsupported UI component: {component}. "
f"Must be one of {SUPPORTED_COMPONENTS}"
)
component_cls = COMPONENTS_CLASS[component]
else:
raise ValueError(
f"Cannot decide the component from {component_def}. "
"Please specify `component` with 1 of the following "
f"values: {SUPPORTED_COMPONENTS}"
)
return component_cls(**component_def.get("params", {}))

View File

@@ -0,0 +1 @@
"""CLI commands that can be imported by the kotaemon.cli module"""

View File

@@ -0,0 +1,182 @@
"""Get config from Pipeline"""
import inspect
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
import yaml
from kotaemon.base import BaseComponent
from kotaemon.chatbot import BaseChatBot
from .base import DEFAULT_COMPONENT_BY_TYPES
def config_from_value(value: Any) -> dict:
"""Get the config from default value
Args:
value (Any): default value
Returns:
dict: config
"""
component = DEFAULT_COMPONENT_BY_TYPES.get(type(value).__name__, "text")
return {
"component": component,
"params": {
"value": value,
},
}
def handle_param(param: dict) -> dict:
"""Convert param definition into promptui-compliant config
Supported gradio's UI components are (https://www.gradio.app/docs/components)
- CheckBoxGroup: list (multi select)
- DropDown: list (single select)
- File
- Image
- Number: int / float
- Radio: list (single select)
- Slider: int / float
- TextBox: str
"""
params = {}
default = param.get("default", None)
if isinstance(default, str) and default.startswith("{{") and default.endswith("}}"):
default = None
if default is not None:
params["value"] = default
ui_component = param.get("component_ui", "")
if not ui_component:
type_: str = type(default).__name__ if default is not None else ""
ui_component = DEFAULT_COMPONENT_BY_TYPES.get(type_, "text")
return {
"component": ui_component,
"params": params,
}
def handle_node(node: dict) -> dict:
"""Convert node definition into promptui-compliant config"""
config = {}
for name, param_def in node.get("params", {}).items():
if isinstance(param_def["auto_callback"], str):
continue
if param_def.get("ignore_ui", False):
continue
config[name] = handle_param(param_def)
for name, node_def in node.get("nodes", {}).items():
if isinstance(node_def["auto_callback"], str):
continue
if node_def.get("ignore_ui", False):
continue
for key, value in handle_node(node_def["default"]).items():
config[f"{name}.{key}"] = value
for key, value in node_def.get("default_kwargs", {}).items():
config[f"{name}.{key}"] = config_from_value(value)
return config
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
"""Get the input from the pipeline"""
signature = inspect.signature(pipeline.run)
inputs: Dict[str, Dict] = {}
for name, param in signature.parameters.items():
if name in ["self", "args", "kwargs"]:
continue
input_def: Dict[str, Optional[Any]] = {"component": "text"}
default = param.default
if default is param.empty:
inputs[name] = input_def
continue
params = {}
params["value"] = default
type_ = type(default).__name__ if default is not None else None
ui_component = None
if type_ is not None:
ui_component = "text"
input_def["component"] = ui_component
input_def["params"] = params
inputs[name] = input_def
return inputs
def export_pipeline_to_config(
pipeline: Union[BaseComponent, Type[BaseComponent]],
path: Optional[str] = None,
) -> dict:
"""Export a pipeline to a promptui-compliant config dict"""
if inspect.isclass(pipeline):
pipeline = pipeline()
pipeline_def = pipeline.describe()
ui_type = "chat" if isinstance(pipeline, BaseChatBot) else "simple"
if ui_type == "chat":
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
params["system_message"] = {"component": "text", "params": {"value": ""}}
outputs = []
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj: dict = {
"ui-type": ui_type,
"params": params,
"inputs": {},
"outputs": outputs,
"logs": {
"full_pipeline": {
"input": {
"step": ".",
"getter": "_get_input",
},
"output": {
"step": ".",
"getter": "_get_output",
},
"preference": {
"step": "preference",
},
}
},
}
else:
outputs = [{"step": ".", "getter": "_get_output", "component": "text"}]
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj = {
"ui-type": ui_type,
"params": handle_node(pipeline_def),
"inputs": handle_input(pipeline),
"outputs": outputs,
"logs": {
"full_pipeline": {
"input": {
"step": ".",
"getter": "_get_input",
},
"output": {
"step": ".",
"getter": "_get_output",
},
},
},
}
config = {f"{pipeline.__module__}.{pipeline.__class__.__name__}": config_obj}
if path is not None:
old_config = config
if Path(path).is_file():
with open(path) as f:
old_config = yaml.safe_load(f)
old_config.update(config)
with open(path, "w") as f:
yaml.safe_dump(old_config, f, sort_keys=False)
return config

View File

@@ -0,0 +1,140 @@
"""Export logs into Excel file"""
import os
import pickle
from pathlib import Path
from typing import Any, Dict, List, Type, Union
import pandas as pd
import yaml
from kotaemon.base import BaseComponent
from theflow.storage import storage
from theflow.utils.modules import import_dotted_string
from .logs import ResultLog
def from_log_to_dict(pipeline_cls: Type[BaseComponent], log_config: dict) -> dict:
"""Export the log to panda dataframes
Args:
pipeline_cls (Type[BaseComponent]): Pipeline class
log_config (dict): Log config
Returns:
dataframe
"""
# get the directory
pipeline_log_path = storage.url(pipeline_cls().config.store_result)
dirs = list(sorted([f.path for f in os.scandir(pipeline_log_path) if f.is_dir()]))
# get resultlog callback
resultlog = getattr(pipeline_cls, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
ids = []
params: Dict[str, List[Any]] = {}
logged_infos: Dict[str, List[Any]] = {}
for idx, each_dir in enumerate(dirs):
ids.append(str(Path(each_dir).name))
# get the params
params_file = os.path.join(each_dir, "params.pkl")
if os.path.exists(params_file):
with open(params_file, "rb") as f:
each_params = pickle.load(f)
for key, value in each_params.items():
if key not in params:
params[key] = [None] * len(dirs)
params[key][idx] = value
# get the progress
progress_file = os.path.join(each_dir, "progress.pkl")
if os.path.exists(progress_file):
with open(progress_file, "rb") as f:
progress = pickle.load(f)
for name, col_info in log_config.items():
step = col_info["step"]
getter = col_info.get("getter", None)
if name not in logged_infos:
logged_infos[name] = [None] * len(dirs)
if step not in progress:
continue
info = progress[step]
if getter:
if getter in allowed_resultlog_callbacks:
info = getattr(resultlog, getter)(info)
else:
implicit_name = f"get_{name}"
if implicit_name in allowed_resultlog_callbacks:
info = getattr(resultlog, implicit_name)(info)
logged_infos[name][idx] = info
return {"ids": ids, **params, **logged_infos}
def export(config: dict, pipeline_def, output_path):
"""Export from config to Excel file"""
pipeline_name = f"{pipeline_def.__module__}.{pipeline_def.__name__}"
# export to Excel
if not config.get("logs", {}):
raise ValueError(f"Pipeline {pipeline_name} has no logs to export")
pds: Dict[str, pd.DataFrame] = {}
for log_name, log_def in config["logs"].items():
pds[log_name] = pd.DataFrame(from_log_to_dict(pipeline_def, log_def))
# from the list of pds, export to Excel to output_path
with pd.ExcelWriter(output_path, engine="openpyxl") as writer: # type: ignore
for log_name, df in pds.items():
df.to_excel(writer, sheet_name=log_name)
def export_from_dict(
config: Union[str, dict],
pipeline: Union[str, Type[BaseComponent]],
output_path: str,
):
"""CLI to export the logs of a pipeline into Excel file
Args:
config_path (str): Path to the config file
pipeline_name (str): Name of the pipeline
output_path (str): Path to the output Excel file
"""
# get the pipeline class and the relevant config dict
config_dict: dict
if isinstance(config, str):
with open(config) as f:
config_dict = yaml.safe_load(f)
elif isinstance(config, dict):
config_dict = config
else:
raise TypeError(f"`config` must be str or dict, not {type(config)}")
pipeline_name: str
pipeline_cls: Type[BaseComponent]
pipeline_config: dict
if isinstance(pipeline, str):
if pipeline not in config_dict:
raise ValueError(f"Pipeline {pipeline} not found in config file")
pipeline_name = pipeline
pipeline_cls = import_dotted_string(pipeline, safe=False)
pipeline_config = config_dict[pipeline]
elif isinstance(pipeline, type) and issubclass(pipeline, BaseComponent):
pipeline_name = f"{pipeline.__module__}.{pipeline.__name__}"
if pipeline_name not in config_dict:
raise ValueError(f"Pipeline {pipeline_name} not found in config file")
pipeline_cls = pipeline
pipeline_config = config_dict[pipeline_name]
else:
raise TypeError(
f"`pipeline` must be str or subclass of BaseComponent, not {type(pipeline)}"
)
export(pipeline_config, pipeline_cls, output_path)

View File

@@ -0,0 +1,16 @@
class ResultLog:
"""Callback getter to get the desired log result
The callback resolution will be as follow:
1. Explicit string name
2. Implicitly by: `get_<name>`
3. Pass through
"""
@staticmethod
def _get_input(obj):
return obj["input"]
@staticmethod
def _get_output(obj):
return obj["output"]

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
from typing import Iterable
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
class John(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.neutral,
secondary_hue: colors.Color | str = colors.neutral,
neutral_hue: colors.Color | str = colors.neutral,
spacing_size: sizes.Size | str = sizes.spacing_sm,
radius_size: sizes.Size | str = sizes.radius_none,
text_size: sizes.Size | str = sizes.text_sm,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Quicksand"),
"ui-sans-serif",
"system-ui",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"Consolas",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
self.name = "monochrome"
super().set(
# Colors
slider_color="*neutral_900",
slider_color_dark="*neutral_500",
body_text_color="*neutral_900",
block_label_text_color="*body_text_color",
block_title_text_color="*body_text_color",
body_text_color_subdued="*neutral_700",
background_fill_primary_dark="*neutral_900",
background_fill_secondary_dark="*neutral_800",
block_background_fill_dark="*neutral_800",
input_background_fill_dark="*neutral_700",
# Button Colors
button_primary_background_fill="*neutral_900",
button_primary_background_fill_hover="*neutral_700",
button_primary_text_color="white",
button_primary_background_fill_dark="*neutral_600",
button_primary_background_fill_hover_dark="*neutral_600",
button_primary_text_color_dark="white",
button_secondary_background_fill=(
"linear-gradient(to bottom right, *neutral_100, *neutral_200)"
),
button_secondary_background_fill_hover=(
"linear-gradient(to bottom right, *neutral_100, *neutral_100)"
),
button_secondary_background_fill_dark=(
"linear-gradient(to bottom right, *neutral_600, *neutral_700)"
),
button_secondary_background_fill_hover_dark=(
"linear-gradient(to bottom right, *neutral_600, *neutral_600)"
),
button_cancel_background_fill="*button_primary_background_fill",
button_cancel_background_fill_hover="*button_primary_background_fill_hover",
button_cancel_text_color="*button_primary_text_color",
# Padding
checkbox_label_padding="*spacing_sm",
button_large_padding="*spacing_sm",
button_small_padding="*spacing_sm",
# Borders
block_border_width="0px",
block_border_width_dark="1px",
shadow_drop_lg="0 1px 4px 0 rgb(0 0 0 / 0.1)",
block_shadow="*shadow_drop_lg",
block_shadow_dark="none",
# Block Labels
block_title_text_weight="600",
block_label_text_weight="600",
block_label_text_size="*text_sm",
)

View File

@@ -0,0 +1,107 @@
import atexit
import logging
import os
import platform
import stat
import subprocess
from pathlib import Path
import requests
VERSION = "1.0"
machine = platform.machine()
if machine == "x86_64":
machine = "amd64"
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
EXTENSION = ".exe" if os.name == "nt" else ""
BINARY_URL = (
"some-endpoint.com"
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
)
BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"
BINARY_FOLDER = Path(__file__).parent
BINARY_PATH = f"{BINARY_FOLDER / BINARY_FILENAME}"
logger = logging.getLogger(__name__)
class Tunnel:
def __init__(self, appname, username, local_port):
self.proc = None
self.url = None
self.appname = appname
self.username = username
self.local_port = local_port
@staticmethod
def download_binary():
if not Path(BINARY_PATH).exists():
print("First time setting tunneling...")
resp = requests.get(BINARY_URL)
if resp.status_code == 404:
raise OSError(
f"Cannot set up a share link as this platform is incompatible. "
"Please create a GitHub issue with information about your "
f"platform: {platform.uname()}"
)
if resp.status_code == 403:
raise OSError(
"You do not have permission to setup the tunneling. Please "
"make sure that you are within Cinnamon VPN or within other "
"approved IPs. If this is new server, please contact @channel "
"at #llm-productization to add your IP address"
)
resp.raise_for_status()
# Save file data to local copy
with open(BINARY_PATH, "wb") as file:
file.write(resp.content)
st = os.stat(BINARY_PATH)
os.chmod(BINARY_PATH, st.st_mode | stat.S_IEXEC)
def run(self) -> str:
"""Setting up tunneling"""
if platform.system().lower() == "windows":
logger.warning("Tunneling is not fully supported on Windows.")
self.download_binary()
self.url = self._start_tunnel(BINARY_PATH)
return self.url
def kill(self):
if self.proc is not None:
print(f"Killing tunnel 127.0.0.1:{self.local_port} <> {self.url}")
self.proc.terminate()
self.proc = None
def _start_tunnel(self, binary: str) -> str:
command = [
binary,
"http",
"-l",
str(self.local_port),
"-i",
"127.0.0.1",
"--uc",
"--sd",
str(self.appname),
"-n",
str(self.appname + self.username),
"--server_addr",
"44.229.38.9:7000",
"--token",
"Wz807/DyC;#t;#/",
"--disable_log_color",
]
self.proc = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
atexit.register(self.kill)
return f"https://{self.appname}.promptui.dm.cinnamon.is"

View File

@@ -0,0 +1,45 @@
from typing import Union
import gradio as gr
import yaml
from theflow.utils.modules import import_dotted_string
from ..themes import John
from .chat import build_chat_ui
from .pipeline import build_pipeline_ui
def build_from_dict(config: Union[str, dict]):
"""Build a full UI from YAML config file"""
if isinstance(config, str):
with open(config) as f:
config_dict: dict = yaml.safe_load(f)
elif isinstance(config, dict):
config_dict = config
else:
raise ValueError(
f"config must be either a yaml path or a dict, got {type(config)}"
)
demos = []
for key, value in config_dict.items():
pipeline_def = import_dotted_string(key, safe=False)
if value["ui-type"] == "chat":
demos.append(build_chat_ui(value, pipeline_def).queue())
else:
demos.append(build_pipeline_ui(value, pipeline_def).queue())
if len(demos) == 1:
demo = demos[0]
else:
demo = gr.TabbedInterface(
demos,
tab_names=list(config_dict.keys()),
title="PromptUI from kotaemon",
analytics_enabled=False,
theme=John(),
)
demo.queue()
return demo

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
from typing import Any, AsyncGenerator
import anyio
from gradio import ChatInterface
from gradio.components import Component, get_component_instance
from gradio.events import on
from gradio.helpers import special_args
from gradio.routes import Request
class ChatBlock(ChatInterface):
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
- Show additional outputs to the chat interface
- Disallow blank user message
"""
def __init__(
self,
*args,
additional_outputs: str | Component | list[str | Component] | None = None,
**kwargs,
):
if additional_outputs:
if not isinstance(additional_outputs, list):
additional_outputs = [additional_outputs]
self.additional_outputs = [
get_component_instance(i) for i in additional_outputs # type: ignore
]
else:
self.additional_outputs = []
super().__init__(*args, **kwargs)
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> tuple[Any, ...]:
input_args = args[: -len(self.additional_outputs)]
output_args = args[-len(self.additional_outputs) :]
if not message:
return history_with_input, history_with_input, *output_args
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *input_args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
output = []
if self.additional_outputs:
text = response[0]
output = response[1:]
else:
text = response
history.append([message, text])
return history, history, *output
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
) -> AsyncGenerator:
raise NotImplementedError("Stream function not implemented for ChatBlock")
def _display_input(
self, message: str, history: list[list[str | None]]
) -> tuple[list[list[str | None]], list[list[str | None]]]:
"""Stop displaying the input message if the message is a blank string"""
if not message:
return history, history
return super()._display_input(message, history)
def _setup_events(self) -> None:
"""Include additional outputs in the submit event"""
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_triggers = (
[self.textbox.submit, self.submit_btn.click]
if self.submit_btn
else [self.textbox.submit]
)
submit_event = (
on(
submit_triggers,
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events(submit_triggers, submit_event)
if self.retry_btn:
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
lambda x: x,
[self.saved_input],
[self.textbox],
api_name=False,
queue=False,
)
if self.clear_btn:
self.clear_btn.click(
lambda: ([], [], None),
None,
[self.chatbot, self.chatbot_state, self.saved_input],
queue=False,
api_name=False,
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state] + self.additional_outputs,
api_name="chat",
)

View File

@@ -0,0 +1,308 @@
import pickle
from datetime import datetime
from pathlib import Path
import gradio as gr
from kotaemon.chatbot import ChatConversation
from kotaemon.contribs.promptui.base import get_component
from kotaemon.contribs.promptui.export import export
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
from theflow.storage import storage
from ..logs import ResultLog
USAGE_INSTRUCTION = """## How to use:
1. Set the desired parameters.
2. Click "New chat" to start a chat session with the supplied parameters. This
set of parameters will persist until the end of the chat session. During an
ongoing chat session, changing the parameters will not take any effect.
3. Chat and interact with the chat bot on the right panel. You can add any
additional input (if any), and they will be supplied to the chatbot.
4. During chat, the log of the chat will show up in the "Output" tabs. This is
empty by default, so if you want to show the log here, tell the AI developers
to configure the UI settings.
5. When finishing chat, select your preference in the radio box. Click "End chat".
This will save the chat log and the preference to disk.
6. To compare the result of different run, click "Export" to get an Excel
spreadsheet summary of different run.
## Support:
In case of errors, you can:
- PromptUI instruction:
https://github.com/Cinnamon/kotaemon/wiki/Utilities#prompt-engineering-ui
- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon
- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization
## Contribute:
- Follow installation at: https://github.com/Cinnamon/kotaemon/
"""
def construct_chat_ui(
config, func_new_chat, func_chat, func_end_chat, func_export_to_excel
) -> gr.Blocks:
"""Construct the prompt engineering UI for chat
Args:
config: the UI config
func_new_chat: the function for starting a new chat session
func_chat: the function for chatting interaction
func_end_chat: the function for ending and saving the chat
func_export_to_excel: the function to export the logs to excel
Returns:
the UI object
"""
inputs, outputs, params = [], [], []
for name, component_def in config.get("inputs", {}).items():
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = True
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = name # type: ignore
inputs.append(component)
for name, component_def in config.get("params", {}).items():
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = True
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = name # type: ignore
params.append(component)
for idx, component_def in enumerate(config.get("outputs", [])):
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = False
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = f"Output {idx}" # type: ignore
outputs.append(component)
sess = gr.State(value=None)
chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
chat = ChatBlock(
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
)
param_state = gr.Textbox(interactive=False)
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
sess.render()
with gr.Accordion(label="HOW TO", open=False):
gr.Markdown(USAGE_INSTRUCTION)
with gr.Row():
run_btn = gr.Button("New chat")
run_btn.click(
func_new_chat,
inputs=params,
outputs=[
chat.chatbot,
chat.chatbot_state,
chat.saved_input,
param_state,
sess,
*outputs,
],
)
with gr.Accordion(label="End chat", open=False):
likes = gr.Radio(["like", "dislike", "neutral"], value="neutral")
save_log = gr.Checkbox(
value=True,
label="Save log",
info="If saved, log can be exported later",
show_label=True,
)
end_btn = gr.Button("End chat")
end_btn.click(
func_end_chat,
inputs=[likes, save_log, sess],
outputs=[param_state, sess],
)
with gr.Accordion(label="Export", open=False):
exported_file = gr.File(
label="Output file", show_label=True, height=100
)
export_btn = gr.Button("Export")
export_btn.click(
func_export_to_excel, inputs=None, outputs=exported_file
)
with gr.Row():
with gr.Column():
with gr.Tab("Params"):
for component in params:
component.render()
with gr.Accordion(label="Session state", open=False):
param_state.render()
with gr.Tab("Outputs"):
for component in outputs:
component.render()
with gr.Column():
chat.render()
return demo.queue()
def build_chat_ui(config, pipeline_def):
"""Build the chat UI
Args:
config: the UI config
pipeline_def: the pipeline definition
Returns:
the UI object
"""
output_dir: Path = Path(storage.url(pipeline_def().config.store_result))
exported_dir = output_dir.parent / "exported"
exported_dir.mkdir(parents=True, exist_ok=True)
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
def new_chat(*args):
"""Start a new chat function
Args:
*args: the pipeline init params
Returns:
new empty states
"""
gr.Info("Starting new session...")
param_dicts = {
name: value for name, value in zip(config["params"].keys(), args)
}
for key in param_dicts.keys():
if config["params"][key].get("component").lower() == "file":
param_dicts[key] = param_dicts[key].name
# TODO: currently hard-code as ChatConversation
pipeline = pipeline_def()
session = ChatConversation(bot=pipeline)
session.set(param_dicts)
session.start_session()
param_state_str = "\n".join(
f"- {name}: {value}" for name, value in param_dicts.items()
)
gr.Info("New chat session started.")
return (
[],
[],
None,
param_state_str,
session,
*[None] * len(config.get("outputs", [])),
)
def chat(message, history, session, *args):
"""The chat interface
# TODO: wrap the input and output of this chat function so that it
work with more types of chat conversation than simple text
Args:
message: the message from the user
history: the gradio history of the chat
session: the chat object session
*args: the additional inputs
Returns:
the response from the chatbot
"""
if session is None:
raise gr.Error(
"No active chat session. Please set the params and click New chat"
)
pred = session(message)
text_response = pred.content
additional_outputs = []
for output_def in config.get("outputs", []):
value = session.last_run.logs(output_def["step"])
getter = output_def.get("getter", None)
if getter and getter in allowed_resultlog_callbacks:
value = getattr(resultlog, getter)(value)
additional_outputs.append(value)
return text_response, *additional_outputs
def end_chat(preference: str, save_log: bool, session):
"""End the chat session
Args:
preference: the preference of the user
save_log: whether to save the result
session: the chat object session
Returns:
the new empty state
"""
gr.Info("Ending session...")
session.end_session()
output_dir: Path = (
Path(storage.url(session.config.store_result)) / session.last_run.id()
)
if not save_log:
if output_dir.exists():
import shutil
shutil.rmtree(output_dir)
session = None
param_state = ""
gr.Info("End session without saving log.")
return param_state, session
# add preference result to progress
with (output_dir / "progress.pkl").open("rb") as fi:
progress = pickle.load(fi)
progress["preference"] = preference
with (output_dir / "progress.pkl").open("wb") as fo:
pickle.dump(progress, fo)
# get the original params
param_dicts = {name: session.getx(name) for name in config["params"].keys()}
with (output_dir / "params.pkl").open("wb") as fo:
pickle.dump(param_dicts, fo)
session = None
param_state = ""
gr.Info("End session and save log.")
return param_state, session
def export_func():
name = (
f"{pipeline_def.__module__}.{pipeline_def.__name__}_{datetime.now()}.xlsx"
)
path = str(exported_dir / name)
gr.Info(f"Begin exporting {name}...")
try:
export(config=config, pipeline_def=pipeline_def, output_path=path)
except Exception as e:
raise gr.Error(f"Failed to export. Please contact project's AIR: {e}")
gr.Info(f"Exported {name}. Please go to the `Exported file` tab to download")
return path
demo = construct_chat_ui(
config=config,
func_new_chat=new_chat,
func_chat=chat,
func_end_chat=end_chat,
func_export_to_excel=export_func,
)
return demo

View File

@@ -0,0 +1,245 @@
import pickle
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
import gradio as gr
import pandas as pd
from kotaemon.contribs.promptui.base import get_component
from kotaemon.contribs.promptui.export import export
from theflow.storage import storage
from ..logs import ResultLog
USAGE_INSTRUCTION = """## How to use:
1. Set the desired parameters.
2. Set the desired inputs.
3. Click "Run" to execute the pipeline with the supplied parameters and inputs
4. The pipeline output will show up in the output panel.
5. Repeat from step 1.
6. To compare the result of different run, click "Export" to get an Excel
spreadsheet summary of different run.
## Support:
In case of errors, you can:
- PromptUI instruction:
https://github.com/Cinnamon/kotaemon/wiki/Utilities#prompt-engineering-ui
- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon
- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization
## Contribute:
- Follow installation at: https://github.com/Cinnamon/kotaemon/
"""
def construct_pipeline_ui(
config, func_run, func_save, func_load_params, func_activate_params, func_export
) -> gr.Blocks:
"""Create UI from config file. Execute the UI from config file
- Can do now: Log from stdout to UI
- In the future, we can provide some hooks and callbacks to let developers better
fine-tune the UI behavior.
"""
inputs, outputs, params = [], [], []
for name, component_def in config.get("inputs", {}).items():
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = True
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = name # type: ignore
inputs.append(component)
for name, component_def in config.get("params", {}).items():
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = True
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = name # type: ignore
params.append(component)
for idx, component_def in enumerate(config.get("outputs", [])):
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = False
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = f"Output {idx}" # type: ignore
outputs.append(component)
exported_file = gr.File(label="Output file", show_label=True)
history_dataframe = gr.DataFrame(wrap=True)
temp = gr.Tab
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
with gr.Accordion(label="HOW TO", open=False):
gr.Markdown(USAGE_INSTRUCTION)
with gr.Accordion(label="Params History", open=False):
with gr.Row():
save_btn = gr.Button("Save params")
save_btn.click(func_save, inputs=params, outputs=history_dataframe)
load_params_btn = gr.Button("Reload params")
load_params_btn.click(
func_load_params, inputs=None, outputs=history_dataframe
)
history_dataframe.render()
history_dataframe.select(
func_activate_params, inputs=params, outputs=params
)
with gr.Row():
run_btn = gr.Button("Run")
run_btn.click(func_run, inputs=inputs + params, outputs=outputs)
export_btn = gr.Button(
"Export (Result will be in Exported file next to Output)"
)
export_btn.click(func_export, inputs=None, outputs=exported_file)
with gr.Row():
with gr.Column():
if params:
with temp("Params"):
for component in params:
component.render()
if inputs:
with temp("Inputs"):
for component in inputs:
component.render()
if not params and not inputs:
gr.Text("No params or inputs")
with gr.Column():
with temp("Outputs"):
for component in outputs:
component.render()
with temp("Exported file"):
exported_file.render()
return demo
def load_saved_params(path: str) -> Dict:
"""Load the saved params from path to a dataframe"""
# get all pickle files
files = list(sorted(Path(path).glob("*.pkl")))
data: Dict[str, Any] = {"_id": [None] * len(files)}
for idx, each_file in enumerate(files):
with open(each_file, "rb") as f:
each_data = pickle.load(f)
data["_id"][idx] = Path(each_file).stem
for key, value in each_data.items():
if key not in data:
data[key] = [None] * len(files)
data[key][idx] = value
return data
def build_pipeline_ui(config: dict, pipeline_def):
"""Build a tab from config file"""
inputs_name = list(config.get("inputs", {}).keys())
params_name = list(config.get("params", {}).keys())
outputs_def = config.get("outputs", [])
output_dir: Path = Path(storage.url(pipeline_def().config.store_result))
exported_dir = output_dir.parent / "exported"
exported_dir.mkdir(parents=True, exist_ok=True)
save_dir = (
output_dir.parent
/ "saved"
/ f"{pipeline_def.__module__}.{pipeline_def.__name__}"
)
save_dir.mkdir(parents=True, exist_ok=True)
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
def run_func(*args):
inputs = {
name: value for name, value in zip(inputs_name, args[: len(inputs_name)])
}
params = {
name: value for name, value in zip(params_name, args[len(inputs_name) :])
}
pipeline = pipeline_def()
pipeline.set(params)
pipeline(**inputs)
with storage.open(
storage.url(
pipeline.config.store_result, pipeline.last_run.id(), "params.pkl"
),
"wb",
) as f:
pickle.dump(params, f)
if outputs_def:
outputs = []
for output_def in outputs_def:
output = pipeline.last_run.logs(output_def["step"])
getter = output_def.get("getter", None)
if getter and getter in allowed_resultlog_callbacks:
output = getattr(resultlog, getter)(output)
outputs.append(output)
if len(outputs_def) == 1:
return outputs[0]
return outputs
def save_func(*args):
params = {name: value for name, value in zip(params_name, args)}
filename = save_dir / f"{int(time.time())}.pkl"
with open(filename, "wb") as f:
pickle.dump(params, f)
gr.Info("Params saved")
data = load_saved_params(str(save_dir))
return pd.DataFrame(data)
def load_params_func():
data = load_saved_params(str(save_dir))
return pd.DataFrame(data)
def activate_params_func(ev: gr.SelectData, *args):
data = load_saved_params(str(save_dir))
output_args = [each for each in args]
if ev.value is None:
gr.Info(f'Blank value: "{ev.value}". Skip')
return output_args
column = list(data.keys())[ev.index[1]]
if column not in params_name:
gr.Info(f'Column "{column}" not in params. Skip')
return output_args
value = data[column][ev.index[0]]
if value is None:
gr.Info(f'Blank value: "{ev.value}". Skip')
return output_args
output_args[params_name.index(column)] = value
return output_args
def export_func():
name = (
f"{pipeline_def.__module__}.{pipeline_def.__name__}_{datetime.now()}.xlsx"
)
path = str(exported_dir / name)
gr.Info(f"Begin exporting {name}...")
try:
export(config=config, pipeline_def=pipeline_def, output_path=path)
except Exception as e:
raise gr.Error(f"Failed to export. Please contact project's AIR: {e}")
gr.Info(f"Exported {name}. Please go to the `Exported file` tab to download")
return path
return construct_pipeline_ui(
config, run_func, save_func, load_params_func, activate_params_func, export_func
)

View File

@@ -0,0 +1,15 @@
from .base import BaseEmbeddings
from .langchain_based import (
AzureOpenAIEmbeddings,
CohereEmbdeddings,
HuggingFaceEmbeddings,
OpenAIEmbeddings,
)
__all__ = [
"BaseEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"CohereEmbdeddings",
"HuggingFaceEmbeddings",
]

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from abc import abstractmethod
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
class BaseEmbeddings(BaseComponent):
@abstractmethod
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[DocumentWithEmbedding]:
...

View File

@@ -0,0 +1,209 @@
from typing import Optional
from kotaemon.base import Document, DocumentWithEmbedding
from .base import BaseEmbeddings
class LCEmbeddingMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self._obj.embed_documents(input_)
return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
]
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**params,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's OpenAI embedding, focusing on key parameters"""
def __init__(
self,
model: str = "text-embedding-ada-002",
openai_api_version: Optional[str] = None,
openai_api_base: Optional[str] = None,
openai_api_type: Optional[str] = None,
openai_api_key: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
model=model,
openai_api_version=openai_api_version,
openai_api_base=openai_api_base,
openai_api_type=openai_api_type,
openai_api_key=openai_api_key,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
try:
from langchain_openai import OpenAIEmbeddings
except ImportError:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's AzureOpenAI embedding, focusing on key parameters"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
deployment: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_version: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
deployment=deployment,
openai_api_version=openai_api_version,
openai_api_key=openai_api_key,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
try:
from langchain_openai import AzureOpenAIEmbeddings
except ImportError:
from langchain.embeddings import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
def __init__(
self,
model: str = "embed-english-v2.0",
cohere_api_key: Optional[str] = None,
truncate: Optional[str] = None,
request_timeout: Optional[float] = None,
**params,
):
super().__init__(
model=model,
cohere_api_key=cohere_api_key,
truncate=truncate,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
try:
from langchain_community.embeddings import CohereEmbeddings
except ImportError:
from langchain.embeddings import CohereEmbeddings
return CohereEmbeddings
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
def __init__(
self,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
**params,
):
super().__init__(
model_name=model_name,
**params,
)
def _get_lc_class(self):
try:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
except ImportError:
from langchain.embeddings import HuggingFaceBgeEmbeddings
return HuggingFaceBgeEmbeddings

View File

@@ -0,0 +1,3 @@
from .vectorindex import VectorIndexing, VectorRetrieval
__all__ = ["VectorIndexing", "VectorRetrieval"]

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Type
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from llama_index.node_parser.interface import NodeParser
class DocTransformer(BaseComponent):
"""This is a base class for document transformers
A document transformer transforms a list of documents into another list
of documents. Transforming can mean splitting a document into multiple documents,
reducing a large list of documents into a smaller list of documents, or adding
metadata to each document in a list of documents, etc.
"""
@abstractmethod
def run(
self,
documents: list[Document],
**kwargs,
) -> list[Document]:
...
class LlamaIndexDocTransformerMixin:
"""Allow automatically wrapping a Llama-index component into kotaemon component
Example:
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
return TokenTextSplitter
To use this mixin, please:
1. Use this class as the 1st parent class, so that Python will prefer to use
the attributes and methods of this class whenever possible.
2. Overwrite `_get_li_class` to return the relevant LlamaIndex component.
"""
def _get_li_class(self) -> Type[NodeParser]:
raise NotImplementedError(
"Please return the relevant LlamaIndex class in _get_li_class"
)
def __init__(self, **params):
self._li_cls = self._get_li_class()
self._obj = self._li_cls(**params)
self._kwargs = params
super().__init__()
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
self._kwargs[name] = value
return setattr(self._obj, name, value)
def __getattr__(self, name: str) -> Any:
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**params,
}
def run(
self,
documents: list[Document],
**kwargs,
) -> list[Document]:
"""Run Llama-index node parser and convert the output to Document from
kotaemon
"""
docs = self._obj(documents, **kwargs) # type: ignore
return [Document.from_dict(doc.to_dict()) for doc in docs]
class BaseIndexing(BaseComponent):
"""Define the base interface for indexing pipeline"""
def to_retrieval_pipeline(self, **kwargs):
"""Convert the indexing pipeline to a retrieval pipeline"""
raise NotImplementedError
def to_qa_pipeline(self, **kwargs):
"""Convert the indexing pipeline to a QA pipeline"""
raise NotImplementedError
class BaseRetrieval(BaseComponent):
"""Define the base interface for retrieval pipeline"""
@abstractmethod
def run(self, *args, **kwargs) -> list[RetrievedDocument]:
...

View File

@@ -0,0 +1,7 @@
from .doc_parsers import BaseDocParser, SummaryExtractor, TitleExtractor
__all__ = [
"BaseDocParser",
"TitleExtractor",
"SummaryExtractor",
]

View File

@@ -0,0 +1,35 @@
from ..base import DocTransformer, LlamaIndexDocTransformerMixin
class BaseDocParser(DocTransformer):
...
class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
nodes: int = 5,
**params,
):
super().__init__(llm=llm, nodes=nodes, **params)
def _get_li_class(self):
from llama_index.extractors import TitleExtractor
return TitleExtractor
class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
summaries: list[str] = ["self"],
**params,
):
super().__init__(llm=llm, summaries=summaries, **params)
def _get_li_class(self):
from llama_index.extractors import SummaryExtractor
return SummaryExtractor

View File

@@ -0,0 +1,3 @@
from .files import DocumentIngestor
__all__ = ["DocumentIngestor"]

View File

@@ -0,0 +1,85 @@
from pathlib import Path
from kotaemon.base import BaseComponent, Document, Param
from kotaemon.indices.extractors import BaseDocParser
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from kotaemon.loaders import (
AutoReader,
DirectoryReader,
MathpixPDFReader,
OCRReader,
PandasExcelReader,
UnstructuredReader,
)
from llama_index.readers.base import BaseReader
class DocumentIngestor(BaseComponent):
"""Ingest common office document types into Document for indexing
Document types:
- pdf
- xlsx, xls
- docx, doc
Args:
pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr"
- normal: parse pdf text
- mathpix: parse pdf text using mathpix
- ocr: parse pdf image using flax
doc_parsers: list of document parsers to parse the document
text_splitter: splitter to split the document into text nodes
"""
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
text_splitter: BaseSplitter = TokenSplitter.withx(
chunk_size=1024,
chunk_overlap=256,
)
def _get_reader(self, input_files: list[str | Path]):
"""Get appropriate readers for the input files based on file extension"""
file_extractor: dict[str, AutoReader | BaseReader] = {
".xlsx": PandasExcelReader(),
".docx": UnstructuredReader(),
".xls": UnstructuredReader(),
".doc": UnstructuredReader(),
}
if self.pdf_mode == "normal":
file_extractor[".pdf"] = AutoReader("UnstructuredReader")
elif self.pdf_mode == "ocr":
file_extractor[".pdf"] = OCRReader()
else:
file_extractor[".pdf"] = MathpixPDFReader()
main_reader = DirectoryReader(
input_files=input_files,
file_extractor=file_extractor,
)
return main_reader
def run(self, file_paths: list[str | Path] | str | Path) -> list[Document]:
"""Ingest the file paths into Document
Args:
file_paths: list of file paths or a single file path
Returns:
list of parsed Documents
"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
documents = self._get_reader(input_files=file_paths)()
nodes = self.text_splitter(documents)
self.log_progress(".num_docs", num_docs=len(nodes))
# document parsers call
if self.doc_parsers:
for parser in self.doc_parsers:
nodes = parser(nodes)
return nodes

View File

@@ -0,0 +1,7 @@
from .citation import CitationPipeline
from .text_based import CitationQAPipeline
__all__ = [
"CitationPipeline",
"CitationQAPipeline",
]

View File

@@ -0,0 +1,101 @@
from typing import Iterator, List
from kotaemon.base import BaseComponent
from kotaemon.base.schema import HumanMessage, SystemMessage
from kotaemon.llms import BaseLLM
from pydantic import BaseModel, Field
class FactWithEvidence(BaseModel):
"""Class representing a single statement.
Each fact has a body and a list of sources.
If there are multiple facts make sure to break them apart
such that each one only uses a set of sources that are relevant to it.
"""
fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field(
...,
description=(
"Each source should be a direct quote from the context, "
"as a substring of the original content"
),
)
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
import regex
minor = quote
major = context
errs_ = 0
s = regex.search(f"({minor}){{e<={errs_}}}", major)
while s is None and errs_ <= errs:
errs_ += 1
s = regex.search(f"({minor}){{e<={errs_}}}", major)
if s is not None:
yield from s.spans()
def get_spans(self, context: str) -> Iterator[str]:
for quote in self.substring_quote:
yield from self._get_span(quote, context)
class QuestionAnswer(BaseModel):
"""A question and its answer as a list of facts each one should have a source.
each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field(
...,
description=(
"Body of the answer, each fact should be "
"its separate object with a body and a list of sources"
),
)
class CitationPipeline(BaseComponent):
"""Citation pipeline to extract cited evidences from source
(based on input question)"""
llm: BaseLLM
def run(self, context: str, question: str):
schema = QuestionAnswer.schema()
function = {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
llm_kwargs = {
"functions": [function],
"function_call": {"name": function["name"]},
}
messages = [
SystemMessage(
content=(
"You are a world class algorithm to answer "
"questions with correct and exact citations."
)
),
HumanMessage(content="Answer question using the following context"),
HumanMessage(content=context),
HumanMessage(content=f"Question: {question}"),
HumanMessage(
content=(
"Tips: Make sure to cite your sources, "
"and use the exact words from the context."
)
),
]
llm_output = self.llm(messages, **llm_kwargs)
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
return output

View File

@@ -0,0 +1,63 @@
import os
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
from .citation import CitationPipeline
class CitationQAPipeline(BaseComponent):
"""Answering question from a text corpus with citation"""
qa_prompt_template: PromptTemplate = PromptTemplate(
'Answer the following question: "{question}". '
"The context is: \n{context}\nAnswer: "
)
llm: BaseLLM = AzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-07-01-preview",
deployment_name="dummy-q2-16k",
temperature=0,
request_timeout=60,
)
citation_pipeline: CitationPipeline = Node(
default_callback=lambda self: CitationPipeline(llm=self.llm)
)
def _format_doc_text(self, text: str) -> str:
"""Format the text of each document"""
return text.replace("\n", " ")
def _format_retrieved_context(self, documents: list[RetrievedDocument]) -> str:
"""Format the texts between all documents"""
matched_texts: list[str] = [
self._format_doc_text(doc.text) for doc in documents
]
return "\n\n".join(matched_texts)
def run(
self,
question: str,
documents: list[RetrievedDocument],
use_citation: bool = False,
**kwargs
) -> Document:
# retrieve relevant documents as context
context = self._format_retrieved_context(documents)
self.log_progress(".context", context=context)
# generate the answer
prompt = self.qa_prompt_template.populate(
context=context,
question=question,
)
self.log_progress(".prompt", prompt=prompt)
answer_text = self.llm(prompt).text
if use_citation:
citation = self.citation_pipeline(context=context, question=question)
else:
citation = None
answer = Document(text=answer_text, metadata={"citation": citation})
return answer

View File

@@ -0,0 +1,5 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .llm import LLMReranking
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from abc import abstractmethod
from kotaemon.base import BaseComponent, Document
class BaseReranking(BaseComponent):
@abstractmethod
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Main method to transform list of documents
(re-ranking, filtering, etc)"""
...

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import os
from kotaemon.base import Document
from .base import BaseReranking
class CohereReranking(BaseReranking):
model_name: str = "rerank-multilingual-v2.0"
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
top_k: int = 1
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
with their relevance score"""
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)
cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []
if not documents: # to avoid empty api call
return compressed_docs
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from kotaemon.base import Document
from kotaemon.llms import BaseLLM, PromptTemplate
from langchain.output_parsers.boolean import BooleanOutputParser
from .base import BaseReranking
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
return YES if the context is relevant to the question and NO if it isn't.
> Question: {question}
> Context:
>>>
{context}
>>>
> Relevant (YES / NO):"""
class LLMReranking(BaseReranking):
llm: BaseLLM
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
top_k: int = 3
concurrent: bool = True
def run(
self,
documents: list[Document],
query: str,
) -> list[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
output_parser = BooleanOutputParser()
if self.concurrent:
with ThreadPoolExecutor() as executor:
futures = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
futures.append(executor.submit(lambda: self.llm(_prompt).text))
results = [future.result() for future in futures]
else:
results = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
results.append(self.llm(_prompt).text)
# use Boolean parser to extract relevancy output from LLM
results = [output_parser.parse(result) for result in results]
for include_doc, doc in zip(results, documents):
if include_doc:
filtered_docs.append(doc)
# prevent returning empty result
if len(filtered_docs) == 0:
filtered_docs = documents[: self.top_k]
return filtered_docs

View File

@@ -0,0 +1,49 @@
from ..base import DocTransformer, LlamaIndexDocTransformerMixin
class BaseSplitter(DocTransformer):
"""Represent base splitter class"""
...
class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 20,
separator: str = " ",
**params,
):
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
**params,
)
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
return TokenTextSplitter
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(
self,
window_size: int = 3,
window_metadata_key: str = "window",
original_text_metadata_key: str = "original_text",
**params,
):
super().__init__(
window_size=window_size,
window_metadata_key=window_metadata_key,
original_text_metadata_key=original_text_metadata_key,
**params,
)
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser
return SentenceWindowNodeParser

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import uuid
from typing import Optional, Sequence, cast
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseIndexing, BaseRetrieval
from .rankings import BaseReranking
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
class VectorIndexing(BaseIndexing):
"""Ingest the document, run through the embedding, and store the embedding in a
vector store.
This pipeline supports the following set of inputs:
- List of documents
- List of texts
"""
vector_store: BaseVectorStore
doc_store: Optional[BaseDocumentStore] = None
embedding: BaseEmbeddings
def to_retrieval_pipeline(self, *args, **kwargs):
"""Convert the indexing pipeline to a retrieval pipeline"""
return VectorRetrieval(
vector_store=self.vector_store,
doc_store=self.doc_store,
embedding=self.embedding,
**kwargs,
)
def to_qa_pipeline(self, *args, **kwargs):
from .qa import CitationQAPipeline
return TextVectorQA(
retrieving_pipeline=self.to_retrieval_pipeline(**kwargs),
qa_pipeline=CitationQAPipeline(**kwargs),
)
def run(self, text: str | list[str] | Document | list[Document]):
input_: list[Document] = []
if not isinstance(text, list):
text = [text]
for item in cast(list, text):
if isinstance(item, str):
input_.append(Document(text=item, id_=str(uuid.uuid4())))
elif isinstance(item, Document):
input_.append(item)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
embeddings = self.embedding(input_)
self.vector_store.add(
embeddings=embeddings,
ids=[t.doc_id for t in input_],
)
if self.doc_store:
self.doc_store.add(input_)
class VectorRetrieval(BaseRetrieval):
"""Retrieve list of documents from vector store"""
vector_store: BaseVectorStore
doc_store: Optional[BaseDocumentStore] = None
embedding: BaseEmbeddings
rerankers: Sequence[BaseReranking] = []
top_k: int = 1
def run(
self, text: str | Document, top_k: Optional[int] = None, **kwargs
) -> list[RetrievedDocument]:
"""Retrieve a list of documents from vector store
Args:
text: the text to retrieve similar documents
top_k: number of top similar documents to return
Returns:
list[RetrievedDocument]: list of retrieved documents
"""
if top_k is None:
top_k = self.top_k
if self.doc_store is None:
raise ValueError(
"doc_store is not provided. Please provide a doc_store to "
"retrieve the documents"
)
emb: list[float] = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
# use additional reranker to re-order the document list
if self.rerankers:
for reranker in self.rerankers:
result = reranker(documents=result, query=text)
return result
class TextVectorQA(BaseComponent):
retrieving_pipeline: BaseRetrieval
qa_pipeline: BaseComponent
def run(self, question, **kwargs):
retrieved_documents = self.retrieving_pipeline(question, **kwargs)
return self.qa_pipeline(question, retrieved_documents, **kwargs)

View File

@@ -0,0 +1,35 @@
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, ChatLLM
from .completions import LLM, AzureOpenAI, OpenAI
from .cot import ManualSequentialChainOfThought, Thought
from .linear import GatedLinearPipeline, SimpleLinearPipeline
from .prompts import BasePromptComponent, PromptTemplate
__all__ = [
"BaseLLM",
# chat-specific components
"ChatLLM",
"BaseMessage",
"HumanMessage",
"AIMessage",
"SystemMessage",
"AzureChatOpenAI",
# completion-specific components
"LLM",
"OpenAI",
"AzureOpenAI",
# prompt-specific components
"BasePromptComponent",
"PromptTemplate",
# strategies
"SimpleLinearPipeline",
"GatedLinearPipeline",
"SimpleBranchingPipeline",
"GatedBranchingPipeline",
# chain-of-thoughts
"ManualSequentialChainOfThought",
"Thought",
]

View File

@@ -0,0 +1,7 @@
from kotaemon.base import BaseComponent
from langchain_core.language_models.base import BaseLanguageModel
class BaseLLM(BaseComponent):
def to_langchain_format(self) -> BaseLanguageModel:
raise NotImplementedError

View File

@@ -0,0 +1,186 @@
from typing import List, Optional
from kotaemon.base import BaseComponent, Document, Param
from .linear import GatedLinearPipeline
class SimpleBranchingPipeline(BaseComponent):
"""
A simple branching pipeline for executing multiple branches.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example:
```python
from kotaemon.llms import (
AzureChatOpenAI,
BasePromptComponent,
GatedLinearPipeline,
)
from kotaemon.parsers import RegexExtractor
def identity(x):
return x
pipeline = SimpleBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
print(pipeline(condition_text="12"))
```
"""
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
def add_branch(self, component: BaseComponent):
"""
Add a new branch to the pipeline.
Args:
component (BaseComponent): The branch component to be added.
"""
self.branches.append(component)
def run(self, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the outputs as a list.
Args:
**prompt_kwargs: Keyword arguments for the branches.
Returns:
List: The outputs of each branch as a list.
"""
output = []
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output.append(branch(**prompt_kwargs))
return output
class GatedBranchingPipeline(SimpleBranchingPipeline):
"""
A simple gated branching pipeline for executing multiple branches based on a
condition.
This class extends the SimpleBranchingPipeline class and adds the ability to execute
the branches until a branch returns a non-empty output based on a condition.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example:
```python
from kotaemon.llms import (
AzureChatOpenAI,
BasePromptComponent,
GatedLinearPipeline,
)
from kotaemon.parsers import RegexExtractor
def identity(x):
return x
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
```
"""
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the output of the first
branch that returns a non-empty output based on the provided condition.
Args:
condition_text (str): The condition text to evaluate for each branch.
Default to None.
**prompt_kwargs: Keyword arguments for the branches.
Returns:
Union[OutputType, None]: The output of the first branch that satisfies the
condition, or None if no branch satisfies the condition.
Raises:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided.")
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output = branch(condition_text=condition_text, **prompt_kwargs)
if output:
return output
return Document(None)
if __name__ == "__main__":
import dotenv
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.parsers import RegexExtractor
def identity(x):
return x
secrets = dotenv.dotenv_values(".env")
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
pipeline(condition_text="1")

View File

@@ -0,0 +1,4 @@
from .base import ChatLLM
from .langchain_based import AzureChatOpenAI, LCChatMixin
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
import logging
from kotaemon.base import BaseComponent
from kotaemon.llms.base import BaseLLM
logger = logging.getLogger(__name__)
class ChatLLM(BaseLLM):
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")
if not isinstance(self.inflow, BaseComponent):
raise ValueError(
f"inflow must be a BaseComponent, found {type(self.inflow)}"
)
text = self.inflow.flow().text
return self.__call__(text)

View File

@@ -0,0 +1,170 @@
from __future__ import annotations
import logging
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
from .base import ChatLLM
logger = logging.getLogger(__name__)
class LCChatMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, stream: bool = False, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
self._stream = stream
super().__init__()
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
if self._stream:
return self.stream(messages, **kwargs) # type: ignore
return self.invoke(messages, **kwargs)
def invoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
Args:
messages: history of messages to generate response from
**kwargs: additional arguments to pass to the langchain chat model
Returns:
LLMInterface: generated response
"""
input_: list[BaseMessage] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
pred = self._obj.generate(messages=[input_], **kwargs)
all_text = [each.text for each in pred.generations[0]]
all_messages = [each.message for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
messages=all_messages,
logits=[],
)
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
for response in self._obj.stream(input=messages, **kwargs):
yield LLMInterface(content=response.content)
def to_langchain_format(self):
return self._obj
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**params,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class AzureChatOpenAI(LCChatMixin, ChatLLM):
def __init__(
self,
azure_endpoint: str | None = None,
openai_api_key: str | None = None,
openai_api_version: str = "",
deployment_name: str | None = None,
temperature: float = 0.7,
request_timeout: float | None = None,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
openai_api_key=openai_api_key,
openai_api_version=openai_api_version,
deployment_name=deployment_name,
temperature=temperature,
request_timeout=request_timeout,
**params,
)
def _get_lc_class(self):
try:
from langchain_openai import AzureChatOpenAI
except ImportError:
from langchain.chat_models import AzureChatOpenAI
return AzureChatOpenAI

View File

@@ -0,0 +1,4 @@
from .base import LLM
from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]

View File

@@ -0,0 +1,5 @@
from kotaemon.llms.base import BaseLLM
class LLM(BaseLLM):
pass

View File

@@ -0,0 +1,197 @@
import logging
from typing import Optional
from kotaemon.base import LLMInterface
from .base import LLM
logger = logging.getLogger(__name__)
class LCCompletionMixin:
def _get_lc_class(self):
raise NotImplementedError(
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
super().__init__()
def run(self, text: str) -> LLMInterface:
pred = self._obj.generate([text])
all_text = [each.text for each in pred.generations[0]]
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
try:
if pred.llm_output is not None:
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
except Exception:
logger.warning(
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
)
return LLMInterface(
text=all_text[0] if len(all_text) > 0 else "",
candidates=all_text,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
logits=[],
)
def to_langchain_format(self):
return self._obj
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name, value):
if name == "_lc_class":
return super().__setattr__(name, value)
if name in self._lc_class.__fields__:
self._kwargs[name] = value
self._obj = self._lc_class(**self._kwargs)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self, *args, **kwargs):
from theflow.utils.modules import serialize
params = {key: serialize(value) for key, value in self._kwargs.items()}
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**params,
}
def specs(self, path: str):
path = path.strip(".")
if "." in path:
raise ValueError("path should not contain '.'")
if path in self._lc_class.__fields__:
return {
"__type__": "theflow.base.ParamAttr",
"refresh_on_set": True,
"strict_type": True,
}
raise ValueError(f"Invalid param {path}")
class OpenAI(LCCompletionMixin, LLM):
"""Wrapper around Langchain's OpenAI class, focusing on key parameters"""
def __init__(
self,
openai_api_key: Optional[str] = None,
openai_api_base: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_tokens: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
best_of: int = 1,
request_timeout: Optional[float] = None,
max_retries: int = 2,
streaming: bool = False,
**params,
):
super().__init__(
openai_api_key=openai_api_key,
openai_api_base=openai_api_base,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
best_of=best_of,
request_timeout=request_timeout,
max_retries=max_retries,
streaming=streaming,
**params,
)
def _get_lc_class(self):
try:
from langchain_openai import OpenAI
except ImportError:
from langchain.llms import OpenAI
return OpenAI
class AzureOpenAI(LCCompletionMixin, LLM):
"""Wrapper around Langchain's AzureOpenAI class, focusing on key parameters"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
deployment_name: Optional[str] = None,
openai_api_version: str = "",
openai_api_key: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_tokens: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
best_of: int = 1,
request_timeout: Optional[float] = None,
max_retries: int = 2,
streaming: bool = False,
**params,
):
super().__init__(
azure_endpoint=azure_endpoint,
deployment_name=deployment_name,
openai_api_version=openai_api_version,
openai_api_key=openai_api_key,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
best_of=best_of,
request_timeout=request_timeout,
max_retries=max_retries,
streaming=streaming,
**params,
)
def _get_lc_class(self):
try:
from langchain_openai import AzureOpenAI
except ImportError:
from langchain.llms import AzureOpenAI
return AzureOpenAI

View File

@@ -0,0 +1,174 @@
from copy import deepcopy
from typing import Callable, List
from kotaemon.base import BaseComponent, Document
from theflow import Function, Node, Param
from .chats import AzureChatOpenAI
from .completions import LLM
from .prompts import BasePromptComponent
class Thought(BaseComponent):
"""A thought in the chain of thought
- Input: `**kwargs` pairs, where key is the placeholder in the prompt, and
value is the value.
- Output: an output dictionary
_**Usage:**_
Create and run a thought:
```python
>> from kotaemon.pipelines.cot import Thought
>> thought = Thought(
prompt="How to {action} {object}?",
llm=AzureChatOpenAI(...),
post_process=lambda string: {"tutorial": string},
)
>> output = thought(action="install", object="python")
>> print(output)
{'tutorial': 'As an AI language model,...'}
```
Basically, when a thought is run, it will:
1. Populate the prompt template with the input `**kwargs`.
2. Run the LLM model with the populated prompt.
3. Post-process the LLM output with the post-processor.
This `Thought` allows chaining sequentially with the + operator. For example:
```python
>> llm = AzureChatOpenAI(...)
>> thought1 = Thought(
prompt="Word {word} in {language} is ",
llm=llm,
post_process=lambda string: {"translated": string},
)
>> thought2 = Thought(
prompt="Translate {translated} to Japanese",
llm=llm,
post_process=lambda string: {"output": string},
)
>> thought = thought1 + thought2
>> thought(word="hello", language="French")
{'word': 'hello',
'language': 'French',
'translated': '"Bonjour"',
'output': 'こんにちは (Konnichiwa)'}
```
Under the hood, when the `+` operator is used, a `ManualSequentialChainOfThought`
is created.
"""
prompt: str = Param(
help=(
"The prompt template string. This prompt template has Python-like variable"
" placeholders, that then will be substituted with real values when this"
" component is executed"
)
)
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
post_process: Function = Node(
help=(
"The function post-processor that post-processes LLM output prediction ."
"It should take a string as input (this is the LLM output text) and return "
"a dictionary, where the key should"
)
)
@Node.auto(depends_on="prompt")
def prompt_template(self):
"""Automatically wrap around param prompt. Can ignore"""
return BasePromptComponent(self.prompt)
def run(self, **kwargs) -> Document:
"""Run the chain of thought"""
prompt = self.prompt_template(**kwargs).text
response = self.llm(prompt).text
response = self.post_process(response)
return Document(response)
def get_variables(self) -> List[str]:
return []
def __add__(self, next_thought: "Thought") -> "ManualSequentialChainOfThought":
return ManualSequentialChainOfThought(
thoughts=[self, next_thought], llm=self.llm
)
class ManualSequentialChainOfThought(BaseComponent):
"""Perform sequential chain-of-thought with manual pre-defined prompts
This method supports variable number of steps. Each step corresponds to a
`kotaemon.pipelines.cot.Thought`. Please refer that section for
Thought's detail. This section is about chaining thought together.
_**Usage:**_
**Create and run a chain of thought without "+" operator:**
```pycon
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
>>> llm = AzureChatOpenAI(...)
>>> thought1 = Thought(
>>> prompt="Word {word} in {language} is ",
>>> post_process=lambda string: {"translated": string},
>>> )
>>> thought2 = Thought(
>>> prompt="Translate {translated} to Japanese",
>>> post_process=lambda string: {"output": string},
>>> )
>>> thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
>>> thought(word="hello", language="French")
{'word': 'hello',
'language': 'French',
'translated': '"Bonjour"',
'output': 'こんにちは (Konnichiwa)'}
```
**Create and run a chain of thought without "+" operator:** Please refer the
`kotaemon.pipelines.cot.Thought` section for examples.
This chain-of-thought optionally takes a termination check callback function.
This function will be called after each thought is executed. It takes in a
dictionary of all thought outputs so far, and it returns True or False. If
True, the chain-of-thought will terminate. If unset, the default callback always
returns False.
"""
thoughts: List[Thought] = Param(
default_callback=lambda *_: [], help="List of Thought"
)
llm: LLM = Param(help="The LLM model to use (base of kotaemon.llms.BaseLLM)")
terminate: Callable = Param(
default=lambda _: False,
help="Callback on terminate condition. Default to always return False",
)
def run(self, **kwargs) -> Document:
"""Run the manual chain of thought"""
inputs = deepcopy(kwargs)
for idx, thought in enumerate(self.thoughts):
if self.llm:
thought.llm = self.llm
self._prepare_child(thought, f"thought{idx}")
output = thought(**inputs)
inputs.update(output.content)
if self.terminate(inputs):
break
return Document(inputs)
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
return ManualSequentialChainOfThought(
thoughts=self.thoughts + [next_thought], llm=self.llm
)

View File

@@ -0,0 +1,155 @@
from typing import Any, Callable, Optional, Union
from ..base import BaseComponent
from ..base.schema import Document, IO_Type
from .chats import ChatLLM
from .completions import LLM
from .prompts import BasePromptComponent
class SimpleLinearPipeline(BaseComponent):
"""
A simple pipeline for running a function with a prompt, a language model, and an
optional post-processor.
Attributes:
prompt (BasePromptComponent): The prompt component used to generate the initial
input.
llm (Union[ChatLLM, LLM]): The language model component used to generate the
output.
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
post-processor component or function.
Example Usage:
```python
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = SimpleLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
llm=llm,
post_processor=identity,
)
print(pipeline(word="lone"))
```
"""
prompt: BasePromptComponent
llm: Union[ChatLLM, LLM]
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
def run(
self,
*,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
):
"""
Run the function with the given arguments and return the final output as a
Document object.
Args:
llm_kwargs (dict): Keyword arguments for the llm call.
post_processor_kwargs (dict): Keyword arguments for the post_processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the function as a Document object.
"""
prompt = self.prompt(**prompt_kwargs)
llm_output = self.llm(prompt.text, **llm_kwargs)
if self.post_processor is not None:
final_output = self.post_processor(llm_output, **post_processor_kwargs)[0]
else:
final_output = llm_output
return Document(final_output)
class GatedLinearPipeline(SimpleLinearPipeline):
"""
A pipeline that extends the SimpleLinearPipeline class and adds a condition
attribute.
Attributes:
condition (Callable[[IO_Type], Any]): A callable function that represents the
condition.
Usage:
```{.py3 title="Example Usage"}
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.parsers import RegexExtractor
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = GatedLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
condition=RegexExtractor(pattern="some pattern"),
llm=llm,
post_processor=identity,
)
print(pipeline(condition_text="some pattern", word="lone"))
print(pipeline(condition_text="other pattern", word="lone"))
```
"""
condition: Callable[[IO_Type], Any]
def run(
self,
*,
condition_text: Optional[str] = None,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
) -> Document:
"""
Run the pipeline with the given arguments and return the final output as a
Document object.
Args:
condition_text (str): The condition text to evaluate. Default to None.
llm_kwargs (dict): Additional keyword arguments for the language model call.
post_processor_kwargs (dict): Additional keyword arguments for the
post-processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the pipeline as a Document object.
Raises:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided")
if self.condition(condition_text)[0]:
return super().run(
llm_kwargs=llm_kwargs,
post_processor_kwargs=post_processor_kwargs,
**prompt_kwargs,
)
return Document(None)

View File

@@ -0,0 +1,4 @@
from .base import BasePromptComponent
from .template import PromptTemplate
__all__ = ["BasePromptComponent", "PromptTemplate"]

View File

@@ -0,0 +1,179 @@
from typing import Callable, Union
from kotaemon.base import BaseComponent, Document
from .template import PromptTemplate
class BasePromptComponent(BaseComponent):
"""
Base class for prompt components.
Args:
template (PromptTemplate): The prompt template.
**kwargs: Any additional keyword arguments that will be used to populate the
given template.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
allow_extra = True
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (
template
if isinstance(template, PromptTemplate)
else PromptTemplate(template)
)
self.__set(**kwargs)
def __check_redundant_kwargs(self, **kwargs):
"""
Check for redundant keyword arguments.
Parameters:
**kwargs (dict): A dictionary of keyword arguments.
Raises:
ValueError: If any keys provided are not in the template.
Returns:
None
"""
self.template.check_redundant_kwargs(**kwargs)
def __check_unset_placeholders(self):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
self.template.check_missing_kwargs(**self.__dict__)
def __validate_value_type(self, **kwargs):
"""
Validates the value types of the given keyword arguments.
Parameters:
**kwargs (dict): A dictionary of keyword arguments to be validated.
Raises:
ValueError: If any of the values in the kwargs dictionary have an
unsupported type.
Returns:
None
"""
type_error = []
for k, v in kwargs.items():
if not isinstance(v, (str, int, Document, Callable)): # type: ignore
type_error.append((k, type(v)))
if type_error:
raise ValueError(
"Type of values must be either int, str, Document, Callable, "
f"found unsupported type for (key, type): {type_error}"
)
def __set(self, **kwargs):
"""
Set the values of the attributes in the object based on the provided keyword
arguments.
Args:
kwargs (dict): A dictionary with the attribute names as keys and the new
values as values.
Returns:
None
"""
self.__check_redundant_kwargs(**kwargs)
self.__validate_value_type(**kwargs)
self.__dict__.update(kwargs)
def __prepare_value(self):
"""
Generate a dictionary of keyword arguments based on the template's placeholders
and the current instance's attributes.
Returns:
dict: A dictionary of keyword arguments.
"""
def __prepare(key, value):
if isinstance(value, str):
return value
if isinstance(value, (int, Document)):
return str(value)
raise ValueError(
f"Unsupported type {type(value)} for template value of key {key}"
)
kwargs = {}
for k in self.template.placeholders:
v = getattr(self, k)
# if get a callable, execute to get its output
if isinstance(v, Callable): # type: ignore[arg-type]
v = v()
if isinstance(v, list):
v = str([__prepare(k, each) for each in v])
elif isinstance(v, (str, int, Document)):
v = __prepare(k, v)
else:
raise ValueError(
f"Unsupported type {type(v)} for template value of key `{k}`"
)
kwargs[k] = v
return kwargs
def set(self, **kwargs):
"""
Similar to `__set` but for external use.
Set the values of the attributes in the object based on the provided keyword
arguments.
Args:
kwargs (dict): A dictionary with the attribute names as keys and the new
values as values.
Returns:
None
"""
self.__set(**kwargs)
def run(self, **kwargs):
"""
Run the function with the given keyword arguments.
Args:
**kwargs: The keyword arguments to pass to the function.
Returns:
The result of calling the `populate` method of the `template` object
with the given keyword arguments.
"""
self.__set(**kwargs)
self.__check_unset_placeholders()
prepared_kwargs = self.__prepare_value()
text = self.template.populate(**prepared_kwargs)
return Document(text=text, metadata={"origin": "PromptComponent"})
def flow(self):
return self.__call__()

View File

@@ -0,0 +1,140 @@
import warnings
from string import Formatter
class PromptTemplate:
"""
Base class for prompt templates.
"""
def __init__(self, template: str, ignore_invalid=True):
template = template
formatter = Formatter()
parsed_template = list(formatter.parse(template))
placeholders = set()
for _, key, _, _ in parsed_template:
if key is None:
continue
if not key.isidentifier():
if ignore_invalid:
warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning)
else:
raise ValueError(
"Placeholder name must be a valid Python identifier, found:"
f" {key}."
)
placeholders.add(key)
self.template = template
self.placeholders = placeholders
self.__formatter = formatter
self.__parsed_template = parsed_template
def check_missing_kwargs(self, **kwargs):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
missing_keys = self.placeholders.difference(kwargs.keys())
if missing_keys:
raise ValueError(f"Missing keys in template: {','.join(missing_keys)}")
def check_redundant_kwargs(self, **kwargs):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
provided_keys = set(kwargs.keys())
redundant_keys = provided_keys - self.placeholders
if redundant_keys:
warnings.warn(
f"Keys provided but not in template: {','.join(redundant_keys)}",
UserWarning,
)
def populate(self, **kwargs) -> str:
"""
Strictly populate the template with the given keyword arguments.
Args:
**kwargs: The keyword arguments to populate the template.
Each keyword corresponds to a placeholder in the template.
Returns:
The populated template.
Raises:
ValueError: If an unknown placeholder is provided.
"""
self.check_missing_kwargs(**kwargs)
return self.partial_populate(**kwargs)
def partial_populate(self, **kwargs):
"""
Partially populate the template with the given keyword arguments.
Args:
**kwargs: The keyword arguments to populate the template.
Each keyword corresponds to a placeholder in the template.
Returns:
str: The populated template.
"""
self.check_redundant_kwargs(**kwargs)
prompt = []
for literal_text, field_name, format_spec, conversion in self.__parsed_template:
prompt.append(literal_text)
if field_name is None:
continue
if field_name not in kwargs:
if conversion:
value = f"{{{field_name}}}!{conversion}:{format_spec}"
else:
value = f"{{{field_name}:{format_spec}}}"
else:
value = kwargs[field_name]
if conversion is not None:
value = self.__formatter.convert_field(value, conversion)
if format_spec is not None:
value = self.__formatter.format_field(value, format_spec)
prompt.append(value)
return "".join(prompt)
def __add__(self, other):
"""
Create a new PromptTemplate object by concatenating the template of the current
object with the template of another PromptTemplate object.
Parameters:
other (PromptTemplate): Another PromptTemplate object.
Returns:
PromptTemplate: A new PromptTemplate object with the concatenated templates.
"""
return PromptTemplate(self.template + "\n" + other.template)

View File

@@ -0,0 +1,14 @@
from .base import AutoReader, DirectoryReader
from .excel_loader import PandasExcelReader
from .mathpix_loader import MathpixPDFReader
from .ocr_loader import OCRReader
from .unstructured_loader import UnstructuredReader
__all__ = [
"AutoReader",
"PandasExcelReader",
"MathpixPDFReader",
"OCRReader",
"DirectoryReader",
"UnstructuredReader",
]

View File

@@ -0,0 +1,65 @@
from pathlib import Path
from typing import Any, List, Type, Union
from kotaemon.base import BaseComponent, Document
from llama_index import SimpleDirectoryReader, download_loader
from llama_index.readers.base import BaseReader
class AutoReader(BaseComponent):
"""General auto reader for a variety of files. (based on llama-hub)"""
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
"""Init reader using string identifier or class name from llama-hub"""
if isinstance(reader_type, str):
self._reader = download_loader(reader_type)()
else:
self._reader = reader_type()
super().__init__()
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
documents = self._reader.load_data(file=file, **kwargs)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
return self.load_data(file=file, **kwargs)
class LIBaseReader(BaseComponent):
_reader_class: Type[BaseReader]
def __init__(self, *args, **kwargs):
if self._reader_class is None:
raise AttributeError(
"Require `_reader_class` to set a BaseReader class from LlamarIndex"
)
self._reader = self._reader_class(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
return super().__setattr__(name, value)
return setattr(self._reader, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._reader, name)
def load_data(self, *args, **kwargs: Any) -> List[Document]:
documents = self._reader.load_data(*args, **kwargs)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(self, *args, **kwargs: Any) -> List[Document]:
return self.load_data(*args, **kwargs)
class DirectoryReader(LIBaseReader):
_reader_class = SimpleDirectoryReader

View File

@@ -0,0 +1,99 @@
"""Pandas Excel reader.
Pandas parser for .xlsx files.
"""
from pathlib import Path
from typing import Any, List, Optional, Union
from kotaemon.base import Document
from llama_index.readers.base import BaseReader
class PandasExcelReader(BaseReader):
r"""Pandas-based CSV parser.
Parses CSVs using the separator detection from Pandas `read_csv` function.
If special parameters are required, use the `pandas_config` dict.
Args:
pandas_config (dict): Options for the `pandas.read_excel` function call.
Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html
for more information. Set to empty dict by default,
this means defaults will be used.
"""
def __init__(
self,
*args: Any,
pandas_config: Optional[dict] = None,
row_joiner: str = "\n",
col_joiner: str = " ",
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._pandas_config = pandas_config or {}
self._row_joiner = row_joiner if row_joiner else "\n"
self._col_joiner = col_joiner if col_joiner else " "
def load_data(
self,
file: Path,
include_sheetname: bool = False,
sheet_name: Optional[Union[str, int, list]] = None,
**kwargs,
) -> List[Document]:
"""Parse file and extract values from a specific column.
Args:
file (Path): The path to the Excel file to read.
include_sheetname (bool): Whether to include the sheet name in the output.
sheet_name (Union[str, int, None]): The specific sheet to read from,
default is None which reads all sheets.
Returns:
List[Document]: A list of`Document objects containing the
values from the specified column in the Excel file.
"""
import itertools
try:
import pandas as pd
except ImportError:
raise ImportError(
"install pandas using `pip3 install pandas` to use this loader"
)
if sheet_name is not None:
sheet_name = (
[sheet_name] if not isinstance(sheet_name, list) else sheet_name
)
dfs = pd.read_excel(file, sheet_name=sheet_name, **self._pandas_config)
sheet_names = dfs.keys()
df_sheets = []
for key in sheet_names:
sheet = []
if include_sheetname:
sheet.append([key])
sheet.extend(dfs[key].values.astype(str).tolist())
df_sheets.append(sheet)
text_list = list(
itertools.chain.from_iterable(df_sheets)
) # flatten list of lists
output = [
Document(
text=self._row_joiner.join(
self._col_joiner.join(sublist) for sublist in text_list
),
metadata={"source": file.stem},
)
]
return output

View File

@@ -0,0 +1,174 @@
import json
import re
import time
from pathlib import Path
from typing import Any, Dict, List
import requests
from kotaemon.base import Document
from langchain.utils import get_from_dict_or_env
from llama_index.readers.base import BaseReader
from .utils.table import parse_markdown_text_to_tables, strip_special_chars_markdown
# MathpixPDFLoader implementation taken largely from Daniel Gross's:
# https://gist.github.com/danielgross/3ab4104e14faccc12b49200843adab21
class MathpixPDFReader(BaseReader):
"""Load `PDF` files using `Mathpix` service."""
def __init__(
self,
processed_file_format: str = "md",
max_wait_time_seconds: int = 500,
should_clean_pdf: bool = True,
**kwargs: Any,
) -> None:
"""Initialize with a file path.
Args:
processed_file_format: a format of the processed file. Default is "mmd".
max_wait_time_seconds: a maximum time to wait for the response from
the server. Default is 500.
should_clean_pdf: a flag to clean the PDF file. Default is False.
**kwargs: additional keyword arguments.
"""
self.mathpix_api_key = get_from_dict_or_env(
kwargs, "mathpix_api_key", "MATHPIX_API_KEY", default="empty"
)
self.mathpix_api_id = get_from_dict_or_env(
kwargs, "mathpix_api_id", "MATHPIX_API_ID", default="empty"
)
self.processed_file_format = processed_file_format
self.max_wait_time_seconds = max_wait_time_seconds
self.should_clean_pdf = should_clean_pdf
super().__init__()
@property
def _mathpix_headers(self) -> Dict[str, str]:
return {"app_id": self.mathpix_api_id, "app_key": self.mathpix_api_key}
@property
def url(self) -> str:
return "https://api.mathpix.com/v3/pdf"
@property
def data(self) -> dict:
options = {
"conversion_formats": {self.processed_file_format: True},
"enable_tables_fallback": True,
}
return {"options_json": json.dumps(options)}
def send_pdf(self, file_path) -> str:
with open(file_path, "rb") as f:
files = {"file": f}
response = requests.post(
self.url, headers=self._mathpix_headers, files=files, data=self.data
)
response_data = response.json()
if "pdf_id" in response_data:
pdf_id = response_data["pdf_id"]
return pdf_id
else:
raise ValueError("Unable to send PDF to Mathpix.")
def wait_for_processing(self, pdf_id: str) -> None:
"""Wait for processing to complete.
Args:
pdf_id: a PDF id.
Returns: None
"""
url = self.url + "/" + pdf_id
for _ in range(0, self.max_wait_time_seconds, 5):
response = requests.get(url, headers=self._mathpix_headers)
response_data = response.json()
status = response_data.get("status", None)
if status == "completed":
return
elif status == "error":
raise ValueError("Unable to retrieve PDF from Mathpix")
else:
print(response_data)
print(url)
time.sleep(5)
raise TimeoutError
def get_processed_pdf(self, pdf_id: str) -> str:
self.wait_for_processing(pdf_id)
url = f"{self.url}/{pdf_id}.{self.processed_file_format}"
response = requests.get(url, headers=self._mathpix_headers)
return response.content.decode("utf-8")
def clean_pdf(self, contents: str) -> str:
"""Clean the PDF file.
Args:
contents: a PDF file contents.
Returns:
"""
contents = "\n".join(
[line for line in contents.split("\n") if not line.startswith("![]")]
)
# replace \section{Title} with # Title
contents = contents.replace("\\section{", "# ")
# replace the "\" slash that Mathpix adds to escape $, %, (, etc.
# http:// or https:// followed by anything but a closing paren
url_regex = "http[s]?://[^)]+"
markup_regex = r"\[]\(\s*({0})\s*\)".format(url_regex)
contents = (
contents.replace(r"\$", "$")
.replace(r"\%", "%")
.replace(r"\(", "(")
.replace(r"\)", ")")
.replace("$\\begin{array}", "")
.replace("\\end{array}$", "")
.replace("\\\\", "")
.replace("\\text", "")
.replace("}", "")
.replace("{", "")
.replace("\\mathrm", "")
)
contents = re.sub(markup_regex, "", contents)
return contents
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
if "response_content" in kwargs:
# overriding response content if specified
content = kwargs["response_content"]
else:
# call original API
pdf_id = self.send_pdf(file_path)
content = self.get_processed_pdf(pdf_id)
if self.should_clean_pdf:
content = self.clean_pdf(content)
tables, texts = parse_markdown_text_to_tables(content)
documents = []
for table in tables:
text = strip_special_chars_markdown(table)
metadata = {
"source": file_path.name,
"table_origin": table,
"type": "table",
}
documents.append(
Document(
text=text,
metadata=metadata,
metadata_template="",
metadata_seperator="",
)
)
for text in texts:
metadata = {"source": file_path.name, "type": "text"}
documents.append(Document(text=text, metadata=metadata))
return documents

View File

@@ -0,0 +1,102 @@
from pathlib import Path
from typing import List
from uuid import uuid4
import requests
from kotaemon.base import Document
from llama_index.readers.base import BaseReader
from .utils.pdf_ocr import parse_ocr_output, read_pdf_unstructured
from .utils.table import strip_special_chars_markdown
DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
class OCRReader(BaseReader):
def __init__(self, endpoint: str = DEFAULT_OCR_ENDPOINT, use_ocr=True):
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)
Args:
endpoint: URL to FullOCR endpoint. Defaults to OCR_ENDPOINT.
use_ocr: whether to use OCR to read text
(e.g: from images, tables) in the PDF
"""
super().__init__()
self.ocr_endpoint = endpoint
self.use_ocr = use_ocr
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
"""Load data using OCR reader
Args:
file_path (Path): Path to PDF file
debug_path (Path): Path to store debug image output
artifact_path (Path): Path to OCR endpoints artifacts directory
Returns:
List[Document]: list of documents extracted from the PDF file
"""
file_path = Path(file_path).resolve()
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
debug_path = kwargs.pop("debug_path", None)
artifact_path = kwargs.pop("artifact_path", None)
# read PDF through normal reader (unstructured)
pdf_page_items = read_pdf_unstructured(file_path)
# merge PDF text output with OCR output
tables, texts = parse_ocr_output(
ocr_results,
pdf_page_items,
debug_path=debug_path,
artifact_path=artifact_path,
)
# create output Document with metadata from table
documents = [
Document(
text=strip_special_chars_markdown(table_text),
metadata={
"table_origin": table_text,
"type": "table",
"page_label": page_id + 1,
"source": file_path.name,
"file_path": str(file_path),
"file_name": file_path.name,
"filename": str(file_path),
},
metadata_template="",
metadata_seperator="",
)
for page_id, table_text in tables
]
# create Document from non-table text
documents.extend(
[
Document(
text=non_table_text,
metadata={
"page_label": page_id + 1,
"source": file_path.name,
"file_path": str(file_path),
"file_name": file_path.name,
"filename": str(file_path),
},
)
for page_id, non_table_text in texts
]
)
return documents

View File

@@ -0,0 +1,110 @@
"""Unstructured file reader.
A parser for unstructured text files using Unstructured.io.
Supports .txt, .docx, .pptx, .jpg, .png, .eml, .html, and .pdf documents.
To use .doc and .xls parser, install
sudo apt-get install -y libmagic-dev poppler-utils libreoffice
pip install xlrd
"""
from pathlib import Path
from typing import Any, Dict, List, Optional
from kotaemon.base import Document
from llama_index.readers.base import BaseReader
class UnstructuredReader(BaseReader):
"""General unstructured text reader for a variety of files."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init params."""
super().__init__(*args) # not passing kwargs to parent bc it cannot accept it
self.api = False # we default to local
if "url" in kwargs:
self.server_url = str(kwargs["url"])
self.api = True # is url was set, switch to api
else:
self.server_url = "http://localhost:8000"
if "api" in kwargs:
self.api = kwargs["api"]
self.api_key = ""
if "api_key" in kwargs:
self.api_key = kwargs["api_key"]
""" Loads data using Unstructured.io
Depending on the construction if url is set or api = True
it'll parse file using API call, else parse it locally
additional_metadata is extended by the returned metadata if
split_documents is True
Returns list of documents
"""
def load_data(
self,
file: Path,
additional_metadata: Optional[Dict] = None,
split_documents: Optional[bool] = False,
**kwargs,
) -> List[Document]:
"""If api is set, parse through api"""
file_path_str = str(file)
if self.api:
from unstructured.partition.api import partition_via_api
elements = partition_via_api(
filename=file_path_str,
api_key=self.api_key,
api_url=self.server_url + "/general/v0/general",
)
else:
"""Parse file locally"""
from unstructured.partition.auto import partition
elements = partition(filename=file_path_str)
""" Process elements """
docs = []
file_name = Path(file).name
file_path = str(Path(file).resolve())
if split_documents:
for node in elements:
metadata = {"file_name": file_name, "file_path": file_path}
if hasattr(node, "metadata"):
"""Load metadata fields"""
for field, val in vars(node.metadata).items():
if field == "_known_field_names":
continue
# removing coordinates because it does not serialize
# and dont want to bother with it
if field == "coordinates":
continue
# removing bc it might cause interference
if field == "parent_id":
continue
metadata[field] = val
if additional_metadata is not None:
metadata.update(additional_metadata)
metadata["file_name"] = file_name
docs.append(Document(text=node.text, metadata=metadata))
else:
text_chunks = [" ".join(str(el).split()) for el in elements]
metadata = {"file_name": file_name, "file_path": file_path}
if additional_metadata is not None:
metadata.update(additional_metadata)
# Create a single document by joining all the texts
docs.append(Document(text="\n\n".join(text_chunks), metadata=metadata))
return docs

View File

@@ -0,0 +1,144 @@
from typing import List, Tuple
def bbox_to_points(box: List[int]):
"""Convert bounding box to list of points"""
x1, y1, x2, y2 = box
return [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
def points_to_bbox(points: List[Tuple[int, int]]):
"""Convert list of points to bounding box"""
all_x = [p[0] for p in points]
all_y = [p[1] for p in points]
return [min(all_x), min(all_y), max(all_x), max(all_y)]
def scale_points(points: List[Tuple[int, int]], scale_factor: float = 1.0):
"""Scale points by a scale factor"""
return [(int(pos[0] * scale_factor), int(pos[1] * scale_factor)) for pos in points]
def union_points(points: List[Tuple[int, int]]):
"""Return union bounding box of list of points"""
all_x = [p[0] for p in points]
all_y = [p[1] for p in points]
bbox = (min(all_x), min(all_y), max(all_x), max(all_y))
return bbox
def scale_box(box: List[int], scale_factor: float = 1.0):
"""Scale box by a scale factor"""
return [int(pos * scale_factor) for pos in box]
def box_h(box: List[int]):
"Return box height"
return box[3] - box[1]
def box_w(box: List[int]):
"Return box width"
return box[2] - box[0]
def box_area(box: List[int]):
"Return box area"
x1, y1, x2, y2 = box
return (x2 - x1) * (y2 - y1)
def get_rect_iou(gt_box: List[tuple], pd_box: List[tuple], iou_type=0) -> int:
"""Intersection over union on layout rectangle
Args:
gt_box: List[tuple]
A list contains bounding box coordinates of ground truth
pd_box: List[tuple]
A list contains bounding box coordinates of prediction
iou_type: int
0: intersection / union, normal IOU
1: intersection / min(areas), useful when boxes are under/over-segmented
Input format: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
Annotation for each element in bbox:
(x1, y1) (x2, y1)
+-------+
| |
| |
+-------+
(x1, y2) (x2, y2)
Returns:
Intersection over union value
"""
assert iou_type in [0, 1], "Only support 0: origin iou, 1: intersection / min(area)"
# determine the (x, y)-coordinates of the intersection rectangle
# gt_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
# pd_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
x_left = max(gt_box[0][0], pd_box[0][0])
y_top = max(gt_box[0][1], pd_box[0][1])
x_right = min(gt_box[2][0], pd_box[2][0])
y_bottom = min(gt_box[2][1], pd_box[2][1])
# compute the area of intersection rectangle
interArea = max(0, x_right - x_left) * max(0, y_bottom - y_top)
# compute the area of both the prediction and ground-truth
# rectangles
gt_area = (gt_box[2][0] - gt_box[0][0]) * (gt_box[2][1] - gt_box[0][1])
pd_area = (pd_box[2][0] - pd_box[0][0]) * (pd_box[2][1] - pd_box[0][1])
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the intersection area
if iou_type == 0:
iou = interArea / float(gt_area + pd_area - interArea)
elif iou_type == 1:
iou = interArea / max(min(gt_area, pd_area), 1)
# return the intersection over union value
return iou
def sort_funsd_reading_order(lines: List[dict], box_key_name: str = "box"):
"""Sort cell list to create the right reading order using their locations
Args:
lines: list of cells to sort
Returns:
a list of cell lists in the right reading order that contain
no key or start with a key and contain no other key
"""
sorted_list = []
if len(lines) == 0:
return lines
while len(lines) > 1:
topleft_line = lines[0]
for line in lines[1:]:
topleft_line_pos = topleft_line[box_key_name]
topleft_line_center_y = (topleft_line_pos[1] + topleft_line_pos[3]) / 2
x1, y1, x2, y2 = line[box_key_name]
box_center_x = (x1 + x2) / 2
box_center_y = (y1 + y2) / 2
cell_h = y2 - y1
if box_center_y <= topleft_line_center_y - cell_h / 2:
topleft_line = line
continue
if (
box_center_x < topleft_line_pos[2]
and box_center_y < topleft_line_pos[3]
):
topleft_line = line
continue
sorted_list.append(topleft_line)
lines.remove(topleft_line)
sorted_list.append(lines[0])
return sorted_list

View File

@@ -0,0 +1,294 @@
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from .box import (
bbox_to_points,
box_area,
box_h,
box_w,
get_rect_iou,
points_to_bbox,
scale_box,
scale_points,
sort_funsd_reading_order,
union_points,
)
from .table import table_cells_to_markdown
IOU_THRES = 0.5
PADDING_THRES = 1.1
def read_pdf_unstructured(input_path: Union[Path, str]):
"""Convert PDF from specified path to list of text items with
location information
Args:
input_path: path to input file
Returns:
Dict page_number: list of text boxes
"""
try:
from unstructured.partition.auto import partition
except ImportError:
raise ImportError(
"Please install unstructured PDF reader `pip install unstructured[pdf]`"
)
page_items = defaultdict(list)
items = partition(input_path)
for item in items:
page_number = item.metadata.page_number
bbox = points_to_bbox(item.metadata.coordinates.points)
coord_system = item.metadata.coordinates.system
max_w, max_h = coord_system.width, coord_system.height
page_items[page_number - 1].append(
{
"text": item.text,
"box": bbox,
"location": bbox_to_points(bbox),
"page_shape": (max_w, max_h),
}
)
return page_items
def merge_ocr_and_pdf_texts(
ocr_list: List[dict], pdf_text_list: List[dict], debug_info=None
):
"""Merge PDF and OCR text using IOU overlapping location
Args:
ocr_list: List of OCR items {"text", "box", "location"}
pdf_text_list: List of PDF items {"text", "box", "location"}
Returns:
Combined list of PDF text and non-overlap OCR text
"""
not_matched_ocr = []
# check for debug info
if debug_info is not None:
cv2, debug_im = debug_info
for ocr_item in ocr_list:
matched = False
for pdf_item in pdf_text_list:
if (
get_rect_iou(ocr_item["location"], pdf_item["location"], iou_type=1)
> IOU_THRES
):
matched = True
break
color = (255, 0, 0)
if not matched:
ocr_item["matched"] = False
not_matched_ocr.append(ocr_item)
color = (0, 255, 255)
if debug_info is not None:
cv2.rectangle(
debug_im,
ocr_item["location"][0],
ocr_item["location"][2],
color=color,
thickness=1,
)
if debug_info is not None:
for pdf_item in pdf_text_list:
cv2.rectangle(
debug_im,
pdf_item["location"][0],
pdf_item["location"][2],
color=(0, 255, 0),
thickness=2,
)
return pdf_text_list + not_matched_ocr
def merge_table_cell_and_ocr(
table_list: List[dict], ocr_list: List[dict], pdf_list: List[dict], debug_info=None
):
"""Merge table items with OCR text using IOU overlapping location
Args:
table_list: List of table items
"type": ("table", "cell", "text"), "text", "box", "location"}
ocr_list: List of OCR items {"text", "box", "location"}
pdf_list: List of PDF items {"text", "box", "location"}
Returns:
all_table_cells: List of tables, each of table is represented
by list of cells with combined text from OCR
not_matched_items: List of PDF text which is not overlapped by table region
"""
# check for debug info
if debug_info is not None:
cv2, debug_im = debug_info
cell_list = [item for item in table_list if item["type"] == "cell"]
table_list = [item for item in table_list if item["type"] == "table"]
# sort table by area
table_list = sorted(table_list, key=lambda item: box_area(item["bbox"]))
all_tables = []
matched_pdf_ids = []
matched_cell_ids = []
for table in table_list:
if debug_info is not None:
cv2.rectangle(
debug_im,
table["location"][0],
table["location"][2],
color=[0, 0, 255],
thickness=5,
)
cur_table_cells = []
for cell_id, cell in enumerate(cell_list):
if cell_id in matched_cell_ids:
continue
if get_rect_iou(
table["location"], cell["location"], iou_type=1
) > IOU_THRES and box_area(table["bbox"]) > box_area(cell["bbox"]):
color = [128, 0, 128]
# cell matched to table
for item_list, item_type in [(pdf_list, "pdf"), (ocr_list, "ocr")]:
cell["ocr"] = []
for item_id, item in enumerate(item_list):
if item_type == "pdf" and item_id in matched_pdf_ids:
continue
if (
get_rect_iou(item["location"], cell["location"], iou_type=1)
> IOU_THRES
):
cell["ocr"].append(item)
if item_type == "pdf":
matched_pdf_ids.append(item_id)
if len(cell["ocr"]) > 0:
# check if union of matched ocr does
# not extend over cell boundary,
# if True, continue to use OCR_list to match
all_box_points_in_cell = []
for item in cell["ocr"]:
all_box_points_in_cell.extend(item["location"])
union_box = union_points(all_box_points_in_cell)
cell_okay = (
box_h(union_box) <= box_h(cell["bbox"]) * PADDING_THRES
and box_w(union_box) <= box_w(cell["bbox"]) * PADDING_THRES
)
else:
cell_okay = False
if cell_okay:
if item_type == "pdf":
color = [255, 0, 255]
break
if debug_info is not None:
cv2.rectangle(
debug_im,
cell["location"][0],
cell["location"][2],
color=color,
thickness=3,
)
matched_cell_ids.append(cell_id)
cur_table_cells.append(cell)
all_tables.append(cur_table_cells)
not_matched_items = [
item for _id, item in enumerate(pdf_list) if _id not in matched_pdf_ids
]
if debug_info is not None:
for item in not_matched_items:
cv2.rectangle(
debug_im,
item["location"][0],
item["location"][2],
color=[128, 128, 128],
thickness=3,
)
return all_tables, not_matched_items
def parse_ocr_output(
ocr_page_items: List[dict],
pdf_page_items: Dict[int, List[dict]],
artifact_path: Optional[str] = None,
debug_path: Optional[str] = None,
):
"""Main function to combine OCR output and PDF text to
form list of table / non-table regions
Args:
ocr_page_items: List of OCR items by page
pdf_page_items: Dict of PDF texts (page number as key)
debug_path: If specified, use OpenCV to plot debug image and save to debug_path
"""
all_tables = []
all_texts = []
for page_id, page in enumerate(ocr_page_items):
ocr_list = page["json"]["ocr"]
table_list = page["json"]["table"]
page_shape = page["image_shape"]
pdf_item_list = pdf_page_items[page_id]
# create bbox additional information
for item in ocr_list:
item["box"] = points_to_bbox(item["location"])
# re-scale pdf items according to new image size
for item in pdf_item_list:
scale_factor = page_shape[0] / item["page_shape"][0]
item["box"] = scale_box(item["box"], scale_factor=scale_factor)
item["location"] = scale_points(item["location"], scale_factor=scale_factor)
# if using debug mode, openCV must be installed
if debug_path and artifact_path is not None:
try:
import cv2
except ImportError:
raise ImportError(
"Please install openCV first to use OCRReader debug mode"
)
image_path = Path(artifact_path) / page["image"]
image = cv2.imread(str(image_path))
debug_info = (cv2, image)
else:
debug_info = None
new_pdf_list = merge_ocr_and_pdf_texts(
ocr_list, pdf_item_list, debug_info=debug_info
)
# sort by reading order
ocr_list = sort_funsd_reading_order(ocr_list)
new_pdf_list = sort_funsd_reading_order(new_pdf_list)
all_table_cells, non_table_text_list = merge_table_cell_and_ocr(
table_list, ocr_list, new_pdf_list, debug_info=debug_info
)
table_texts = [table_cells_to_markdown(cells) for cells in all_table_cells]
all_tables.extend([(page_id, text) for text in table_texts])
all_texts.append(
(page_id, " ".join(item["text"] for item in non_table_text_list))
)
# export debug image to debug_path
if debug_path:
cv2.imwrite(str(Path(debug_path) / "page_{}.png".format(page_id)), image)
return all_tables, all_texts

View File

@@ -0,0 +1,288 @@
import csv
from io import StringIO
from typing import List, Optional, Tuple
from .box import get_rect_iou
def check_col_conflicts(
col_a: List[str], col_b: List[str], thres: float = 0.15
) -> bool:
"""Check if 2 columns A and B has non-empty content in the same row
(to be used with merge_cols)
Args:
col_a: column A (list of str)
col_b: column B (list of str)
thres: percentage of overlapping allowed
Returns:
if number of overlapping greater than threshold
"""
num_rows = len([cell for cell in col_a if cell])
assert len(col_a) == len(col_b)
conflict_count = 0
for cell_a, cell_b in zip(col_a, col_b):
if cell_a and cell_b:
conflict_count += 1
return conflict_count > num_rows * thres
def merge_cols(col_a: List[str], col_b: List[str]) -> List[str]:
"""Merge column A and B if they do not have conflict rows
Args:
col_a: column A (list of str)
col_b: column B (list of str)
Returns:
merged column
"""
for r_id in range(len(col_a)):
if col_b[r_id]:
col_a[r_id] = col_a[r_id] + " " + col_b[r_id]
return col_a
def add_index_col(csv_rows: List[List[str]]) -> List[List[str]]:
"""Add index column as the first column of the table csv_rows
Args:
csv_rows: input table
Returns:
output table with index column
"""
new_csv_rows = [["row id"] + [""] * len(csv_rows[0])]
for r_id, row in enumerate(csv_rows):
new_csv_rows.append([str(r_id + 1)] + row)
return new_csv_rows
def compress_csv(csv_rows: List[List[str]]) -> List[List[str]]:
"""Compress table csv_rows by merging sparse columns (merge_cols)
Args:
csv_rows: input table
Returns:
output: compressed table
"""
csv_cols = [[r[c_id] for r in csv_rows] for c_id in range(len(csv_rows[0]))]
to_remove_col_ids = []
last_c_id = 0
for c_id in range(1, len(csv_cols)):
if not check_col_conflicts(csv_cols[last_c_id], csv_cols[c_id]):
to_remove_col_ids.append(c_id)
csv_cols[last_c_id] = merge_cols(csv_cols[last_c_id], csv_cols[c_id])
else:
last_c_id = c_id
csv_cols = [r for c_id, r in enumerate(csv_cols) if c_id not in to_remove_col_ids]
csv_rows = [[c[r_id] for c in csv_cols] for r_id in range(len(csv_cols[0]))]
return csv_rows
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
"""Get list of text lines belong to table regions specified by table_list
Args:
ocr_list: list of OCR output in Casia format (Flax)
table_list: list of table output in Casia format (Flax)
Returns:
_type_: _description_
"""
table_texts = []
for table in table_list:
if table["type"] != "table":
continue
cur_table_texts = []
for ocr in ocr_list:
_iou = get_rect_iou(table["location"], ocr["location"], iou_type=1)
if _iou > 0.8:
cur_table_texts.append(ocr["text"])
table_texts.append(cur_table_texts)
return table_texts
def make_markdown_table(array: List[List[str]]) -> str:
"""Convert table rows in list format to markdown string
Args:
Python list with rows of table as lists
First element as header.
Example Input:
[["Name", "Age", "Height"],
["Jake", 20, 5'10],
["Mary", 21, 5'7]]
Returns:
String to put into a .md file
"""
array = compress_csv(array)
array = add_index_col(array)
markdown = "\n" + str("| ")
for e in array[0]:
to_add = " " + str(e) + str(" |")
markdown += to_add
markdown += "\n"
markdown += "| "
for i in range(len(array[0])):
markdown += str("--- | ")
markdown += "\n"
for entry in array[1:]:
markdown += str("| ")
for e in entry:
to_add = str(e) + str(" | ")
markdown += to_add
markdown += "\n"
return markdown + "\n"
def parse_csv_string_to_list(csv_str: str) -> List[List[str]]:
"""Convert CSV string to list of rows
Args:
csv_str: input CSV string
Returns:
Output table in list format
"""
io = StringIO(csv_str)
csv_reader = csv.reader(io, delimiter=",")
rows = [row for row in csv_reader]
return rows
def format_cell(cell: str, length_limit: Optional[int] = None) -> str:
"""Format cell content by remove redundant character and enforce length limit
Args:
cell: input cell text
length_limit: limit of text length.
Returns:
new cell text
"""
cell = cell.replace("\n", " ")
if length_limit:
cell = cell[:length_limit]
return cell
def extract_tables_from_csv_string(
csv_content: str, table_texts: List[List[str]]
) -> Tuple[List[str], str]:
"""Extract list of table from FullOCR output
(csv_content) with the specified table_texts
Args:
csv_content: CSV output from FullOCR pipeline
table_texts: list of table texts extracted
from get_table_from_ocr()
Returns:
List of tables and non-text content
"""
rows = parse_csv_string_to_list(csv_content)
used_row_ids = []
table_csv_list = []
for table in table_texts:
cur_rows = []
for row_id, row in enumerate(rows):
scores = [
any(cell in cell_reference for cell in table)
for cell_reference in row
if cell_reference
]
score = sum(scores) / len(scores)
if score > 0.5 and row_id not in used_row_ids:
used_row_ids.append(row_id)
cur_rows.append([format_cell(cell) for cell in row])
if cur_rows:
table_csv_list.append(make_markdown_table(cur_rows))
else:
print("table not matched", table)
non_table_rows = [
row for row_id, row in enumerate(rows) if row_id not in used_row_ids
]
non_table_text = "\n".join(
" ".join(format_cell(cell) for cell in row) for row in non_table_rows
)
return table_csv_list, non_table_text
def strip_special_chars_markdown(text: str) -> str:
"""Strip special characters from input text in markdown table format"""
return text.replace("|", "").replace(":---:", "").replace("---", "")
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
"""Convert markdown text to list of non-table spans and table spans
Args:
text: input markdown text
Returns:
list of table spans and non-table spans
"""
# init empty tables and texts list
tables = []
texts = []
# split input by line break
lines = text.split("\n")
cur_table = []
cur_text: List[str] = []
for line in lines:
line = line.strip()
if line.startswith("|"):
if len(cur_text) > 0:
texts.append(cur_text)
cur_text = []
cur_table.append(line)
else:
# add new table to the list
if len(cur_table) > 0:
tables.append(cur_table)
cur_table = []
cur_text.append(line)
table_texts = ["\n".join(table) for table in tables]
non_table_texts = ["\n".join(text) for text in texts]
return table_texts, non_table_texts
def table_cells_to_markdown(cells: List[dict]):
"""Convert list of cells with attached text to Markdown table"""
if len(cells) == 0:
return ""
all_row_ids = []
all_col_ids = []
for cell in cells:
all_row_ids.extend(cell["rows"])
all_col_ids.extend(cell["columns"])
num_rows, num_cols = max(all_row_ids) + 1, max(all_col_ids) + 1
table_rows = [["" for c in range(num_cols)] for r in range(num_rows)]
# start filling in the grid
for cell in cells:
cell_text = " ".join(item["text"] for item in cell["ocr"])
start_row_id, end_row_id = cell["rows"]
start_col_id, end_col_id = cell["columns"]
span_cell = end_row_id != start_row_id or end_col_id != start_col_id
# do not repeat long text in span cell to prevent context length issue
if span_cell and len(cell_text.replace(" ", "")) < 20 and start_row_id > 0:
for row in range(start_row_id, end_row_id + 1):
for col in range(start_col_id, end_col_id + 1):
table_rows[row][col] += cell_text + " "
else:
table_rows[start_row_id][start_col_id] += cell_text + " "
return make_markdown_table(table_rows)

View File

@@ -0,0 +1,3 @@
from .regex_extractor import FirstMatchRegexExtractor, RegexExtractor
__all__ = ["RegexExtractor", "FirstMatchRegexExtractor"]

View File

@@ -0,0 +1,150 @@
from __future__ import annotations
import re
from typing import Callable
from kotaemon.base import BaseComponent, Document, ExtractorOutput, Param
class RegexExtractor(BaseComponent):
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: list[str]
output_map: dict[str, str] | Callable[[str], str] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: str | list[str], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod
def run_raw_static(pattern: str, text: str) -> list[str]:
"""
Finds all non-overlapping occurrences of a pattern in a string.
Parameters:
pattern (str): The regular expression pattern to search for.
text (str): The input string to search in.
Returns:
List[str]: A list of all non-overlapping occurrences of the pattern in the
string.
"""
return re.findall(pattern, text)
@staticmethod
def map_output(text, output_map) -> str:
"""
Maps the given `text` to its corresponding value in the `output_map` dictionary.
Parameters:
text (str): The input text to be mapped.
output_map (dict): A dictionary containing mapping of input text to output
values.
Returns:
str: The corresponding value from the `output_map` if `text` is found in the
dictionary, otherwise returns the original `text`.
"""
if not output_map:
return text
if isinstance(output_map, dict):
return output_map.get(text, text)
return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
"""
Matches the raw text against the pattern and rans the output mapping, returning
an instance of ExtractorOutput.
Args:
text (str): The raw text to be processed.
Returns:
ExtractorOutput: The processed output as a list of ExtractorOutput.
"""
output: list[str] = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
)
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0] if output else "",
matches=output,
metadata={"origin": "RegexExtractor"},
)
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[ExtractorOutput]:
"""Match the input against a pattern and return the output for each input
Parameters:
text: contains the input string to be processed
Returns:
A list contains the output ExtractorOutput for each input
Example:
```pycon
>>> document1 = Document(...)
>>> document2 = Document(...)
>>> document_batch = [document1, document2]
>>> batch_output = self(document_batch)
>>> print(batch_output)
[output1_document1, output1_document2]
```
"""
# TODO: this conversion seems common
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
output = []
for each_input in input_:
output.append(self.run_raw(each_input))
return output
class FirstMatchRegexExtractor(RegexExtractor):
pattern: list[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)

View File

@@ -0,0 +1,25 @@
from .docstores import (
BaseDocumentStore,
ElasticsearchDocumentStore,
InMemoryDocumentStore,
SimpleFileDocumentStore,
)
from .vectorstores import (
BaseVectorStore,
ChromaVectorStore,
InMemoryVectorStore,
SimpleFileVectorStore,
)
__all__ = [
# Document stores
"BaseDocumentStore",
"InMemoryDocumentStore",
"ElasticsearchDocumentStore",
"SimpleFileDocumentStore",
# Vector stores
"BaseVectorStore",
"ChromaVectorStore",
"InMemoryVectorStore",
"SimpleFileVectorStore",
]

View File

@@ -0,0 +1,11 @@
from .base import BaseDocumentStore
from .elasticsearch import ElasticsearchDocumentStore
from .in_memory import InMemoryDocumentStore
from .simple_file import SimpleFileDocumentStore
__all__ = [
"BaseDocumentStore",
"InMemoryDocumentStore",
"ElasticsearchDocumentStore",
"SimpleFileDocumentStore",
]

View File

@@ -0,0 +1,47 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from kotaemon.base import Document
class BaseDocumentStore(ABC):
"""A document store is in charged of storing and managing documents"""
@abstractmethod
def __init__(self, *args, **kwargs):
...
@abstractmethod
def add(
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
**kwargs,
):
"""Add document into document store
Args:
docs: Document or list of documents
ids: List of ids of the documents. Optional, if not set will use doc.doc_id
"""
...
@abstractmethod
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
...
@abstractmethod
def get_all(self) -> List[Document]:
"""Get all documents"""
...
@abstractmethod
def count(self) -> int:
"""Count number of documents"""
...
@abstractmethod
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""
...

View File

@@ -0,0 +1,173 @@
from typing import List, Optional, Union
from kotaemon.base import Document
from .base import BaseDocumentStore
MAX_DOCS_TO_GET = 10**4
class ElasticsearchDocumentStore(BaseDocumentStore):
"""Simple memory document store that store document in a dictionary"""
def __init__(
self,
index_name: str = "docstore",
elasticsearch_url: str = "http://localhost:9200",
k1: float = 2.0,
b: float = 0.75,
**kwargs,
):
try:
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
except ImportError:
raise ImportError(
"To use ElaticsearchDocstore please install `pip install elasticsearch`"
)
self.elasticsearch_url = elasticsearch_url
self.index_name = index_name
self.k1 = k1
self.b = b
# Create an Elasticsearch client instance
self.client = Elasticsearch(elasticsearch_url, **kwargs)
self.es_bulk = bulk
# Define the index settings and mappings
settings = {
"analysis": {"analyzer": {"default": {"type": "standard"}}},
"similarity": {
"custom_bm25": {
"type": "BM25",
"k1": k1,
"b": b,
}
},
}
mappings = {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25", # Use the custom BM25 similarity
}
}
}
# Create the index with the specified settings and mappings
if not self.client.indices.exists(index=index_name):
self.client.indices.create(
index=index_name, mappings=mappings, settings=settings
)
def add(
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
refresh_indices: bool = True,
**kwargs,
):
"""Add document into document store
Args:
docs: list of documents to add
ids: specify the ids of documents to add or use existing doc.doc_id
refresh_indices: request Elasticsearch to update its index (default to True)
"""
if ids and not isinstance(ids, list):
ids = [ids]
if not isinstance(docs, list):
docs = [docs]
doc_ids = ids if ids else [doc.doc_id for doc in docs]
requests = []
for doc_id, doc in zip(doc_ids, docs):
text = doc.text
metadata = doc.metadata
request = {
"_op_type": "index",
"_index": self.index_name,
"content": text,
"metadata": metadata,
"_id": doc_id,
}
requests.append(request)
self.es_bulk(self.client, requests)
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
def query_raw(self, query: dict) -> List[Document]:
"""Query Elasticsearch store using query format of ES client
Args:
query (dict): Elasticsearch query format
Returns:
List[Document]: List of result documents
"""
res = self.client.search(index=self.index_name, body=query)
docs = []
for r in res["hits"]["hits"]:
docs.append(
Document(
id_=r["_id"],
text=r["_source"]["content"],
metadata=r["_source"]["metadata"],
)
)
return docs
def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
"""Search Elasticsearch docstore using search query (BM25)
Args:
query (str): query text
top_k (int, optional): number of
top documents to return. Defaults to 10.
Returns:
List[Document]: List of result documents
"""
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
if doc_ids:
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
return self.query_raw(query_dict)
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
if not isinstance(ids, list):
ids = [ids]
query_dict = {"query": {"terms": {"_id": ids}}}
return self.query_raw(query_dict)
def count(self) -> int:
"""Count number of documents"""
count = int(
self.client.cat.count(index=self.index_name, format="json")[0]["count"]
)
return count
def get_all(self) -> List[Document]:
"""Get all documents"""
query_dict = {"query": {"match_all": {}}, "size": MAX_DOCS_TO_GET}
return self.query_raw(query_dict)
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""
if not isinstance(ids, list):
ids = [ids]
query = {"query": {"terms": {"_id": ids}}}
self.client.delete_by_query(index=self.index_name, body=query)
self.client.indices.refresh(index=self.index_name)
def __persist_flow__(self):
return {
"index_name": self.index_name,
"elasticsearch_url": self.elasticsearch_url,
"k1": self.k1,
"b": self.b,
}

View File

@@ -0,0 +1,85 @@
import json
from pathlib import Path
from typing import List, Optional, Union
from kotaemon.base import Document
from .base import BaseDocumentStore
class InMemoryDocumentStore(BaseDocumentStore):
"""Simple memory document store that store document in a dictionary"""
def __init__(self):
self._store = {}
def add(
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
**kwargs,
):
"""Add document into document store
Args:
docs: list of documents to add
ids: specify the ids of documents to add or
use existing doc.doc_id
exist_ok: raise error when duplicate doc-id
found in the docstore (default to False)
"""
exist_ok: bool = kwargs.pop("exist_ok", False)
if ids and not isinstance(ids, list):
ids = [ids]
if not isinstance(docs, list):
docs = [docs]
doc_ids = ids if ids else [doc.doc_id for doc in docs]
for doc_id, doc in zip(doc_ids, docs):
if doc_id in self._store and not exist_ok:
raise ValueError(f"Document with id {doc_id} already exist")
self._store[doc_id] = doc
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
if not isinstance(ids, list):
ids = [ids]
return [self._store[doc_id] for doc_id in ids]
def get_all(self) -> List[Document]:
"""Get all documents"""
return list(self._store.values())
def count(self) -> int:
"""Count number of documents"""
return len(self._store)
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""
if not isinstance(ids, list):
ids = [ids]
for doc_id in ids:
del self._store[doc_id]
def save(self, path: Union[str, Path]):
"""Save document to path"""
store = {key: value.to_dict() for key, value in self._store.items()}
with open(path, "w") as f:
json.dump(store, f)
def load(self, path: Union[str, Path]):
"""Load document store from path"""
with open(path) as f:
store = json.load(f)
# TODO: save and load aren't lossless. A Document-subclass will lose
# information. Need to edit the `to_dict` and `from_dict` methods in
# the Document class.
# For better query support, utilize SQLite as the default document store.
# Also, for portability, use SQLAlchemy for document store.
self._store = {key: Document.from_dict(value) for key, value in store.items()}
def __persist_flow__(self):
return {}

View File

@@ -0,0 +1,56 @@
from pathlib import Path
from typing import List, Optional, Union
from kotaemon.base import Document
from .in_memory import InMemoryDocumentStore
class SimpleFileDocumentStore(InMemoryDocumentStore):
"""Improve InMemoryDocumentStore by auto saving whenever the corpus is changed"""
def __init__(self, path: str | Path):
super().__init__()
self._path = path
if path is not None and Path(path).is_file():
self.load(path)
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
if not isinstance(ids, list):
ids = [ids]
for doc_id in ids:
if doc_id not in self._store:
self.load(self._path)
break
return [self._store[doc_id] for doc_id in ids]
def add(
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
**kwargs,
):
"""Add document into document store
Args:
docs: list of documents to add
ids: specify the ids of documents to add or
use existing doc.doc_id
exist_ok: raise error when duplicate doc-id
found in the docstore (default to False)
"""
super().add(docs=docs, ids=ids, **kwargs)
self.save(self._path)
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""
super().delete(ids=ids)
self.save(self._path)
def __persist_flow__(self):
from theflow.utils.modules import serialize
return {"path": serialize(self._path)}

View File

@@ -0,0 +1,11 @@
from .base import BaseVectorStore
from .chroma import ChromaVectorStore
from .in_memory import InMemoryVectorStore
from .simple_file import SimpleFileVectorStore
__all__ = [
"BaseVectorStore",
"ChromaVectorStore",
"InMemoryVectorStore",
"SimpleFileVectorStore",
]

View File

@@ -0,0 +1,169 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional
from kotaemon.base import DocumentWithEmbedding
from llama_index.schema import NodeRelationship, RelatedNodeInfo
from llama_index.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.types import VectorStore as LIVectorStore
from llama_index.vector_stores.types import VectorStoreQuery
class BaseVectorStore(ABC):
@abstractmethod
def __init__(self, *args, **kwargs):
...
@abstractmethod
def add(
self,
embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
) -> list[str]:
"""Add vector embeddings to vector stores
Args:
embeddings: List of embeddings
metadatas: List of metadata of the embeddings
ids: List of ids of the embeddings
kwargs: meant for vectorstore-specific parameters
Returns:
List of ids of the embeddings
"""
...
@abstractmethod
def delete(self, ids: list[str], **kwargs):
"""Delete vector embeddings from vector stores
Args:
ids: List of ids of the embeddings to be deleted
kwargs: meant for vectorstore-specific parameters
"""
...
@abstractmethod
def query(
self,
embedding: list[float],
top_k: int = 1,
ids: Optional[list[str]] = None,
**kwargs,
) -> tuple[list[list[float]], list[float], list[str]]:
"""Return the top k most similar vector embeddings
Args:
embedding: List of embeddings
top_k: Number of most similar embeddings to return
ids: List of ids of the embeddings to be queried
Returns:
the matched embeddings, the similarity scores, and the ids
"""
...
class LlamaIndexVectorStore(BaseVectorStore):
_li_class: type[LIVectorStore | BasePydanticVectorStore]
def __init__(self, *args, **kwargs):
if self._li_class is None:
raise AttributeError(
"Require `_li_class` to set a VectorStore class from LlamarIndex"
)
from dataclasses import fields
self._client = self._li_class(*args, **kwargs)
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
if key in self._vsq_kwargs:
self._vsq_kwargs.remove(key)
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
return super().__setattr__(name, value)
return setattr(self._client, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._client, name)
def add(
self,
embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
):
if isinstance(embeddings[0], list):
nodes: list[DocumentWithEmbedding] = [
DocumentWithEmbedding(embedding=embedding) for embedding in embeddings
]
else:
nodes = embeddings # type: ignore
if metadatas is not None:
for node, metadata in zip(nodes, metadatas):
node.metadata = metadata
if ids is not None:
for node, id in zip(nodes, ids):
node.id_ = id
node.relationships = {
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
}
return self._client.add(nodes=nodes)
def delete(self, ids: list[str], **kwargs):
for id_ in ids:
self._client.delete(ref_doc_id=id_, **kwargs)
def query(
self,
embedding: list[float],
top_k: int = 1,
ids: Optional[list[str]] = None,
**kwargs,
) -> tuple[list[list[float]], list[float], list[str]]:
"""Return the top k most similar vector embeddings
Args:
embedding: List of embeddings
top_k: Number of most similar embeddings to return
ids: List of ids of the embeddings to be queried
kwargs: extra query parameters. Depending on the name, these parameters
will be used when constructing the VectorStoreQuery object or when
performing querying of the underlying vector store.
Returns:
the matched embeddings, the similarity scores, and the ids
"""
vsq_kwargs = {}
vs_kwargs = {}
for kwkey, kwvalue in kwargs.items():
if kwkey in self._vsq_kwargs:
vsq_kwargs[kwkey] = kwvalue
else:
vs_kwargs[kwkey] = kwvalue
output = self._client.query(
query=VectorStoreQuery(
query_embedding=embedding,
similarity_top_k=top_k,
node_ids=ids,
**vsq_kwargs,
),
**vs_kwargs,
)
embeddings = []
if output.nodes:
for node in output.nodes:
embeddings.append(node.embedding)
similarities = output.similarities if output.similarities else []
out_ids = output.ids if output.ids else []
return embeddings, similarities, out_ids

View File

@@ -0,0 +1,96 @@
from typing import Any, Dict, List, Optional, Type, cast
from llama_index.vector_stores.chroma import ChromaVectorStore as LIChromaVectorStore
from .base import LlamaIndexVectorStore
class ChromaVectorStore(LlamaIndexVectorStore):
_li_class: Type[LIChromaVectorStore] = LIChromaVectorStore
def __init__(
self,
path: str = "./chroma",
collection_name: str = "default",
host: str = "localhost",
port: str = "8000",
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
collection_kwargs: Optional[dict] = None,
stores_text: bool = True,
flat_metadata: bool = True,
**kwargs: Any,
):
self._path = path
self._collection_name = collection_name
self._host = host
self._port = port
self._ssl = ssl
self._headers = headers
self._collection_kwargs = collection_kwargs
self._stores_text = stores_text
self._flat_metadata = flat_metadata
self._kwargs = kwargs
try:
import chromadb
except ImportError:
raise ImportError(
"ChromaVectorStore requires chromadb. "
"Please install chromadb first `pip install chromadb`"
)
client = chromadb.PersistentClient(path=path)
collection = client.get_or_create_collection(collection_name)
# pass through for nice IDE support
super().__init__(
chroma_collection=collection,
host=host,
port=port,
ssl=ssl,
headers=headers or {},
collection_kwargs=collection_kwargs or {},
stores_text=stores_text,
flat_metadata=flat_metadata,
**kwargs,
)
self._client = cast(LIChromaVectorStore, self._client)
def delete(self, ids: List[str], **kwargs):
"""Delete vector embeddings from vector stores
Args:
ids: List of ids of the embeddings to be deleted
kwargs: meant for vectorstore-specific parameters
"""
self._client.client.delete(ids=ids)
def delete_collection(self, collection_name: Optional[str] = None):
"""Delete entire collection under specified name from vector stores
Args:
collection_name: Name of the collection to delete
"""
# a rather ugly chain call but it do the job of finding
# original chromadb client and call delete_collection() method
if collection_name is None:
collection_name = self._client.client.name
self._client.client._client.delete_collection(collection_name)
def count(self) -> int:
return self._collection.count()
def __persist_flow__(self):
return {
"path": self._path,
"collection_name": self._collection_name,
"host": self._host,
"port": self._port,
"ssl": self._ssl,
"headers": self._headers,
"collection_kwargs": self._collection_kwargs,
"stores_text": self._stores_text,
"flat_metadata": self._flat_metadata,
**self._kwargs,
}

View File

@@ -0,0 +1,62 @@
"""Simple vector store index."""
from typing import Any, Optional, Type
import fsspec
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.vector_stores.simple import SimpleVectorStoreData
from .base import LlamaIndexVectorStore
class InMemoryVectorStore(LlamaIndexVectorStore):
_li_class: Type[LISimpleVectorStore] = LISimpleVectorStore
store_text: bool = False
def __init__(
self,
data: Optional[SimpleVectorStoreData] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._data = data or SimpleVectorStoreData()
self._fs = fs or fsspec.filesystem("file")
super().__init__(
data=data,
fs=fs,
**kwargs,
)
def save(
self,
save_path: str,
fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs,
):
"""save a simpleVectorStore to a dictionary.
Args:
save_path: Path of saving vector to disk.
fs: An abstract super-class for pythonic file-systems
"""
self._client.persist(persist_path=save_path, fs=fs)
def load(self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None):
"""Create a SimpleKVStore from a load directory.
Args:
load_path: Path of loading vector.
fs: An abstract super-class for pythonic file-systems
"""
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)
def __persist_flow__(self):
d = self._data.to_dict()
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
return {
"data": d,
# "fs": self._fs,
}

View File

@@ -0,0 +1,65 @@
"""Simple file vector store index."""
from pathlib import Path
from typing import Any, Optional, Type
import fsspec
from kotaemon.base import DocumentWithEmbedding
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
from llama_index.vector_stores.simple import SimpleVectorStoreData
from .base import LlamaIndexVectorStore
class SimpleFileVectorStore(LlamaIndexVectorStore):
"""Similar to InMemoryVectorStore but is backed by file by default"""
_li_class: Type[LISimpleVectorStore] = LISimpleVectorStore
store_text: bool = False
def __init__(
self,
path: str | Path,
data: Optional[SimpleVectorStoreData] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._data = data or SimpleVectorStoreData()
self._fs = fs or fsspec.filesystem("file")
self._path = path
self._save_path = Path(path)
super().__init__(
data=data,
fs=fs,
**kwargs,
)
if self._save_path.is_file():
self._client = self._li_class.from_persist_path(
persist_path=str(self._save_path), fs=self._fs
)
def add(
self,
embeddings: list[list[float]] | list[DocumentWithEmbedding],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
):
r = super().add(embeddings, metadatas, ids)
self._client.persist(str(self._save_path), self._fs)
return r
def delete(self, ids: list[str], **kwargs):
r = super().delete(ids, **kwargs)
self._client.persist(str(self._save_path), self._fs)
return r
def __persist_flow__(self):
d = self._data.to_dict()
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
return {
"data": d,
"path": str(self._path),
# "fs": self._fs,
}

View File

@@ -0,0 +1,73 @@
# build backand and build dependencies
[build-system]
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
include-package-data = false
packages.find.include = ["kotaemon*"]
packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies
[project]
name = "kotaemon"
version = "0.3.5"
requires-python = ">= 3.10"
description = "Kotaemon core library for AI development."
dependencies = [
"langchain",
"langchain-community",
"theflow",
"llama-index>=0.9.0",
"llama-hub",
"gradio>=4.0.0",
"openpyxl",
"cookiecutter",
"click",
"pandas",
"trogon",
]
readme = "README.md"
license = { text = "MIT License" }
authors = [
{ name = "john", email = "john@cinnamon.is" },
{ name = "ian", email = "ian@cinnamon.is" },
{ name = "tadashi", email = "tadashi@cinnamon.is" },
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
[project.optional-dependencies]
dev = [
"ipython",
"pytest",
"pre-commit",
"black",
"flake8",
"sphinx",
"coverage",
"openai",
"langchain-openai",
"chromadb",
"wikipedia",
"duckduckgo-search",
"googlesearch-python",
"python-dotenv",
"pytest-mock",
"unstructured[pdf]",
"sentence_transformers",
"cohere",
"elasticsearch",
"pypdf",
]
[project.scripts]
kh = "kotaemon.cli:main"
[project.urls]
Homepage = "https://github.com/Cinnamon/kotaemon/"
Repository = "https://github.com/Cinnamon/kotaemon/"
Documentation = "https://github.com/Cinnamon/kotaemon/wiki"

9
libs/kotaemon/pytest.ini Normal file
View File

@@ -0,0 +1,9 @@
[pytest]
minversion = 7.4.0
testpaths = tests
addopts = -ra -q
log_cli=true
log_level=WARNING
log_format = %(asctime)s %(levelname)s %(message)s
log_date_format = %Y-%m-%d %H:%M:%S
log_file = logs/pytest-logs.txt

Some files were not shown because too many files have changed in this diff Show More