Make ktem official (#134)
* Move kotaemon and ktem into same folder * Update docs * Update CI * Resolve mypy, isorts * Re-allow test pdf files
This commit is contained in:
committed by
GitHub
parent
9c5b707010
commit
2dd531114f
130
libs/kotaemon/README.md
Normal file
130
libs/kotaemon/README.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# kotaemon
|
||||
|
||||
Quick and easy AI components to build Kotaemon - applicable in client
|
||||
project.
|
||||
|
||||
## Documentation
|
||||
|
||||
https://docs.promptui.dm.cinnamon.is
|
||||
|
||||
## Install
|
||||
|
||||
```shell
|
||||
pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
### Setup
|
||||
|
||||
- Create conda environment (suggest 3.10)
|
||||
|
||||
```shell
|
||||
conda create -n kotaemon python=3.10
|
||||
conda activate kotaemon
|
||||
```
|
||||
|
||||
- Clone the repo
|
||||
|
||||
```shell
|
||||
git clone git@github.com:Cinnamon/kotaemon.git
|
||||
cd kotaemon
|
||||
```
|
||||
|
||||
- Install all
|
||||
|
||||
```shell
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
- Pre-commit
|
||||
|
||||
```shell
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
- Test
|
||||
|
||||
```shell
|
||||
pytest tests
|
||||
```
|
||||
|
||||
### Credential sharing
|
||||
|
||||
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
|
||||
internally uses `gpg` to encrypt and decrypt secret files.
|
||||
|
||||
This repo uses `python-dotenv` to manage credentials stored as environment variable.
|
||||
Please note that the use of `python-dotenv` and credentials are for development
|
||||
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
|
||||
|
||||
#### Install git-secret
|
||||
|
||||
Please follow the [official guide](https://sobolevn.me/git-secret/installation) to install git-secret.
|
||||
|
||||
For Windows users, see [For Windows users](#for-windows-users).
|
||||
|
||||
For users who don't have sudo privilege to install packages, follow the `Manual Installation` in the [official guide](https://sobolevn.me/git-secret/installation) and set `PREFIX` to a path that you have access to. And please don't forget to add `PREFIX` to your `PATH`.
|
||||
|
||||
#### Gaining access
|
||||
|
||||
In order to gain access to the secret files, you must provide your gpg public file to anyone who has access and ask them to ask your key to the keyring. For a quick tutorial on generating your gpg key pair, you can refer to the `Using gpg` section from the [git-secret main page](https://sobolevn.me/git-secret/).
|
||||
|
||||
#### Decrypt the secret file
|
||||
|
||||
The credentials are encrypted in the `.env.secret` file. To print the decrypted content to stdout, run
|
||||
|
||||
```shell
|
||||
git-secret cat [filename]
|
||||
```
|
||||
|
||||
Or to get the decrypted `.env` file, run
|
||||
|
||||
```shell
|
||||
git-secret reveal [filename]
|
||||
```
|
||||
|
||||
#### For Windows users
|
||||
|
||||
git-secret is currently not available for Windows, thus the easiest way is to use it in WSL (please use the latest version of WSL2). From there you have 2 options:
|
||||
|
||||
1. Using the gpg of WSL.
|
||||
|
||||
This is the most straight-forward option since you would use WSL just like any other unix environment. However, the downside is that you have to make WSL your main environment, which means WSL must have write permission on your repo. To achieve this, you must either:
|
||||
|
||||
- Clone and store your repo inside WSL's file system.
|
||||
- Provide WSL with necessary permission on your Windows file system. This can be achieve by setting `automount` options for WSL. To do that, add these content to `/etc/wsl.conf` and then restart your sub-system.
|
||||
|
||||
```shell
|
||||
[automount]
|
||||
options = "metadata,umask=022,fmask=011"
|
||||
```
|
||||
|
||||
This enables all permissions for user owner.
|
||||
|
||||
2. Using the gpg of Windows but with git-secret from WSL.
|
||||
|
||||
For those who use Windows as the main environment, having to switch back and forth between Windows and WSL will be inconvenient. You can instead stay within your Windows environment and apply some tricks to use `git-secret` from WSL.
|
||||
|
||||
- Install and setup `gpg` on Windows.
|
||||
- Install `git-secret` on WSL. Now in Windows, you can invoke `git-secret` using `wsl git-secret`.
|
||||
- Alternatively you can setup alias in CMD to shorten the syntax. Please refer to [this SO answer](https://stackoverflow.com/a/65823225) for the instruction. Some recommended aliases are:
|
||||
|
||||
```bat
|
||||
@echo off
|
||||
|
||||
:: Commands
|
||||
DOSKEY ls=dir /B $*
|
||||
DOSKEY ll=dir /a $*
|
||||
DOSKEY git-secret=wsl git-secret $*
|
||||
DOSKEY gs=wsl git-secret $*
|
||||
```
|
||||
|
||||
Now you can invoke `git-secret` in CMD using `git-secret` or `gs`.
|
||||
|
||||
- For Powershell users, similar behaviours can be achieved using `Set-Alias` and `profile.ps1`. Please refer this [SO thread](https://stackoverflow.com/questions/61081434/how-do-i-create-a-permanent-alias-file-in-powershell-core) as an example.
|
||||
|
||||
### Code base structure
|
||||
|
||||
- documents: define document
|
||||
- loaders
|
23
libs/kotaemon/kotaemon/__init__.py
Normal file
23
libs/kotaemon/kotaemon/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Disable telemetry with monkey patching
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
import posthog
|
||||
|
||||
def capture(*args, **kwargs):
|
||||
logger.info("posthog.capture called with args: %s, kwargs: %s", args, kwargs)
|
||||
|
||||
posthog.capture = capture
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
os.environ["HAYSTACK_TELEMETRY_ENABLED"] = "False"
|
||||
import haystack.telemetry
|
||||
|
||||
haystack.telemetry.telemetry = None
|
||||
except ImportError:
|
||||
pass
|
25
libs/kotaemon/kotaemon/agents/__init__.py
Normal file
25
libs/kotaemon/kotaemon/agents/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .base import BaseAgent
|
||||
from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad
|
||||
from .langchain_based import LangchainAgent
|
||||
from .react.agent import ReactAgent
|
||||
from .rewoo.agent import RewooAgent
|
||||
from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool
|
||||
|
||||
__all__ = [
|
||||
# agent
|
||||
"BaseAgent",
|
||||
"ReactAgent",
|
||||
"RewooAgent",
|
||||
"LangchainAgent",
|
||||
# tool
|
||||
"BaseTool",
|
||||
"ComponentTool",
|
||||
"GoogleSearchTool",
|
||||
"WikipediaTool",
|
||||
"LLMTool",
|
||||
# io
|
||||
"AgentType",
|
||||
"AgentOutput",
|
||||
"AgentFinish",
|
||||
"BaseScratchPad",
|
||||
]
|
57
libs/kotaemon/kotaemon/agents/base.py
Normal file
57
libs/kotaemon/kotaemon/agents/base.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from kotaemon.base import BaseComponent, Node, Param
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .io import AgentOutput, AgentType
|
||||
from .tools import BaseTool
|
||||
|
||||
|
||||
class BaseAgent(BaseComponent):
|
||||
"""Define base agent interface"""
|
||||
|
||||
name: str = Param(help="Name of the agent.")
|
||||
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
|
||||
description: str = Param(
|
||||
help=(
|
||||
"Description used to tell the model how/when/why to use the agent. You can"
|
||||
" provide few-shot examples as a part of the description. This will be"
|
||||
" input to the prompt of LLM."
|
||||
)
|
||||
)
|
||||
llm: Optional[BaseLLM] = Node(
|
||||
help=(
|
||||
"LLM to be used for the agent (optional). LLM must implement BaseLLM"
|
||||
" interface."
|
||||
)
|
||||
)
|
||||
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
|
||||
help="A prompt template or a dict to supply different prompt to the agent"
|
||||
)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [],
|
||||
help="List of plugins / tools to be used in the agent",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def safeguard_run(run_func, *args, **kwargs):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return run_func(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
return AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status="failed",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
def add_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Helper method to add tools and update agent state if needed"""
|
||||
self.plugins.extend(tools)
|
||||
|
||||
def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]:
|
||||
"""Run the component."""
|
||||
raise NotImplementedError()
|
3
libs/kotaemon/kotaemon/agents/io/__init__.py
Normal file
3
libs/kotaemon/kotaemon/agents/io/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad
|
||||
|
||||
__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"]
|
254
libs/kotaemon/kotaemon/agents/io/base.py
Normal file
254
libs/kotaemon/kotaemon/agents/io/base.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
|
||||
|
||||
from kotaemon.base import LLMInterface
|
||||
from pydantic import Extra
|
||||
|
||||
|
||||
def check_log():
|
||||
"""
|
||||
Checks if logging has been enabled.
|
||||
:return: True if logging has been enabled, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return os.environ.get("LOG_PATH", None) is not None
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
|
||||
class BaseScratchPad:
|
||||
"""
|
||||
Base class for output handlers.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
logger : logging.Logger
|
||||
The logger object to log messages.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
stop():
|
||||
Stop the output.
|
||||
|
||||
update_status(output: str, **kwargs):
|
||||
Update the status of the output.
|
||||
|
||||
thinking(name: str):
|
||||
Log that a process is thinking.
|
||||
|
||||
done(_all=False):
|
||||
Log that the process is done.
|
||||
|
||||
stream_print(item: str):
|
||||
Not implemented.
|
||||
|
||||
json_print(item: Dict[str, Any]):
|
||||
Log a JSON object.
|
||||
|
||||
panel_print(item: Any, title: str = "Output", stream: bool = False):
|
||||
Log a panel output.
|
||||
|
||||
clear():
|
||||
Not implemented.
|
||||
|
||||
print(content: str, **kwargs):
|
||||
Log arbitrary content.
|
||||
|
||||
format_json(json_obj: str):
|
||||
Format a JSON object.
|
||||
|
||||
debug(content: str, **kwargs):
|
||||
Log a debug message.
|
||||
|
||||
info(content: str, **kwargs):
|
||||
Log an informational message.
|
||||
|
||||
warning(content: str, **kwargs):
|
||||
Log a warning message.
|
||||
|
||||
error(content: str, **kwargs):
|
||||
Log an error message.
|
||||
|
||||
critical(content: str, **kwargs):
|
||||
Log a critical message.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the BaseOutput object.
|
||||
|
||||
"""
|
||||
self.logger = logging
|
||||
self.log = []
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the output.
|
||||
"""
|
||||
|
||||
def update_status(self, output: str, **kwargs):
|
||||
"""
|
||||
Update the status of the output.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(output)
|
||||
|
||||
def thinking(self, name: str):
|
||||
"""
|
||||
Log that a process is thinking.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(f"{name} is thinking...")
|
||||
|
||||
def done(self, _all=False):
|
||||
"""
|
||||
Log that the process is done.
|
||||
"""
|
||||
|
||||
if check_log():
|
||||
self.logger.info("Done")
|
||||
|
||||
def stream_print(self, item: str):
|
||||
"""
|
||||
Stream print.
|
||||
"""
|
||||
|
||||
def json_print(self, item: Dict[str, Any]):
|
||||
"""
|
||||
Log a JSON object.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(json.dumps(item, indent=2))
|
||||
|
||||
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
|
||||
"""
|
||||
Log a panel output.
|
||||
|
||||
Args:
|
||||
item : Any
|
||||
The item to log.
|
||||
title : str, optional
|
||||
The title of the panel, defaults to "Output".
|
||||
stream : bool, optional
|
||||
"""
|
||||
if not stream:
|
||||
self.log.append(item)
|
||||
if check_log():
|
||||
self.logger.info("-" * 20)
|
||||
self.logger.info(item)
|
||||
self.logger.info("-" * 20)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Not implemented.
|
||||
"""
|
||||
|
||||
def print(self, content: str, **kwargs):
|
||||
"""
|
||||
Log arbitrary content.
|
||||
"""
|
||||
self.log.append(content)
|
||||
if check_log():
|
||||
self.logger.info(content)
|
||||
|
||||
def format_json(self, json_obj: str):
|
||||
"""
|
||||
Format a JSON object.
|
||||
"""
|
||||
formatted_json = json.dumps(json_obj, indent=2)
|
||||
return formatted_json
|
||||
|
||||
def debug(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a debug message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.debug(content, **kwargs)
|
||||
|
||||
def info(self, content: str, **kwargs):
|
||||
"""
|
||||
Log an informational message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(content, **kwargs)
|
||||
|
||||
def warning(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a warning message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.warning(content, **kwargs)
|
||||
|
||||
def error(self, content: str, **kwargs):
|
||||
"""
|
||||
Log an error message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.error(content, **kwargs)
|
||||
|
||||
def critical(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a critical message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.critical(content, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Agent's action to take.
|
||||
|
||||
Args:
|
||||
tool: The tool to invoke.
|
||||
tool_input: The input to the tool.
|
||||
log: The log message.
|
||||
"""
|
||||
|
||||
tool: str
|
||||
tool_input: Union[str, dict]
|
||||
log: str
|
||||
|
||||
|
||||
class AgentFinish(NamedTuple):
|
||||
"""Agent's return value when finishing execution.
|
||||
|
||||
Args:
|
||||
return_values: The return values of the agent.
|
||||
log: The log message.
|
||||
"""
|
||||
|
||||
return_values: dict
|
||||
log: str
|
||||
|
||||
|
||||
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
"""Output from an agent.
|
||||
|
||||
Args:
|
||||
text: The text output from the agent.
|
||||
agent_type: The type of agent.
|
||||
status: The status after executing the agent.
|
||||
error: The error message if any.
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str = "agent"
|
||||
agent_type: AgentType
|
||||
status: Literal["finished", "stopped", "failed"]
|
||||
error: Optional[str] = None
|
78
libs/kotaemon/kotaemon/agents/langchain_based.py
Normal file
78
libs/kotaemon/kotaemon/agents/langchain_based.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from kotaemon.llms import LLM, ChatLLM
|
||||
from langchain.agents import AgentType as LCAgentType
|
||||
from langchain.agents import initialize_agent
|
||||
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from .base import BaseAgent
|
||||
from .io import AgentOutput, AgentType
|
||||
from .tools import BaseTool
|
||||
|
||||
|
||||
class LangchainAgent(BaseAgent):
|
||||
"""Wrapper for Langchain Agent"""
|
||||
|
||||
name: str = "LangchainAgent"
|
||||
agent_type: AgentType
|
||||
description: str = "LangchainAgent for answering multi-step reasoning questions"
|
||||
AGENT_TYPE_MAP = {
|
||||
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
|
||||
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
|
||||
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
|
||||
}
|
||||
agent: Optional[LCAgentExecutor] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.agent_type not in self.AGENT_TYPE_MAP:
|
||||
raise NotImplementedError(
|
||||
f"AgentType {self.agent_type } not supported by Langchain wrapper"
|
||||
)
|
||||
self.update_agent_tools()
|
||||
|
||||
def update_agent_tools(self):
|
||||
assert isinstance(self.llm, (ChatLLM, LLM))
|
||||
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
|
||||
|
||||
# a fix for search_doc tool name:
|
||||
# use "Intermediate Answer" for self-ask agent
|
||||
found_search_tool = False
|
||||
if self.agent_type == AgentType.self_ask:
|
||||
for plugin in langchain_plugins:
|
||||
if plugin.name == "search_doc":
|
||||
plugin.name = "Intermediate Answer"
|
||||
langchain_plugins = [plugin]
|
||||
found_search_tool = True
|
||||
break
|
||||
|
||||
if self.agent_type != AgentType.self_ask or found_search_tool:
|
||||
# reinit Langchain AgentExecutor
|
||||
self.agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
self.llm.to_langchain_format(),
|
||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def add_tools(self, tools: List[BaseTool]) -> None:
|
||||
super().add_tools(tools)
|
||||
self.update_agent_tools()
|
||||
return
|
||||
|
||||
def run(self, instruction: str) -> AgentOutput:
|
||||
assert (
|
||||
self.agent is not None
|
||||
), "Lanchain AgentExecutor is not correctly initialized"
|
||||
|
||||
# Langchain AgentExecutor call
|
||||
output = self.agent(instruction)["output"]
|
||||
|
||||
return AgentOutput(
|
||||
text=output,
|
||||
agent_type=self.agent_type,
|
||||
status="finished",
|
||||
)
|
3
libs/kotaemon/kotaemon/agents/react/__init__.py
Normal file
3
libs/kotaemon/kotaemon/agents/react/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent import ReactAgent
|
||||
|
||||
__all__ = ["ReactAgent"]
|
204
libs/kotaemon/kotaemon/agents/react/agent.py
Normal file
204
libs/kotaemon/kotaemon/agents/react/agent.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from kotaemon.agents.base import BaseAgent, BaseLLM
|
||||
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.base import Param
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
"""
|
||||
Sequential ReactAgent class inherited from BaseAgent.
|
||||
Implementing ReAct agent paradigm https://arxiv.org/pdf/2210.03629.pdf
|
||||
"""
|
||||
|
||||
name: str = "ReactAgent"
|
||||
agent_type: AgentType = AgentType.react
|
||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||
)
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent. "
|
||||
)
|
||||
intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = Param(
|
||||
default_callback=lambda _: [],
|
||||
help="List of AgentAction and observation (tool) output",
|
||||
)
|
||||
max_iterations: int = 10
|
||||
strict_decode: bool = False
|
||||
|
||||
def _compose_plugin_description(self) -> str:
|
||||
"""
|
||||
Compose the worker prompt from the workers.
|
||||
|
||||
Example:
|
||||
toolname1[input]: tool1 description
|
||||
toolname2[input]: tool2 description
|
||||
"""
|
||||
prompt = ""
|
||||
try:
|
||||
for plugin in self.plugins:
|
||||
prompt += f"{plugin.name}[input]: {plugin.description}\n"
|
||||
except Exception:
|
||||
raise ValueError("Worker must have a name and description.")
|
||||
return prompt
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = []
|
||||
) -> str:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\nThought:"
|
||||
return thoughts
|
||||
|
||||
def _parse_output(self, text: str) -> Optional[AgentAction | AgentFinish]:
|
||||
"""
|
||||
Parse text output from LLM for the next Action or Final Answer
|
||||
Using Regex to parse "Action:\n Action Input:\n" for the next Action
|
||||
Using FINAL_ANSWER_ACTION to parse Final Answer
|
||||
|
||||
Args:
|
||||
text[str]: input text to parse
|
||||
"""
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
action_output: Optional[AgentAction | AgentFinish] = None
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise Exception(
|
||||
"Parsing LLM output produced both a final answer "
|
||||
f"and a parse-able action: {text}"
|
||||
)
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
# ensure if its a well formed SQL query we don't remove any trailing " chars
|
||||
if tool_input.startswith("SELECT ") is False:
|
||||
tool_input = tool_input.strip('"')
|
||||
|
||||
action_output = AgentAction(action, tool_input, text)
|
||||
|
||||
elif includes_answer:
|
||||
action_output = AgentFinish(
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
else:
|
||||
if self.strict_decode:
|
||||
raise Exception(f"Could not parse LLM output: `{text}`")
|
||||
else:
|
||||
action_output = AgentFinish({"output": text}, text)
|
||||
|
||||
return action_output
|
||||
|
||||
def _compose_prompt(self, instruction) -> str:
|
||||
"""
|
||||
Compose the prompt from template, worker description, examples and instruction.
|
||||
"""
|
||||
agent_scratchpad = self._construct_scratchpad(self.intermediate_steps)
|
||||
tool_description = self._compose_plugin_description()
|
||||
tool_names = ", ".join([plugin.name for plugin in self.plugins])
|
||||
if self.prompt_template is None:
|
||||
from .prompt import zero_shot_react_prompt
|
||||
|
||||
self.prompt_template = zero_shot_react_prompt
|
||||
return self.prompt_template.populate(
|
||||
instruction=instruction,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
tool_description=tool_description,
|
||||
tool_names=tool_names,
|
||||
)
|
||||
|
||||
def _format_function_map(self) -> dict[str, BaseTool]:
|
||||
"""Format the function map for the open AI function API.
|
||||
|
||||
Return:
|
||||
Dict[str, Callable]: The function map.
|
||||
"""
|
||||
# Map the function name to the real function object.
|
||||
function_map = {}
|
||||
for plugin in self.plugins:
|
||||
function_map[plugin.name] = plugin
|
||||
return function_map
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clear and reset the agent.
|
||||
"""
|
||||
self.intermediate_steps = []
|
||||
|
||||
def run(self, instruction, max_iterations=None) -> AgentOutput:
|
||||
"""
|
||||
Run the agent with the given instruction.
|
||||
|
||||
Args:
|
||||
instruction: Instruction to run the agent with.
|
||||
max_iterations: Maximum number of iterations
|
||||
of reasoning steps, defaults to 10.
|
||||
|
||||
Return:
|
||||
AgentOutput object.
|
||||
"""
|
||||
if not max_iterations:
|
||||
max_iterations = self.max_iterations
|
||||
assert max_iterations > 0
|
||||
|
||||
self.clear()
|
||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
status = "failed"
|
||||
response_text = None
|
||||
|
||||
for step_count in range(1, max_iterations + 1):
|
||||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
response = self.llm(
|
||||
prompt, stop=["Observation:"]
|
||||
) # could cause bugs if llm doesn't have `stop` as a parameter
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
if action_step is None:
|
||||
raise ValueError("Invalid action")
|
||||
is_finished_chain = isinstance(action_step, AgentFinish)
|
||||
if is_finished_chain:
|
||||
result = ""
|
||||
else:
|
||||
assert isinstance(action_step, AgentAction)
|
||||
action_name = action_step.tool
|
||||
tool_input = action_step.tool_input
|
||||
logging.info(f"Action: {action_name}")
|
||||
logging.info(f"Tool Input: {tool_input}")
|
||||
result = self._format_function_map()[action_name](tool_input)
|
||||
logging.info(f"Result: {result}")
|
||||
|
||||
self.intermediate_steps.append((action_step, result))
|
||||
if is_finished_chain:
|
||||
logging.info(f"Finished after {step_count} steps.")
|
||||
status = "finished"
|
||||
break
|
||||
else:
|
||||
status = "stopped"
|
||||
|
||||
return AgentOutput(
|
||||
text=response_text,
|
||||
agent_type=self.agent_type,
|
||||
status=status,
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
intermediate_steps=self.intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
)
|
28
libs/kotaemon/kotaemon/agents/react/prompt.py
Normal file
28
libs/kotaemon/kotaemon/agents/react/prompt.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# flake8: noqa
|
||||
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
zero_shot_react_prompt = PromptTemplate(
|
||||
template="""Answer the following questions as best you can. You have access to the following tools:
|
||||
{tool_description}
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
|
||||
Action Input: the input to the action
|
||||
|
||||
Observation: the result of the action
|
||||
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
#Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin! After each Action Input.
|
||||
|
||||
Question: {instruction}
|
||||
Thought:{agent_scratchpad}
|
||||
"""
|
||||
)
|
3
libs/kotaemon/kotaemon/agents/rewoo/__init__.py
Normal file
3
libs/kotaemon/kotaemon/agents/rewoo/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent import RewooAgent
|
||||
|
||||
__all__ = ["RewooAgent"]
|
273
libs/kotaemon/kotaemon/agents/rewoo/agent.py
Normal file
273
libs/kotaemon/kotaemon/agents/rewoo/agent.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from kotaemon.agents.base import BaseAgent
|
||||
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.agents.utils import get_plugin_response_content
|
||||
from kotaemon.base import Node, Param
|
||||
from kotaemon.indices.qa import CitationPipeline
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .planner import Planner
|
||||
from .solver import Solver
|
||||
|
||||
|
||||
class RewooAgent(BaseAgent):
|
||||
"""Distributive RewooAgent class inherited from BaseAgent.
|
||||
Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf"""
|
||||
|
||||
name: str = "RewooAgent"
|
||||
agent_type: AgentType = AgentType.rewoo
|
||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||
planner_llm: BaseLLM
|
||||
solver_llm: BaseLLM
|
||||
prompt_template: dict[str, PromptTemplate] = Param(
|
||||
default_callback=lambda _: {},
|
||||
help="A dict to supply different prompt to the agent.",
|
||||
)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="A list of plugins to be used in the model."
|
||||
)
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
|
||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||
def planner(self):
|
||||
return Planner(
|
||||
model=self.planner_llm,
|
||||
plugins=self.plugins,
|
||||
prompt_template=self.prompt_template.get("Planner", None),
|
||||
examples=self.examples.get("Planner", None),
|
||||
)
|
||||
|
||||
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
|
||||
def solver(self):
|
||||
return Solver(
|
||||
model=self.solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
)
|
||||
|
||||
def _parse_plan_map(
|
||||
self, planner_response: str
|
||||
) -> tuple[dict[str, list[str]], dict[str, str]]:
|
||||
"""
|
||||
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
|
||||
This is because sometimes LLM cannot follow the strict output format.
|
||||
Example:
|
||||
#Plan1
|
||||
#E1
|
||||
#E2
|
||||
should result in: {"#Plan1": ["#E1", "#E2"]}
|
||||
Or:
|
||||
#Plan1
|
||||
#Plan2
|
||||
#E1
|
||||
should result in: {"#Plan1": [], "#Plan2": ["#E1"]}
|
||||
This function should also return a plan map.
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
|
||||
"""
|
||||
valid_chunk = [
|
||||
line
|
||||
for line in planner_response.splitlines()
|
||||
if line.startswith("#Plan") or line.startswith("#E")
|
||||
]
|
||||
|
||||
plan_to_es: dict[str, list[str]] = dict()
|
||||
plans: dict[str, str] = dict()
|
||||
prev_key = ""
|
||||
for line in valid_chunk:
|
||||
key, description = line.split(":", 1)
|
||||
key = key.strip()
|
||||
if key.startswith("#Plan"):
|
||||
plans[key] = description.strip()
|
||||
plan_to_es[key] = []
|
||||
prev_key = key
|
||||
elif key.startswith("#E"):
|
||||
plan_to_es[prev_key].append(key)
|
||||
|
||||
return plan_to_es, plans
|
||||
|
||||
def _parse_planner_evidences(
|
||||
self, planner_response: str
|
||||
) -> tuple[dict[str, str], list[list[str]]]:
|
||||
"""
|
||||
Parse planner output. This should return a mapping from #E to tool call.
|
||||
It should also identify the level of each #E in dependency map.
|
||||
Example:
|
||||
{
|
||||
"#E1": "Tool1", "#E2": "Tool2",
|
||||
"#E3": "Tool3", "#E4": "Tool4"
|
||||
}, [[#E1, #E2], [#E3, #E4]]
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, str], List[List[str]]]:
|
||||
A mapping from #E to tool call and a list of levels.
|
||||
"""
|
||||
evidences: dict[str, str] = dict()
|
||||
dependence: dict[str, list[str]] = dict()
|
||||
for line in planner_response.splitlines():
|
||||
if line.startswith("#E") and line[2].isdigit():
|
||||
e, tool_call = line.split(":", 1)
|
||||
e, tool_call = e.strip(), tool_call.strip()
|
||||
if len(e) == 3:
|
||||
dependence[e] = []
|
||||
evidences[e] = tool_call
|
||||
for var in re.findall(r"#E\d+", tool_call):
|
||||
if var in evidences:
|
||||
dependence[e].append(var)
|
||||
else:
|
||||
evidences[e] = "No evidence found"
|
||||
level = []
|
||||
while dependence:
|
||||
select = [i for i in dependence if not dependence[i]]
|
||||
if len(select) == 0:
|
||||
raise ValueError("Circular dependency detected.")
|
||||
level.append(select)
|
||||
for item in select:
|
||||
dependence.pop(item)
|
||||
for item in dependence:
|
||||
for i in select:
|
||||
if i in dependence[item]:
|
||||
dependence[item].remove(i)
|
||||
|
||||
return evidences, level
|
||||
|
||||
def _run_plugin(
|
||||
self,
|
||||
e: str,
|
||||
planner_evidences: dict[str, str],
|
||||
worker_evidences: dict[str, str],
|
||||
output=BaseScratchPad(),
|
||||
):
|
||||
"""
|
||||
Run a plugin for a given evidence.
|
||||
This function should also cumulate the cost and tokens.
|
||||
"""
|
||||
result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="")
|
||||
tool_call = planner_evidences[e]
|
||||
if "[" not in tool_call:
|
||||
result["evidence"] = tool_call
|
||||
else:
|
||||
tool, tool_input = tool_call.split("[", 1)
|
||||
tool_input = tool_input[:-1]
|
||||
# find variables in input and replace with previous evidences
|
||||
for var in re.findall(r"#E\d+", tool_input):
|
||||
if var in worker_evidences:
|
||||
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
|
||||
try:
|
||||
selected_plugin = self._find_plugin(tool)
|
||||
if selected_plugin is None:
|
||||
raise ValueError("Invalid plugin detected")
|
||||
tool_response = selected_plugin(tool_input)
|
||||
result["evidence"] = get_plugin_response_content(tool_response)
|
||||
except ValueError:
|
||||
result["evidence"] = "No evidence found."
|
||||
finally:
|
||||
output.panel_print(
|
||||
result["evidence"], f"[green] Function Response of [blue]{tool}: "
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_worker_evidence(
|
||||
self,
|
||||
planner_evidences: dict[str, str],
|
||||
evidences_level: list[list[str]],
|
||||
output=BaseScratchPad(),
|
||||
) -> Any:
|
||||
"""
|
||||
Parallel execution of plugins in DAG for speedup.
|
||||
This is one of core benefits of ReWOO agents.
|
||||
|
||||
Args:
|
||||
planner_evidences: A mapping from #E to tool call.
|
||||
evidences_level: A list of levels of evidences.
|
||||
Calculated from DAG of plugin calls.
|
||||
output: Output object, defaults to BaseOutput().
|
||||
Returns:
|
||||
A mapping from #E to tool call.
|
||||
"""
|
||||
worker_evidences: dict[str, str] = dict()
|
||||
plugin_cost, plugin_token = 0.0, 0.0
|
||||
with ThreadPoolExecutor() as pool:
|
||||
for level in evidences_level:
|
||||
results = []
|
||||
for e in level:
|
||||
results.append(
|
||||
pool.submit(
|
||||
self._run_plugin,
|
||||
e,
|
||||
planner_evidences,
|
||||
worker_evidences,
|
||||
output,
|
||||
)
|
||||
)
|
||||
if len(results) > 1:
|
||||
output.update_status(f"Running tasks {level} in parallel.")
|
||||
else:
|
||||
output.update_status(f"Running task {level[0]}.")
|
||||
for r in results:
|
||||
resp = r.result()
|
||||
plugin_cost += resp["plugin_cost"]
|
||||
plugin_token += resp["plugin_token"]
|
||||
worker_evidences[resp["e"]] = resp["evidence"]
|
||||
output.done()
|
||||
|
||||
return worker_evidences, plugin_cost, plugin_token
|
||||
|
||||
def _find_plugin(self, name: str):
|
||||
for p in self.plugins:
|
||||
if p.name == name:
|
||||
return p
|
||||
|
||||
@BaseAgent.safeguard_run
|
||||
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
||||
"""
|
||||
Run the agent with a given instruction.
|
||||
"""
|
||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
|
||||
# Plan
|
||||
planner_output = self.planner(instruction)
|
||||
planner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(planner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
planner_text_output
|
||||
)
|
||||
|
||||
# Work
|
||||
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||
planner_evidences, evidence_level
|
||||
)
|
||||
worker_log = ""
|
||||
for plan in plan_to_es:
|
||||
worker_log += f"{plan}: {plans[plan]}\n"
|
||||
for e in plan_to_es[plan]:
|
||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||
|
||||
# Solve
|
||||
solver_output = self.solver(instruction, worker_log)
|
||||
solver_output_text = solver_output.text
|
||||
if use_citation:
|
||||
citation_pipeline = CitationPipeline(llm=self.solver_llm)
|
||||
citation = citation_pipeline(context=worker_log, question=instruction)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
return AgentOutput(
|
||||
text=solver_output_text,
|
||||
agent_type=self.agent_type,
|
||||
status="finished",
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
metadata={"citation": citation},
|
||||
)
|
83
libs/kotaemon/kotaemon/agents/rewoo/planner.py
Normal file
83
libs/kotaemon/kotaemon/agents/rewoo/planner.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.agents.base import BaseLLM, BaseTool
|
||||
from kotaemon.agents.io import BaseScratchPad
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
||||
|
||||
|
||||
class Planner(BaseComponent):
|
||||
model: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
examples: Optional[Union[str, List[str]]] = None
|
||||
plugins: List[BaseTool]
|
||||
|
||||
def _compose_worker_description(self) -> str:
|
||||
"""
|
||||
Compose the worker prompt from the workers.
|
||||
|
||||
Example:
|
||||
toolname1[input]: tool1 description
|
||||
toolname2[input]: tool2 description
|
||||
"""
|
||||
prompt = ""
|
||||
try:
|
||||
for worker in self.plugins:
|
||||
prompt += f"{worker.name}[input]: {worker.description}\n"
|
||||
except Exception:
|
||||
raise ValueError("Worker must have a name and description.")
|
||||
return prompt
|
||||
|
||||
def _compose_fewshot_prompt(self) -> str:
|
||||
if self.examples is None:
|
||||
return ""
|
||||
if isinstance(self.examples, str):
|
||||
return self.examples
|
||||
else:
|
||||
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||
|
||||
def _compose_prompt(self, instruction) -> str:
|
||||
"""
|
||||
Compose the prompt from template, worker description, examples and instruction.
|
||||
"""
|
||||
worker_desctription = self._compose_worker_description()
|
||||
fewshot = self._compose_fewshot_prompt()
|
||||
if self.prompt_template is not None:
|
||||
if "fewshot" in self.prompt_template.placeholders:
|
||||
return self.prompt_template.populate(
|
||||
tool_description=worker_desctription,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
)
|
||||
else:
|
||||
return self.prompt_template.populate(
|
||||
tool_description=worker_desctription, task=instruction
|
||||
)
|
||||
else:
|
||||
if self.examples is not None:
|
||||
return few_shot_planner_prompt.populate(
|
||||
tool_description=worker_desctription,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
)
|
||||
else:
|
||||
return zero_shot_planner_prompt.populate(
|
||||
tool_description=worker_desctription, task=instruction
|
||||
)
|
||||
|
||||
def run(self, instruction: str, output: BaseScratchPad = BaseScratchPad()) -> Any:
|
||||
response = None
|
||||
output.info("Running Planner")
|
||||
prompt = self._compose_prompt(instruction)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
try:
|
||||
response = self.model(prompt)
|
||||
self.log_progress(".planner", response=response)
|
||||
output.info("Planner run successful.")
|
||||
except ValueError as e:
|
||||
output.error("Planner failed to retrieve response from LLM")
|
||||
raise ValueError("Planner failed to retrieve response from LLM") from e
|
||||
|
||||
return response
|
119
libs/kotaemon/kotaemon/agents/rewoo/prompt.py
Normal file
119
libs/kotaemon/kotaemon/agents/rewoo/prompt.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# flake8: noqa
|
||||
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
zero_shot_planner_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||
|
||||
##Available Tools##
|
||||
{tool_description}
|
||||
|
||||
##Output Format (Replace '<...>')##
|
||||
#Plan1: <describe your plan here>
|
||||
#E1: <toolname>[<input here>] (eg. Search[What is Python])
|
||||
#Plan2: <describe next plan>
|
||||
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
|
||||
And so on...
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
one_shot_planner_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||
|
||||
##Available Tools##
|
||||
{tool_description}
|
||||
|
||||
##Output Format##
|
||||
#Plan1: <describe your plan here>
|
||||
#E1: <toolname>[<input here>]
|
||||
#Plan2: <describe next plan>
|
||||
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
|
||||
And so on...
|
||||
|
||||
##Example##
|
||||
Task: What is the 4th root of 64 to the power of 3?
|
||||
#Plan1: Find the 4th root of 64
|
||||
#E1: Calculator[64^(1/4)]
|
||||
#Plan2: Raise the result from #Plan1 to the power of 3
|
||||
#E2: Calculator[#E1^3]
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
few_shot_planner_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||
|
||||
##Available Tools##
|
||||
{tool_description}
|
||||
|
||||
##Output Format (Replace '<...>')##
|
||||
#Plan1: <describe your plan here>
|
||||
#E1: <toolname>[<input>]
|
||||
#Plan2: <describe next plan>
|
||||
#E2: <toolname>[<input, you can use #E1 to represent its expected output>]
|
||||
And so on...
|
||||
|
||||
##Examples##
|
||||
{fewshot}
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
zero_shot_solver_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||
|
||||
##My Plans and Evidences##
|
||||
{plan_evidence}
|
||||
|
||||
##Example Output##
|
||||
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
|
||||
So, <your conclusion>.
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
few_shot_solver_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||
|
||||
##My Plans and Evidences##
|
||||
{plan_evidence}
|
||||
|
||||
##Example Output##
|
||||
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
|
||||
So, <your conclusion>.
|
||||
|
||||
##Example##
|
||||
{fewshot}
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
65
libs/kotaemon/kotaemon/agents/rewoo/solver.py
Normal file
65
libs/kotaemon/kotaemon/agents/rewoo/solver.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.agents.io import BaseScratchPad
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
||||
|
||||
|
||||
class Solver(BaseComponent):
|
||||
model: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
examples: Optional[Union[str, List[str]]] = None
|
||||
|
||||
def _compose_fewshot_prompt(self) -> str:
|
||||
if self.examples is None:
|
||||
return ""
|
||||
if isinstance(self.examples, str):
|
||||
return self.examples
|
||||
else:
|
||||
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||
|
||||
def _compose_prompt(self, instruction, plan_evidence) -> str:
|
||||
"""
|
||||
Compose the prompt from template, plan&evidence, examples and instruction.
|
||||
"""
|
||||
fewshot = self._compose_fewshot_prompt()
|
||||
if self.prompt_template is not None:
|
||||
if "fewshot" in self.prompt_template.placeholders:
|
||||
return self.prompt_template.populate(
|
||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||
)
|
||||
else:
|
||||
return self.prompt_template.populate(
|
||||
plan_evidence=plan_evidence, task=instruction
|
||||
)
|
||||
else:
|
||||
if self.examples is not None:
|
||||
return few_shot_solver_prompt.populate(
|
||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||
)
|
||||
else:
|
||||
return zero_shot_solver_prompt.populate(
|
||||
plan_evidence=plan_evidence, task=instruction
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
instruction: str,
|
||||
plan_evidence: str,
|
||||
output: BaseScratchPad = BaseScratchPad(),
|
||||
) -> Any:
|
||||
response = None
|
||||
output.info("Running Solver")
|
||||
output.debug(f"Instruction: {instruction}")
|
||||
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||
prompt = self._compose_prompt(instruction, plan_evidence)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
try:
|
||||
response = self.model(prompt)
|
||||
output.info("Solver run successful.")
|
||||
except ValueError:
|
||||
output.error("Solver failed to retrieve response from LLM")
|
||||
|
||||
return response
|
6
libs/kotaemon/kotaemon/agents/tools/__init__.py
Normal file
6
libs/kotaemon/kotaemon/agents/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .base import BaseTool, ComponentTool
|
||||
from .google import GoogleSearchTool
|
||||
from .llm import LLMTool
|
||||
from .wikipedia import WikipediaTool
|
||||
|
||||
__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool", "LLMTool"]
|
138
libs/kotaemon/kotaemon/agents/tools/base.py
Normal file
138
libs/kotaemon/kotaemon/agents/tools/base.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from langchain.agents import Tool as LCTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolException(Exception):
|
||||
"""An optional exception that tool throws when execution error occurs.
|
||||
|
||||
When this exception is thrown, the agent will not stop working,
|
||||
but will handle the exception according to the handle_tool_error
|
||||
variable of the tool, and the processing result will be returned
|
||||
to the agent as observation, and printed in red on the console.
|
||||
"""
|
||||
|
||||
|
||||
class BaseTool(BaseComponent):
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Description used to tell the model how/when/why to use the tool.
|
||||
You can provide few-shot examples as a part of the description. This will be
|
||||
input to the prompt of LLM.
|
||||
"""
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
verbose: bool = False
|
||||
"""Whether to log the tool's progress."""
|
||||
handle_tool_error: Optional[
|
||||
Union[bool, str, Callable[[ToolException], str]]
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
args_schema = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if args_schema is not None:
|
||||
key_ = next(iter(args_schema.model_fields.keys()))
|
||||
args_schema.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
if args_schema is not None:
|
||||
result = args_schema.parse_obj(tool_input)
|
||||
return {k: v for k, v in result.dict().items() if k in tool_input}
|
||||
return tool_input
|
||||
|
||||
def _run_tool(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Call tool."""
|
||||
raise NotImplementedError(f"_run_tool is not implemented for {self.name}")
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(tool_input, str):
|
||||
return (tool_input,), {}
|
||||
else:
|
||||
return (), tool_input
|
||||
|
||||
def _handle_tool_error(self, e: ToolException) -> Any:
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
observation = None
|
||||
if not self.handle_tool_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
return observation
|
||||
|
||||
def to_langchain_format(self) -> LCTool:
|
||||
"""Convert this tool to Langchain format to use with its agent"""
|
||||
return LCTool(name=self.name, description=self.description, func=self.run)
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
# TODO (verbose_): Add logging
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
call_kwargs = {**kwargs, **tool_kwargs}
|
||||
observation = self._run_tool(*tool_args, **call_kwargs)
|
||||
except ToolException as e:
|
||||
observation = self._handle_tool_error(e)
|
||||
return observation
|
||||
else:
|
||||
return observation
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
|
||||
"""Wrapper for Langchain Tool"""
|
||||
new_tool = BaseTool(
|
||||
name=langchain_tool.name, description=langchain_tool.description
|
||||
)
|
||||
new_tool._run_tool = langchain_tool._run # type: ignore
|
||||
return new_tool
|
||||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
"""Wrapper around other BaseComponent to use it as a tool
|
||||
|
||||
Args:
|
||||
component: BaseComponent-based component to wrap
|
||||
postprocessor: Optional postprocessor for the component output
|
||||
"""
|
||||
|
||||
component: BaseComponent
|
||||
postprocessor: Optional[Callable] = None
|
||||
|
||||
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||||
output = self.component(*args, **kwargs)
|
||||
if self.postprocessor:
|
||||
output = self.postprocessor(output)
|
||||
|
||||
return output
|
51
libs/kotaemon/kotaemon/agents/tools/google.py
Normal file
51
libs/kotaemon/kotaemon/agents/tools/google.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import AnyStr, Optional, Type
|
||||
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class GoogleSearchArgs(BaseModel):
|
||||
query: str = Field(..., description="a search query")
|
||||
|
||||
|
||||
class GoogleSearchTool(BaseTool):
|
||||
name: str = "google_search"
|
||||
description: str = (
|
||||
"A search engine retrieving top search results as snippets from Google. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
try:
|
||||
from googlesearch import search
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"install googlesearch using `pip3 install googlesearch-python` to "
|
||||
"use this tool"
|
||||
)
|
||||
output = ""
|
||||
search_results = search(query, advanced=True)
|
||||
if search_results:
|
||||
output = "\n".join(
|
||||
"{} {}".format(item.title, item.description) for item in search_results
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SerpTool(BaseTool):
|
||||
name = "google_search"
|
||||
description = (
|
||||
"Worker that searches results from Google. Useful when you need to find short "
|
||||
"and succinct answers about a specific topic. Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
tool = SerpAPIWrapper()
|
||||
evidence = tool.run(query)
|
||||
|
||||
return evidence
|
31
libs/kotaemon/kotaemon/agents/tools/llm.py
Normal file
31
libs/kotaemon/kotaemon/agents/tools/llm.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import AnyStr, Optional, Type
|
||||
|
||||
from kotaemon.llms import BaseLLM
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool, ToolException
|
||||
|
||||
|
||||
class LLMArgs(BaseModel):
|
||||
query: str = Field(..., description="a search question or prompt")
|
||||
|
||||
|
||||
class LLMTool(BaseTool):
|
||||
name: str = "llm"
|
||||
description: str = (
|
||||
"A pretrained LLM like yourself. Useful when you need to act with "
|
||||
"general world knowledge and common sense. Prioritize it when you "
|
||||
"are confident in solving the problem "
|
||||
"yourself. Input can be any instruction."
|
||||
)
|
||||
llm: BaseLLM
|
||||
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
output = None
|
||||
try:
|
||||
response = self.llm(query)
|
||||
except ValueError:
|
||||
raise ToolException("LLM Tool call failed")
|
||||
output = response.text
|
||||
return output
|
65
libs/kotaemon/kotaemon/agents/tools/wikipedia.py
Normal file
65
libs/kotaemon/kotaemon/agents/tools/wikipedia.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Any, AnyStr, Optional, Type, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class Wiki:
|
||||
"""Wrapper around wikipedia API."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Check that wikipedia package is installed."""
|
||||
try:
|
||||
import wikipedia # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import wikipedia python package. "
|
||||
"Please install it with `pip install wikipedia`."
|
||||
)
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Try to search for wiki page.
|
||||
|
||||
If page exists, return the page summary, and a PageWithLookups object.
|
||||
If page does not exist, return similar entries.
|
||||
"""
|
||||
import wikipedia
|
||||
|
||||
try:
|
||||
page_content = wikipedia.page(search).content
|
||||
url = wikipedia.page(search).url
|
||||
result: Union[str, Document] = Document(
|
||||
text=page_content, metadata={"page": url}
|
||||
)
|
||||
except wikipedia.PageError:
|
||||
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
|
||||
except wikipedia.DisambiguationError:
|
||||
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
|
||||
return result
|
||||
|
||||
|
||||
class WikipediaArgs(BaseModel):
|
||||
query: str = Field(..., description="a search query as input to wkipedia")
|
||||
|
||||
|
||||
class WikipediaTool(BaseTool):
|
||||
"""Tool that adds the capability to query the Wikipedia API."""
|
||||
|
||||
name: str = "wikipedia"
|
||||
description: str = (
|
||||
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
||||
"Useful when you need to get holistic knowledge about people, "
|
||||
"places, companies, historical events, or other subjects. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = WikipediaArgs
|
||||
doc_store: Any = None
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> AnyStr:
|
||||
if not self.doc_store:
|
||||
self.doc_store = Wiki()
|
||||
tool = self.doc_store
|
||||
evidence = tool.search(query)
|
||||
return evidence
|
22
libs/kotaemon/kotaemon/agents/utils.py
Normal file
22
libs/kotaemon/kotaemon/agents/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
||||
def get_plugin_response_content(output) -> str:
|
||||
"""
|
||||
Wrapper for AgentOutput content return
|
||||
"""
|
||||
if isinstance(output, Document):
|
||||
return output.text
|
||||
else:
|
||||
return str(output)
|
||||
|
||||
|
||||
def calculate_cost(model_name: str, prompt_token: int, completion_token: int) -> float:
|
||||
"""
|
||||
Calculate the cost of a prompt and completion.
|
||||
|
||||
Returns:
|
||||
float: Cost of the provided model name with provided token information
|
||||
"""
|
||||
# TODO: to be implemented
|
||||
return 0.0
|
28
libs/kotaemon/kotaemon/base/__init__.py
Normal file
28
libs/kotaemon/kotaemon/base/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .component import BaseComponent, Node, Param, lazy
|
||||
from .schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
Document,
|
||||
DocumentWithEmbedding,
|
||||
ExtractorOutput,
|
||||
HumanMessage,
|
||||
LLMInterface,
|
||||
RetrievedDocument,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseComponent",
|
||||
"Document",
|
||||
"DocumentWithEmbedding",
|
||||
"BaseMessage",
|
||||
"SystemMessage",
|
||||
"AIMessage",
|
||||
"HumanMessage",
|
||||
"RetrievedDocument",
|
||||
"LLMInterface",
|
||||
"ExtractorOutput",
|
||||
"Param",
|
||||
"Node",
|
||||
"lazy",
|
||||
]
|
42
libs/kotaemon/kotaemon/base/component.py
Normal file
42
libs/kotaemon/kotaemon/base/component.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Iterator
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from theflow import Function, Node, Param, lazy
|
||||
|
||||
|
||||
class BaseComponent(Function):
|
||||
"""A component is a class that can be used to compose a pipeline.
|
||||
|
||||
!!! tip "Benefits of component"
|
||||
- Auto caching, logging
|
||||
- Allow deployment
|
||||
|
||||
!!! tip "For each component, the spirit is"
|
||||
- Tolerate multiple input types, e.g. str, Document, List[str], List[Document]
|
||||
- Enforce single output type. Hence, the output type of a component should be
|
||||
as generic as possible.
|
||||
"""
|
||||
|
||||
inflow = None
|
||||
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
return self.__call__(self.inflow.flow())
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, *args, **kwargs
|
||||
) -> Document | list[Document] | Iterator[Document] | None:
|
||||
"""Run the component."""
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["BaseComponent", "Param", "Node", "lazy"]
|
132
libs/kotaemon/kotaemon/base/schema.py
Normal file
132
libs/kotaemon/kotaemon/base/schema.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
|
||||
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||
from langchain.schema.messages import SystemMessage as LCSystemMessage
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.schema import Document as BaseDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
IO_Type = TypeVar("IO_Type", "Document", str)
|
||||
SAMPLE_TEXT = "A sample Document from kotaemon"
|
||||
|
||||
|
||||
class Document(BaseDocument):
|
||||
"""
|
||||
Base document class, mostly inherited from Document class from llama-index.
|
||||
|
||||
This class accept one positional argument `content` of an arbitrary type, which will
|
||||
store the raw content of the document. If specified, the class will use
|
||||
`content` to initialize the base llama_index class.
|
||||
|
||||
Attributes:
|
||||
content: raw content of the document, can be anything
|
||||
source: id of the source of the Document. Optional.
|
||||
"""
|
||||
|
||||
content: Any
|
||||
source: Optional[str] = None
|
||||
|
||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||
if content is None:
|
||||
if kwargs.get("text", None) is not None:
|
||||
kwargs["content"] = kwargs["text"]
|
||||
elif kwargs.get("embedding", None) is not None:
|
||||
kwargs["content"] = kwargs["embedding"]
|
||||
# default text indicating this document only contains embedding
|
||||
kwargs["text"] = "<EMBEDDING>"
|
||||
elif isinstance(content, Document):
|
||||
kwargs = content.dict()
|
||||
else:
|
||||
kwargs["content"] = content
|
||||
if content:
|
||||
kwargs["text"] = str(content)
|
||||
else:
|
||||
kwargs["text"] = ""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.content)
|
||||
|
||||
@classmethod
|
||||
def example(cls) -> "Document":
|
||||
document = Document(
|
||||
text=SAMPLE_TEXT,
|
||||
metadata={"filename": "README.md", "category": "codebase"},
|
||||
)
|
||||
return document
|
||||
|
||||
def to_haystack_format(self) -> "HaystackDocument":
|
||||
"""Convert struct to Haystack document format."""
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
metadata = self.metadata or {}
|
||||
text = self.text
|
||||
return HaystackDocument(content=text, meta=metadata)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class DocumentWithEmbedding(Document):
|
||||
"""Subclass of Document which must contains embedding
|
||||
|
||||
Use this if you want to enforce component's IOs to must contain embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding: list[float], *args, **kwargs):
|
||||
kwargs["embedding"] = embedding
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class BaseMessage(Document):
|
||||
def __add__(self, other: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage, LCSystemMessage):
|
||||
pass
|
||||
|
||||
|
||||
class AIMessage(BaseMessage, LCAIMessage):
|
||||
pass
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage, LCHumanMessage):
|
||||
pass
|
||||
|
||||
|
||||
class RetrievedDocument(Document):
|
||||
"""Subclass of Document with retrieval-related information
|
||||
|
||||
Attributes:
|
||||
score (float): score of the document (from 0.0 to 1.0)
|
||||
retrieval_metadata (dict): metadata from the retrieval process, can be used
|
||||
by different components in a retrieved pipeline to communicate with each
|
||||
other
|
||||
"""
|
||||
|
||||
score: float = Field(default=0.0)
|
||||
retrieval_metadata: dict = Field(default={})
|
||||
|
||||
|
||||
class LLMInterface(AIMessage):
|
||||
candidates: list[str] = Field(default_factory=list)
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
total_cost: float = 0
|
||||
logits: list[list[float]] = Field(default_factory=list)
|
||||
messages: list[AIMessage] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExtractorOutput(Document):
|
||||
"""
|
||||
Represents the output of an extractor.
|
||||
"""
|
||||
|
||||
matches: list[str]
|
4
libs/kotaemon/kotaemon/chatbot/__init__.py
Normal file
4
libs/kotaemon/kotaemon/chatbot/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import BaseChatBot, ChatConversation
|
||||
from .simple_respondent import SimpleRespondentChatbot
|
||||
|
||||
__all__ = ["BaseChatBot", "SimpleRespondentChatbot", "ChatConversation"]
|
114
libs/kotaemon/kotaemon/chatbot/base.py
Normal file
114
libs/kotaemon/kotaemon/chatbot/base.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from kotaemon.base import BaseComponent, LLMInterface
|
||||
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from theflow import SessionFunction
|
||||
|
||||
|
||||
class BaseChatBot(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(self, messages: List[BaseMessage]) -> LLMInterface:
|
||||
...
|
||||
|
||||
|
||||
def session_chat_storage(obj):
|
||||
"""Store using the bot location rather than the session location"""
|
||||
return obj._store_result
|
||||
|
||||
|
||||
class ChatConversation(SessionFunction):
|
||||
"""Base implementation of a chat bot component
|
||||
|
||||
A chatbot component should:
|
||||
- handle internal state, including history messages
|
||||
- return output for a given input
|
||||
"""
|
||||
|
||||
class Config:
|
||||
store_result = session_chat_storage
|
||||
|
||||
system_message: str = ""
|
||||
bot: BaseChatBot
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._history: List[BaseMessage] = []
|
||||
self._store_result = (
|
||||
f"{self.__module__}.{self.__class__.__name__},uninitiated_bot"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, message: HumanMessage) -> Optional[BaseMessage]:
|
||||
"""Chat, given a message, return a response
|
||||
|
||||
Args:
|
||||
message: The message to respond to
|
||||
|
||||
Returns:
|
||||
The response to the message. If None, no response is sent.
|
||||
"""
|
||||
user_message = (
|
||||
HumanMessage(content=message) if isinstance(message, str) else message
|
||||
)
|
||||
self.history.append(user_message)
|
||||
|
||||
output = self.bot(self.history).text
|
||||
output_message = None
|
||||
if output is not None:
|
||||
output_message = AIMessage(content=output)
|
||||
self.history.append(output_message)
|
||||
|
||||
return output_message
|
||||
|
||||
def start_session(self):
|
||||
self._store_result = self.bot.config.store_result
|
||||
super().start_session()
|
||||
if not self.history and self.system_message:
|
||||
system_message = SystemMessage(content=self.system_message)
|
||||
self.history.append(system_message)
|
||||
|
||||
def end_session(self):
|
||||
super().end_session()
|
||||
self._history = []
|
||||
|
||||
def check_end(
|
||||
self,
|
||||
history: Optional[List[BaseMessage]] = None,
|
||||
user_message: Optional[HumanMessage] = None,
|
||||
bot_message: Optional[AIMessage] = None,
|
||||
) -> bool:
|
||||
"""Check if a conversation should end"""
|
||||
if user_message is not None and user_message.content == "":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def terminal_session(self):
|
||||
"""Create a terminal session"""
|
||||
self.start_session()
|
||||
print(">> Start chat:")
|
||||
|
||||
while True:
|
||||
human = HumanMessage(content=input("Human: "))
|
||||
if self.check_end(history=self.history, user_message=human):
|
||||
break
|
||||
|
||||
output = self(human)
|
||||
if output is None:
|
||||
print("AI: <No response>")
|
||||
else:
|
||||
print("AI:", output.content)
|
||||
|
||||
if self.check_end(history=self.history, bot_message=output):
|
||||
break
|
||||
|
||||
self.end_session()
|
||||
|
||||
@property
|
||||
def history(self):
|
||||
return self._history
|
||||
|
||||
@history.setter
|
||||
def history(self, value):
|
||||
self._history = value
|
||||
self._variablex()
|
11
libs/kotaemon/kotaemon/chatbot/simple_respondent.py
Normal file
11
libs/kotaemon/kotaemon/chatbot/simple_respondent.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ..llms import ChatLLM
|
||||
from .base import BaseChatBot
|
||||
|
||||
|
||||
class SimpleRespondentChatbot(BaseChatBot):
|
||||
"""Simple text respondent chatbot that essentially wraps around a chat LLM"""
|
||||
|
||||
llm: ChatLLM
|
||||
|
||||
def _get_message(self) -> str:
|
||||
return self.llm(self.history).text
|
189
libs/kotaemon/kotaemon/cli.py
Normal file
189
libs/kotaemon/kotaemon/cli.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from trogon import tui
|
||||
|
||||
|
||||
# check if the output is not a .yml file -> raise error
|
||||
def check_config_format(config):
|
||||
if os.path.exists(config):
|
||||
if isinstance(config, str):
|
||||
with open(config) as f:
|
||||
yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError("config must be yaml format.")
|
||||
|
||||
|
||||
@tui(command="ui", help="Open the terminal UI") # generate the terminal UI
|
||||
@click.group()
|
||||
def main():
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def promptui():
|
||||
pass
|
||||
|
||||
|
||||
main.add_command(promptui)
|
||||
|
||||
|
||||
@promptui.command()
|
||||
@click.argument("export_path", nargs=1)
|
||||
@click.option("--output", default="promptui.yml", show_default=True, required=False)
|
||||
def export(export_path, output):
|
||||
"""Export a pipeline to a config file"""
|
||||
import sys
|
||||
|
||||
from kotaemon.contribs.promptui.config import export_pipeline_to_config
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
cls = import_dotted_string(export_path, safe=False)
|
||||
export_pipeline_to_config(cls, output)
|
||||
check_config_format(output)
|
||||
|
||||
|
||||
@promptui.command()
|
||||
@click.argument("run_path", required=False, default="promptui.yml")
|
||||
@click.option(
|
||||
"--share",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Share the app through Gradio. Requires --username to enable authentication.",
|
||||
)
|
||||
@click.option(
|
||||
"--username",
|
||||
required=False,
|
||||
help=(
|
||||
"Username for the user. If not provided, the promptui will not have "
|
||||
"authentication."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--password",
|
||||
required=False,
|
||||
help="Password for the user. If not provided, will be prompted.",
|
||||
)
|
||||
@click.option(
|
||||
"--appname",
|
||||
required=False,
|
||||
help="The share app subdomain. Requires --share and --username",
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
required=False,
|
||||
help="Port to run the app. If not provided, will $GRADIO_SERVER_PORT (7860)",
|
||||
)
|
||||
def run(run_path, share, username, password, appname, port):
|
||||
"""Run the UI from a config file
|
||||
|
||||
Examples:
|
||||
|
||||
\b
|
||||
# Run with default config file
|
||||
$ kh promptui run
|
||||
|
||||
\b
|
||||
# Run with username and password supplied
|
||||
$ kh promptui run --username admin --password password
|
||||
|
||||
\b
|
||||
# Run with username and prompted password
|
||||
$ kh promptui run --username admin
|
||||
|
||||
# Run and share to promptui
|
||||
# kh promptui run --username admin --password password --share --appname hey \
|
||||
--port 7861
|
||||
"""
|
||||
import sys
|
||||
|
||||
from kotaemon.contribs.promptui.ui import build_from_dict
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
check_config_format(run_path)
|
||||
demo = build_from_dict(run_path)
|
||||
|
||||
params: dict = {}
|
||||
if username is not None:
|
||||
if password is not None:
|
||||
auth = (username, password)
|
||||
else:
|
||||
auth = (username, click.prompt("Password", hide_input=True))
|
||||
params["auth"] = auth
|
||||
|
||||
port = int(port) if port else int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
||||
params["server_port"] = port
|
||||
|
||||
if share:
|
||||
if username is None:
|
||||
raise ValueError(
|
||||
"Username must be provided to enable authentication for sharing"
|
||||
)
|
||||
if appname:
|
||||
from kotaemon.contribs.promptui.tunnel import Tunnel
|
||||
|
||||
tunnel = Tunnel(
|
||||
appname=str(appname), username=str(username), local_port=port
|
||||
)
|
||||
url = tunnel.run()
|
||||
print(f"App is shared at {url}")
|
||||
else:
|
||||
params["share"] = True
|
||||
print("App is shared at Gradio")
|
||||
|
||||
demo.launch(**params)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("module", required=True)
|
||||
@click.option(
|
||||
"--output", default="docs.md", required=False, help="The output markdown file"
|
||||
)
|
||||
@click.option(
|
||||
"--separation-level", required=False, default=1, help="Organize markdown layout"
|
||||
)
|
||||
def makedoc(module, output, separation_level):
|
||||
"""Make documentation for module `module`
|
||||
|
||||
Example:
|
||||
|
||||
\b
|
||||
# Make component documentation for kotaemon library
|
||||
$ kh makedoc kotaemon
|
||||
"""
|
||||
from kotaemon.contribs.docs import make_doc
|
||||
|
||||
make_doc(module, output, separation_level)
|
||||
print(f"Documentation exported to {output}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--template",
|
||||
default="project-default",
|
||||
required=False,
|
||||
help="Template name",
|
||||
show_default=True,
|
||||
)
|
||||
def start_project(template):
|
||||
"""Start a project from a template.
|
||||
|
||||
Important: the value for --template corresponds to the name of the template folder,
|
||||
which is located at https://github.com/Cinnamon/kotaemon/tree/main/templates
|
||||
The default value is "project-default", which should work when you are starting a
|
||||
client project.
|
||||
"""
|
||||
|
||||
print("Retrieving template...")
|
||||
os.system(
|
||||
"cookiecutter git@github.com:Cinnamon/kotaemon.git "
|
||||
f"--directory='templates/{template}'"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
libs/kotaemon/kotaemon/contribs/__init__.py
Normal file
0
libs/kotaemon/kotaemon/contribs/__init__.py
Normal file
66
libs/kotaemon/kotaemon/contribs/docs.py
Normal file
66
libs/kotaemon/kotaemon/contribs/docs.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
|
||||
from theflow.utils.documentation import get_function_documentation_from_module
|
||||
|
||||
|
||||
def from_definition_to_markdown(definition: dict) -> str:
|
||||
"""From definition to markdown"""
|
||||
|
||||
# Handle params
|
||||
params = " N/A\n"
|
||||
if definition["params"]:
|
||||
params = "\n| Name | Description | Type | Default |\n"
|
||||
params += "| --- | --- | --- | --- |\n"
|
||||
for name, p in definition["params"].items():
|
||||
type_ = p["type"].__name__ if inspect.isclass(p["type"]) else p["type"]
|
||||
params += f"| {name} | {p['desc']} | {type_} | {p['default']} |\n"
|
||||
|
||||
# Handle nodes
|
||||
nodes = " N/A\n"
|
||||
if definition["nodes"]:
|
||||
nodes = "\n| Name | Description | Type | Input | Output |\n"
|
||||
nodes += "| --- | --- | --- | --- | --- |\n"
|
||||
for name, n in definition["nodes"].items():
|
||||
type_ = n["type"].__name__ if inspect.isclass(n["type"]) else str(n["type"])
|
||||
input_ = (
|
||||
n["input"].__name__ if inspect.isclass(n["input"]) else str(n["input"])
|
||||
)
|
||||
output_ = (
|
||||
n["output"].__name__
|
||||
if inspect.isclass(n["output"])
|
||||
else str(n["output"])
|
||||
)
|
||||
nodes += f"|{name}|{n['desc']}|{type_}|{input_}|{output_}|\n"
|
||||
|
||||
description = inspect.cleandoc(definition["desc"])
|
||||
return f"{description}\n\n_**Params:**_{params}\n_**Nodes:**_{nodes}"
|
||||
|
||||
|
||||
def make_doc(module: str, output: str, separation_level: int):
|
||||
"""Run exporting components to markdown
|
||||
|
||||
Args:
|
||||
module (str): module name
|
||||
output_path (str): output path to save
|
||||
separation_level (int): level of separation
|
||||
"""
|
||||
documentation = sorted(
|
||||
get_function_documentation_from_module(module).items(), key=lambda x: x[0]
|
||||
)
|
||||
|
||||
entries = defaultdict(list)
|
||||
|
||||
for name, definition in documentation:
|
||||
section = name.split(".")[separation_level].capitalize()
|
||||
cls_name = name.split(".")[-1]
|
||||
|
||||
markdown = from_definition_to_markdown(definition)
|
||||
entries[section].append(f"### {cls_name}\n{markdown}")
|
||||
|
||||
final = "\n".join(
|
||||
[f"## {section}\n" + "\n".join(entries[section]) for section in entries]
|
||||
)
|
||||
|
||||
with open(output, "w") as f:
|
||||
f.write(final)
|
1
libs/kotaemon/kotaemon/contribs/promptui/.gitignore
vendored
Normal file
1
libs/kotaemon/kotaemon/contribs/promptui/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/frpc_*
|
43
libs/kotaemon/kotaemon/contribs/promptui/base.py
Normal file
43
libs/kotaemon/kotaemon/contribs/promptui/base.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import gradio as gr
|
||||
|
||||
COMPONENTS_CLASS = {
|
||||
"text": gr.components.Textbox,
|
||||
"checkbox": gr.components.CheckboxGroup,
|
||||
"dropdown": gr.components.Dropdown,
|
||||
"file": gr.components.File,
|
||||
"image": gr.components.Image,
|
||||
"number": gr.components.Number,
|
||||
"radio": gr.components.Radio,
|
||||
"slider": gr.components.Slider,
|
||||
}
|
||||
SUPPORTED_COMPONENTS = set(COMPONENTS_CLASS.keys())
|
||||
DEFAULT_COMPONENT_BY_TYPES = {
|
||||
"str": "text",
|
||||
"bool": "checkbox",
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"list": "dropdown",
|
||||
}
|
||||
|
||||
|
||||
def get_component(component_def: dict) -> gr.components.Component:
|
||||
"""Get the component based on component definition"""
|
||||
component_cls = None
|
||||
|
||||
if "component" in component_def:
|
||||
component = component_def["component"]
|
||||
if component not in SUPPORTED_COMPONENTS:
|
||||
raise ValueError(
|
||||
f"Unsupported UI component: {component}. "
|
||||
f"Must be one of {SUPPORTED_COMPONENTS}"
|
||||
)
|
||||
|
||||
component_cls = COMPONENTS_CLASS[component]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot decide the component from {component_def}. "
|
||||
"Please specify `component` with 1 of the following "
|
||||
f"values: {SUPPORTED_COMPONENTS}"
|
||||
)
|
||||
|
||||
return component_cls(**component_def.get("params", {}))
|
1
libs/kotaemon/kotaemon/contribs/promptui/cli.py
Normal file
1
libs/kotaemon/kotaemon/contribs/promptui/cli.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CLI commands that can be imported by the kotaemon.cli module"""
|
182
libs/kotaemon/kotaemon/contribs/promptui/config.py
Normal file
182
libs/kotaemon/kotaemon/contribs/promptui/config.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Get config from Pipeline"""
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.chatbot import BaseChatBot
|
||||
|
||||
from .base import DEFAULT_COMPONENT_BY_TYPES
|
||||
|
||||
|
||||
def config_from_value(value: Any) -> dict:
|
||||
"""Get the config from default value
|
||||
|
||||
Args:
|
||||
value (Any): default value
|
||||
|
||||
Returns:
|
||||
dict: config
|
||||
"""
|
||||
component = DEFAULT_COMPONENT_BY_TYPES.get(type(value).__name__, "text")
|
||||
return {
|
||||
"component": component,
|
||||
"params": {
|
||||
"value": value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def handle_param(param: dict) -> dict:
|
||||
"""Convert param definition into promptui-compliant config
|
||||
|
||||
Supported gradio's UI components are (https://www.gradio.app/docs/components)
|
||||
- CheckBoxGroup: list (multi select)
|
||||
- DropDown: list (single select)
|
||||
- File
|
||||
- Image
|
||||
- Number: int / float
|
||||
- Radio: list (single select)
|
||||
- Slider: int / float
|
||||
- TextBox: str
|
||||
"""
|
||||
params = {}
|
||||
default = param.get("default", None)
|
||||
if isinstance(default, str) and default.startswith("{{") and default.endswith("}}"):
|
||||
default = None
|
||||
if default is not None:
|
||||
params["value"] = default
|
||||
|
||||
ui_component = param.get("component_ui", "")
|
||||
if not ui_component:
|
||||
type_: str = type(default).__name__ if default is not None else ""
|
||||
ui_component = DEFAULT_COMPONENT_BY_TYPES.get(type_, "text")
|
||||
|
||||
return {
|
||||
"component": ui_component,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
|
||||
def handle_node(node: dict) -> dict:
|
||||
"""Convert node definition into promptui-compliant config"""
|
||||
config = {}
|
||||
for name, param_def in node.get("params", {}).items():
|
||||
if isinstance(param_def["auto_callback"], str):
|
||||
continue
|
||||
if param_def.get("ignore_ui", False):
|
||||
continue
|
||||
config[name] = handle_param(param_def)
|
||||
for name, node_def in node.get("nodes", {}).items():
|
||||
if isinstance(node_def["auto_callback"], str):
|
||||
continue
|
||||
if node_def.get("ignore_ui", False):
|
||||
continue
|
||||
for key, value in handle_node(node_def["default"]).items():
|
||||
config[f"{name}.{key}"] = value
|
||||
for key, value in node_def.get("default_kwargs", {}).items():
|
||||
config[f"{name}.{key}"] = config_from_value(value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
|
||||
"""Get the input from the pipeline"""
|
||||
signature = inspect.signature(pipeline.run)
|
||||
inputs: Dict[str, Dict] = {}
|
||||
for name, param in signature.parameters.items():
|
||||
if name in ["self", "args", "kwargs"]:
|
||||
continue
|
||||
input_def: Dict[str, Optional[Any]] = {"component": "text"}
|
||||
default = param.default
|
||||
if default is param.empty:
|
||||
inputs[name] = input_def
|
||||
continue
|
||||
|
||||
params = {}
|
||||
params["value"] = default
|
||||
type_ = type(default).__name__ if default is not None else None
|
||||
ui_component = None
|
||||
if type_ is not None:
|
||||
ui_component = "text"
|
||||
|
||||
input_def["component"] = ui_component
|
||||
input_def["params"] = params
|
||||
|
||||
inputs[name] = input_def
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def export_pipeline_to_config(
|
||||
pipeline: Union[BaseComponent, Type[BaseComponent]],
|
||||
path: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Export a pipeline to a promptui-compliant config dict"""
|
||||
if inspect.isclass(pipeline):
|
||||
pipeline = pipeline()
|
||||
|
||||
pipeline_def = pipeline.describe()
|
||||
ui_type = "chat" if isinstance(pipeline, BaseChatBot) else "simple"
|
||||
if ui_type == "chat":
|
||||
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
|
||||
params["system_message"] = {"component": "text", "params": {"value": ""}}
|
||||
outputs = []
|
||||
if hasattr(pipeline, "_promptui_outputs"):
|
||||
outputs = pipeline._promptui_outputs
|
||||
config_obj: dict = {
|
||||
"ui-type": ui_type,
|
||||
"params": params,
|
||||
"inputs": {},
|
||||
"outputs": outputs,
|
||||
"logs": {
|
||||
"full_pipeline": {
|
||||
"input": {
|
||||
"step": ".",
|
||||
"getter": "_get_input",
|
||||
},
|
||||
"output": {
|
||||
"step": ".",
|
||||
"getter": "_get_output",
|
||||
},
|
||||
"preference": {
|
||||
"step": "preference",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
else:
|
||||
outputs = [{"step": ".", "getter": "_get_output", "component": "text"}]
|
||||
if hasattr(pipeline, "_promptui_outputs"):
|
||||
outputs = pipeline._promptui_outputs
|
||||
config_obj = {
|
||||
"ui-type": ui_type,
|
||||
"params": handle_node(pipeline_def),
|
||||
"inputs": handle_input(pipeline),
|
||||
"outputs": outputs,
|
||||
"logs": {
|
||||
"full_pipeline": {
|
||||
"input": {
|
||||
"step": ".",
|
||||
"getter": "_get_input",
|
||||
},
|
||||
"output": {
|
||||
"step": ".",
|
||||
"getter": "_get_output",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config = {f"{pipeline.__module__}.{pipeline.__class__.__name__}": config_obj}
|
||||
if path is not None:
|
||||
old_config = config
|
||||
if Path(path).is_file():
|
||||
with open(path) as f:
|
||||
old_config = yaml.safe_load(f)
|
||||
old_config.update(config)
|
||||
with open(path, "w") as f:
|
||||
yaml.safe_dump(old_config, f, sort_keys=False)
|
||||
|
||||
return config
|
140
libs/kotaemon/kotaemon/contribs/promptui/export.py
Normal file
140
libs/kotaemon/kotaemon/contribs/promptui/export.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Export logs into Excel file"""
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from kotaemon.base import BaseComponent
|
||||
from theflow.storage import storage
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from .logs import ResultLog
|
||||
|
||||
|
||||
def from_log_to_dict(pipeline_cls: Type[BaseComponent], log_config: dict) -> dict:
|
||||
"""Export the log to panda dataframes
|
||||
|
||||
Args:
|
||||
pipeline_cls (Type[BaseComponent]): Pipeline class
|
||||
log_config (dict): Log config
|
||||
|
||||
Returns:
|
||||
dataframe
|
||||
"""
|
||||
# get the directory
|
||||
pipeline_log_path = storage.url(pipeline_cls().config.store_result)
|
||||
dirs = list(sorted([f.path for f in os.scandir(pipeline_log_path) if f.is_dir()]))
|
||||
|
||||
# get resultlog callback
|
||||
resultlog = getattr(pipeline_cls, "_promptui_resultlog", ResultLog)
|
||||
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
|
||||
|
||||
ids = []
|
||||
params: Dict[str, List[Any]] = {}
|
||||
logged_infos: Dict[str, List[Any]] = {}
|
||||
|
||||
for idx, each_dir in enumerate(dirs):
|
||||
ids.append(str(Path(each_dir).name))
|
||||
|
||||
# get the params
|
||||
params_file = os.path.join(each_dir, "params.pkl")
|
||||
if os.path.exists(params_file):
|
||||
with open(params_file, "rb") as f:
|
||||
each_params = pickle.load(f)
|
||||
for key, value in each_params.items():
|
||||
if key not in params:
|
||||
params[key] = [None] * len(dirs)
|
||||
params[key][idx] = value
|
||||
|
||||
# get the progress
|
||||
progress_file = os.path.join(each_dir, "progress.pkl")
|
||||
if os.path.exists(progress_file):
|
||||
with open(progress_file, "rb") as f:
|
||||
progress = pickle.load(f)
|
||||
|
||||
for name, col_info in log_config.items():
|
||||
step = col_info["step"]
|
||||
getter = col_info.get("getter", None)
|
||||
if name not in logged_infos:
|
||||
logged_infos[name] = [None] * len(dirs)
|
||||
|
||||
if step not in progress:
|
||||
continue
|
||||
|
||||
info = progress[step]
|
||||
if getter:
|
||||
if getter in allowed_resultlog_callbacks:
|
||||
info = getattr(resultlog, getter)(info)
|
||||
else:
|
||||
implicit_name = f"get_{name}"
|
||||
if implicit_name in allowed_resultlog_callbacks:
|
||||
info = getattr(resultlog, implicit_name)(info)
|
||||
logged_infos[name][idx] = info
|
||||
|
||||
return {"ids": ids, **params, **logged_infos}
|
||||
|
||||
|
||||
def export(config: dict, pipeline_def, output_path):
|
||||
"""Export from config to Excel file"""
|
||||
|
||||
pipeline_name = f"{pipeline_def.__module__}.{pipeline_def.__name__}"
|
||||
|
||||
# export to Excel
|
||||
if not config.get("logs", {}):
|
||||
raise ValueError(f"Pipeline {pipeline_name} has no logs to export")
|
||||
|
||||
pds: Dict[str, pd.DataFrame] = {}
|
||||
for log_name, log_def in config["logs"].items():
|
||||
pds[log_name] = pd.DataFrame(from_log_to_dict(pipeline_def, log_def))
|
||||
|
||||
# from the list of pds, export to Excel to output_path
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer: # type: ignore
|
||||
for log_name, df in pds.items():
|
||||
df.to_excel(writer, sheet_name=log_name)
|
||||
|
||||
|
||||
def export_from_dict(
|
||||
config: Union[str, dict],
|
||||
pipeline: Union[str, Type[BaseComponent]],
|
||||
output_path: str,
|
||||
):
|
||||
"""CLI to export the logs of a pipeline into Excel file
|
||||
|
||||
Args:
|
||||
config_path (str): Path to the config file
|
||||
pipeline_name (str): Name of the pipeline
|
||||
output_path (str): Path to the output Excel file
|
||||
"""
|
||||
# get the pipeline class and the relevant config dict
|
||||
config_dict: dict
|
||||
if isinstance(config, str):
|
||||
with open(config) as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
elif isinstance(config, dict):
|
||||
config_dict = config
|
||||
else:
|
||||
raise TypeError(f"`config` must be str or dict, not {type(config)}")
|
||||
|
||||
pipeline_name: str
|
||||
pipeline_cls: Type[BaseComponent]
|
||||
pipeline_config: dict
|
||||
if isinstance(pipeline, str):
|
||||
if pipeline not in config_dict:
|
||||
raise ValueError(f"Pipeline {pipeline} not found in config file")
|
||||
pipeline_name = pipeline
|
||||
pipeline_cls = import_dotted_string(pipeline, safe=False)
|
||||
pipeline_config = config_dict[pipeline]
|
||||
elif isinstance(pipeline, type) and issubclass(pipeline, BaseComponent):
|
||||
pipeline_name = f"{pipeline.__module__}.{pipeline.__name__}"
|
||||
if pipeline_name not in config_dict:
|
||||
raise ValueError(f"Pipeline {pipeline_name} not found in config file")
|
||||
pipeline_cls = pipeline
|
||||
pipeline_config = config_dict[pipeline_name]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"`pipeline` must be str or subclass of BaseComponent, not {type(pipeline)}"
|
||||
)
|
||||
|
||||
export(pipeline_config, pipeline_cls, output_path)
|
16
libs/kotaemon/kotaemon/contribs/promptui/logs.py
Normal file
16
libs/kotaemon/kotaemon/contribs/promptui/logs.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class ResultLog:
|
||||
"""Callback getter to get the desired log result
|
||||
|
||||
The callback resolution will be as follow:
|
||||
1. Explicit string name
|
||||
2. Implicitly by: `get_<name>`
|
||||
3. Pass through
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_input(obj):
|
||||
return obj["input"]
|
||||
|
||||
@staticmethod
|
||||
def _get_output(obj):
|
||||
return obj["output"]
|
95
libs/kotaemon/kotaemon/contribs/promptui/themes.py
Normal file
95
libs/kotaemon/kotaemon/contribs/promptui/themes.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from gradio.themes.base import Base
|
||||
from gradio.themes.utils import colors, fonts, sizes
|
||||
|
||||
|
||||
class John(Base):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
primary_hue: colors.Color | str = colors.neutral,
|
||||
secondary_hue: colors.Color | str = colors.neutral,
|
||||
neutral_hue: colors.Color | str = colors.neutral,
|
||||
spacing_size: sizes.Size | str = sizes.spacing_sm,
|
||||
radius_size: sizes.Size | str = sizes.radius_none,
|
||||
text_size: sizes.Size | str = sizes.text_sm,
|
||||
font: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
fonts.GoogleFont("Quicksand"),
|
||||
"ui-sans-serif",
|
||||
"system-ui",
|
||||
"sans-serif",
|
||||
),
|
||||
font_mono: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
fonts.GoogleFont("IBM Plex Mono"),
|
||||
"ui-monospace",
|
||||
"Consolas",
|
||||
"monospace",
|
||||
),
|
||||
):
|
||||
super().__init__(
|
||||
primary_hue=primary_hue,
|
||||
secondary_hue=secondary_hue,
|
||||
neutral_hue=neutral_hue,
|
||||
spacing_size=spacing_size,
|
||||
radius_size=radius_size,
|
||||
text_size=text_size,
|
||||
font=font,
|
||||
font_mono=font_mono,
|
||||
)
|
||||
self.name = "monochrome"
|
||||
super().set(
|
||||
# Colors
|
||||
slider_color="*neutral_900",
|
||||
slider_color_dark="*neutral_500",
|
||||
body_text_color="*neutral_900",
|
||||
block_label_text_color="*body_text_color",
|
||||
block_title_text_color="*body_text_color",
|
||||
body_text_color_subdued="*neutral_700",
|
||||
background_fill_primary_dark="*neutral_900",
|
||||
background_fill_secondary_dark="*neutral_800",
|
||||
block_background_fill_dark="*neutral_800",
|
||||
input_background_fill_dark="*neutral_700",
|
||||
# Button Colors
|
||||
button_primary_background_fill="*neutral_900",
|
||||
button_primary_background_fill_hover="*neutral_700",
|
||||
button_primary_text_color="white",
|
||||
button_primary_background_fill_dark="*neutral_600",
|
||||
button_primary_background_fill_hover_dark="*neutral_600",
|
||||
button_primary_text_color_dark="white",
|
||||
button_secondary_background_fill=(
|
||||
"linear-gradient(to bottom right, *neutral_100, *neutral_200)"
|
||||
),
|
||||
button_secondary_background_fill_hover=(
|
||||
"linear-gradient(to bottom right, *neutral_100, *neutral_100)"
|
||||
),
|
||||
button_secondary_background_fill_dark=(
|
||||
"linear-gradient(to bottom right, *neutral_600, *neutral_700)"
|
||||
),
|
||||
button_secondary_background_fill_hover_dark=(
|
||||
"linear-gradient(to bottom right, *neutral_600, *neutral_600)"
|
||||
),
|
||||
button_cancel_background_fill="*button_primary_background_fill",
|
||||
button_cancel_background_fill_hover="*button_primary_background_fill_hover",
|
||||
button_cancel_text_color="*button_primary_text_color",
|
||||
# Padding
|
||||
checkbox_label_padding="*spacing_sm",
|
||||
button_large_padding="*spacing_sm",
|
||||
button_small_padding="*spacing_sm",
|
||||
# Borders
|
||||
block_border_width="0px",
|
||||
block_border_width_dark="1px",
|
||||
shadow_drop_lg="0 1px 4px 0 rgb(0 0 0 / 0.1)",
|
||||
block_shadow="*shadow_drop_lg",
|
||||
block_shadow_dark="none",
|
||||
# Block Labels
|
||||
block_title_text_weight="600",
|
||||
block_label_text_weight="600",
|
||||
block_label_text_size="*text_sm",
|
||||
)
|
107
libs/kotaemon/kotaemon/contribs/promptui/tunnel.py
Normal file
107
libs/kotaemon/kotaemon/contribs/promptui/tunnel.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import stat
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
VERSION = "1.0"
|
||||
|
||||
machine = platform.machine()
|
||||
if machine == "x86_64":
|
||||
machine = "amd64"
|
||||
|
||||
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
|
||||
EXTENSION = ".exe" if os.name == "nt" else ""
|
||||
BINARY_URL = (
|
||||
"some-endpoint.com"
|
||||
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
|
||||
)
|
||||
|
||||
BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"
|
||||
BINARY_FOLDER = Path(__file__).parent
|
||||
BINARY_PATH = f"{BINARY_FOLDER / BINARY_FILENAME}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tunnel:
|
||||
def __init__(self, appname, username, local_port):
|
||||
self.proc = None
|
||||
self.url = None
|
||||
self.appname = appname
|
||||
self.username = username
|
||||
self.local_port = local_port
|
||||
|
||||
@staticmethod
|
||||
def download_binary():
|
||||
if not Path(BINARY_PATH).exists():
|
||||
print("First time setting tunneling...")
|
||||
resp = requests.get(BINARY_URL)
|
||||
|
||||
if resp.status_code == 404:
|
||||
raise OSError(
|
||||
f"Cannot set up a share link as this platform is incompatible. "
|
||||
"Please create a GitHub issue with information about your "
|
||||
f"platform: {platform.uname()}"
|
||||
)
|
||||
|
||||
if resp.status_code == 403:
|
||||
raise OSError(
|
||||
"You do not have permission to setup the tunneling. Please "
|
||||
"make sure that you are within Cinnamon VPN or within other "
|
||||
"approved IPs. If this is new server, please contact @channel "
|
||||
"at #llm-productization to add your IP address"
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
# Save file data to local copy
|
||||
with open(BINARY_PATH, "wb") as file:
|
||||
file.write(resp.content)
|
||||
st = os.stat(BINARY_PATH)
|
||||
os.chmod(BINARY_PATH, st.st_mode | stat.S_IEXEC)
|
||||
|
||||
def run(self) -> str:
|
||||
"""Setting up tunneling"""
|
||||
if platform.system().lower() == "windows":
|
||||
logger.warning("Tunneling is not fully supported on Windows.")
|
||||
|
||||
self.download_binary()
|
||||
self.url = self._start_tunnel(BINARY_PATH)
|
||||
return self.url
|
||||
|
||||
def kill(self):
|
||||
if self.proc is not None:
|
||||
print(f"Killing tunnel 127.0.0.1:{self.local_port} <> {self.url}")
|
||||
self.proc.terminate()
|
||||
self.proc = None
|
||||
|
||||
def _start_tunnel(self, binary: str) -> str:
|
||||
command = [
|
||||
binary,
|
||||
"http",
|
||||
"-l",
|
||||
str(self.local_port),
|
||||
"-i",
|
||||
"127.0.0.1",
|
||||
"--uc",
|
||||
"--sd",
|
||||
str(self.appname),
|
||||
"-n",
|
||||
str(self.appname + self.username),
|
||||
"--server_addr",
|
||||
"44.229.38.9:7000",
|
||||
"--token",
|
||||
"Wz807/DyC;#t;#/",
|
||||
"--disable_log_color",
|
||||
]
|
||||
self.proc = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
atexit.register(self.kill)
|
||||
return f"https://{self.appname}.promptui.dm.cinnamon.is"
|
45
libs/kotaemon/kotaemon/contribs/promptui/ui/__init__.py
Normal file
45
libs/kotaemon/kotaemon/contribs/promptui/ui/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Union
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from ..themes import John
|
||||
from .chat import build_chat_ui
|
||||
from .pipeline import build_pipeline_ui
|
||||
|
||||
|
||||
def build_from_dict(config: Union[str, dict]):
|
||||
"""Build a full UI from YAML config file"""
|
||||
|
||||
if isinstance(config, str):
|
||||
with open(config) as f:
|
||||
config_dict: dict = yaml.safe_load(f)
|
||||
elif isinstance(config, dict):
|
||||
config_dict = config
|
||||
else:
|
||||
raise ValueError(
|
||||
f"config must be either a yaml path or a dict, got {type(config)}"
|
||||
)
|
||||
|
||||
demos = []
|
||||
for key, value in config_dict.items():
|
||||
pipeline_def = import_dotted_string(key, safe=False)
|
||||
if value["ui-type"] == "chat":
|
||||
demos.append(build_chat_ui(value, pipeline_def).queue())
|
||||
else:
|
||||
demos.append(build_pipeline_ui(value, pipeline_def).queue())
|
||||
if len(demos) == 1:
|
||||
demo = demos[0]
|
||||
else:
|
||||
demo = gr.TabbedInterface(
|
||||
demos,
|
||||
tab_names=list(config_dict.keys()),
|
||||
title="PromptUI from kotaemon",
|
||||
analytics_enabled=False,
|
||||
theme=John(),
|
||||
)
|
||||
|
||||
demo.queue()
|
||||
|
||||
return demo
|
181
libs/kotaemon/kotaemon/contribs/promptui/ui/blocks.py
Normal file
181
libs/kotaemon/kotaemon/contribs/promptui/ui/blocks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import anyio
|
||||
from gradio import ChatInterface
|
||||
from gradio.components import Component, get_component_instance
|
||||
from gradio.events import on
|
||||
from gradio.helpers import special_args
|
||||
from gradio.routes import Request
|
||||
|
||||
|
||||
class ChatBlock(ChatInterface):
|
||||
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
|
||||
|
||||
- Show additional outputs to the chat interface
|
||||
- Disallow blank user message
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
additional_outputs: str | Component | list[str | Component] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if additional_outputs:
|
||||
if not isinstance(additional_outputs, list):
|
||||
additional_outputs = [additional_outputs]
|
||||
self.additional_outputs = [
|
||||
get_component_instance(i) for i in additional_outputs # type: ignore
|
||||
]
|
||||
else:
|
||||
self.additional_outputs = []
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _submit_fn(
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
request: Request,
|
||||
*args,
|
||||
) -> tuple[Any, ...]:
|
||||
input_args = args[: -len(self.additional_outputs)]
|
||||
output_args = args[-len(self.additional_outputs) :]
|
||||
if not message:
|
||||
return history_with_input, history_with_input, *output_args
|
||||
|
||||
history = history_with_input[:-1]
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *input_args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
response = await self.fn(*inputs)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
|
||||
output = []
|
||||
if self.additional_outputs:
|
||||
text = response[0]
|
||||
output = response[1:]
|
||||
else:
|
||||
text = response
|
||||
|
||||
history.append([message, text])
|
||||
return history, history, *output
|
||||
|
||||
async def _stream_fn(
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
*args,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError("Stream function not implemented for ChatBlock")
|
||||
|
||||
def _display_input(
|
||||
self, message: str, history: list[list[str | None]]
|
||||
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
||||
"""Stop displaying the input message if the message is a blank string"""
|
||||
if not message:
|
||||
return history, history
|
||||
return super()._display_input(message, history)
|
||||
|
||||
def _setup_events(self) -> None:
|
||||
"""Include additional outputs in the submit event"""
|
||||
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
||||
submit_triggers = (
|
||||
[self.textbox.submit, self.submit_btn.click]
|
||||
if self.submit_btn
|
||||
else [self.textbox.submit]
|
||||
)
|
||||
submit_event = (
|
||||
on(
|
||||
submit_triggers,
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state]
|
||||
+ self.additional_inputs
|
||||
+ self.additional_outputs,
|
||||
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(submit_triggers, submit_event)
|
||||
|
||||
if self.retry_btn:
|
||||
retry_event = (
|
||||
self.retry_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state]
|
||||
+ self.additional_inputs
|
||||
+ self.additional_outputs,
|
||||
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events([self.retry_btn.click], retry_event)
|
||||
|
||||
if self.undo_btn:
|
||||
self.undo_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
lambda x: x,
|
||||
[self.saved_input],
|
||||
[self.textbox],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if self.clear_btn:
|
||||
self.clear_btn.click(
|
||||
lambda: ([], [], None),
|
||||
None,
|
||||
[self.chatbot, self.chatbot_state, self.saved_input],
|
||||
queue=False,
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
def _setup_api(self) -> None:
|
||||
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
||||
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
||||
[self.textbox, self.chatbot_state] + self.additional_outputs,
|
||||
api_name="chat",
|
||||
)
|
308
libs/kotaemon/kotaemon/contribs/promptui/ui/chat.py
Normal file
308
libs/kotaemon/kotaemon/contribs/promptui/ui/chat.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from kotaemon.chatbot import ChatConversation
|
||||
from kotaemon.contribs.promptui.base import get_component
|
||||
from kotaemon.contribs.promptui.export import export
|
||||
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
|
||||
from theflow.storage import storage
|
||||
|
||||
from ..logs import ResultLog
|
||||
|
||||
USAGE_INSTRUCTION = """## How to use:
|
||||
|
||||
1. Set the desired parameters.
|
||||
2. Click "New chat" to start a chat session with the supplied parameters. This
|
||||
set of parameters will persist until the end of the chat session. During an
|
||||
ongoing chat session, changing the parameters will not take any effect.
|
||||
3. Chat and interact with the chat bot on the right panel. You can add any
|
||||
additional input (if any), and they will be supplied to the chatbot.
|
||||
4. During chat, the log of the chat will show up in the "Output" tabs. This is
|
||||
empty by default, so if you want to show the log here, tell the AI developers
|
||||
to configure the UI settings.
|
||||
5. When finishing chat, select your preference in the radio box. Click "End chat".
|
||||
This will save the chat log and the preference to disk.
|
||||
6. To compare the result of different run, click "Export" to get an Excel
|
||||
spreadsheet summary of different run.
|
||||
|
||||
## Support:
|
||||
|
||||
In case of errors, you can:
|
||||
|
||||
- PromptUI instruction:
|
||||
https://github.com/Cinnamon/kotaemon/wiki/Utilities#prompt-engineering-ui
|
||||
- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon
|
||||
- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization
|
||||
|
||||
## Contribute:
|
||||
|
||||
- Follow installation at: https://github.com/Cinnamon/kotaemon/
|
||||
"""
|
||||
|
||||
|
||||
def construct_chat_ui(
|
||||
config, func_new_chat, func_chat, func_end_chat, func_export_to_excel
|
||||
) -> gr.Blocks:
|
||||
"""Construct the prompt engineering UI for chat
|
||||
|
||||
Args:
|
||||
config: the UI config
|
||||
func_new_chat: the function for starting a new chat session
|
||||
func_chat: the function for chatting interaction
|
||||
func_end_chat: the function for ending and saving the chat
|
||||
func_export_to_excel: the function to export the logs to excel
|
||||
|
||||
Returns:
|
||||
the UI object
|
||||
"""
|
||||
inputs, outputs, params = [], [], []
|
||||
for name, component_def in config.get("inputs", {}).items():
|
||||
if "params" not in component_def:
|
||||
component_def["params"] = {}
|
||||
component_def["params"]["interactive"] = True
|
||||
component = get_component(component_def)
|
||||
if hasattr(component, "label") and not component.label: # type: ignore
|
||||
component.label = name # type: ignore
|
||||
|
||||
inputs.append(component)
|
||||
|
||||
for name, component_def in config.get("params", {}).items():
|
||||
if "params" not in component_def:
|
||||
component_def["params"] = {}
|
||||
component_def["params"]["interactive"] = True
|
||||
component = get_component(component_def)
|
||||
if hasattr(component, "label") and not component.label: # type: ignore
|
||||
component.label = name # type: ignore
|
||||
|
||||
params.append(component)
|
||||
|
||||
for idx, component_def in enumerate(config.get("outputs", [])):
|
||||
if "params" not in component_def:
|
||||
component_def["params"] = {}
|
||||
component_def["params"]["interactive"] = False
|
||||
component = get_component(component_def)
|
||||
if hasattr(component, "label") and not component.label: # type: ignore
|
||||
component.label = f"Output {idx}" # type: ignore
|
||||
|
||||
outputs.append(component)
|
||||
|
||||
sess = gr.State(value=None)
|
||||
chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
|
||||
chat = ChatBlock(
|
||||
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
|
||||
)
|
||||
param_state = gr.Textbox(interactive=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
|
||||
sess.render()
|
||||
with gr.Accordion(label="HOW TO", open=False):
|
||||
gr.Markdown(USAGE_INSTRUCTION)
|
||||
with gr.Row():
|
||||
run_btn = gr.Button("New chat")
|
||||
run_btn.click(
|
||||
func_new_chat,
|
||||
inputs=params,
|
||||
outputs=[
|
||||
chat.chatbot,
|
||||
chat.chatbot_state,
|
||||
chat.saved_input,
|
||||
param_state,
|
||||
sess,
|
||||
*outputs,
|
||||
],
|
||||
)
|
||||
with gr.Accordion(label="End chat", open=False):
|
||||
likes = gr.Radio(["like", "dislike", "neutral"], value="neutral")
|
||||
save_log = gr.Checkbox(
|
||||
value=True,
|
||||
label="Save log",
|
||||
info="If saved, log can be exported later",
|
||||
show_label=True,
|
||||
)
|
||||
end_btn = gr.Button("End chat")
|
||||
end_btn.click(
|
||||
func_end_chat,
|
||||
inputs=[likes, save_log, sess],
|
||||
outputs=[param_state, sess],
|
||||
)
|
||||
with gr.Accordion(label="Export", open=False):
|
||||
exported_file = gr.File(
|
||||
label="Output file", show_label=True, height=100
|
||||
)
|
||||
export_btn = gr.Button("Export")
|
||||
export_btn.click(
|
||||
func_export_to_excel, inputs=None, outputs=exported_file
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Tab("Params"):
|
||||
for component in params:
|
||||
component.render()
|
||||
with gr.Accordion(label="Session state", open=False):
|
||||
param_state.render()
|
||||
|
||||
with gr.Tab("Outputs"):
|
||||
for component in outputs:
|
||||
component.render()
|
||||
with gr.Column():
|
||||
chat.render()
|
||||
|
||||
return demo.queue()
|
||||
|
||||
|
||||
def build_chat_ui(config, pipeline_def):
|
||||
"""Build the chat UI
|
||||
|
||||
Args:
|
||||
config: the UI config
|
||||
pipeline_def: the pipeline definition
|
||||
|
||||
Returns:
|
||||
the UI object
|
||||
"""
|
||||
output_dir: Path = Path(storage.url(pipeline_def().config.store_result))
|
||||
exported_dir = output_dir.parent / "exported"
|
||||
exported_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
|
||||
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
|
||||
|
||||
def new_chat(*args):
|
||||
"""Start a new chat function
|
||||
|
||||
Args:
|
||||
*args: the pipeline init params
|
||||
|
||||
Returns:
|
||||
new empty states
|
||||
"""
|
||||
gr.Info("Starting new session...")
|
||||
param_dicts = {
|
||||
name: value for name, value in zip(config["params"].keys(), args)
|
||||
}
|
||||
for key in param_dicts.keys():
|
||||
if config["params"][key].get("component").lower() == "file":
|
||||
param_dicts[key] = param_dicts[key].name
|
||||
|
||||
# TODO: currently hard-code as ChatConversation
|
||||
pipeline = pipeline_def()
|
||||
session = ChatConversation(bot=pipeline)
|
||||
session.set(param_dicts)
|
||||
session.start_session()
|
||||
|
||||
param_state_str = "\n".join(
|
||||
f"- {name}: {value}" for name, value in param_dicts.items()
|
||||
)
|
||||
|
||||
gr.Info("New chat session started.")
|
||||
return (
|
||||
[],
|
||||
[],
|
||||
None,
|
||||
param_state_str,
|
||||
session,
|
||||
*[None] * len(config.get("outputs", [])),
|
||||
)
|
||||
|
||||
def chat(message, history, session, *args):
|
||||
"""The chat interface
|
||||
|
||||
# TODO: wrap the input and output of this chat function so that it
|
||||
work with more types of chat conversation than simple text
|
||||
|
||||
Args:
|
||||
message: the message from the user
|
||||
history: the gradio history of the chat
|
||||
session: the chat object session
|
||||
*args: the additional inputs
|
||||
|
||||
Returns:
|
||||
the response from the chatbot
|
||||
"""
|
||||
if session is None:
|
||||
raise gr.Error(
|
||||
"No active chat session. Please set the params and click New chat"
|
||||
)
|
||||
|
||||
pred = session(message)
|
||||
text_response = pred.content
|
||||
|
||||
additional_outputs = []
|
||||
for output_def in config.get("outputs", []):
|
||||
value = session.last_run.logs(output_def["step"])
|
||||
getter = output_def.get("getter", None)
|
||||
if getter and getter in allowed_resultlog_callbacks:
|
||||
value = getattr(resultlog, getter)(value)
|
||||
additional_outputs.append(value)
|
||||
|
||||
return text_response, *additional_outputs
|
||||
|
||||
def end_chat(preference: str, save_log: bool, session):
|
||||
"""End the chat session
|
||||
|
||||
Args:
|
||||
preference: the preference of the user
|
||||
save_log: whether to save the result
|
||||
session: the chat object session
|
||||
|
||||
Returns:
|
||||
the new empty state
|
||||
"""
|
||||
gr.Info("Ending session...")
|
||||
session.end_session()
|
||||
output_dir: Path = (
|
||||
Path(storage.url(session.config.store_result)) / session.last_run.id()
|
||||
)
|
||||
|
||||
if not save_log:
|
||||
if output_dir.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
session = None
|
||||
param_state = ""
|
||||
gr.Info("End session without saving log.")
|
||||
return param_state, session
|
||||
|
||||
# add preference result to progress
|
||||
with (output_dir / "progress.pkl").open("rb") as fi:
|
||||
progress = pickle.load(fi)
|
||||
progress["preference"] = preference
|
||||
with (output_dir / "progress.pkl").open("wb") as fo:
|
||||
pickle.dump(progress, fo)
|
||||
|
||||
# get the original params
|
||||
param_dicts = {name: session.getx(name) for name in config["params"].keys()}
|
||||
with (output_dir / "params.pkl").open("wb") as fo:
|
||||
pickle.dump(param_dicts, fo)
|
||||
|
||||
session = None
|
||||
param_state = ""
|
||||
gr.Info("End session and save log.")
|
||||
return param_state, session
|
||||
|
||||
def export_func():
|
||||
name = (
|
||||
f"{pipeline_def.__module__}.{pipeline_def.__name__}_{datetime.now()}.xlsx"
|
||||
)
|
||||
path = str(exported_dir / name)
|
||||
gr.Info(f"Begin exporting {name}...")
|
||||
try:
|
||||
export(config=config, pipeline_def=pipeline_def, output_path=path)
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Failed to export. Please contact project's AIR: {e}")
|
||||
gr.Info(f"Exported {name}. Please go to the `Exported file` tab to download")
|
||||
return path
|
||||
|
||||
demo = construct_chat_ui(
|
||||
config=config,
|
||||
func_new_chat=new_chat,
|
||||
func_chat=chat,
|
||||
func_end_chat=end_chat,
|
||||
func_export_to_excel=export_func,
|
||||
)
|
||||
return demo
|
245
libs/kotaemon/kotaemon/contribs/promptui/ui/pipeline.py
Normal file
245
libs/kotaemon/kotaemon/contribs/promptui/ui/pipeline.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import pickle
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
from kotaemon.contribs.promptui.base import get_component
|
||||
from kotaemon.contribs.promptui.export import export
|
||||
from theflow.storage import storage
|
||||
|
||||
from ..logs import ResultLog
|
||||
|
||||
USAGE_INSTRUCTION = """## How to use:
|
||||
|
||||
1. Set the desired parameters.
|
||||
2. Set the desired inputs.
|
||||
3. Click "Run" to execute the pipeline with the supplied parameters and inputs
|
||||
4. The pipeline output will show up in the output panel.
|
||||
5. Repeat from step 1.
|
||||
6. To compare the result of different run, click "Export" to get an Excel
|
||||
spreadsheet summary of different run.
|
||||
|
||||
## Support:
|
||||
|
||||
In case of errors, you can:
|
||||
|
||||
- PromptUI instruction:
|
||||
https://github.com/Cinnamon/kotaemon/wiki/Utilities#prompt-engineering-ui
|
||||
- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon
|
||||
- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization
|
||||
|
||||
## Contribute:
|
||||
|
||||
- Follow installation at: https://github.com/Cinnamon/kotaemon/
|
||||
"""
|
||||
|
||||
|
||||
def construct_pipeline_ui(
|
||||
config, func_run, func_save, func_load_params, func_activate_params, func_export
|
||||
) -> gr.Blocks:
|
||||
"""Create UI from config file. Execute the UI from config file
|
||||
|
||||
- Can do now: Log from stdout to UI
|
||||
- In the future, we can provide some hooks and callbacks to let developers better
|
||||
fine-tune the UI behavior.
|
||||
"""
|
||||
inputs, outputs, params = [], [], []
|
||||
for name, component_def in config.get("inputs", {}).items():
|
||||
if "params" not in component_def:
|
||||
component_def["params"] = {}
|
||||
component_def["params"]["interactive"] = True
|
||||
component = get_component(component_def)
|
||||
if hasattr(component, "label") and not component.label: # type: ignore
|
||||
component.label = name # type: ignore
|
||||
|
||||
inputs.append(component)
|
||||
|
||||
for name, component_def in config.get("params", {}).items():
|
||||
if "params" not in component_def:
|
||||
component_def["params"] = {}
|
||||
component_def["params"]["interactive"] = True
|
||||
component = get_component(component_def)
|
||||
if hasattr(component, "label") and not component.label: # type: ignore
|
||||
component.label = name # type: ignore
|
||||
|
||||
params.append(component)
|
||||
|
||||
for idx, component_def in enumerate(config.get("outputs", [])):
|
||||
if "params" not in component_def:
|
||||
component_def["params"] = {}
|
||||
component_def["params"]["interactive"] = False
|
||||
component = get_component(component_def)
|
||||
if hasattr(component, "label") and not component.label: # type: ignore
|
||||
component.label = f"Output {idx}" # type: ignore
|
||||
|
||||
outputs.append(component)
|
||||
|
||||
exported_file = gr.File(label="Output file", show_label=True)
|
||||
history_dataframe = gr.DataFrame(wrap=True)
|
||||
|
||||
temp = gr.Tab
|
||||
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
|
||||
with gr.Accordion(label="HOW TO", open=False):
|
||||
gr.Markdown(USAGE_INSTRUCTION)
|
||||
with gr.Accordion(label="Params History", open=False):
|
||||
with gr.Row():
|
||||
save_btn = gr.Button("Save params")
|
||||
save_btn.click(func_save, inputs=params, outputs=history_dataframe)
|
||||
load_params_btn = gr.Button("Reload params")
|
||||
load_params_btn.click(
|
||||
func_load_params, inputs=None, outputs=history_dataframe
|
||||
)
|
||||
history_dataframe.render()
|
||||
history_dataframe.select(
|
||||
func_activate_params, inputs=params, outputs=params
|
||||
)
|
||||
with gr.Row():
|
||||
run_btn = gr.Button("Run")
|
||||
run_btn.click(func_run, inputs=inputs + params, outputs=outputs)
|
||||
export_btn = gr.Button(
|
||||
"Export (Result will be in Exported file next to Output)"
|
||||
)
|
||||
export_btn.click(func_export, inputs=None, outputs=exported_file)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
if params:
|
||||
with temp("Params"):
|
||||
for component in params:
|
||||
component.render()
|
||||
if inputs:
|
||||
with temp("Inputs"):
|
||||
for component in inputs:
|
||||
component.render()
|
||||
if not params and not inputs:
|
||||
gr.Text("No params or inputs")
|
||||
with gr.Column():
|
||||
with temp("Outputs"):
|
||||
for component in outputs:
|
||||
component.render()
|
||||
with temp("Exported file"):
|
||||
exported_file.render()
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def load_saved_params(path: str) -> Dict:
|
||||
"""Load the saved params from path to a dataframe"""
|
||||
# get all pickle files
|
||||
files = list(sorted(Path(path).glob("*.pkl")))
|
||||
data: Dict[str, Any] = {"_id": [None] * len(files)}
|
||||
for idx, each_file in enumerate(files):
|
||||
with open(each_file, "rb") as f:
|
||||
each_data = pickle.load(f)
|
||||
data["_id"][idx] = Path(each_file).stem
|
||||
for key, value in each_data.items():
|
||||
if key not in data:
|
||||
data[key] = [None] * len(files)
|
||||
data[key][idx] = value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def build_pipeline_ui(config: dict, pipeline_def):
|
||||
"""Build a tab from config file"""
|
||||
inputs_name = list(config.get("inputs", {}).keys())
|
||||
params_name = list(config.get("params", {}).keys())
|
||||
outputs_def = config.get("outputs", [])
|
||||
|
||||
output_dir: Path = Path(storage.url(pipeline_def().config.store_result))
|
||||
exported_dir = output_dir.parent / "exported"
|
||||
exported_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_dir = (
|
||||
output_dir.parent
|
||||
/ "saved"
|
||||
/ f"{pipeline_def.__module__}.{pipeline_def.__name__}"
|
||||
)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
|
||||
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
|
||||
|
||||
def run_func(*args):
|
||||
inputs = {
|
||||
name: value for name, value in zip(inputs_name, args[: len(inputs_name)])
|
||||
}
|
||||
params = {
|
||||
name: value for name, value in zip(params_name, args[len(inputs_name) :])
|
||||
}
|
||||
pipeline = pipeline_def()
|
||||
pipeline.set(params)
|
||||
pipeline(**inputs)
|
||||
with storage.open(
|
||||
storage.url(
|
||||
pipeline.config.store_result, pipeline.last_run.id(), "params.pkl"
|
||||
),
|
||||
"wb",
|
||||
) as f:
|
||||
pickle.dump(params, f)
|
||||
if outputs_def:
|
||||
outputs = []
|
||||
for output_def in outputs_def:
|
||||
output = pipeline.last_run.logs(output_def["step"])
|
||||
getter = output_def.get("getter", None)
|
||||
if getter and getter in allowed_resultlog_callbacks:
|
||||
output = getattr(resultlog, getter)(output)
|
||||
outputs.append(output)
|
||||
if len(outputs_def) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
def save_func(*args):
|
||||
params = {name: value for name, value in zip(params_name, args)}
|
||||
filename = save_dir / f"{int(time.time())}.pkl"
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(params, f)
|
||||
gr.Info("Params saved")
|
||||
|
||||
data = load_saved_params(str(save_dir))
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def load_params_func():
|
||||
data = load_saved_params(str(save_dir))
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def activate_params_func(ev: gr.SelectData, *args):
|
||||
data = load_saved_params(str(save_dir))
|
||||
output_args = [each for each in args]
|
||||
if ev.value is None:
|
||||
gr.Info(f'Blank value: "{ev.value}". Skip')
|
||||
return output_args
|
||||
|
||||
column = list(data.keys())[ev.index[1]]
|
||||
|
||||
if column not in params_name:
|
||||
gr.Info(f'Column "{column}" not in params. Skip')
|
||||
return output_args
|
||||
|
||||
value = data[column][ev.index[0]]
|
||||
if value is None:
|
||||
gr.Info(f'Blank value: "{ev.value}". Skip')
|
||||
return output_args
|
||||
|
||||
output_args[params_name.index(column)] = value
|
||||
|
||||
return output_args
|
||||
|
||||
def export_func():
|
||||
name = (
|
||||
f"{pipeline_def.__module__}.{pipeline_def.__name__}_{datetime.now()}.xlsx"
|
||||
)
|
||||
path = str(exported_dir / name)
|
||||
gr.Info(f"Begin exporting {name}...")
|
||||
try:
|
||||
export(config=config, pipeline_def=pipeline_def, output_path=path)
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Failed to export. Please contact project's AIR: {e}")
|
||||
gr.Info(f"Exported {name}. Please go to the `Exported file` tab to download")
|
||||
return path
|
||||
|
||||
return construct_pipeline_ui(
|
||||
config, run_func, save_func, load_params_func, activate_params_func, export_func
|
||||
)
|
15
libs/kotaemon/kotaemon/embeddings/__init__.py
Normal file
15
libs/kotaemon/kotaemon/embeddings/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .base import BaseEmbeddings
|
||||
from .langchain_based import (
|
||||
AzureOpenAIEmbeddings,
|
||||
CohereEmbdeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"CohereEmbdeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
]
|
13
libs/kotaemon/kotaemon/embeddings/base.py
Normal file
13
libs/kotaemon/kotaemon/embeddings/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, DocumentWithEmbedding
|
||||
|
||||
|
||||
class BaseEmbeddings(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
...
|
209
libs/kotaemon/kotaemon/embeddings/langchain_based.py
Normal file
209
libs/kotaemon/kotaemon/embeddings/langchain_based.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from typing import Optional
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
|
||||
class LCEmbeddingMixin:
|
||||
def _get_lc_class(self):
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def __init__(self, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
self._kwargs: dict = params
|
||||
|
||||
super().__init__()
|
||||
|
||||
def run(self, text):
|
||||
input_: list[str] = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
for item in text:
|
||||
if isinstance(item, str):
|
||||
input_.append(item)
|
||||
elif isinstance(item, Document):
|
||||
input_.append(item.text)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
embeddings = self._obj.embed_documents(input_)
|
||||
|
||||
return [
|
||||
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
|
||||
for each_text, each_embedding in zip(input_, embeddings)
|
||||
]
|
||||
|
||||
def __repr__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = repr(value_obj)
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __str__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = str(value_obj)
|
||||
if len(value) > 20:
|
||||
value = f"{value[:15]}..."
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == "_lc_class":
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
if name in self._lc_class.__fields__:
|
||||
self._kwargs[name] = value
|
||||
self._obj = self._lc_class(**self._kwargs)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self._kwargs:
|
||||
return self._kwargs[name]
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def dump(self, *args, **kwargs):
|
||||
from theflow.utils.modules import serialize
|
||||
|
||||
params = {key: serialize(value) for key, value in self._kwargs.items()}
|
||||
return {
|
||||
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
|
||||
**params,
|
||||
}
|
||||
|
||||
def specs(self, path: str):
|
||||
path = path.strip(".")
|
||||
if "." in path:
|
||||
raise ValueError("path should not contain '.'")
|
||||
|
||||
if path in self._lc_class.__fields__:
|
||||
return {
|
||||
"__type__": "theflow.base.ParamAttr",
|
||||
"refresh_on_set": True,
|
||||
"strict_type": True,
|
||||
}
|
||||
|
||||
raise ValueError(f"Invalid param {path}")
|
||||
|
||||
|
||||
class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's OpenAI embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "text-embedding-ada-002",
|
||||
openai_api_version: Optional[str] = None,
|
||||
openai_api_base: Optional[str] = None,
|
||||
openai_api_type: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
request_timeout: Optional[float] = None,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
openai_api_version=openai_api_version,
|
||||
openai_api_base=openai_api_base,
|
||||
openai_api_type=openai_api_type,
|
||||
openai_api_key=openai_api_key,
|
||||
request_timeout=request_timeout,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's AzureOpenAI embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: Optional[str] = None,
|
||||
deployment: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
openai_api_version: Optional[str] = None,
|
||||
request_timeout: Optional[float] = None,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
azure_endpoint=azure_endpoint,
|
||||
deployment=deployment,
|
||||
openai_api_version=openai_api_version,
|
||||
openai_api_key=openai_api_key,
|
||||
request_timeout=request_timeout,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import AzureOpenAIEmbeddings
|
||||
|
||||
return AzureOpenAIEmbeddings
|
||||
|
||||
|
||||
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "embed-english-v2.0",
|
||||
cohere_api_key: Optional[str] = None,
|
||||
truncate: Optional[str] = None,
|
||||
request_timeout: Optional[float] = None,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
cohere_api_key=cohere_api_key,
|
||||
truncate=truncate,
|
||||
request_timeout=request_timeout,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_community.embeddings import CohereEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
|
||||
return CohereEmbeddings
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
model_name=model_name,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
return HuggingFaceBgeEmbeddings
|
3
libs/kotaemon/kotaemon/indices/__init__.py
Normal file
3
libs/kotaemon/kotaemon/indices/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vectorindex import VectorIndexing, VectorRetrieval
|
||||
|
||||
__all__ = ["VectorIndexing", "VectorRetrieval"]
|
122
libs/kotaemon/kotaemon/indices/base.py
Normal file
122
libs/kotaemon/kotaemon/indices/base.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Type
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from llama_index.node_parser.interface import NodeParser
|
||||
|
||||
|
||||
class DocTransformer(BaseComponent):
|
||||
"""This is a base class for document transformers
|
||||
|
||||
A document transformer transforms a list of documents into another list
|
||||
of documents. Transforming can mean splitting a document into multiple documents,
|
||||
reducing a large list of documents into a smaller list of documents, or adding
|
||||
metadata to each document in a list of documents, etc.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
**kwargs,
|
||||
) -> list[Document]:
|
||||
...
|
||||
|
||||
|
||||
class LlamaIndexDocTransformerMixin:
|
||||
"""Allow automatically wrapping a Llama-index component into kotaemon component
|
||||
|
||||
Example:
|
||||
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
|
||||
def _get_li_class(self):
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
return TokenTextSplitter
|
||||
|
||||
To use this mixin, please:
|
||||
1. Use this class as the 1st parent class, so that Python will prefer to use
|
||||
the attributes and methods of this class whenever possible.
|
||||
2. Overwrite `_get_li_class` to return the relevant LlamaIndex component.
|
||||
"""
|
||||
|
||||
def _get_li_class(self) -> Type[NodeParser]:
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant LlamaIndex class in _get_li_class"
|
||||
)
|
||||
|
||||
def __init__(self, **params):
|
||||
self._li_cls = self._get_li_class()
|
||||
self._obj = self._li_cls(**params)
|
||||
self._kwargs = params
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = repr(value_obj)
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __str__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = str(value_obj)
|
||||
if len(value) > 20:
|
||||
value = f"{value[:15]}..."
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_") or name in self._protected_keywords():
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
self._kwargs[name] = value
|
||||
return setattr(self._obj, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._kwargs:
|
||||
return self._kwargs[name]
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def dump(self, *args, **kwargs):
|
||||
from theflow.utils.modules import serialize
|
||||
|
||||
params = {key: serialize(value) for key, value in self._kwargs.items()}
|
||||
return {
|
||||
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
|
||||
**params,
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
**kwargs,
|
||||
) -> list[Document]:
|
||||
"""Run Llama-index node parser and convert the output to Document from
|
||||
kotaemon
|
||||
"""
|
||||
docs = self._obj(documents, **kwargs) # type: ignore
|
||||
return [Document.from_dict(doc.to_dict()) for doc in docs]
|
||||
|
||||
|
||||
class BaseIndexing(BaseComponent):
|
||||
"""Define the base interface for indexing pipeline"""
|
||||
|
||||
def to_retrieval_pipeline(self, **kwargs):
|
||||
"""Convert the indexing pipeline to a retrieval pipeline"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_qa_pipeline(self, **kwargs):
|
||||
"""Convert the indexing pipeline to a QA pipeline"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseRetrieval(BaseComponent):
|
||||
"""Define the base interface for retrieval pipeline"""
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs) -> list[RetrievedDocument]:
|
||||
...
|
7
libs/kotaemon/kotaemon/indices/extractors/__init__.py
Normal file
7
libs/kotaemon/kotaemon/indices/extractors/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .doc_parsers import BaseDocParser, SummaryExtractor, TitleExtractor
|
||||
|
||||
__all__ = [
|
||||
"BaseDocParser",
|
||||
"TitleExtractor",
|
||||
"SummaryExtractor",
|
||||
]
|
35
libs/kotaemon/kotaemon/indices/extractors/doc_parsers.py
Normal file
35
libs/kotaemon/kotaemon/indices/extractors/doc_parsers.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from ..base import DocTransformer, LlamaIndexDocTransformerMixin
|
||||
|
||||
|
||||
class BaseDocParser(DocTransformer):
|
||||
...
|
||||
|
||||
|
||||
class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||
def __init__(
|
||||
self,
|
||||
llm=None,
|
||||
nodes: int = 5,
|
||||
**params,
|
||||
):
|
||||
super().__init__(llm=llm, nodes=nodes, **params)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.extractors import TitleExtractor
|
||||
|
||||
return TitleExtractor
|
||||
|
||||
|
||||
class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||
def __init__(
|
||||
self,
|
||||
llm=None,
|
||||
summaries: list[str] = ["self"],
|
||||
**params,
|
||||
):
|
||||
super().__init__(llm=llm, summaries=summaries, **params)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.extractors import SummaryExtractor
|
||||
|
||||
return SummaryExtractor
|
3
libs/kotaemon/kotaemon/indices/ingests/__init__.py
Normal file
3
libs/kotaemon/kotaemon/indices/ingests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .files import DocumentIngestor
|
||||
|
||||
__all__ = ["DocumentIngestor"]
|
85
libs/kotaemon/kotaemon/indices/ingests/files.py
Normal file
85
libs/kotaemon/kotaemon/indices/ingests/files.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Param
|
||||
from kotaemon.indices.extractors import BaseDocParser
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
from kotaemon.loaders import (
|
||||
AutoReader,
|
||||
DirectoryReader,
|
||||
MathpixPDFReader,
|
||||
OCRReader,
|
||||
PandasExcelReader,
|
||||
UnstructuredReader,
|
||||
)
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
|
||||
class DocumentIngestor(BaseComponent):
|
||||
"""Ingest common office document types into Document for indexing
|
||||
|
||||
Document types:
|
||||
- pdf
|
||||
- xlsx, xls
|
||||
- docx, doc
|
||||
|
||||
Args:
|
||||
pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr"
|
||||
- normal: parse pdf text
|
||||
- mathpix: parse pdf text using mathpix
|
||||
- ocr: parse pdf image using flax
|
||||
doc_parsers: list of document parsers to parse the document
|
||||
text_splitter: splitter to split the document into text nodes
|
||||
"""
|
||||
|
||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
|
||||
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
|
||||
text_splitter: BaseSplitter = TokenSplitter.withx(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=256,
|
||||
)
|
||||
|
||||
def _get_reader(self, input_files: list[str | Path]):
|
||||
"""Get appropriate readers for the input files based on file extension"""
|
||||
file_extractor: dict[str, AutoReader | BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
".docx": UnstructuredReader(),
|
||||
".xls": UnstructuredReader(),
|
||||
".doc": UnstructuredReader(),
|
||||
}
|
||||
|
||||
if self.pdf_mode == "normal":
|
||||
file_extractor[".pdf"] = AutoReader("UnstructuredReader")
|
||||
elif self.pdf_mode == "ocr":
|
||||
file_extractor[".pdf"] = OCRReader()
|
||||
else:
|
||||
file_extractor[".pdf"] = MathpixPDFReader()
|
||||
|
||||
main_reader = DirectoryReader(
|
||||
input_files=input_files,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
|
||||
return main_reader
|
||||
|
||||
def run(self, file_paths: list[str | Path] | str | Path) -> list[Document]:
|
||||
"""Ingest the file paths into Document
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths or a single file path
|
||||
|
||||
Returns:
|
||||
list of parsed Documents
|
||||
"""
|
||||
if not isinstance(file_paths, list):
|
||||
file_paths = [file_paths]
|
||||
|
||||
documents = self._get_reader(input_files=file_paths)()
|
||||
nodes = self.text_splitter(documents)
|
||||
self.log_progress(".num_docs", num_docs=len(nodes))
|
||||
|
||||
# document parsers call
|
||||
if self.doc_parsers:
|
||||
for parser in self.doc_parsers:
|
||||
nodes = parser(nodes)
|
||||
|
||||
return nodes
|
7
libs/kotaemon/kotaemon/indices/qa/__init__.py
Normal file
7
libs/kotaemon/kotaemon/indices/qa/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .citation import CitationPipeline
|
||||
from .text_based import CitationQAPipeline
|
||||
|
||||
__all__ = [
|
||||
"CitationPipeline",
|
||||
"CitationQAPipeline",
|
||||
]
|
101
libs/kotaemon/kotaemon/indices/qa/citation.py
Normal file
101
libs/kotaemon/kotaemon/indices/qa/citation.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||
from kotaemon.llms import BaseLLM
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing a single statement.
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CitationPipeline(BaseComponent):
|
||||
"""Citation pipeline to extract cited evidences from source
|
||||
(based on input question)"""
|
||||
|
||||
llm: BaseLLM
|
||||
|
||||
def run(self, context: str, question: str):
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
llm_kwargs = {
|
||||
"functions": [function],
|
||||
"function_call": {"name": function["name"]},
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a world class algorithm to answer "
|
||||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessage(content=context),
|
||||
HumanMessage(content=f"Question: {question}"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Tips: Make sure to cite your sources, "
|
||||
"and use the exact words from the context."
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
llm_output = self.llm(messages, **llm_kwargs)
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
return output
|
63
libs/kotaemon/kotaemon/indices/qa/text_based.py
Normal file
63
libs/kotaemon/kotaemon/indices/qa/text_based.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
|
||||
|
||||
class CitationQAPipeline(BaseComponent):
|
||||
"""Answering question from a text corpus with citation"""
|
||||
|
||||
qa_prompt_template: PromptTemplate = PromptTemplate(
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-07-01-preview",
|
||||
deployment_name="dummy-q2-16k",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda self: CitationPipeline(llm=self.llm)
|
||||
)
|
||||
|
||||
def _format_doc_text(self, text: str) -> str:
|
||||
"""Format the text of each document"""
|
||||
return text.replace("\n", " ")
|
||||
|
||||
def _format_retrieved_context(self, documents: list[RetrievedDocument]) -> str:
|
||||
"""Format the texts between all documents"""
|
||||
matched_texts: list[str] = [
|
||||
self._format_doc_text(doc.text) for doc in documents
|
||||
]
|
||||
return "\n\n".join(matched_texts)
|
||||
|
||||
def run(
|
||||
self,
|
||||
question: str,
|
||||
documents: list[RetrievedDocument],
|
||||
use_citation: bool = False,
|
||||
**kwargs
|
||||
) -> Document:
|
||||
# retrieve relevant documents as context
|
||||
context = self._format_retrieved_context(documents)
|
||||
self.log_progress(".context", context=context)
|
||||
|
||||
# generate the answer
|
||||
prompt = self.qa_prompt_template.populate(
|
||||
context=context,
|
||||
question=question,
|
||||
)
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
citation = self.citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
answer = Document(text=answer_text, metadata={"citation": citation})
|
||||
return answer
|
5
libs/kotaemon/kotaemon/indices/rankings/__init__.py
Normal file
5
libs/kotaemon/kotaemon/indices/rankings/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .llm import LLMReranking
|
||||
|
||||
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]
|
13
libs/kotaemon/kotaemon/indices/rankings/base.py
Normal file
13
libs/kotaemon/kotaemon/indices/rankings/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
|
||||
|
||||
class BaseReranking(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Main method to transform list of documents
|
||||
(re-ranking, filtering, etc)"""
|
||||
...
|
40
libs/kotaemon/kotaemon/indices/rankings/cohere.py
Normal file
40
libs/kotaemon/kotaemon/indices/rankings/cohere.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
model_name: str = "rerank-multilingual-v2.0"
|
||||
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
|
||||
top_k: int = 1
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
with their relevance score"""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
65
libs/kotaemon/kotaemon/indices/rankings/llm.py
Normal file
65
libs/kotaemon/kotaemon/indices/rankings/llm.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
|
||||
return YES if the context is relevant to the question and NO if it isn't.
|
||||
|
||||
> Question: {question}
|
||||
> Context:
|
||||
>>>
|
||||
{context}
|
||||
>>>
|
||||
> Relevant (YES / NO):"""
|
||||
|
||||
|
||||
class LLMReranking(BaseReranking):
|
||||
llm: BaseLLM
|
||||
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
|
||||
top_k: int = 3
|
||||
concurrent: bool = True
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
query: str,
|
||||
) -> list[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
output_parser = BooleanOutputParser()
|
||||
|
||||
if self.concurrent:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
futures.append(executor.submit(lambda: self.llm(_prompt).text))
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
else:
|
||||
results = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
results.append(self.llm(_prompt).text)
|
||||
|
||||
# use Boolean parser to extract relevancy output from LLM
|
||||
results = [output_parser.parse(result) for result in results]
|
||||
for include_doc, doc in zip(results, documents):
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
# prevent returning empty result
|
||||
if len(filtered_docs) == 0:
|
||||
filtered_docs = documents[: self.top_k]
|
||||
|
||||
return filtered_docs
|
49
libs/kotaemon/kotaemon/indices/splitters/__init__.py
Normal file
49
libs/kotaemon/kotaemon/indices/splitters/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from ..base import DocTransformer, LlamaIndexDocTransformerMixin
|
||||
|
||||
|
||||
class BaseSplitter(DocTransformer):
|
||||
"""Represent base splitter class"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1024,
|
||||
chunk_overlap: int = 20,
|
||||
separator: str = " ",
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=separator,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
|
||||
return TokenTextSplitter
|
||||
|
||||
|
||||
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 3,
|
||||
window_metadata_key: str = "window",
|
||||
original_text_metadata_key: str = "original_text",
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
window_size=window_size,
|
||||
window_metadata_key=window_metadata_key,
|
||||
original_text_metadata_key=original_text_metadata_key,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.node_parser import SentenceWindowNodeParser
|
||||
|
||||
return SentenceWindowNodeParser
|
122
libs/kotaemon/kotaemon/indices/vectorindex.py
Normal file
122
libs/kotaemon/kotaemon/indices/vectorindex.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Optional, Sequence, cast
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.embeddings import BaseEmbeddings
|
||||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||
|
||||
from .base import BaseIndexing, BaseRetrieval
|
||||
from .rankings import BaseReranking
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
|
||||
|
||||
class VectorIndexing(BaseIndexing):
|
||||
"""Ingest the document, run through the embedding, and store the embedding in a
|
||||
vector store.
|
||||
|
||||
This pipeline supports the following set of inputs:
|
||||
- List of documents
|
||||
- List of texts
|
||||
"""
|
||||
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
|
||||
def to_retrieval_pipeline(self, *args, **kwargs):
|
||||
"""Convert the indexing pipeline to a retrieval pipeline"""
|
||||
return VectorRetrieval(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_qa_pipeline(self, *args, **kwargs):
|
||||
from .qa import CitationQAPipeline
|
||||
|
||||
return TextVectorQA(
|
||||
retrieving_pipeline=self.to_retrieval_pipeline(**kwargs),
|
||||
qa_pipeline=CitationQAPipeline(**kwargs),
|
||||
)
|
||||
|
||||
def run(self, text: str | list[str] | Document | list[Document]):
|
||||
input_: list[Document] = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
for item in cast(list, text):
|
||||
if isinstance(item, str):
|
||||
input_.append(Document(text=item, id_=str(uuid.uuid4())))
|
||||
elif isinstance(item, Document):
|
||||
input_.append(item)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
embeddings = self.embedding(input_)
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.doc_id for t in input_],
|
||||
)
|
||||
if self.doc_store:
|
||||
self.doc_store.add(input_)
|
||||
|
||||
|
||||
class VectorRetrieval(BaseRetrieval):
|
||||
"""Retrieve list of documents from vector store"""
|
||||
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
rerankers: Sequence[BaseReranking] = []
|
||||
top_k: int = 1
|
||||
|
||||
def run(
|
||||
self, text: str | Document, top_k: Optional[int] = None, **kwargs
|
||||
) -> list[RetrievedDocument]:
|
||||
"""Retrieve a list of documents from vector store
|
||||
|
||||
Args:
|
||||
text: the text to retrieve similar documents
|
||||
top_k: number of top similar documents to return
|
||||
|
||||
Returns:
|
||||
list[RetrievedDocument]: list of retrieved documents
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
if self.doc_store is None:
|
||||
raise ValueError(
|
||||
"doc_store is not provided. Please provide a doc_store to "
|
||||
"retrieve the documents"
|
||||
)
|
||||
|
||||
emb: list[float] = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
for doc, score in zip(docs, scores)
|
||||
]
|
||||
# use additional reranker to re-order the document list
|
||||
if self.rerankers:
|
||||
for reranker in self.rerankers:
|
||||
result = reranker(documents=result, query=text)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TextVectorQA(BaseComponent):
|
||||
retrieving_pipeline: BaseRetrieval
|
||||
qa_pipeline: BaseComponent
|
||||
|
||||
def run(self, question, **kwargs):
|
||||
retrieved_documents = self.retrieving_pipeline(question, **kwargs)
|
||||
return self.qa_pipeline(question, retrieved_documents, **kwargs)
|
35
libs/kotaemon/kotaemon/llms/__init__.py
Normal file
35
libs/kotaemon/kotaemon/llms/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from .base import BaseLLM
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, ChatLLM
|
||||
from .completions import LLM, AzureOpenAI, OpenAI
|
||||
from .cot import ManualSequentialChainOfThought, Thought
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
from .prompts import BasePromptComponent, PromptTemplate
|
||||
|
||||
__all__ = [
|
||||
"BaseLLM",
|
||||
# chat-specific components
|
||||
"ChatLLM",
|
||||
"BaseMessage",
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
"SystemMessage",
|
||||
"AzureChatOpenAI",
|
||||
# completion-specific components
|
||||
"LLM",
|
||||
"OpenAI",
|
||||
"AzureOpenAI",
|
||||
# prompt-specific components
|
||||
"BasePromptComponent",
|
||||
"PromptTemplate",
|
||||
# strategies
|
||||
"SimpleLinearPipeline",
|
||||
"GatedLinearPipeline",
|
||||
"SimpleBranchingPipeline",
|
||||
"GatedBranchingPipeline",
|
||||
# chain-of-thoughts
|
||||
"ManualSequentialChainOfThought",
|
||||
"Thought",
|
||||
]
|
7
libs/kotaemon/kotaemon/llms/base.py
Normal file
7
libs/kotaemon/kotaemon/llms/base.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from kotaemon.base import BaseComponent
|
||||
from langchain_core.language_models.base import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseLLM(BaseComponent):
|
||||
def to_langchain_format(self) -> BaseLanguageModel:
|
||||
raise NotImplementedError
|
186
libs/kotaemon/kotaemon/llms/branching.py
Normal file
186
libs/kotaemon/kotaemon/llms/branching.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Param
|
||||
|
||||
from .linear import GatedLinearPipeline
|
||||
|
||||
|
||||
class SimpleBranchingPipeline(BaseComponent):
|
||||
"""
|
||||
A simple branching pipeline for executing multiple branches.
|
||||
|
||||
Attributes:
|
||||
branches (List[BaseComponent]): The list of branches to be executed.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedLinearPipeline,
|
||||
)
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
pipeline = SimpleBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
pipeline.add_branch(
|
||||
GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern=f"{i}"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
)
|
||||
print(pipeline(condition_text="1"))
|
||||
print(pipeline(condition_text="2"))
|
||||
print(pipeline(condition_text="12"))
|
||||
```
|
||||
"""
|
||||
|
||||
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
|
||||
|
||||
def add_branch(self, component: BaseComponent):
|
||||
"""
|
||||
Add a new branch to the pipeline.
|
||||
|
||||
Args:
|
||||
component (BaseComponent): The branch component to be added.
|
||||
"""
|
||||
self.branches.append(component)
|
||||
|
||||
def run(self, **prompt_kwargs):
|
||||
"""
|
||||
Execute the pipeline by running each branch and return the outputs as a list.
|
||||
|
||||
Args:
|
||||
**prompt_kwargs: Keyword arguments for the branches.
|
||||
|
||||
Returns:
|
||||
List: The outputs of each branch as a list.
|
||||
"""
|
||||
output = []
|
||||
for i, branch in enumerate(self.branches):
|
||||
self._prepare_child(branch, name=f"branch-{i}")
|
||||
output.append(branch(**prompt_kwargs))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||
"""
|
||||
A simple gated branching pipeline for executing multiple branches based on a
|
||||
condition.
|
||||
|
||||
This class extends the SimpleBranchingPipeline class and adds the ability to execute
|
||||
the branches until a branch returns a non-empty output based on a condition.
|
||||
|
||||
Attributes:
|
||||
branches (List[BaseComponent]): The list of branches to be executed.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedLinearPipeline,
|
||||
)
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
pipeline = GatedBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
pipeline.add_branch(
|
||||
GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern=f"{i}"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
)
|
||||
print(pipeline(condition_text="1"))
|
||||
print(pipeline(condition_text="2"))
|
||||
```
|
||||
"""
|
||||
|
||||
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
|
||||
"""
|
||||
Execute the pipeline by running each branch and return the output of the first
|
||||
branch that returns a non-empty output based on the provided condition.
|
||||
|
||||
Args:
|
||||
condition_text (str): The condition text to evaluate for each branch.
|
||||
Default to None.
|
||||
**prompt_kwargs: Keyword arguments for the branches.
|
||||
|
||||
Returns:
|
||||
Union[OutputType, None]: The output of the first branch that satisfies the
|
||||
condition, or None if no branch satisfies the condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If condition_text is None
|
||||
"""
|
||||
if condition_text is None:
|
||||
raise ValueError("`condition_text` must be provided.")
|
||||
|
||||
for i, branch in enumerate(self.branches):
|
||||
self._prepare_child(branch, name=f"branch-{i}")
|
||||
output = branch(condition_text=condition_text, **prompt_kwargs)
|
||||
if output:
|
||||
return output
|
||||
|
||||
return Document(None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dotenv
|
||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
secrets = dotenv.dotenv_values(".env")
|
||||
|
||||
pipeline = GatedBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
|
||||
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
pipeline.add_branch(
|
||||
GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern=f"{i}"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
)
|
||||
pipeline(condition_text="1")
|
4
libs/kotaemon/kotaemon/llms/chats/__init__.py
Normal file
4
libs/kotaemon/kotaemon/llms/chats/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import ChatLLM
|
||||
from .langchain_based import AzureChatOpenAI, LCChatMixin
|
||||
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]
|
22
libs/kotaemon/kotaemon/llms/chats/base.py
Normal file
22
libs/kotaemon/kotaemon/llms/chats/base.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms.base import BaseLLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatLLM(BaseLLM):
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
||||
if not isinstance(self.inflow, BaseComponent):
|
||||
raise ValueError(
|
||||
f"inflow must be a BaseComponent, found {type(self.inflow)}"
|
||||
)
|
||||
|
||||
text = self.inflow.flow().text
|
||||
return self.__call__(text)
|
170
libs/kotaemon/kotaemon/llms/chats/langchain_based.py
Normal file
170
libs/kotaemon/kotaemon/llms/chats/langchain_based.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
|
||||
|
||||
from .base import ChatLLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LCChatMixin:
|
||||
def _get_lc_class(self):
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def __init__(self, stream: bool = False, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
self._kwargs: dict = params
|
||||
self._stream = stream
|
||||
|
||||
super().__init__()
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
if self._stream:
|
||||
return self.stream(messages, **kwargs) # type: ignore
|
||||
return self.invoke(messages, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Generate response from messages
|
||||
|
||||
Args:
|
||||
messages: history of messages to generate response from
|
||||
**kwargs: additional arguments to pass to the langchain chat model
|
||||
|
||||
Returns:
|
||||
LLMInterface: generated response
|
||||
"""
|
||||
input_: list[BaseMessage] = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
input_ = [HumanMessage(content=messages)]
|
||||
elif isinstance(messages, BaseMessage):
|
||||
input_ = [messages]
|
||||
else:
|
||||
input_ = messages
|
||||
|
||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
all_messages = [each.message for each in pred.generations[0]]
|
||||
|
||||
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
|
||||
try:
|
||||
if pred.llm_output is not None:
|
||||
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
|
||||
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
|
||||
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
|
||||
)
|
||||
|
||||
return LLMInterface(
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
messages=all_messages,
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
|
||||
for response in self._obj.stream(input=messages, **kwargs):
|
||||
yield LLMInterface(content=response.content)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
def __repr__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = repr(value_obj)
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __str__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = str(value_obj)
|
||||
if len(value) > 20:
|
||||
value = f"{value[:15]}..."
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == "_lc_class":
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
if name in self._lc_class.__fields__:
|
||||
self._kwargs[name] = value
|
||||
self._obj = self._lc_class(**self._kwargs)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self._kwargs:
|
||||
return self._kwargs[name]
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def dump(self, *args, **kwargs):
|
||||
from theflow.utils.modules import serialize
|
||||
|
||||
params = {key: serialize(value) for key, value in self._kwargs.items()}
|
||||
return {
|
||||
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
|
||||
**params,
|
||||
}
|
||||
|
||||
def specs(self, path: str):
|
||||
path = path.strip(".")
|
||||
if "." in path:
|
||||
raise ValueError("path should not contain '.'")
|
||||
|
||||
if path in self._lc_class.__fields__:
|
||||
return {
|
||||
"__type__": "theflow.base.ParamAttr",
|
||||
"refresh_on_set": True,
|
||||
"strict_type": True,
|
||||
}
|
||||
|
||||
raise ValueError(f"Invalid param {path}")
|
||||
|
||||
|
||||
class AzureChatOpenAI(LCChatMixin, ChatLLM):
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: str | None = None,
|
||||
openai_api_key: str | None = None,
|
||||
openai_api_version: str = "",
|
||||
deployment_name: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
request_timeout: float | None = None,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
azure_endpoint=azure_endpoint,
|
||||
openai_api_key=openai_api_key,
|
||||
openai_api_version=openai_api_version,
|
||||
deployment_name=deployment_name,
|
||||
temperature=temperature,
|
||||
request_timeout=request_timeout,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
except ImportError:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI
|
4
libs/kotaemon/kotaemon/llms/completions/__init__.py
Normal file
4
libs/kotaemon/kotaemon/llms/completions/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import LLM
|
||||
from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
|
||||
|
||||
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]
|
5
libs/kotaemon/kotaemon/llms/completions/base.py
Normal file
5
libs/kotaemon/kotaemon/llms/completions/base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from kotaemon.llms.base import BaseLLM
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
pass
|
197
libs/kotaemon/kotaemon/llms/completions/langchain_based.py
Normal file
197
libs/kotaemon/kotaemon/llms/completions/langchain_based.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from kotaemon.base import LLMInterface
|
||||
|
||||
from .base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LCCompletionMixin:
|
||||
def _get_lc_class(self):
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def __init__(self, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
self._kwargs: dict = params
|
||||
|
||||
super().__init__()
|
||||
|
||||
def run(self, text: str) -> LLMInterface:
|
||||
pred = self._obj.generate([text])
|
||||
all_text = [each.text for each in pred.generations[0]]
|
||||
|
||||
completion_tokens, total_tokens, prompt_tokens = 0, 0, 0
|
||||
try:
|
||||
if pred.llm_output is not None:
|
||||
completion_tokens = pred.llm_output["token_usage"]["completion_tokens"]
|
||||
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
|
||||
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
|
||||
)
|
||||
|
||||
return LLMInterface(
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
candidates=all_text,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
logits=[],
|
||||
)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
def __repr__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = repr(value_obj)
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __str__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
value = str(value_obj)
|
||||
if len(value) > 20:
|
||||
value = f"{value[:15]}..."
|
||||
kwargs.append(f"{key}={value}")
|
||||
kwargs_repr = ", ".join(kwargs)
|
||||
return f"{self.__class__.__name__}({kwargs_repr})"
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == "_lc_class":
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
if name in self._lc_class.__fields__:
|
||||
self._kwargs[name] = value
|
||||
self._obj = self._lc_class(**self._kwargs)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self._kwargs:
|
||||
return self._kwargs[name]
|
||||
return getattr(self._obj, name)
|
||||
|
||||
def dump(self, *args, **kwargs):
|
||||
from theflow.utils.modules import serialize
|
||||
|
||||
params = {key: serialize(value) for key, value in self._kwargs.items()}
|
||||
return {
|
||||
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
|
||||
**params,
|
||||
}
|
||||
|
||||
def specs(self, path: str):
|
||||
path = path.strip(".")
|
||||
if "." in path:
|
||||
raise ValueError("path should not contain '.'")
|
||||
|
||||
if path in self._lc_class.__fields__:
|
||||
return {
|
||||
"__type__": "theflow.base.ParamAttr",
|
||||
"refresh_on_set": True,
|
||||
"strict_type": True,
|
||||
}
|
||||
|
||||
raise ValueError(f"Invalid param {path}")
|
||||
|
||||
|
||||
class OpenAI(LCCompletionMixin, LLM):
|
||||
"""Wrapper around Langchain's OpenAI class, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
openai_api_key: Optional[str] = None,
|
||||
openai_api_base: Optional[str] = None,
|
||||
model_name: str = "text-davinci-003",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 256,
|
||||
top_p: float = 1,
|
||||
frequency_penalty: float = 0,
|
||||
n: int = 1,
|
||||
best_of: int = 1,
|
||||
request_timeout: Optional[float] = None,
|
||||
max_retries: int = 2,
|
||||
streaming: bool = False,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
openai_api_key=openai_api_key,
|
||||
openai_api_base=openai_api_base,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
n=n,
|
||||
best_of=best_of,
|
||||
request_timeout=request_timeout,
|
||||
max_retries=max_retries,
|
||||
streaming=streaming,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_openai import OpenAI
|
||||
except ImportError:
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
return OpenAI
|
||||
|
||||
|
||||
class AzureOpenAI(LCCompletionMixin, LLM):
|
||||
"""Wrapper around Langchain's AzureOpenAI class, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: Optional[str] = None,
|
||||
deployment_name: Optional[str] = None,
|
||||
openai_api_version: str = "",
|
||||
openai_api_key: Optional[str] = None,
|
||||
model_name: str = "text-davinci-003",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 256,
|
||||
top_p: float = 1,
|
||||
frequency_penalty: float = 0,
|
||||
n: int = 1,
|
||||
best_of: int = 1,
|
||||
request_timeout: Optional[float] = None,
|
||||
max_retries: int = 2,
|
||||
streaming: bool = False,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
azure_endpoint=azure_endpoint,
|
||||
deployment_name=deployment_name,
|
||||
openai_api_version=openai_api_version,
|
||||
openai_api_key=openai_api_key,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
n=n,
|
||||
best_of=best_of,
|
||||
request_timeout=request_timeout,
|
||||
max_retries=max_retries,
|
||||
streaming=streaming,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_openai import AzureOpenAI
|
||||
except ImportError:
|
||||
from langchain.llms import AzureOpenAI
|
||||
|
||||
return AzureOpenAI
|
174
libs/kotaemon/kotaemon/llms/cot.py
Normal file
174
libs/kotaemon/kotaemon/llms/cot.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from copy import deepcopy
|
||||
from typing import Callable, List
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
from theflow import Function, Node, Param
|
||||
|
||||
from .chats import AzureChatOpenAI
|
||||
from .completions import LLM
|
||||
from .prompts import BasePromptComponent
|
||||
|
||||
|
||||
class Thought(BaseComponent):
|
||||
"""A thought in the chain of thought
|
||||
|
||||
- Input: `**kwargs` pairs, where key is the placeholder in the prompt, and
|
||||
value is the value.
|
||||
- Output: an output dictionary
|
||||
|
||||
_**Usage:**_
|
||||
|
||||
Create and run a thought:
|
||||
|
||||
```python
|
||||
>> from kotaemon.pipelines.cot import Thought
|
||||
>> thought = Thought(
|
||||
prompt="How to {action} {object}?",
|
||||
llm=AzureChatOpenAI(...),
|
||||
post_process=lambda string: {"tutorial": string},
|
||||
)
|
||||
>> output = thought(action="install", object="python")
|
||||
>> print(output)
|
||||
{'tutorial': 'As an AI language model,...'}
|
||||
```
|
||||
|
||||
Basically, when a thought is run, it will:
|
||||
|
||||
1. Populate the prompt template with the input `**kwargs`.
|
||||
2. Run the LLM model with the populated prompt.
|
||||
3. Post-process the LLM output with the post-processor.
|
||||
|
||||
This `Thought` allows chaining sequentially with the + operator. For example:
|
||||
|
||||
```python
|
||||
>> llm = AzureChatOpenAI(...)
|
||||
>> thought1 = Thought(
|
||||
prompt="Word {word} in {language} is ",
|
||||
llm=llm,
|
||||
post_process=lambda string: {"translated": string},
|
||||
)
|
||||
>> thought2 = Thought(
|
||||
prompt="Translate {translated} to Japanese",
|
||||
llm=llm,
|
||||
post_process=lambda string: {"output": string},
|
||||
)
|
||||
|
||||
>> thought = thought1 + thought2
|
||||
>> thought(word="hello", language="French")
|
||||
{'word': 'hello',
|
||||
'language': 'French',
|
||||
'translated': '"Bonjour"',
|
||||
'output': 'こんにちは (Konnichiwa)'}
|
||||
```
|
||||
|
||||
Under the hood, when the `+` operator is used, a `ManualSequentialChainOfThought`
|
||||
is created.
|
||||
"""
|
||||
|
||||
prompt: str = Param(
|
||||
help=(
|
||||
"The prompt template string. This prompt template has Python-like variable"
|
||||
" placeholders, that then will be substituted with real values when this"
|
||||
" component is executed"
|
||||
)
|
||||
)
|
||||
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
|
||||
post_process: Function = Node(
|
||||
help=(
|
||||
"The function post-processor that post-processes LLM output prediction ."
|
||||
"It should take a string as input (this is the LLM output text) and return "
|
||||
"a dictionary, where the key should"
|
||||
)
|
||||
)
|
||||
|
||||
@Node.auto(depends_on="prompt")
|
||||
def prompt_template(self):
|
||||
"""Automatically wrap around param prompt. Can ignore"""
|
||||
return BasePromptComponent(self.prompt)
|
||||
|
||||
def run(self, **kwargs) -> Document:
|
||||
"""Run the chain of thought"""
|
||||
prompt = self.prompt_template(**kwargs).text
|
||||
response = self.llm(prompt).text
|
||||
response = self.post_process(response)
|
||||
|
||||
return Document(response)
|
||||
|
||||
def get_variables(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def __add__(self, next_thought: "Thought") -> "ManualSequentialChainOfThought":
|
||||
return ManualSequentialChainOfThought(
|
||||
thoughts=[self, next_thought], llm=self.llm
|
||||
)
|
||||
|
||||
|
||||
class ManualSequentialChainOfThought(BaseComponent):
|
||||
"""Perform sequential chain-of-thought with manual pre-defined prompts
|
||||
|
||||
This method supports variable number of steps. Each step corresponds to a
|
||||
`kotaemon.pipelines.cot.Thought`. Please refer that section for
|
||||
Thought's detail. This section is about chaining thought together.
|
||||
|
||||
_**Usage:**_
|
||||
|
||||
**Create and run a chain of thought without "+" operator:**
|
||||
|
||||
```pycon
|
||||
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
|
||||
>>> llm = AzureChatOpenAI(...)
|
||||
>>> thought1 = Thought(
|
||||
>>> prompt="Word {word} in {language} is ",
|
||||
>>> post_process=lambda string: {"translated": string},
|
||||
>>> )
|
||||
>>> thought2 = Thought(
|
||||
>>> prompt="Translate {translated} to Japanese",
|
||||
>>> post_process=lambda string: {"output": string},
|
||||
>>> )
|
||||
>>> thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
|
||||
>>> thought(word="hello", language="French")
|
||||
{'word': 'hello',
|
||||
'language': 'French',
|
||||
'translated': '"Bonjour"',
|
||||
'output': 'こんにちは (Konnichiwa)'}
|
||||
```
|
||||
|
||||
**Create and run a chain of thought without "+" operator:** Please refer the
|
||||
`kotaemon.pipelines.cot.Thought` section for examples.
|
||||
|
||||
This chain-of-thought optionally takes a termination check callback function.
|
||||
This function will be called after each thought is executed. It takes in a
|
||||
dictionary of all thought outputs so far, and it returns True or False. If
|
||||
True, the chain-of-thought will terminate. If unset, the default callback always
|
||||
returns False.
|
||||
"""
|
||||
|
||||
thoughts: List[Thought] = Param(
|
||||
default_callback=lambda *_: [], help="List of Thought"
|
||||
)
|
||||
llm: LLM = Param(help="The LLM model to use (base of kotaemon.llms.BaseLLM)")
|
||||
terminate: Callable = Param(
|
||||
default=lambda _: False,
|
||||
help="Callback on terminate condition. Default to always return False",
|
||||
)
|
||||
|
||||
def run(self, **kwargs) -> Document:
|
||||
"""Run the manual chain of thought"""
|
||||
|
||||
inputs = deepcopy(kwargs)
|
||||
for idx, thought in enumerate(self.thoughts):
|
||||
if self.llm:
|
||||
thought.llm = self.llm
|
||||
self._prepare_child(thought, f"thought{idx}")
|
||||
|
||||
output = thought(**inputs)
|
||||
inputs.update(output.content)
|
||||
if self.terminate(inputs):
|
||||
break
|
||||
|
||||
return Document(inputs)
|
||||
|
||||
def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought":
|
||||
return ManualSequentialChainOfThought(
|
||||
thoughts=self.thoughts + [next_thought], llm=self.llm
|
||||
)
|
155
libs/kotaemon/kotaemon/llms/linear.py
Normal file
155
libs/kotaemon/kotaemon/llms/linear.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import Document, IO_Type
|
||||
from .chats import ChatLLM
|
||||
from .completions import LLM
|
||||
from .prompts import BasePromptComponent
|
||||
|
||||
|
||||
class SimpleLinearPipeline(BaseComponent):
|
||||
"""
|
||||
A simple pipeline for running a function with a prompt, a language model, and an
|
||||
optional post-processor.
|
||||
|
||||
Attributes:
|
||||
prompt (BasePromptComponent): The prompt component used to generate the initial
|
||||
input.
|
||||
llm (Union[ChatLLM, LLM]): The language model component used to generate the
|
||||
output.
|
||||
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
|
||||
post-processor component or function.
|
||||
|
||||
Example Usage:
|
||||
```python
|
||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
pipeline = SimpleLinearPipeline(
|
||||
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
print(pipeline(word="lone"))
|
||||
```
|
||||
"""
|
||||
|
||||
prompt: BasePromptComponent
|
||||
llm: Union[ChatLLM, LLM]
|
||||
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
llm_kwargs: Optional[dict] = {},
|
||||
post_processor_kwargs: Optional[dict] = {},
|
||||
**prompt_kwargs,
|
||||
):
|
||||
"""
|
||||
Run the function with the given arguments and return the final output as a
|
||||
Document object.
|
||||
|
||||
Args:
|
||||
llm_kwargs (dict): Keyword arguments for the llm call.
|
||||
post_processor_kwargs (dict): Keyword arguments for the post_processor.
|
||||
**prompt_kwargs: Keyword arguments for populating the prompt.
|
||||
|
||||
Returns:
|
||||
Document: The final output of the function as a Document object.
|
||||
"""
|
||||
prompt = self.prompt(**prompt_kwargs)
|
||||
llm_output = self.llm(prompt.text, **llm_kwargs)
|
||||
if self.post_processor is not None:
|
||||
final_output = self.post_processor(llm_output, **post_processor_kwargs)[0]
|
||||
else:
|
||||
final_output = llm_output
|
||||
|
||||
return Document(final_output)
|
||||
|
||||
|
||||
class GatedLinearPipeline(SimpleLinearPipeline):
|
||||
"""
|
||||
A pipeline that extends the SimpleLinearPipeline class and adds a condition
|
||||
attribute.
|
||||
|
||||
Attributes:
|
||||
condition (Callable[[IO_Type], Any]): A callable function that represents the
|
||||
condition.
|
||||
|
||||
Usage:
|
||||
```{.py3 title="Example Usage"}
|
||||
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
pipeline = GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern="some pattern"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
print(pipeline(condition_text="some pattern", word="lone"))
|
||||
print(pipeline(condition_text="other pattern", word="lone"))
|
||||
```
|
||||
"""
|
||||
|
||||
condition: Callable[[IO_Type], Any]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
condition_text: Optional[str] = None,
|
||||
llm_kwargs: Optional[dict] = {},
|
||||
post_processor_kwargs: Optional[dict] = {},
|
||||
**prompt_kwargs,
|
||||
) -> Document:
|
||||
"""
|
||||
Run the pipeline with the given arguments and return the final output as a
|
||||
Document object.
|
||||
|
||||
Args:
|
||||
condition_text (str): The condition text to evaluate. Default to None.
|
||||
llm_kwargs (dict): Additional keyword arguments for the language model call.
|
||||
post_processor_kwargs (dict): Additional keyword arguments for the
|
||||
post-processor.
|
||||
**prompt_kwargs: Keyword arguments for populating the prompt.
|
||||
|
||||
Returns:
|
||||
Document: The final output of the pipeline as a Document object.
|
||||
|
||||
Raises:
|
||||
ValueError: If condition_text is None
|
||||
"""
|
||||
if condition_text is None:
|
||||
raise ValueError("`condition_text` must be provided")
|
||||
|
||||
if self.condition(condition_text)[0]:
|
||||
return super().run(
|
||||
llm_kwargs=llm_kwargs,
|
||||
post_processor_kwargs=post_processor_kwargs,
|
||||
**prompt_kwargs,
|
||||
)
|
||||
|
||||
return Document(None)
|
4
libs/kotaemon/kotaemon/llms/prompts/__init__.py
Normal file
4
libs/kotaemon/kotaemon/llms/prompts/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import BasePromptComponent
|
||||
from .template import PromptTemplate
|
||||
|
||||
__all__ = ["BasePromptComponent", "PromptTemplate"]
|
179
libs/kotaemon/kotaemon/llms/prompts/base.py
Normal file
179
libs/kotaemon/kotaemon/llms/prompts/base.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from typing import Callable, Union
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
|
||||
from .template import PromptTemplate
|
||||
|
||||
|
||||
class BasePromptComponent(BaseComponent):
|
||||
"""
|
||||
Base class for prompt components.
|
||||
|
||||
Args:
|
||||
template (PromptTemplate): The prompt template.
|
||||
**kwargs: Any additional keyword arguments that will be used to populate the
|
||||
given template.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
|
||||
allow_extra = True
|
||||
|
||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||
super().__init__()
|
||||
self.template = (
|
||||
template
|
||||
if isinstance(template, PromptTemplate)
|
||||
else PromptTemplate(template)
|
||||
)
|
||||
|
||||
self.__set(**kwargs)
|
||||
|
||||
def __check_redundant_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check for redundant keyword arguments.
|
||||
|
||||
Parameters:
|
||||
**kwargs (dict): A dictionary of keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If any keys provided are not in the template.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.template.check_redundant_kwargs(**kwargs)
|
||||
|
||||
def __check_unset_placeholders(self):
|
||||
"""
|
||||
Check if all the placeholders in the template are set.
|
||||
|
||||
This function checks if all the expected placeholders in the template are set as
|
||||
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||
is raised with the names of the missing keys.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.template.check_missing_kwargs(**self.__dict__)
|
||||
|
||||
def __validate_value_type(self, **kwargs):
|
||||
"""
|
||||
Validates the value types of the given keyword arguments.
|
||||
|
||||
Parameters:
|
||||
**kwargs (dict): A dictionary of keyword arguments to be validated.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the values in the kwargs dictionary have an
|
||||
unsupported type.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
type_error = []
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance(v, (str, int, Document, Callable)): # type: ignore
|
||||
type_error.append((k, type(v)))
|
||||
|
||||
if type_error:
|
||||
raise ValueError(
|
||||
"Type of values must be either int, str, Document, Callable, "
|
||||
f"found unsupported type for (key, type): {type_error}"
|
||||
)
|
||||
|
||||
def __set(self, **kwargs):
|
||||
"""
|
||||
Set the values of the attributes in the object based on the provided keyword
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
kwargs (dict): A dictionary with the attribute names as keys and the new
|
||||
values as values.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.__check_redundant_kwargs(**kwargs)
|
||||
self.__validate_value_type(**kwargs)
|
||||
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __prepare_value(self):
|
||||
"""
|
||||
Generate a dictionary of keyword arguments based on the template's placeholders
|
||||
and the current instance's attributes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of keyword arguments.
|
||||
"""
|
||||
|
||||
def __prepare(key, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (int, Document)):
|
||||
return str(value)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(value)} for template value of key {key}"
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
for k in self.template.placeholders:
|
||||
v = getattr(self, k)
|
||||
|
||||
# if get a callable, execute to get its output
|
||||
if isinstance(v, Callable): # type: ignore[arg-type]
|
||||
v = v()
|
||||
|
||||
if isinstance(v, list):
|
||||
v = str([__prepare(k, each) for each in v])
|
||||
elif isinstance(v, (str, int, Document)):
|
||||
v = __prepare(k, v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(v)} for template value of key `{k}`"
|
||||
)
|
||||
kwargs[k] = v
|
||||
|
||||
return kwargs
|
||||
|
||||
def set(self, **kwargs):
|
||||
"""
|
||||
Similar to `__set` but for external use.
|
||||
|
||||
Set the values of the attributes in the object based on the provided keyword
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
kwargs (dict): A dictionary with the attribute names as keys and the new
|
||||
values as values.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.__set(**kwargs)
|
||||
|
||||
def run(self, **kwargs):
|
||||
"""
|
||||
Run the function with the given keyword arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
The result of calling the `populate` method of the `template` object
|
||||
with the given keyword arguments.
|
||||
"""
|
||||
self.__set(**kwargs)
|
||||
self.__check_unset_placeholders()
|
||||
prepared_kwargs = self.__prepare_value()
|
||||
|
||||
text = self.template.populate(**prepared_kwargs)
|
||||
return Document(text=text, metadata={"origin": "PromptComponent"})
|
||||
|
||||
def flow(self):
|
||||
return self.__call__()
|
140
libs/kotaemon/kotaemon/llms/prompts/template.py
Normal file
140
libs/kotaemon/kotaemon/llms/prompts/template.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import warnings
|
||||
from string import Formatter
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""
|
||||
Base class for prompt templates.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str, ignore_invalid=True):
|
||||
template = template
|
||||
formatter = Formatter()
|
||||
parsed_template = list(formatter.parse(template))
|
||||
|
||||
placeholders = set()
|
||||
for _, key, _, _ in parsed_template:
|
||||
if key is None:
|
||||
continue
|
||||
if not key.isidentifier():
|
||||
if ignore_invalid:
|
||||
warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Placeholder name must be a valid Python identifier, found:"
|
||||
f" {key}."
|
||||
)
|
||||
placeholders.add(key)
|
||||
|
||||
self.template = template
|
||||
self.placeholders = placeholders
|
||||
self.__formatter = formatter
|
||||
self.__parsed_template = parsed_template
|
||||
|
||||
def check_missing_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check if all the placeholders in the template are set.
|
||||
|
||||
This function checks if all the expected placeholders in the template are set as
|
||||
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||
is raised with the names of the missing keys.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
missing_keys = self.placeholders.difference(kwargs.keys())
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing keys in template: {','.join(missing_keys)}")
|
||||
|
||||
def check_redundant_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check if all the placeholders in the template are set.
|
||||
|
||||
This function checks if all the expected placeholders in the template are set as
|
||||
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||
is raised with the names of the missing keys.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
provided_keys = set(kwargs.keys())
|
||||
redundant_keys = provided_keys - self.placeholders
|
||||
|
||||
if redundant_keys:
|
||||
warnings.warn(
|
||||
f"Keys provided but not in template: {','.join(redundant_keys)}",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def populate(self, **kwargs) -> str:
|
||||
"""
|
||||
Strictly populate the template with the given keyword arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: The keyword arguments to populate the template.
|
||||
Each keyword corresponds to a placeholder in the template.
|
||||
|
||||
Returns:
|
||||
The populated template.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unknown placeholder is provided.
|
||||
"""
|
||||
self.check_missing_kwargs(**kwargs)
|
||||
|
||||
return self.partial_populate(**kwargs)
|
||||
|
||||
def partial_populate(self, **kwargs):
|
||||
"""
|
||||
Partially populate the template with the given keyword arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: The keyword arguments to populate the template.
|
||||
Each keyword corresponds to a placeholder in the template.
|
||||
|
||||
Returns:
|
||||
str: The populated template.
|
||||
"""
|
||||
self.check_redundant_kwargs(**kwargs)
|
||||
|
||||
prompt = []
|
||||
for literal_text, field_name, format_spec, conversion in self.__parsed_template:
|
||||
prompt.append(literal_text)
|
||||
|
||||
if field_name is None:
|
||||
continue
|
||||
|
||||
if field_name not in kwargs:
|
||||
if conversion:
|
||||
value = f"{{{field_name}}}!{conversion}:{format_spec}"
|
||||
else:
|
||||
value = f"{{{field_name}:{format_spec}}}"
|
||||
else:
|
||||
value = kwargs[field_name]
|
||||
if conversion is not None:
|
||||
value = self.__formatter.convert_field(value, conversion)
|
||||
if format_spec is not None:
|
||||
value = self.__formatter.format_field(value, format_spec)
|
||||
|
||||
prompt.append(value)
|
||||
|
||||
return "".join(prompt)
|
||||
|
||||
def __add__(self, other):
|
||||
"""
|
||||
Create a new PromptTemplate object by concatenating the template of the current
|
||||
object with the template of another PromptTemplate object.
|
||||
|
||||
Parameters:
|
||||
other (PromptTemplate): Another PromptTemplate object.
|
||||
|
||||
Returns:
|
||||
PromptTemplate: A new PromptTemplate object with the concatenated templates.
|
||||
"""
|
||||
return PromptTemplate(self.template + "\n" + other.template)
|
14
libs/kotaemon/kotaemon/loaders/__init__.py
Normal file
14
libs/kotaemon/kotaemon/loaders/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .base import AutoReader, DirectoryReader
|
||||
from .excel_loader import PandasExcelReader
|
||||
from .mathpix_loader import MathpixPDFReader
|
||||
from .ocr_loader import OCRReader
|
||||
from .unstructured_loader import UnstructuredReader
|
||||
|
||||
__all__ = [
|
||||
"AutoReader",
|
||||
"PandasExcelReader",
|
||||
"MathpixPDFReader",
|
||||
"OCRReader",
|
||||
"DirectoryReader",
|
||||
"UnstructuredReader",
|
||||
]
|
65
libs/kotaemon/kotaemon/loaders/base.py
Normal file
65
libs/kotaemon/kotaemon/loaders/base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Type, Union
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
from llama_index import SimpleDirectoryReader, download_loader
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
|
||||
class AutoReader(BaseComponent):
|
||||
"""General auto reader for a variety of files. (based on llama-hub)"""
|
||||
|
||||
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
|
||||
"""Init reader using string identifier or class name from llama-hub"""
|
||||
|
||||
if isinstance(reader_type, str):
|
||||
self._reader = download_loader(reader_type)()
|
||||
else:
|
||||
self._reader = reader_type()
|
||||
super().__init__()
|
||||
|
||||
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
||||
documents = self._reader.load_data(file=file, **kwargs)
|
||||
|
||||
# convert Document to new base class from kotaemon
|
||||
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||
return converted_documents
|
||||
|
||||
def run(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
||||
return self.load_data(file=file, **kwargs)
|
||||
|
||||
|
||||
class LIBaseReader(BaseComponent):
|
||||
_reader_class: Type[BaseReader]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self._reader_class is None:
|
||||
raise AttributeError(
|
||||
"Require `_reader_class` to set a BaseReader class from LlamarIndex"
|
||||
)
|
||||
|
||||
self._reader = self._reader_class(*args, **kwargs)
|
||||
super().__init__()
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
return setattr(self._reader, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._reader, name)
|
||||
|
||||
def load_data(self, *args, **kwargs: Any) -> List[Document]:
|
||||
documents = self._reader.load_data(*args, **kwargs)
|
||||
|
||||
# convert Document to new base class from kotaemon
|
||||
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||
return converted_documents
|
||||
|
||||
def run(self, *args, **kwargs: Any) -> List[Document]:
|
||||
return self.load_data(*args, **kwargs)
|
||||
|
||||
|
||||
class DirectoryReader(LIBaseReader):
|
||||
_reader_class = SimpleDirectoryReader
|
99
libs/kotaemon/kotaemon/loaders/excel_loader.py
Normal file
99
libs/kotaemon/kotaemon/loaders/excel_loader.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Pandas Excel reader.
|
||||
|
||||
Pandas parser for .xlsx files.
|
||||
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
|
||||
class PandasExcelReader(BaseReader):
|
||||
r"""Pandas-based CSV parser.
|
||||
|
||||
Parses CSVs using the separator detection from Pandas `read_csv` function.
|
||||
If special parameters are required, use the `pandas_config` dict.
|
||||
|
||||
Args:
|
||||
|
||||
pandas_config (dict): Options for the `pandas.read_excel` function call.
|
||||
Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html
|
||||
for more information. Set to empty dict by default,
|
||||
this means defaults will be used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
pandas_config: Optional[dict] = None,
|
||||
row_joiner: str = "\n",
|
||||
col_joiner: str = " ",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pandas_config = pandas_config or {}
|
||||
self._row_joiner = row_joiner if row_joiner else "\n"
|
||||
self._col_joiner = col_joiner if col_joiner else " "
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
include_sheetname: bool = False,
|
||||
sheet_name: Optional[Union[str, int, list]] = None,
|
||||
**kwargs,
|
||||
) -> List[Document]:
|
||||
"""Parse file and extract values from a specific column.
|
||||
|
||||
Args:
|
||||
file (Path): The path to the Excel file to read.
|
||||
include_sheetname (bool): Whether to include the sheet name in the output.
|
||||
sheet_name (Union[str, int, None]): The specific sheet to read from,
|
||||
default is None which reads all sheets.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of`Document objects containing the
|
||||
values from the specified column in the Excel file.
|
||||
"""
|
||||
import itertools
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"install pandas using `pip3 install pandas` to use this loader"
|
||||
)
|
||||
|
||||
if sheet_name is not None:
|
||||
sheet_name = (
|
||||
[sheet_name] if not isinstance(sheet_name, list) else sheet_name
|
||||
)
|
||||
|
||||
dfs = pd.read_excel(file, sheet_name=sheet_name, **self._pandas_config)
|
||||
sheet_names = dfs.keys()
|
||||
df_sheets = []
|
||||
|
||||
for key in sheet_names:
|
||||
sheet = []
|
||||
if include_sheetname:
|
||||
sheet.append([key])
|
||||
sheet.extend(dfs[key].values.astype(str).tolist())
|
||||
df_sheets.append(sheet)
|
||||
|
||||
text_list = list(
|
||||
itertools.chain.from_iterable(df_sheets)
|
||||
) # flatten list of lists
|
||||
|
||||
output = [
|
||||
Document(
|
||||
text=self._row_joiner.join(
|
||||
self._col_joiner.join(sublist) for sublist in text_list
|
||||
),
|
||||
metadata={"source": file.stem},
|
||||
)
|
||||
]
|
||||
|
||||
return output
|
174
libs/kotaemon/kotaemon/loaders/mathpix_loader.py
Normal file
174
libs/kotaemon/kotaemon/loaders/mathpix_loader.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
from kotaemon.base import Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
from .utils.table import parse_markdown_text_to_tables, strip_special_chars_markdown
|
||||
|
||||
|
||||
# MathpixPDFLoader implementation taken largely from Daniel Gross's:
|
||||
# https://gist.github.com/danielgross/3ab4104e14faccc12b49200843adab21
|
||||
class MathpixPDFReader(BaseReader):
|
||||
"""Load `PDF` files using `Mathpix` service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processed_file_format: str = "md",
|
||||
max_wait_time_seconds: int = 500,
|
||||
should_clean_pdf: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with a file path.
|
||||
|
||||
Args:
|
||||
processed_file_format: a format of the processed file. Default is "mmd".
|
||||
max_wait_time_seconds: a maximum time to wait for the response from
|
||||
the server. Default is 500.
|
||||
should_clean_pdf: a flag to clean the PDF file. Default is False.
|
||||
**kwargs: additional keyword arguments.
|
||||
"""
|
||||
self.mathpix_api_key = get_from_dict_or_env(
|
||||
kwargs, "mathpix_api_key", "MATHPIX_API_KEY", default="empty"
|
||||
)
|
||||
self.mathpix_api_id = get_from_dict_or_env(
|
||||
kwargs, "mathpix_api_id", "MATHPIX_API_ID", default="empty"
|
||||
)
|
||||
self.processed_file_format = processed_file_format
|
||||
self.max_wait_time_seconds = max_wait_time_seconds
|
||||
self.should_clean_pdf = should_clean_pdf
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def _mathpix_headers(self) -> Dict[str, str]:
|
||||
return {"app_id": self.mathpix_api_id, "app_key": self.mathpix_api_key}
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return "https://api.mathpix.com/v3/pdf"
|
||||
|
||||
@property
|
||||
def data(self) -> dict:
|
||||
options = {
|
||||
"conversion_formats": {self.processed_file_format: True},
|
||||
"enable_tables_fallback": True,
|
||||
}
|
||||
return {"options_json": json.dumps(options)}
|
||||
|
||||
def send_pdf(self, file_path) -> str:
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": f}
|
||||
response = requests.post(
|
||||
self.url, headers=self._mathpix_headers, files=files, data=self.data
|
||||
)
|
||||
response_data = response.json()
|
||||
if "pdf_id" in response_data:
|
||||
pdf_id = response_data["pdf_id"]
|
||||
return pdf_id
|
||||
else:
|
||||
raise ValueError("Unable to send PDF to Mathpix.")
|
||||
|
||||
def wait_for_processing(self, pdf_id: str) -> None:
|
||||
"""Wait for processing to complete.
|
||||
|
||||
Args:
|
||||
pdf_id: a PDF id.
|
||||
|
||||
Returns: None
|
||||
"""
|
||||
url = self.url + "/" + pdf_id
|
||||
for _ in range(0, self.max_wait_time_seconds, 5):
|
||||
response = requests.get(url, headers=self._mathpix_headers)
|
||||
response_data = response.json()
|
||||
status = response_data.get("status", None)
|
||||
|
||||
if status == "completed":
|
||||
return
|
||||
elif status == "error":
|
||||
raise ValueError("Unable to retrieve PDF from Mathpix")
|
||||
else:
|
||||
print(response_data)
|
||||
print(url)
|
||||
time.sleep(5)
|
||||
raise TimeoutError
|
||||
|
||||
def get_processed_pdf(self, pdf_id: str) -> str:
|
||||
self.wait_for_processing(pdf_id)
|
||||
url = f"{self.url}/{pdf_id}.{self.processed_file_format}"
|
||||
response = requests.get(url, headers=self._mathpix_headers)
|
||||
return response.content.decode("utf-8")
|
||||
|
||||
def clean_pdf(self, contents: str) -> str:
|
||||
"""Clean the PDF file.
|
||||
|
||||
Args:
|
||||
contents: a PDF file contents.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
contents = "\n".join(
|
||||
[line for line in contents.split("\n") if not line.startswith("![]")]
|
||||
)
|
||||
# replace \section{Title} with # Title
|
||||
contents = contents.replace("\\section{", "# ")
|
||||
# replace the "\" slash that Mathpix adds to escape $, %, (, etc.
|
||||
|
||||
# http:// or https:// followed by anything but a closing paren
|
||||
url_regex = "http[s]?://[^)]+"
|
||||
markup_regex = r"\[]\(\s*({0})\s*\)".format(url_regex)
|
||||
contents = (
|
||||
contents.replace(r"\$", "$")
|
||||
.replace(r"\%", "%")
|
||||
.replace(r"\(", "(")
|
||||
.replace(r"\)", ")")
|
||||
.replace("$\\begin{array}", "")
|
||||
.replace("\\end{array}$", "")
|
||||
.replace("\\\\", "")
|
||||
.replace("\\text", "")
|
||||
.replace("}", "")
|
||||
.replace("{", "")
|
||||
.replace("\\mathrm", "")
|
||||
)
|
||||
contents = re.sub(markup_regex, "", contents)
|
||||
return contents
|
||||
|
||||
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
content = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
pdf_id = self.send_pdf(file_path)
|
||||
content = self.get_processed_pdf(pdf_id)
|
||||
|
||||
if self.should_clean_pdf:
|
||||
content = self.clean_pdf(content)
|
||||
tables, texts = parse_markdown_text_to_tables(content)
|
||||
documents = []
|
||||
for table in tables:
|
||||
text = strip_special_chars_markdown(table)
|
||||
metadata = {
|
||||
"source": file_path.name,
|
||||
"table_origin": table,
|
||||
"type": "table",
|
||||
}
|
||||
documents.append(
|
||||
Document(
|
||||
text=text,
|
||||
metadata=metadata,
|
||||
metadata_template="",
|
||||
metadata_seperator="",
|
||||
)
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
metadata = {"source": file_path.name, "type": "text"}
|
||||
documents.append(Document(text=text, metadata=metadata))
|
||||
|
||||
return documents
|
102
libs/kotaemon/kotaemon/loaders/ocr_loader.py
Normal file
102
libs/kotaemon/kotaemon/loaders/ocr_loader.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from kotaemon.base import Document
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
from .utils.pdf_ocr import parse_ocr_output, read_pdf_unstructured
|
||||
from .utils.table import strip_special_chars_markdown
|
||||
|
||||
DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
|
||||
|
||||
|
||||
class OCRReader(BaseReader):
|
||||
def __init__(self, endpoint: str = DEFAULT_OCR_ENDPOINT, use_ocr=True):
|
||||
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)
|
||||
|
||||
Args:
|
||||
endpoint: URL to FullOCR endpoint. Defaults to OCR_ENDPOINT.
|
||||
use_ocr: whether to use OCR to read text
|
||||
(e.g: from images, tables) in the PDF
|
||||
"""
|
||||
super().__init__()
|
||||
self.ocr_endpoint = endpoint
|
||||
self.use_ocr = use_ocr
|
||||
|
||||
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
|
||||
"""Load data using OCR reader
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to PDF file
|
||||
debug_path (Path): Path to store debug image output
|
||||
artifact_path (Path): Path to OCR endpoints artifacts directory
|
||||
|
||||
Returns:
|
||||
List[Document]: list of documents extracted from the PDF file
|
||||
"""
|
||||
file_path = Path(file_path).resolve()
|
||||
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
debug_path = kwargs.pop("debug_path", None)
|
||||
artifact_path = kwargs.pop("artifact_path", None)
|
||||
|
||||
# read PDF through normal reader (unstructured)
|
||||
pdf_page_items = read_pdf_unstructured(file_path)
|
||||
# merge PDF text output with OCR output
|
||||
tables, texts = parse_ocr_output(
|
||||
ocr_results,
|
||||
pdf_page_items,
|
||||
debug_path=debug_path,
|
||||
artifact_path=artifact_path,
|
||||
)
|
||||
|
||||
# create output Document with metadata from table
|
||||
documents = [
|
||||
Document(
|
||||
text=strip_special_chars_markdown(table_text),
|
||||
metadata={
|
||||
"table_origin": table_text,
|
||||
"type": "table",
|
||||
"page_label": page_id + 1,
|
||||
"source": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"filename": str(file_path),
|
||||
},
|
||||
metadata_template="",
|
||||
metadata_seperator="",
|
||||
)
|
||||
for page_id, table_text in tables
|
||||
]
|
||||
# create Document from non-table text
|
||||
documents.extend(
|
||||
[
|
||||
Document(
|
||||
text=non_table_text,
|
||||
metadata={
|
||||
"page_label": page_id + 1,
|
||||
"source": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"filename": str(file_path),
|
||||
},
|
||||
)
|
||||
for page_id, non_table_text in texts
|
||||
]
|
||||
)
|
||||
|
||||
return documents
|
110
libs/kotaemon/kotaemon/loaders/unstructured_loader.py
Normal file
110
libs/kotaemon/kotaemon/loaders/unstructured_loader.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Unstructured file reader.
|
||||
|
||||
A parser for unstructured text files using Unstructured.io.
|
||||
Supports .txt, .docx, .pptx, .jpg, .png, .eml, .html, and .pdf documents.
|
||||
|
||||
To use .doc and .xls parser, install
|
||||
|
||||
sudo apt-get install -y libmagic-dev poppler-utils libreoffice
|
||||
pip install xlrd
|
||||
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from kotaemon.base import Document
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
|
||||
class UnstructuredReader(BaseReader):
|
||||
"""General unstructured text reader for a variety of files."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Init params."""
|
||||
super().__init__(*args) # not passing kwargs to parent bc it cannot accept it
|
||||
|
||||
self.api = False # we default to local
|
||||
if "url" in kwargs:
|
||||
self.server_url = str(kwargs["url"])
|
||||
self.api = True # is url was set, switch to api
|
||||
else:
|
||||
self.server_url = "http://localhost:8000"
|
||||
|
||||
if "api" in kwargs:
|
||||
self.api = kwargs["api"]
|
||||
|
||||
self.api_key = ""
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs["api_key"]
|
||||
|
||||
""" Loads data using Unstructured.io
|
||||
|
||||
Depending on the construction if url is set or api = True
|
||||
it'll parse file using API call, else parse it locally
|
||||
additional_metadata is extended by the returned metadata if
|
||||
split_documents is True
|
||||
|
||||
Returns list of documents
|
||||
"""
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
additional_metadata: Optional[Dict] = None,
|
||||
split_documents: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> List[Document]:
|
||||
"""If api is set, parse through api"""
|
||||
file_path_str = str(file)
|
||||
if self.api:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(
|
||||
filename=file_path_str,
|
||||
api_key=self.api_key,
|
||||
api_url=self.server_url + "/general/v0/general",
|
||||
)
|
||||
else:
|
||||
"""Parse file locally"""
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
elements = partition(filename=file_path_str)
|
||||
|
||||
""" Process elements """
|
||||
docs = []
|
||||
file_name = Path(file).name
|
||||
file_path = str(Path(file).resolve())
|
||||
if split_documents:
|
||||
for node in elements:
|
||||
metadata = {"file_name": file_name, "file_path": file_path}
|
||||
if hasattr(node, "metadata"):
|
||||
"""Load metadata fields"""
|
||||
for field, val in vars(node.metadata).items():
|
||||
if field == "_known_field_names":
|
||||
continue
|
||||
# removing coordinates because it does not serialize
|
||||
# and dont want to bother with it
|
||||
if field == "coordinates":
|
||||
continue
|
||||
# removing bc it might cause interference
|
||||
if field == "parent_id":
|
||||
continue
|
||||
metadata[field] = val
|
||||
|
||||
if additional_metadata is not None:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
metadata["file_name"] = file_name
|
||||
docs.append(Document(text=node.text, metadata=metadata))
|
||||
|
||||
else:
|
||||
text_chunks = [" ".join(str(el).split()) for el in elements]
|
||||
metadata = {"file_name": file_name, "file_path": file_path}
|
||||
|
||||
if additional_metadata is not None:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
# Create a single document by joining all the texts
|
||||
docs.append(Document(text="\n\n".join(text_chunks), metadata=metadata))
|
||||
|
||||
return docs
|
0
libs/kotaemon/kotaemon/loaders/utils/__init__.py
Normal file
0
libs/kotaemon/kotaemon/loaders/utils/__init__.py
Normal file
144
libs/kotaemon/kotaemon/loaders/utils/box.py
Normal file
144
libs/kotaemon/kotaemon/loaders/utils/box.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def bbox_to_points(box: List[int]):
|
||||
"""Convert bounding box to list of points"""
|
||||
x1, y1, x2, y2 = box
|
||||
return [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
|
||||
|
||||
def points_to_bbox(points: List[Tuple[int, int]]):
|
||||
"""Convert list of points to bounding box"""
|
||||
all_x = [p[0] for p in points]
|
||||
all_y = [p[1] for p in points]
|
||||
return [min(all_x), min(all_y), max(all_x), max(all_y)]
|
||||
|
||||
|
||||
def scale_points(points: List[Tuple[int, int]], scale_factor: float = 1.0):
|
||||
"""Scale points by a scale factor"""
|
||||
return [(int(pos[0] * scale_factor), int(pos[1] * scale_factor)) for pos in points]
|
||||
|
||||
|
||||
def union_points(points: List[Tuple[int, int]]):
|
||||
"""Return union bounding box of list of points"""
|
||||
all_x = [p[0] for p in points]
|
||||
all_y = [p[1] for p in points]
|
||||
bbox = (min(all_x), min(all_y), max(all_x), max(all_y))
|
||||
return bbox
|
||||
|
||||
|
||||
def scale_box(box: List[int], scale_factor: float = 1.0):
|
||||
"""Scale box by a scale factor"""
|
||||
return [int(pos * scale_factor) for pos in box]
|
||||
|
||||
|
||||
def box_h(box: List[int]):
|
||||
"Return box height"
|
||||
return box[3] - box[1]
|
||||
|
||||
|
||||
def box_w(box: List[int]):
|
||||
"Return box width"
|
||||
return box[2] - box[0]
|
||||
|
||||
|
||||
def box_area(box: List[int]):
|
||||
"Return box area"
|
||||
x1, y1, x2, y2 = box
|
||||
return (x2 - x1) * (y2 - y1)
|
||||
|
||||
|
||||
def get_rect_iou(gt_box: List[tuple], pd_box: List[tuple], iou_type=0) -> int:
|
||||
"""Intersection over union on layout rectangle
|
||||
|
||||
Args:
|
||||
gt_box: List[tuple]
|
||||
A list contains bounding box coordinates of ground truth
|
||||
pd_box: List[tuple]
|
||||
A list contains bounding box coordinates of prediction
|
||||
iou_type: int
|
||||
0: intersection / union, normal IOU
|
||||
1: intersection / min(areas), useful when boxes are under/over-segmented
|
||||
|
||||
Input format: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
Annotation for each element in bbox:
|
||||
(x1, y1) (x2, y1)
|
||||
+-------+
|
||||
| |
|
||||
| |
|
||||
+-------+
|
||||
(x1, y2) (x2, y2)
|
||||
|
||||
Returns:
|
||||
Intersection over union value
|
||||
"""
|
||||
|
||||
assert iou_type in [0, 1], "Only support 0: origin iou, 1: intersection / min(area)"
|
||||
|
||||
# determine the (x, y)-coordinates of the intersection rectangle
|
||||
# gt_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
# pd_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
x_left = max(gt_box[0][0], pd_box[0][0])
|
||||
y_top = max(gt_box[0][1], pd_box[0][1])
|
||||
x_right = min(gt_box[2][0], pd_box[2][0])
|
||||
y_bottom = min(gt_box[2][1], pd_box[2][1])
|
||||
|
||||
# compute the area of intersection rectangle
|
||||
interArea = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
||||
|
||||
# compute the area of both the prediction and ground-truth
|
||||
# rectangles
|
||||
gt_area = (gt_box[2][0] - gt_box[0][0]) * (gt_box[2][1] - gt_box[0][1])
|
||||
pd_area = (pd_box[2][0] - pd_box[0][0]) * (pd_box[2][1] - pd_box[0][1])
|
||||
|
||||
# compute the intersection over union by taking the intersection
|
||||
# area and dividing it by the sum of prediction + ground-truth
|
||||
# areas - the intersection area
|
||||
if iou_type == 0:
|
||||
iou = interArea / float(gt_area + pd_area - interArea)
|
||||
elif iou_type == 1:
|
||||
iou = interArea / max(min(gt_area, pd_area), 1)
|
||||
|
||||
# return the intersection over union value
|
||||
return iou
|
||||
|
||||
|
||||
def sort_funsd_reading_order(lines: List[dict], box_key_name: str = "box"):
|
||||
"""Sort cell list to create the right reading order using their locations
|
||||
|
||||
Args:
|
||||
lines: list of cells to sort
|
||||
|
||||
Returns:
|
||||
a list of cell lists in the right reading order that contain
|
||||
no key or start with a key and contain no other key
|
||||
"""
|
||||
sorted_list = []
|
||||
|
||||
if len(lines) == 0:
|
||||
return lines
|
||||
|
||||
while len(lines) > 1:
|
||||
topleft_line = lines[0]
|
||||
for line in lines[1:]:
|
||||
topleft_line_pos = topleft_line[box_key_name]
|
||||
topleft_line_center_y = (topleft_line_pos[1] + topleft_line_pos[3]) / 2
|
||||
x1, y1, x2, y2 = line[box_key_name]
|
||||
box_center_x = (x1 + x2) / 2
|
||||
box_center_y = (y1 + y2) / 2
|
||||
cell_h = y2 - y1
|
||||
if box_center_y <= topleft_line_center_y - cell_h / 2:
|
||||
topleft_line = line
|
||||
continue
|
||||
if (
|
||||
box_center_x < topleft_line_pos[2]
|
||||
and box_center_y < topleft_line_pos[3]
|
||||
):
|
||||
topleft_line = line
|
||||
continue
|
||||
sorted_list.append(topleft_line)
|
||||
lines.remove(topleft_line)
|
||||
|
||||
sorted_list.append(lines[0])
|
||||
|
||||
return sorted_list
|
294
libs/kotaemon/kotaemon/loaders/utils/pdf_ocr.py
Normal file
294
libs/kotaemon/kotaemon/loaders/utils/pdf_ocr.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from .box import (
|
||||
bbox_to_points,
|
||||
box_area,
|
||||
box_h,
|
||||
box_w,
|
||||
get_rect_iou,
|
||||
points_to_bbox,
|
||||
scale_box,
|
||||
scale_points,
|
||||
sort_funsd_reading_order,
|
||||
union_points,
|
||||
)
|
||||
from .table import table_cells_to_markdown
|
||||
|
||||
IOU_THRES = 0.5
|
||||
PADDING_THRES = 1.1
|
||||
|
||||
|
||||
def read_pdf_unstructured(input_path: Union[Path, str]):
|
||||
"""Convert PDF from specified path to list of text items with
|
||||
location information
|
||||
|
||||
Args:
|
||||
input_path: path to input file
|
||||
|
||||
Returns:
|
||||
Dict page_number: list of text boxes
|
||||
"""
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install unstructured PDF reader `pip install unstructured[pdf]`"
|
||||
)
|
||||
|
||||
page_items = defaultdict(list)
|
||||
items = partition(input_path)
|
||||
for item in items:
|
||||
page_number = item.metadata.page_number
|
||||
bbox = points_to_bbox(item.metadata.coordinates.points)
|
||||
coord_system = item.metadata.coordinates.system
|
||||
max_w, max_h = coord_system.width, coord_system.height
|
||||
page_items[page_number - 1].append(
|
||||
{
|
||||
"text": item.text,
|
||||
"box": bbox,
|
||||
"location": bbox_to_points(bbox),
|
||||
"page_shape": (max_w, max_h),
|
||||
}
|
||||
)
|
||||
|
||||
return page_items
|
||||
|
||||
|
||||
def merge_ocr_and_pdf_texts(
|
||||
ocr_list: List[dict], pdf_text_list: List[dict], debug_info=None
|
||||
):
|
||||
"""Merge PDF and OCR text using IOU overlapping location
|
||||
Args:
|
||||
ocr_list: List of OCR items {"text", "box", "location"}
|
||||
pdf_text_list: List of PDF items {"text", "box", "location"}
|
||||
|
||||
Returns:
|
||||
Combined list of PDF text and non-overlap OCR text
|
||||
"""
|
||||
not_matched_ocr = []
|
||||
|
||||
# check for debug info
|
||||
if debug_info is not None:
|
||||
cv2, debug_im = debug_info
|
||||
|
||||
for ocr_item in ocr_list:
|
||||
matched = False
|
||||
for pdf_item in pdf_text_list:
|
||||
if (
|
||||
get_rect_iou(ocr_item["location"], pdf_item["location"], iou_type=1)
|
||||
> IOU_THRES
|
||||
):
|
||||
matched = True
|
||||
break
|
||||
|
||||
color = (255, 0, 0)
|
||||
if not matched:
|
||||
ocr_item["matched"] = False
|
||||
not_matched_ocr.append(ocr_item)
|
||||
color = (0, 255, 255)
|
||||
|
||||
if debug_info is not None:
|
||||
cv2.rectangle(
|
||||
debug_im,
|
||||
ocr_item["location"][0],
|
||||
ocr_item["location"][2],
|
||||
color=color,
|
||||
thickness=1,
|
||||
)
|
||||
|
||||
if debug_info is not None:
|
||||
for pdf_item in pdf_text_list:
|
||||
cv2.rectangle(
|
||||
debug_im,
|
||||
pdf_item["location"][0],
|
||||
pdf_item["location"][2],
|
||||
color=(0, 255, 0),
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
return pdf_text_list + not_matched_ocr
|
||||
|
||||
|
||||
def merge_table_cell_and_ocr(
|
||||
table_list: List[dict], ocr_list: List[dict], pdf_list: List[dict], debug_info=None
|
||||
):
|
||||
"""Merge table items with OCR text using IOU overlapping location
|
||||
Args:
|
||||
table_list: List of table items
|
||||
"type": ("table", "cell", "text"), "text", "box", "location"}
|
||||
ocr_list: List of OCR items {"text", "box", "location"}
|
||||
pdf_list: List of PDF items {"text", "box", "location"}
|
||||
|
||||
Returns:
|
||||
all_table_cells: List of tables, each of table is represented
|
||||
by list of cells with combined text from OCR
|
||||
not_matched_items: List of PDF text which is not overlapped by table region
|
||||
"""
|
||||
# check for debug info
|
||||
if debug_info is not None:
|
||||
cv2, debug_im = debug_info
|
||||
|
||||
cell_list = [item for item in table_list if item["type"] == "cell"]
|
||||
table_list = [item for item in table_list if item["type"] == "table"]
|
||||
|
||||
# sort table by area
|
||||
table_list = sorted(table_list, key=lambda item: box_area(item["bbox"]))
|
||||
|
||||
all_tables = []
|
||||
matched_pdf_ids = []
|
||||
matched_cell_ids = []
|
||||
|
||||
for table in table_list:
|
||||
if debug_info is not None:
|
||||
cv2.rectangle(
|
||||
debug_im,
|
||||
table["location"][0],
|
||||
table["location"][2],
|
||||
color=[0, 0, 255],
|
||||
thickness=5,
|
||||
)
|
||||
|
||||
cur_table_cells = []
|
||||
for cell_id, cell in enumerate(cell_list):
|
||||
if cell_id in matched_cell_ids:
|
||||
continue
|
||||
|
||||
if get_rect_iou(
|
||||
table["location"], cell["location"], iou_type=1
|
||||
) > IOU_THRES and box_area(table["bbox"]) > box_area(cell["bbox"]):
|
||||
color = [128, 0, 128]
|
||||
# cell matched to table
|
||||
for item_list, item_type in [(pdf_list, "pdf"), (ocr_list, "ocr")]:
|
||||
cell["ocr"] = []
|
||||
for item_id, item in enumerate(item_list):
|
||||
if item_type == "pdf" and item_id in matched_pdf_ids:
|
||||
continue
|
||||
if (
|
||||
get_rect_iou(item["location"], cell["location"], iou_type=1)
|
||||
> IOU_THRES
|
||||
):
|
||||
cell["ocr"].append(item)
|
||||
if item_type == "pdf":
|
||||
matched_pdf_ids.append(item_id)
|
||||
|
||||
if len(cell["ocr"]) > 0:
|
||||
# check if union of matched ocr does
|
||||
# not extend over cell boundary,
|
||||
# if True, continue to use OCR_list to match
|
||||
all_box_points_in_cell = []
|
||||
for item in cell["ocr"]:
|
||||
all_box_points_in_cell.extend(item["location"])
|
||||
union_box = union_points(all_box_points_in_cell)
|
||||
cell_okay = (
|
||||
box_h(union_box) <= box_h(cell["bbox"]) * PADDING_THRES
|
||||
and box_w(union_box) <= box_w(cell["bbox"]) * PADDING_THRES
|
||||
)
|
||||
else:
|
||||
cell_okay = False
|
||||
|
||||
if cell_okay:
|
||||
if item_type == "pdf":
|
||||
color = [255, 0, 255]
|
||||
break
|
||||
|
||||
if debug_info is not None:
|
||||
cv2.rectangle(
|
||||
debug_im,
|
||||
cell["location"][0],
|
||||
cell["location"][2],
|
||||
color=color,
|
||||
thickness=3,
|
||||
)
|
||||
|
||||
matched_cell_ids.append(cell_id)
|
||||
cur_table_cells.append(cell)
|
||||
|
||||
all_tables.append(cur_table_cells)
|
||||
|
||||
not_matched_items = [
|
||||
item for _id, item in enumerate(pdf_list) if _id not in matched_pdf_ids
|
||||
]
|
||||
if debug_info is not None:
|
||||
for item in not_matched_items:
|
||||
cv2.rectangle(
|
||||
debug_im,
|
||||
item["location"][0],
|
||||
item["location"][2],
|
||||
color=[128, 128, 128],
|
||||
thickness=3,
|
||||
)
|
||||
|
||||
return all_tables, not_matched_items
|
||||
|
||||
|
||||
def parse_ocr_output(
|
||||
ocr_page_items: List[dict],
|
||||
pdf_page_items: Dict[int, List[dict]],
|
||||
artifact_path: Optional[str] = None,
|
||||
debug_path: Optional[str] = None,
|
||||
):
|
||||
"""Main function to combine OCR output and PDF text to
|
||||
form list of table / non-table regions
|
||||
Args:
|
||||
ocr_page_items: List of OCR items by page
|
||||
pdf_page_items: Dict of PDF texts (page number as key)
|
||||
debug_path: If specified, use OpenCV to plot debug image and save to debug_path
|
||||
"""
|
||||
all_tables = []
|
||||
all_texts = []
|
||||
|
||||
for page_id, page in enumerate(ocr_page_items):
|
||||
ocr_list = page["json"]["ocr"]
|
||||
table_list = page["json"]["table"]
|
||||
page_shape = page["image_shape"]
|
||||
pdf_item_list = pdf_page_items[page_id]
|
||||
|
||||
# create bbox additional information
|
||||
for item in ocr_list:
|
||||
item["box"] = points_to_bbox(item["location"])
|
||||
|
||||
# re-scale pdf items according to new image size
|
||||
for item in pdf_item_list:
|
||||
scale_factor = page_shape[0] / item["page_shape"][0]
|
||||
item["box"] = scale_box(item["box"], scale_factor=scale_factor)
|
||||
item["location"] = scale_points(item["location"], scale_factor=scale_factor)
|
||||
|
||||
# if using debug mode, openCV must be installed
|
||||
if debug_path and artifact_path is not None:
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install openCV first to use OCRReader debug mode"
|
||||
)
|
||||
image_path = Path(artifact_path) / page["image"]
|
||||
image = cv2.imread(str(image_path))
|
||||
debug_info = (cv2, image)
|
||||
else:
|
||||
debug_info = None
|
||||
|
||||
new_pdf_list = merge_ocr_and_pdf_texts(
|
||||
ocr_list, pdf_item_list, debug_info=debug_info
|
||||
)
|
||||
|
||||
# sort by reading order
|
||||
ocr_list = sort_funsd_reading_order(ocr_list)
|
||||
new_pdf_list = sort_funsd_reading_order(new_pdf_list)
|
||||
|
||||
all_table_cells, non_table_text_list = merge_table_cell_and_ocr(
|
||||
table_list, ocr_list, new_pdf_list, debug_info=debug_info
|
||||
)
|
||||
|
||||
table_texts = [table_cells_to_markdown(cells) for cells in all_table_cells]
|
||||
all_tables.extend([(page_id, text) for text in table_texts])
|
||||
all_texts.append(
|
||||
(page_id, " ".join(item["text"] for item in non_table_text_list))
|
||||
)
|
||||
|
||||
# export debug image to debug_path
|
||||
if debug_path:
|
||||
cv2.imwrite(str(Path(debug_path) / "page_{}.png".format(page_id)), image)
|
||||
|
||||
return all_tables, all_texts
|
288
libs/kotaemon/kotaemon/loaders/utils/table.py
Normal file
288
libs/kotaemon/kotaemon/loaders/utils/table.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import csv
|
||||
from io import StringIO
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .box import get_rect_iou
|
||||
|
||||
|
||||
def check_col_conflicts(
|
||||
col_a: List[str], col_b: List[str], thres: float = 0.15
|
||||
) -> bool:
|
||||
"""Check if 2 columns A and B has non-empty content in the same row
|
||||
(to be used with merge_cols)
|
||||
|
||||
Args:
|
||||
col_a: column A (list of str)
|
||||
col_b: column B (list of str)
|
||||
thres: percentage of overlapping allowed
|
||||
Returns:
|
||||
if number of overlapping greater than threshold
|
||||
"""
|
||||
num_rows = len([cell for cell in col_a if cell])
|
||||
assert len(col_a) == len(col_b)
|
||||
conflict_count = 0
|
||||
for cell_a, cell_b in zip(col_a, col_b):
|
||||
if cell_a and cell_b:
|
||||
conflict_count += 1
|
||||
return conflict_count > num_rows * thres
|
||||
|
||||
|
||||
def merge_cols(col_a: List[str], col_b: List[str]) -> List[str]:
|
||||
"""Merge column A and B if they do not have conflict rows
|
||||
|
||||
Args:
|
||||
col_a: column A (list of str)
|
||||
col_b: column B (list of str)
|
||||
Returns:
|
||||
merged column
|
||||
"""
|
||||
for r_id in range(len(col_a)):
|
||||
if col_b[r_id]:
|
||||
col_a[r_id] = col_a[r_id] + " " + col_b[r_id]
|
||||
return col_a
|
||||
|
||||
|
||||
def add_index_col(csv_rows: List[List[str]]) -> List[List[str]]:
|
||||
"""Add index column as the first column of the table csv_rows
|
||||
|
||||
Args:
|
||||
csv_rows: input table
|
||||
Returns:
|
||||
output table with index column
|
||||
"""
|
||||
new_csv_rows = [["row id"] + [""] * len(csv_rows[0])]
|
||||
for r_id, row in enumerate(csv_rows):
|
||||
new_csv_rows.append([str(r_id + 1)] + row)
|
||||
return new_csv_rows
|
||||
|
||||
|
||||
def compress_csv(csv_rows: List[List[str]]) -> List[List[str]]:
|
||||
"""Compress table csv_rows by merging sparse columns (merge_cols)
|
||||
|
||||
Args:
|
||||
csv_rows: input table
|
||||
Returns:
|
||||
output: compressed table
|
||||
"""
|
||||
csv_cols = [[r[c_id] for r in csv_rows] for c_id in range(len(csv_rows[0]))]
|
||||
to_remove_col_ids = []
|
||||
last_c_id = 0
|
||||
for c_id in range(1, len(csv_cols)):
|
||||
if not check_col_conflicts(csv_cols[last_c_id], csv_cols[c_id]):
|
||||
to_remove_col_ids.append(c_id)
|
||||
csv_cols[last_c_id] = merge_cols(csv_cols[last_c_id], csv_cols[c_id])
|
||||
else:
|
||||
last_c_id = c_id
|
||||
|
||||
csv_cols = [r for c_id, r in enumerate(csv_cols) if c_id not in to_remove_col_ids]
|
||||
csv_rows = [[c[r_id] for c in csv_cols] for r_id in range(len(csv_cols[0]))]
|
||||
return csv_rows
|
||||
|
||||
|
||||
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
||||
"""Get list of text lines belong to table regions specified by table_list
|
||||
|
||||
Args:
|
||||
ocr_list: list of OCR output in Casia format (Flax)
|
||||
table_list: list of table output in Casia format (Flax)
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
table_texts = []
|
||||
for table in table_list:
|
||||
if table["type"] != "table":
|
||||
continue
|
||||
cur_table_texts = []
|
||||
for ocr in ocr_list:
|
||||
_iou = get_rect_iou(table["location"], ocr["location"], iou_type=1)
|
||||
if _iou > 0.8:
|
||||
cur_table_texts.append(ocr["text"])
|
||||
table_texts.append(cur_table_texts)
|
||||
|
||||
return table_texts
|
||||
|
||||
|
||||
def make_markdown_table(array: List[List[str]]) -> str:
|
||||
"""Convert table rows in list format to markdown string
|
||||
|
||||
Args:
|
||||
Python list with rows of table as lists
|
||||
First element as header.
|
||||
Example Input:
|
||||
[["Name", "Age", "Height"],
|
||||
["Jake", 20, 5'10],
|
||||
["Mary", 21, 5'7]]
|
||||
Returns:
|
||||
String to put into a .md file
|
||||
"""
|
||||
array = compress_csv(array)
|
||||
array = add_index_col(array)
|
||||
markdown = "\n" + str("| ")
|
||||
|
||||
for e in array[0]:
|
||||
to_add = " " + str(e) + str(" |")
|
||||
markdown += to_add
|
||||
markdown += "\n"
|
||||
|
||||
markdown += "| "
|
||||
for i in range(len(array[0])):
|
||||
markdown += str("--- | ")
|
||||
markdown += "\n"
|
||||
|
||||
for entry in array[1:]:
|
||||
markdown += str("| ")
|
||||
for e in entry:
|
||||
to_add = str(e) + str(" | ")
|
||||
markdown += to_add
|
||||
markdown += "\n"
|
||||
|
||||
return markdown + "\n"
|
||||
|
||||
|
||||
def parse_csv_string_to_list(csv_str: str) -> List[List[str]]:
|
||||
"""Convert CSV string to list of rows
|
||||
|
||||
Args:
|
||||
csv_str: input CSV string
|
||||
|
||||
Returns:
|
||||
Output table in list format
|
||||
"""
|
||||
io = StringIO(csv_str)
|
||||
csv_reader = csv.reader(io, delimiter=",")
|
||||
rows = [row for row in csv_reader]
|
||||
return rows
|
||||
|
||||
|
||||
def format_cell(cell: str, length_limit: Optional[int] = None) -> str:
|
||||
"""Format cell content by remove redundant character and enforce length limit
|
||||
|
||||
Args:
|
||||
cell: input cell text
|
||||
length_limit: limit of text length.
|
||||
|
||||
Returns:
|
||||
new cell text
|
||||
"""
|
||||
cell = cell.replace("\n", " ")
|
||||
if length_limit:
|
||||
cell = cell[:length_limit]
|
||||
return cell
|
||||
|
||||
|
||||
def extract_tables_from_csv_string(
|
||||
csv_content: str, table_texts: List[List[str]]
|
||||
) -> Tuple[List[str], str]:
|
||||
"""Extract list of table from FullOCR output
|
||||
(csv_content) with the specified table_texts
|
||||
|
||||
Args:
|
||||
csv_content: CSV output from FullOCR pipeline
|
||||
table_texts: list of table texts extracted
|
||||
from get_table_from_ocr()
|
||||
|
||||
Returns:
|
||||
List of tables and non-text content
|
||||
"""
|
||||
rows = parse_csv_string_to_list(csv_content)
|
||||
used_row_ids = []
|
||||
table_csv_list = []
|
||||
for table in table_texts:
|
||||
cur_rows = []
|
||||
for row_id, row in enumerate(rows):
|
||||
scores = [
|
||||
any(cell in cell_reference for cell in table)
|
||||
for cell_reference in row
|
||||
if cell_reference
|
||||
]
|
||||
score = sum(scores) / len(scores)
|
||||
if score > 0.5 and row_id not in used_row_ids:
|
||||
used_row_ids.append(row_id)
|
||||
cur_rows.append([format_cell(cell) for cell in row])
|
||||
if cur_rows:
|
||||
table_csv_list.append(make_markdown_table(cur_rows))
|
||||
else:
|
||||
print("table not matched", table)
|
||||
|
||||
non_table_rows = [
|
||||
row for row_id, row in enumerate(rows) if row_id not in used_row_ids
|
||||
]
|
||||
non_table_text = "\n".join(
|
||||
" ".join(format_cell(cell) for cell in row) for row in non_table_rows
|
||||
)
|
||||
return table_csv_list, non_table_text
|
||||
|
||||
|
||||
def strip_special_chars_markdown(text: str) -> str:
|
||||
"""Strip special characters from input text in markdown table format"""
|
||||
return text.replace("|", "").replace(":---:", "").replace("---", "")
|
||||
|
||||
|
||||
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
||||
"""Convert markdown text to list of non-table spans and table spans
|
||||
|
||||
Args:
|
||||
text: input markdown text
|
||||
|
||||
Returns:
|
||||
list of table spans and non-table spans
|
||||
"""
|
||||
# init empty tables and texts list
|
||||
tables = []
|
||||
texts = []
|
||||
|
||||
# split input by line break
|
||||
lines = text.split("\n")
|
||||
cur_table = []
|
||||
cur_text: List[str] = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("|"):
|
||||
if len(cur_text) > 0:
|
||||
texts.append(cur_text)
|
||||
cur_text = []
|
||||
cur_table.append(line)
|
||||
else:
|
||||
# add new table to the list
|
||||
if len(cur_table) > 0:
|
||||
tables.append(cur_table)
|
||||
cur_table = []
|
||||
cur_text.append(line)
|
||||
|
||||
table_texts = ["\n".join(table) for table in tables]
|
||||
non_table_texts = ["\n".join(text) for text in texts]
|
||||
return table_texts, non_table_texts
|
||||
|
||||
|
||||
def table_cells_to_markdown(cells: List[dict]):
|
||||
"""Convert list of cells with attached text to Markdown table"""
|
||||
|
||||
if len(cells) == 0:
|
||||
return ""
|
||||
|
||||
all_row_ids = []
|
||||
all_col_ids = []
|
||||
for cell in cells:
|
||||
all_row_ids.extend(cell["rows"])
|
||||
all_col_ids.extend(cell["columns"])
|
||||
|
||||
num_rows, num_cols = max(all_row_ids) + 1, max(all_col_ids) + 1
|
||||
table_rows = [["" for c in range(num_cols)] for r in range(num_rows)]
|
||||
|
||||
# start filling in the grid
|
||||
for cell in cells:
|
||||
cell_text = " ".join(item["text"] for item in cell["ocr"])
|
||||
start_row_id, end_row_id = cell["rows"]
|
||||
start_col_id, end_col_id = cell["columns"]
|
||||
span_cell = end_row_id != start_row_id or end_col_id != start_col_id
|
||||
|
||||
# do not repeat long text in span cell to prevent context length issue
|
||||
if span_cell and len(cell_text.replace(" ", "")) < 20 and start_row_id > 0:
|
||||
for row in range(start_row_id, end_row_id + 1):
|
||||
for col in range(start_col_id, end_col_id + 1):
|
||||
table_rows[row][col] += cell_text + " "
|
||||
else:
|
||||
table_rows[start_row_id][start_col_id] += cell_text + " "
|
||||
|
||||
return make_markdown_table(table_rows)
|
3
libs/kotaemon/kotaemon/parsers/__init__.py
Normal file
3
libs/kotaemon/kotaemon/parsers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .regex_extractor import FirstMatchRegexExtractor, RegexExtractor
|
||||
|
||||
__all__ = ["RegexExtractor", "FirstMatchRegexExtractor"]
|
150
libs/kotaemon/kotaemon/parsers/regex_extractor.py
Normal file
150
libs/kotaemon/kotaemon/parsers/regex_extractor.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, ExtractorOutput, Param
|
||||
|
||||
|
||||
class RegexExtractor(BaseComponent):
|
||||
"""
|
||||
Simple class for extracting text from a document using a regex pattern.
|
||||
|
||||
Args:
|
||||
pattern (List[str]): The regex pattern(s) to use.
|
||||
output_map (dict, optional): A mapping from extracted text to the
|
||||
desired output. Defaults to None.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
|
||||
|
||||
pattern: list[str]
|
||||
output_map: dict[str, str] | Callable[[str], str] = Param(
|
||||
default_callback=lambda *_: {}
|
||||
)
|
||||
|
||||
def __init__(self, pattern: str | list[str], **kwargs):
|
||||
if isinstance(pattern, str):
|
||||
pattern = [pattern]
|
||||
super().__init__(pattern=pattern, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def run_raw_static(pattern: str, text: str) -> list[str]:
|
||||
"""
|
||||
Finds all non-overlapping occurrences of a pattern in a string.
|
||||
|
||||
Parameters:
|
||||
pattern (str): The regular expression pattern to search for.
|
||||
text (str): The input string to search in.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of all non-overlapping occurrences of the pattern in the
|
||||
string.
|
||||
"""
|
||||
return re.findall(pattern, text)
|
||||
|
||||
@staticmethod
|
||||
def map_output(text, output_map) -> str:
|
||||
"""
|
||||
Maps the given `text` to its corresponding value in the `output_map` dictionary.
|
||||
|
||||
Parameters:
|
||||
text (str): The input text to be mapped.
|
||||
output_map (dict): A dictionary containing mapping of input text to output
|
||||
values.
|
||||
|
||||
Returns:
|
||||
str: The corresponding value from the `output_map` if `text` is found in the
|
||||
dictionary, otherwise returns the original `text`.
|
||||
"""
|
||||
if not output_map:
|
||||
return text
|
||||
|
||||
if isinstance(output_map, dict):
|
||||
return output_map.get(text, text)
|
||||
|
||||
return output_map(text)
|
||||
|
||||
def run_raw(self, text: str) -> ExtractorOutput:
|
||||
"""
|
||||
Matches the raw text against the pattern and rans the output mapping, returning
|
||||
an instance of ExtractorOutput.
|
||||
|
||||
Args:
|
||||
text (str): The raw text to be processed.
|
||||
|
||||
Returns:
|
||||
ExtractorOutput: The processed output as a list of ExtractorOutput.
|
||||
"""
|
||||
output: list[str] = sum(
|
||||
[self.run_raw_static(p, text) for p in self.pattern], []
|
||||
)
|
||||
output = [self.map_output(text, self.output_map) for text in output]
|
||||
|
||||
return ExtractorOutput(
|
||||
text=output[0] if output else "",
|
||||
matches=output,
|
||||
metadata={"origin": "RegexExtractor"},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, text: str | list[str] | Document | list[Document]
|
||||
) -> list[ExtractorOutput]:
|
||||
"""Match the input against a pattern and return the output for each input
|
||||
|
||||
Parameters:
|
||||
text: contains the input string to be processed
|
||||
|
||||
Returns:
|
||||
A list contains the output ExtractorOutput for each input
|
||||
|
||||
Example:
|
||||
```pycon
|
||||
>>> document1 = Document(...)
|
||||
>>> document2 = Document(...)
|
||||
>>> document_batch = [document1, document2]
|
||||
>>> batch_output = self(document_batch)
|
||||
>>> print(batch_output)
|
||||
[output1_document1, output1_document2]
|
||||
```
|
||||
"""
|
||||
# TODO: this conversion seems common
|
||||
input_: list[str] = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
for item in text:
|
||||
if isinstance(item, str):
|
||||
input_.append(item)
|
||||
elif isinstance(item, Document):
|
||||
input_.append(item.text)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
output = []
|
||||
for each_input in input_:
|
||||
output.append(self.run_raw(each_input))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FirstMatchRegexExtractor(RegexExtractor):
|
||||
pattern: list[str]
|
||||
|
||||
def run_raw(self, text: str) -> ExtractorOutput:
|
||||
for p in self.pattern:
|
||||
output = self.run_raw_static(p, text)
|
||||
if output:
|
||||
output = [self.map_output(text, self.output_map) for text in output]
|
||||
return ExtractorOutput(
|
||||
text=output[0],
|
||||
matches=output,
|
||||
metadata={"origin": "FirstMatchRegexExtractor"},
|
||||
)
|
||||
|
||||
return ExtractorOutput(
|
||||
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
|
||||
)
|
25
libs/kotaemon/kotaemon/storages/__init__.py
Normal file
25
libs/kotaemon/kotaemon/storages/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .docstores import (
|
||||
BaseDocumentStore,
|
||||
ElasticsearchDocumentStore,
|
||||
InMemoryDocumentStore,
|
||||
SimpleFileDocumentStore,
|
||||
)
|
||||
from .vectorstores import (
|
||||
BaseVectorStore,
|
||||
ChromaVectorStore,
|
||||
InMemoryVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Document stores
|
||||
"BaseDocumentStore",
|
||||
"InMemoryDocumentStore",
|
||||
"ElasticsearchDocumentStore",
|
||||
"SimpleFileDocumentStore",
|
||||
# Vector stores
|
||||
"BaseVectorStore",
|
||||
"ChromaVectorStore",
|
||||
"InMemoryVectorStore",
|
||||
"SimpleFileVectorStore",
|
||||
]
|
11
libs/kotaemon/kotaemon/storages/docstores/__init__.py
Normal file
11
libs/kotaemon/kotaemon/storages/docstores/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .base import BaseDocumentStore
|
||||
from .elasticsearch import ElasticsearchDocumentStore
|
||||
from .in_memory import InMemoryDocumentStore
|
||||
from .simple_file import SimpleFileDocumentStore
|
||||
|
||||
__all__ = [
|
||||
"BaseDocumentStore",
|
||||
"InMemoryDocumentStore",
|
||||
"ElasticsearchDocumentStore",
|
||||
"SimpleFileDocumentStore",
|
||||
]
|
47
libs/kotaemon/kotaemon/storages/docstores/base.py
Normal file
47
libs/kotaemon/kotaemon/storages/docstores/base.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
||||
class BaseDocumentStore(ABC):
|
||||
"""A document store is in charged of storing and managing documents"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Add document into document store
|
||||
|
||||
Args:
|
||||
docs: Document or list of documents
|
||||
ids: List of ids of the documents. Optional, if not set will use doc.doc_id
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self) -> List[Document]:
|
||||
"""Get all documents"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def count(self) -> int:
|
||||
"""Count number of documents"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, ids: Union[List[str], str]):
|
||||
"""Delete document by id"""
|
||||
...
|
173
libs/kotaemon/kotaemon/storages/docstores/elasticsearch.py
Normal file
173
libs/kotaemon/kotaemon/storages/docstores/elasticsearch.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .base import BaseDocumentStore
|
||||
|
||||
MAX_DOCS_TO_GET = 10**4
|
||||
|
||||
|
||||
class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"""Simple memory document store that store document in a dictionary"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str = "docstore",
|
||||
elasticsearch_url: str = "http://localhost:9200",
|
||||
k1: float = 2.0,
|
||||
b: float = 0.75,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use ElaticsearchDocstore please install `pip install elasticsearch`"
|
||||
)
|
||||
|
||||
self.elasticsearch_url = elasticsearch_url
|
||||
self.index_name = index_name
|
||||
self.k1 = k1
|
||||
self.b = b
|
||||
|
||||
# Create an Elasticsearch client instance
|
||||
self.client = Elasticsearch(elasticsearch_url, **kwargs)
|
||||
self.es_bulk = bulk
|
||||
# Define the index settings and mappings
|
||||
settings = {
|
||||
"analysis": {"analyzer": {"default": {"type": "standard"}}},
|
||||
"similarity": {
|
||||
"custom_bm25": {
|
||||
"type": "BM25",
|
||||
"k1": k1,
|
||||
"b": b,
|
||||
}
|
||||
},
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"similarity": "custom_bm25", # Use the custom BM25 similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create the index with the specified settings and mappings
|
||||
if not self.client.indices.exists(index=index_name):
|
||||
self.client.indices.create(
|
||||
index=index_name, mappings=mappings, settings=settings
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
refresh_indices: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Add document into document store
|
||||
|
||||
Args:
|
||||
docs: list of documents to add
|
||||
ids: specify the ids of documents to add or use existing doc.doc_id
|
||||
refresh_indices: request Elasticsearch to update its index (default to True)
|
||||
"""
|
||||
if ids and not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
if not isinstance(docs, list):
|
||||
docs = [docs]
|
||||
doc_ids = ids if ids else [doc.doc_id for doc in docs]
|
||||
|
||||
requests = []
|
||||
for doc_id, doc in zip(doc_ids, docs):
|
||||
text = doc.text
|
||||
metadata = doc.metadata
|
||||
request = {
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"content": text,
|
||||
"metadata": metadata,
|
||||
"_id": doc_id,
|
||||
}
|
||||
requests.append(request)
|
||||
self.es_bulk(self.client, requests)
|
||||
|
||||
if refresh_indices:
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
|
||||
def query_raw(self, query: dict) -> List[Document]:
|
||||
"""Query Elasticsearch store using query format of ES client
|
||||
|
||||
Args:
|
||||
query (dict): Elasticsearch query format
|
||||
|
||||
Returns:
|
||||
List[Document]: List of result documents
|
||||
"""
|
||||
res = self.client.search(index=self.index_name, body=query)
|
||||
docs = []
|
||||
for r in res["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
id_=r["_id"],
|
||||
text=r["_source"]["content"],
|
||||
metadata=r["_source"]["metadata"],
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def query(
|
||||
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||
) -> List[Document]:
|
||||
"""Search Elasticsearch docstore using search query (BM25)
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
top_k (int, optional): number of
|
||||
top documents to return. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of result documents
|
||||
"""
|
||||
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
|
||||
if doc_ids:
|
||||
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
query_dict = {"query": {"terms": {"_id": ids}}}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def count(self) -> int:
|
||||
"""Count number of documents"""
|
||||
count = int(
|
||||
self.client.cat.count(index=self.index_name, format="json")[0]["count"]
|
||||
)
|
||||
return count
|
||||
|
||||
def get_all(self) -> List[Document]:
|
||||
"""Get all documents"""
|
||||
query_dict = {"query": {"match_all": {}}, "size": MAX_DOCS_TO_GET}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def delete(self, ids: Union[List[str], str]):
|
||||
"""Delete document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
query = {"query": {"terms": {"_id": ids}}}
|
||||
self.client.delete_by_query(index=self.index_name, body=query)
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {
|
||||
"index_name": self.index_name,
|
||||
"elasticsearch_url": self.elasticsearch_url,
|
||||
"k1": self.k1,
|
||||
"b": self.b,
|
||||
}
|
85
libs/kotaemon/kotaemon/storages/docstores/in_memory.py
Normal file
85
libs/kotaemon/kotaemon/storages/docstores/in_memory.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .base import BaseDocumentStore
|
||||
|
||||
|
||||
class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"""Simple memory document store that store document in a dictionary"""
|
||||
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Add document into document store
|
||||
|
||||
Args:
|
||||
docs: list of documents to add
|
||||
ids: specify the ids of documents to add or
|
||||
use existing doc.doc_id
|
||||
exist_ok: raise error when duplicate doc-id
|
||||
found in the docstore (default to False)
|
||||
"""
|
||||
exist_ok: bool = kwargs.pop("exist_ok", False)
|
||||
|
||||
if ids and not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
if not isinstance(docs, list):
|
||||
docs = [docs]
|
||||
doc_ids = ids if ids else [doc.doc_id for doc in docs]
|
||||
|
||||
for doc_id, doc in zip(doc_ids, docs):
|
||||
if doc_id in self._store and not exist_ok:
|
||||
raise ValueError(f"Document with id {doc_id} already exist")
|
||||
self._store[doc_id] = doc
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
return [self._store[doc_id] for doc_id in ids]
|
||||
|
||||
def get_all(self) -> List[Document]:
|
||||
"""Get all documents"""
|
||||
return list(self._store.values())
|
||||
|
||||
def count(self) -> int:
|
||||
"""Count number of documents"""
|
||||
return len(self._store)
|
||||
|
||||
def delete(self, ids: Union[List[str], str]):
|
||||
"""Delete document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
for doc_id in ids:
|
||||
del self._store[doc_id]
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
"""Save document to path"""
|
||||
store = {key: value.to_dict() for key, value in self._store.items()}
|
||||
with open(path, "w") as f:
|
||||
json.dump(store, f)
|
||||
|
||||
def load(self, path: Union[str, Path]):
|
||||
"""Load document store from path"""
|
||||
with open(path) as f:
|
||||
store = json.load(f)
|
||||
# TODO: save and load aren't lossless. A Document-subclass will lose
|
||||
# information. Need to edit the `to_dict` and `from_dict` methods in
|
||||
# the Document class.
|
||||
# For better query support, utilize SQLite as the default document store.
|
||||
# Also, for portability, use SQLAlchemy for document store.
|
||||
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {}
|
56
libs/kotaemon/kotaemon/storages/docstores/simple_file.py
Normal file
56
libs/kotaemon/kotaemon/storages/docstores/simple_file.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .in_memory import InMemoryDocumentStore
|
||||
|
||||
|
||||
class SimpleFileDocumentStore(InMemoryDocumentStore):
|
||||
"""Improve InMemoryDocumentStore by auto saving whenever the corpus is changed"""
|
||||
|
||||
def __init__(self, path: str | Path):
|
||||
super().__init__()
|
||||
self._path = path
|
||||
if path is not None and Path(path).is_file():
|
||||
self.load(path)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
for doc_id in ids:
|
||||
if doc_id not in self._store:
|
||||
self.load(self._path)
|
||||
break
|
||||
|
||||
return [self._store[doc_id] for doc_id in ids]
|
||||
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Add document into document store
|
||||
|
||||
Args:
|
||||
docs: list of documents to add
|
||||
ids: specify the ids of documents to add or
|
||||
use existing doc.doc_id
|
||||
exist_ok: raise error when duplicate doc-id
|
||||
found in the docstore (default to False)
|
||||
"""
|
||||
super().add(docs=docs, ids=ids, **kwargs)
|
||||
self.save(self._path)
|
||||
|
||||
def delete(self, ids: Union[List[str], str]):
|
||||
"""Delete document by id"""
|
||||
super().delete(ids=ids)
|
||||
self.save(self._path)
|
||||
|
||||
def __persist_flow__(self):
|
||||
from theflow.utils.modules import serialize
|
||||
|
||||
return {"path": serialize(self._path)}
|
11
libs/kotaemon/kotaemon/storages/vectorstores/__init__.py
Normal file
11
libs/kotaemon/kotaemon/storages/vectorstores/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .base import BaseVectorStore
|
||||
from .chroma import ChromaVectorStore
|
||||
from .in_memory import InMemoryVectorStore
|
||||
from .simple_file import SimpleFileVectorStore
|
||||
|
||||
__all__ = [
|
||||
"BaseVectorStore",
|
||||
"ChromaVectorStore",
|
||||
"InMemoryVectorStore",
|
||||
"SimpleFileVectorStore",
|
||||
]
|
169
libs/kotaemon/kotaemon/storages/vectorstores/base.py
Normal file
169
libs/kotaemon/kotaemon/storages/vectorstores/base.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
from llama_index.schema import NodeRelationship, RelatedNodeInfo
|
||||
from llama_index.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
||||
from llama_index.vector_stores.types import VectorStoreQuery
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
embeddings: list[list[float]] | list[DocumentWithEmbedding],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
) -> list[str]:
|
||||
"""Add vector embeddings to vector stores
|
||||
|
||||
Args:
|
||||
embeddings: List of embeddings
|
||||
metadatas: List of metadata of the embeddings
|
||||
ids: List of ids of the embeddings
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
|
||||
Returns:
|
||||
List of ids of the embeddings
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, ids: list[str], **kwargs):
|
||||
"""Delete vector embeddings from vector stores
|
||||
|
||||
Args:
|
||||
ids: List of ids of the embeddings to be deleted
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
embedding: list[float],
|
||||
top_k: int = 1,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[float]], list[float], list[str]]:
|
||||
"""Return the top k most similar vector embeddings
|
||||
|
||||
Args:
|
||||
embedding: List of embeddings
|
||||
top_k: Number of most similar embeddings to return
|
||||
ids: List of ids of the embeddings to be queried
|
||||
|
||||
Returns:
|
||||
the matched embeddings, the similarity scores, and the ids
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class LlamaIndexVectorStore(BaseVectorStore):
|
||||
_li_class: type[LIVectorStore | BasePydanticVectorStore]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self._li_class is None:
|
||||
raise AttributeError(
|
||||
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
||||
)
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
self._client = self._li_class(*args, **kwargs)
|
||||
|
||||
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
||||
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
||||
if key in self._vsq_kwargs:
|
||||
self._vsq_kwargs.remove(key)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
return setattr(self._client, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._client, name)
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: list[list[float]] | list[DocumentWithEmbedding],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
):
|
||||
if isinstance(embeddings[0], list):
|
||||
nodes: list[DocumentWithEmbedding] = [
|
||||
DocumentWithEmbedding(embedding=embedding) for embedding in embeddings
|
||||
]
|
||||
else:
|
||||
nodes = embeddings # type: ignore
|
||||
if metadatas is not None:
|
||||
for node, metadata in zip(nodes, metadatas):
|
||||
node.metadata = metadata
|
||||
if ids is not None:
|
||||
for node, id in zip(nodes, ids):
|
||||
node.id_ = id
|
||||
node.relationships = {
|
||||
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=id)
|
||||
}
|
||||
|
||||
return self._client.add(nodes=nodes)
|
||||
|
||||
def delete(self, ids: list[str], **kwargs):
|
||||
for id_ in ids:
|
||||
self._client.delete(ref_doc_id=id_, **kwargs)
|
||||
|
||||
def query(
|
||||
self,
|
||||
embedding: list[float],
|
||||
top_k: int = 1,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[float]], list[float], list[str]]:
|
||||
"""Return the top k most similar vector embeddings
|
||||
|
||||
Args:
|
||||
embedding: List of embeddings
|
||||
top_k: Number of most similar embeddings to return
|
||||
ids: List of ids of the embeddings to be queried
|
||||
kwargs: extra query parameters. Depending on the name, these parameters
|
||||
will be used when constructing the VectorStoreQuery object or when
|
||||
performing querying of the underlying vector store.
|
||||
|
||||
Returns:
|
||||
the matched embeddings, the similarity scores, and the ids
|
||||
"""
|
||||
vsq_kwargs = {}
|
||||
vs_kwargs = {}
|
||||
for kwkey, kwvalue in kwargs.items():
|
||||
if kwkey in self._vsq_kwargs:
|
||||
vsq_kwargs[kwkey] = kwvalue
|
||||
else:
|
||||
vs_kwargs[kwkey] = kwvalue
|
||||
|
||||
output = self._client.query(
|
||||
query=VectorStoreQuery(
|
||||
query_embedding=embedding,
|
||||
similarity_top_k=top_k,
|
||||
node_ids=ids,
|
||||
**vsq_kwargs,
|
||||
),
|
||||
**vs_kwargs,
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
if output.nodes:
|
||||
for node in output.nodes:
|
||||
embeddings.append(node.embedding)
|
||||
similarities = output.similarities if output.similarities else []
|
||||
out_ids = output.ids if output.ids else []
|
||||
|
||||
return embeddings, similarities, out_ids
|
96
libs/kotaemon/kotaemon/storages/vectorstores/chroma.py
Normal file
96
libs/kotaemon/kotaemon/storages/vectorstores/chroma.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore as LIChromaVectorStore
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class ChromaVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIChromaVectorStore] = LIChromaVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "./chroma",
|
||||
collection_name: str = "default",
|
||||
host: str = "localhost",
|
||||
port: str = "8000",
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
collection_kwargs: Optional[dict] = None,
|
||||
stores_text: bool = True,
|
||||
flat_metadata: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._path = path
|
||||
self._collection_name = collection_name
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._ssl = ssl
|
||||
self._headers = headers
|
||||
self._collection_kwargs = collection_kwargs
|
||||
self._stores_text = stores_text
|
||||
self._flat_metadata = flat_metadata
|
||||
self._kwargs = kwargs
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ChromaVectorStore requires chromadb. "
|
||||
"Please install chromadb first `pip install chromadb`"
|
||||
)
|
||||
|
||||
client = chromadb.PersistentClient(path=path)
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
|
||||
# pass through for nice IDE support
|
||||
super().__init__(
|
||||
chroma_collection=collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers or {},
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
stores_text=stores_text,
|
||||
flat_metadata=flat_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._client = cast(LIChromaVectorStore, self._client)
|
||||
|
||||
def delete(self, ids: List[str], **kwargs):
|
||||
"""Delete vector embeddings from vector stores
|
||||
|
||||
Args:
|
||||
ids: List of ids of the embeddings to be deleted
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
"""
|
||||
self._client.client.delete(ids=ids)
|
||||
|
||||
def delete_collection(self, collection_name: Optional[str] = None):
|
||||
"""Delete entire collection under specified name from vector stores
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to delete
|
||||
"""
|
||||
# a rather ugly chain call but it do the job of finding
|
||||
# original chromadb client and call delete_collection() method
|
||||
if collection_name is None:
|
||||
collection_name = self._client.client.name
|
||||
self._client.client._client.delete_collection(collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
return self._collection.count()
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {
|
||||
"path": self._path,
|
||||
"collection_name": self._collection_name,
|
||||
"host": self._host,
|
||||
"port": self._port,
|
||||
"ssl": self._ssl,
|
||||
"headers": self._headers,
|
||||
"collection_kwargs": self._collection_kwargs,
|
||||
"stores_text": self._stores_text,
|
||||
"flat_metadata": self._flat_metadata,
|
||||
**self._kwargs,
|
||||
}
|
62
libs/kotaemon/kotaemon/storages/vectorstores/in_memory.py
Normal file
62
libs/kotaemon/kotaemon/storages/vectorstores/in_memory.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Simple vector store index."""
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import fsspec
|
||||
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LISimpleVectorStore] = LISimpleVectorStore
|
||||
store_text: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional[SimpleVectorStoreData] = None,
|
||||
fs: Optional[fsspec.AbstractFileSystem] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
self._data = data or SimpleVectorStoreData()
|
||||
self._fs = fs or fsspec.filesystem("file")
|
||||
|
||||
super().__init__(
|
||||
data=data,
|
||||
fs=fs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
save_path: str,
|
||||
fs: Optional[fsspec.AbstractFileSystem] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
"""save a simpleVectorStore to a dictionary.
|
||||
|
||||
Args:
|
||||
save_path: Path of saving vector to disk.
|
||||
fs: An abstract super-class for pythonic file-systems
|
||||
"""
|
||||
self._client.persist(persist_path=save_path, fs=fs)
|
||||
|
||||
def load(self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None):
|
||||
|
||||
"""Create a SimpleKVStore from a load directory.
|
||||
|
||||
Args:
|
||||
load_path: Path of loading vector.
|
||||
fs: An abstract super-class for pythonic file-systems
|
||||
"""
|
||||
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)
|
||||
|
||||
def __persist_flow__(self):
|
||||
d = self._data.to_dict()
|
||||
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
|
||||
return {
|
||||
"data": d,
|
||||
# "fs": self._fs,
|
||||
}
|
65
libs/kotaemon/kotaemon/storages/vectorstores/simple_file.py
Normal file
65
libs/kotaemon/kotaemon/storages/vectorstores/simple_file.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Simple file vector store index."""
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import fsspec
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class SimpleFileVectorStore(LlamaIndexVectorStore):
|
||||
"""Similar to InMemoryVectorStore but is backed by file by default"""
|
||||
|
||||
_li_class: Type[LISimpleVectorStore] = LISimpleVectorStore
|
||||
store_text: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path,
|
||||
data: Optional[SimpleVectorStoreData] = None,
|
||||
fs: Optional[fsspec.AbstractFileSystem] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
self._data = data or SimpleVectorStoreData()
|
||||
self._fs = fs or fsspec.filesystem("file")
|
||||
self._path = path
|
||||
self._save_path = Path(path)
|
||||
|
||||
super().__init__(
|
||||
data=data,
|
||||
fs=fs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self._save_path.is_file():
|
||||
self._client = self._li_class.from_persist_path(
|
||||
persist_path=str(self._save_path), fs=self._fs
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: list[list[float]] | list[DocumentWithEmbedding],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
):
|
||||
r = super().add(embeddings, metadatas, ids)
|
||||
self._client.persist(str(self._save_path), self._fs)
|
||||
return r
|
||||
|
||||
def delete(self, ids: list[str], **kwargs):
|
||||
r = super().delete(ids, **kwargs)
|
||||
self._client.persist(str(self._save_path), self._fs)
|
||||
return r
|
||||
|
||||
def __persist_flow__(self):
|
||||
d = self._data.to_dict()
|
||||
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
|
||||
return {
|
||||
"data": d,
|
||||
"path": str(self._path),
|
||||
# "fs": self._fs,
|
||||
}
|
73
libs/kotaemon/pyproject.toml
Normal file
73
libs/kotaemon/pyproject.toml
Normal file
@@ -0,0 +1,73 @@
|
||||
# build backand and build dependencies
|
||||
[build-system]
|
||||
requires = ["setuptools >= 61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = false
|
||||
packages.find.include = ["kotaemon*"]
|
||||
packages.find.exclude = ["tests*", "env*"]
|
||||
|
||||
# metadata and dependencies
|
||||
[project]
|
||||
name = "kotaemon"
|
||||
version = "0.3.5"
|
||||
requires-python = ">= 3.10"
|
||||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
"langchain",
|
||||
"langchain-community",
|
||||
"theflow",
|
||||
"llama-index>=0.9.0",
|
||||
"llama-hub",
|
||||
"gradio>=4.0.0",
|
||||
"openpyxl",
|
||||
"cookiecutter",
|
||||
"click",
|
||||
"pandas",
|
||||
"trogon",
|
||||
]
|
||||
readme = "README.md"
|
||||
license = { text = "MIT License" }
|
||||
authors = [
|
||||
{ name = "john", email = "john@cinnamon.is" },
|
||||
{ name = "ian", email = "ian@cinnamon.is" },
|
||||
{ name = "tadashi", email = "tadashi@cinnamon.is" },
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"ipython",
|
||||
"pytest",
|
||||
"pre-commit",
|
||||
"black",
|
||||
"flake8",
|
||||
"sphinx",
|
||||
"coverage",
|
||||
"openai",
|
||||
"langchain-openai",
|
||||
"chromadb",
|
||||
"wikipedia",
|
||||
"duckduckgo-search",
|
||||
"googlesearch-python",
|
||||
"python-dotenv",
|
||||
"pytest-mock",
|
||||
"unstructured[pdf]",
|
||||
"sentence_transformers",
|
||||
"cohere",
|
||||
"elasticsearch",
|
||||
"pypdf",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
kh = "kotaemon.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Cinnamon/kotaemon/"
|
||||
Repository = "https://github.com/Cinnamon/kotaemon/"
|
||||
Documentation = "https://github.com/Cinnamon/kotaemon/wiki"
|
9
libs/kotaemon/pytest.ini
Normal file
9
libs/kotaemon/pytest.ini
Normal file
@@ -0,0 +1,9 @@
|
||||
[pytest]
|
||||
minversion = 7.4.0
|
||||
testpaths = tests
|
||||
addopts = -ra -q
|
||||
log_cli=true
|
||||
log_level=WARNING
|
||||
log_format = %(asctime)s %(levelname)s %(message)s
|
||||
log_date_format = %Y-%m-%d %H:%M:%S
|
||||
log_file = logs/pytest-logs.txt
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user