Upgrade the declarative pipeline for cleaner interface (#51)

This commit is contained in:
Nguyen Trung Duc (john) 2023-10-24 11:12:22 +07:00 committed by GitHub
parent aab982ddc4
commit 9035e25666
26 changed files with 365 additions and 169 deletions

View File

@ -22,4 +22,4 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "0.0.4" __version__ = "0.2.0"

View File

@ -10,4 +10,4 @@ class SimpleRespondentChatbot(BaseChatBot):
llm: Node[ChatLLM] llm: Node[ChatLLM]
def _get_message(self) -> str: def _get_message(self) -> str:
return self.llm(self.history).text[0] return self.llm(self.history).text

View File

@ -63,19 +63,19 @@ def handle_node(node: dict) -> dict:
"""Convert node definition into promptui-compliant config""" """Convert node definition into promptui-compliant config"""
config = {} config = {}
for name, param_def in node.get("params", {}).items(): for name, param_def in node.get("params", {}).items():
if isinstance(param_def["default_callback"], str): if isinstance(param_def["auto_callback"], str):
continue continue
if param_def.get("ignore_ui", False): if param_def.get("ignore_ui", False):
continue continue
config[name] = handle_param(param_def) config[name] = handle_param(param_def)
for name, node_def in node.get("nodes", {}).items(): for name, node_def in node.get("nodes", {}).items():
if isinstance(node_def["default_callback"], str): if isinstance(node_def["auto_callback"], str):
continue continue
if node_def.get("ignore_ui", False): if node_def.get("ignore_ui", False):
continue continue
for key, value in handle_node(node_def["default"]).items(): for key, value in handle_node(node_def["default"]).items():
config[f"{name}.{key}"] = value config[f"{name}.{key}"] = value
for key, value in node_def["default_kwargs"].items(): for key, value in node_def.get("default_kwargs", {}).items():
config[f"{name}.{key}"] = config_from_value(value) config[f"{name}.{key}"] = config_from_value(value)
return config return config
@ -124,11 +124,14 @@ def export_pipeline_to_config(
if ui_type == "chat": if ui_type == "chat":
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()} params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
params["system_message"] = {"component": "text", "params": {"value": ""}} params["system_message"] = {"component": "text", "params": {"value": ""}}
outputs = []
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj: dict = { config_obj: dict = {
"ui-type": ui_type, "ui-type": ui_type,
"params": params, "params": params,
"inputs": {}, "inputs": {},
"outputs": [], "outputs": outputs,
"logs": { "logs": {
"full_pipeline": { "full_pipeline": {
"input": { "input": {

View File

@ -61,6 +61,9 @@ def from_log_to_dict(pipeline_cls: Type[BaseComponent], log_config: dict) -> dic
if name not in logged_infos: if name not in logged_infos:
logged_infos[name] = [None] * len(dirs) logged_infos[name] = [None] * len(dirs)
if step not in progress:
continue
info = progress[step] info = progress[step]
if getter: if getter:
if getter in allowed_resultlog_callbacks: if getter in allowed_resultlog_callbacks:

View File

@ -13,9 +13,9 @@ class John(Base):
primary_hue: colors.Color | str = colors.neutral, primary_hue: colors.Color | str = colors.neutral,
secondary_hue: colors.Color | str = colors.neutral, secondary_hue: colors.Color | str = colors.neutral,
neutral_hue: colors.Color | str = colors.neutral, neutral_hue: colors.Color | str = colors.neutral,
spacing_size: sizes.Size | str = sizes.spacing_lg, spacing_size: sizes.Size | str = sizes.spacing_sm,
radius_size: sizes.Size | str = sizes.radius_none, radius_size: sizes.Size | str = sizes.radius_none,
text_size: sizes.Size | str = sizes.text_md, text_size: sizes.Size | str = sizes.text_sm,
font: fonts.Font font: fonts.Font
| str | str
| Iterable[fonts.Font | str] = ( | Iterable[fonts.Font | str] = (
@ -79,8 +79,8 @@ class John(Base):
button_cancel_background_fill_hover="*button_primary_background_fill_hover", button_cancel_background_fill_hover="*button_primary_background_fill_hover",
button_cancel_text_color="*button_primary_text_color", button_cancel_text_color="*button_primary_text_color",
# Padding # Padding
checkbox_label_padding="*spacing_md", checkbox_label_padding="*spacing_sm",
button_large_padding="*spacing_lg", button_large_padding="*spacing_sm",
button_small_padding="*spacing_sm", button_small_padding="*spacing_sm",
# Borders # Borders
block_border_width="0px", block_border_width="0px",
@ -91,5 +91,5 @@ class John(Base):
# Block Labels # Block Labels
block_title_text_weight="600", block_title_text_weight="600",
block_label_text_weight="600", block_label_text_weight="600",
block_label_text_size="*text_md", block_label_text_size="*text_sm",
) )

View File

@ -26,9 +26,9 @@ def build_from_dict(config: Union[str, dict]):
for key, value in config_dict.items(): for key, value in config_dict.items():
pipeline_def = import_dotted_string(key, safe=False) pipeline_def = import_dotted_string(key, safe=False)
if value["ui-type"] == "chat": if value["ui-type"] == "chat":
demos.append(build_chat_ui(value, pipeline_def)) demos.append(build_chat_ui(value, pipeline_def).queue())
else: else:
demos.append(build_pipeline_ui(value, pipeline_def)) demos.append(build_pipeline_ui(value, pipeline_def).queue())
if len(demos) == 1: if len(demos) == 1:
demo = demos[0] demo = demos[0]
else: else:

View File

@ -0,0 +1,181 @@
from __future__ import annotations
from typing import Any, AsyncGenerator
import anyio
from gradio import ChatInterface
from gradio.components import IOComponent, get_component_instance
from gradio.events import on
from gradio.helpers import special_args
from gradio.routes import Request
class ChatBlock(ChatInterface):
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
- Show additional outputs to the chat interface
- Disallow blank user message
"""
def __init__(
self,
*args,
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
**kwargs,
):
if additional_outputs:
if not isinstance(additional_outputs, list):
additional_outputs = [additional_outputs]
self.additional_outputs = [
get_component_instance(i) for i in additional_outputs # type: ignore
]
else:
self.additional_outputs = []
super().__init__(*args, **kwargs)
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> tuple[Any, ...]:
input_args = args[: -len(self.additional_outputs)]
output_args = args[-len(self.additional_outputs) :]
if not message:
return history_with_input, history_with_input, *output_args
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *input_args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
output = []
if self.additional_outputs:
text = response[0]
output = response[1:]
else:
text = response
history.append([message, text])
return history, history, *output
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
) -> AsyncGenerator:
raise NotImplementedError("Stream function not implemented for ChatBlock")
def _display_input(
self, message: str, history: list[list[str | None]]
) -> tuple[list[list[str | None]], list[list[str | None]]]:
"""Stop displaying the input message if the message is a blank string"""
if not message:
return history, history
return super()._display_input(message, history)
def _setup_events(self) -> None:
"""Include additional outputs in the submit event"""
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_triggers = (
[self.textbox.submit, self.submit_btn.click]
if self.submit_btn
else [self.textbox.submit]
)
submit_event = (
on(
submit_triggers,
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events(submit_triggers, submit_event)
if self.retry_btn:
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
lambda x: x,
[self.saved_input],
[self.textbox],
api_name=False,
queue=False,
)
if self.clear_btn:
self.clear_btn.click(
lambda: ([], [], None),
None,
[self.chatbot, self.chatbot_state, self.saved_input],
queue=False,
api_name=False,
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state] + self.additional_outputs,
api_name="chat",
)

View File

@ -8,6 +8,9 @@ from theflow.storage import storage
from kotaemon.chatbot import ChatConversation from kotaemon.chatbot import ChatConversation
from kotaemon.contribs.promptui.base import get_component from kotaemon.contribs.promptui.base import get_component
from kotaemon.contribs.promptui.export import export from kotaemon.contribs.promptui.export import export
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
from ..logs import ResultLog
USAGE_INSTRUCTION = """## How to use: USAGE_INSTRUCTION = """## How to use:
@ -87,8 +90,10 @@ def construct_chat_ui(
outputs.append(component) outputs.append(component)
sess = gr.State(value=None) sess = gr.State(value=None)
chatbot = gr.Chatbot(label="Chatbot") chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
chat = gr.ChatInterface(func_chat, chatbot=chatbot, additional_inputs=[sess]) chat = ChatBlock(
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
)
param_state = gr.Textbox(interactive=False) param_state = gr.Textbox(interactive=False)
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo: with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
@ -106,6 +111,7 @@ def construct_chat_ui(
chat.saved_input, chat.saved_input,
param_state, param_state,
sess, sess,
*outputs,
], ],
) )
with gr.Accordion(label="End chat", open=False): with gr.Accordion(label="End chat", open=False):
@ -162,6 +168,9 @@ def build_chat_ui(config, pipeline_def):
exported_dir = output_dir.parent / "exported" exported_dir = output_dir.parent / "exported"
exported_dir.mkdir(parents=True, exist_ok=True) exported_dir.mkdir(parents=True, exist_ok=True)
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
def new_chat(*args): def new_chat(*args):
"""Start a new chat function """Start a new chat function
@ -190,7 +199,14 @@ def build_chat_ui(config, pipeline_def):
) )
gr.Info("New chat session started.") gr.Info("New chat session started.")
return [], [], None, param_state_str, session return (
[],
[],
None,
param_state_str,
session,
*[None] * len(config.get("outputs", [])),
)
def chat(message, history, session, *args): def chat(message, history, session, *args):
"""The chat interface """The chat interface
@ -212,7 +228,18 @@ def build_chat_ui(config, pipeline_def):
"No active chat session. Please set the params and click New chat" "No active chat session. Please set the params and click New chat"
) )
return session(message).content pred = session(message)
text_response = pred.content
additional_outputs = []
for output_def in config.get("outputs", []):
value = session.last_run.logs(output_def["step"])
getter = output_def.get("getter", None)
if getter and getter in allowed_resultlog_callbacks:
value = getattr(resultlog, getter)(value)
additional_outputs.append(value)
return text_response, *additional_outputs
def end_chat(preference: str, save_log: bool, session): def end_chat(preference: str, save_log: bool, session):
"""End the chat session """End the chat session

View File

@ -1,9 +1,10 @@
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional, TypeVar
from haystack.schema import Document as HaystackDocument
from llama_index.bridge.pydantic import Field from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument from llama_index.schema import Document as BaseDocument
from pyparsing import TypeVar
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
IO_Type = TypeVar("IO_Type", "Document", str) IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon" SAMPLE_TEXT = "A sample Document from kotaemon"
@ -49,6 +50,8 @@ class Document(BaseDocument):
def to_haystack_format(self) -> "HaystackDocument": def to_haystack_format(self) -> "HaystackDocument":
"""Convert struct to Haystack document format.""" """Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
metadata = self.metadata or {} metadata = self.metadata or {}
text = self.text text = self.text
return HaystackDocument(content=text, meta=metadata) return HaystackDocument(content=text, meta=metadata)

View File

@ -56,11 +56,11 @@ class LangchainEmbeddings(BaseEmbeddings):
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in self._lc_class.__fields__: if name in self._lc_class.__fields__:
setattr(self.agent, name, value) self._kwargs[name] = value
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)
@Param.decorate(no_cache=True) @Param.auto(cache=False)
def agent(self): def agent(self):
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)

View File

@ -6,7 +6,7 @@ from kotaemon.documents.base import Document
class LLMInterface(Document): class LLMInterface(Document):
candidates: List[str] candidates: List[str] = Field(default_factory=list)
completion_tokens: int = -1 completion_tokens: int = -1
total_tokens: int = -1 total_tokens: int = -1
prompt_tokens: int = -1 prompt_tokens: int = -1

View File

@ -40,7 +40,7 @@ class LangchainChatLLM(ChatLLM):
self._kwargs[param] = params.pop(param) self._kwargs[param] = params.pop(param)
super().__init__(**params) super().__init__(**params)
@Param.decorate(no_cache=True) @Param.auto(cache=False)
def agent(self) -> BaseLanguageModel: def agent(self) -> BaseLanguageModel:
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)
@ -92,3 +92,9 @@ class LangchainChatLLM(ChatLLM):
setattr(self.agent, name, value) setattr(self.agent, name, value)
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._lc_class.__fields__:
getattr(self.agent, name)
else:
super().__getattr__(name)

View File

@ -27,7 +27,7 @@ class LangchainLLM(LLM):
self._kwargs[param] = params.pop(param) self._kwargs[param] = params.pop(param)
super().__init__(**params) super().__init__(**params)
@Param.decorate(no_cache=True) @Param.auto(cache=False)
def agent(self): def agent(self):
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)

View File

@ -69,8 +69,8 @@ class Thought(BaseComponent):
"variable placeholders, that then will be subsituted with real values when " "variable placeholders, that then will be subsituted with real values when "
"this component is executed" "this component is executed"
) )
llm = Node( llm: Node[BaseComponent] = Node(
default=AzureChatOpenAI, help="The LLM model to execute the input prompt" AzureChatOpenAI, help="The LLM model to execute the input prompt"
) )
post_process: Node[Compose] = Node( post_process: Node[Compose] = Node(
help="The function post-processor that post-processes LLM output prediction ." help="The function post-processor that post-processes LLM output prediction ."
@ -78,7 +78,7 @@ class Thought(BaseComponent):
"a dictionary, where the key should" "a dictionary, where the key should"
) )
@Node.decorate(depends_on="prompt") @Node.auto(depends_on="prompt")
def prompt_template(self): def prompt_template(self):
"""Automatically wrap around param prompt. Can ignore""" """Automatically wrap around param prompt. Can ignore"""
return BasePromptComponent(self.prompt) return BasePromptComponent(self.prompt)

View File

@ -1,8 +1,10 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from theflow import Node, Param from llama_index.readers.base import BaseReader
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore from kotaemon.docstores import InMemoryDocumentStore
@ -32,33 +34,22 @@ class ReaderIndexingPipeline(BaseComponent):
# Expose variables for users to switch in prompt ui # Expose variables for users to switch in prompt ui
storage_path: Path = Path("./storage") storage_path: Path = Path("./storage")
reader_name: str = "normal" # "normal" or "mathpix" reader_name: str = "normal" # "normal" or "mathpix"
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
chunk_size: int = 1024 chunk_size: int = 1024
chunk_overlap: int = 256 chunk_overlap: int = 256
file_name_list: List[str] = list() file_name_list: List[str] = list()
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
@Param.decorate() embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
def vector_store(self): model="text-embedding-ada-002",
return InMemoryVectorStore() deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
@Param.decorate() openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
def doc_store(self): )
doc_store = InMemoryDocumentStore()
return doc_store
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
def embedding(self):
return AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
)
def get_reader(self, input_files: List[Union[str, Path]]): def get_reader(self, input_files: List[Union[str, Path]]):
# document parsers # document parsers
file_extractor = { file_extractor: Dict[str, BaseReader] = {
".xlsx": PandasExcelReader(), ".xlsx": PandasExcelReader(),
} }
if self.reader_name == "normal": if self.reader_name == "normal":
@ -71,7 +62,7 @@ class ReaderIndexingPipeline(BaseComponent):
) )
return main_reader return main_reader
@Node.decorate(depends_on=["doc_store", "vector_store", "embedding"]) @Node.auto(depends_on=["doc_store", "vector_store", "embedding"])
def indexing_vector_pipeline(self): def indexing_vector_pipeline(self):
return IndexVectorStoreFromDocumentPipeline( return IndexVectorStoreFromDocumentPipeline(
doc_store=self.doc_store, doc_store=self.doc_store,
@ -79,12 +70,9 @@ class ReaderIndexingPipeline(BaseComponent):
embedding=self.embedding, embedding=self.embedding,
) )
@Node.decorate(depends_on=["chunk_size", "chunk_overlap"]) text_splitter: SimpleNodeParser = SimpleNodeParser.withx(
def text_splitter(self): chunk_size=1024, chunk_overlap=256
# chunking using NodeParser from llama-index )
return SimpleNodeParser(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
def run( def run(
self, self,

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import List from typing import List
from theflow import Node, Param from theflow import Node, Param
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore from kotaemon.docstores import InMemoryDocumentStore
@ -25,8 +26,6 @@ class QuestionAnsweringPipeline(BaseComponent):
storage_path: Path = Path("./storage") storage_path: Path = Path("./storage")
retrieval_top_k: int = 3 retrieval_top_k: int = 3
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
file_name_list: List[str] file_name_list: List[str]
"""List of filename, incombination with storage_path to """List of filename, incombination with storage_path to
create persistent path of vectorstore""" create persistent path of vectorstore"""
@ -35,37 +34,27 @@ class QuestionAnsweringPipeline(BaseComponent):
"The context is: \n{context}\nAnswer: " "The context is: \n{context}\nAnswer: "
) )
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"]) llm: AzureChatOpenAI = AzureChatOpenAI.withx(
def llm(self): openai_api_base="https://bleh-dummy-2.openai.azure.com/",
return AzureChatOpenAI( openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_base="https://bleh-dummy-2.openai.azure.com/", openai_api_version="2023-03-15-preview",
openai_api_key=self.openai_api_key, deployment_name="dummy-q2-gpt35",
openai_api_version="2023-03-15-preview", temperature=0,
deployment_name="dummy-q2-gpt35", request_timeout=60,
temperature=0, )
request_timeout=60,
)
@Param.decorate() vector_store: Param[InMemoryVectorStore] = Param(_(InMemoryVectorStore))
def vector_store(self): doc_store: Param[InMemoryDocumentStore] = Param(_(InMemoryDocumentStore))
return InMemoryVectorStore()
@Param.decorate() embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
def doc_store(self): model="text-embedding-ada-002",
doc_store = InMemoryDocumentStore() deployment="dummy-q2-text-embedding",
return doc_store openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"]) @Node.default()
def embedding(self): def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
return AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
)
@Node.decorate(depends_on=["doc_store", "vector_store", "embedding"])
def retrieving_pipeline(self):
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline( retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store, vector_store=self.vector_store,
doc_store=self.doc_store, doc_store=self.doc_store,

View File

@ -32,5 +32,5 @@ class LLMTool(BaseTool):
response = self.llm(query) response = self.llm(query)
except ValueError: except ValueError:
raise ToolException("LLM Tool call failed") raise ToolException("LLM Tool call failed")
output = response.text[0] output = response.text
return output return output

View File

@ -30,7 +30,6 @@ setuptools.setup(
exclude=("tests", "tests.*", "examples", "examples.*") exclude=("tests", "tests.*", "examples", "examples.*")
), ),
install_requires=[ install_requires=[
"farm-haystack==1.19.0",
"langchain", "langchain",
"theflow", "theflow",
"llama-index", "llama-index",
@ -59,6 +58,7 @@ setuptools.setup(
"python-dotenv", "python-dotenv",
"pytest-mock", "pytest-mock",
"unstructured[pdf]", "unstructured[pdf]",
"farm-haystack==1.19.0",
], ],
}, },
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]}, entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},

View File

@ -1,7 +1,8 @@
import os import os
from typing import List from typing import List
from theflow import Node, Param from theflow import Param
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore from kotaemon.docstores import InMemoryDocumentStore
@ -13,35 +14,28 @@ from kotaemon.vectorstores import ChromaVectorStore
class QuestionAnsweringPipeline(BaseComponent): class QuestionAnsweringPipeline(BaseComponent):
vectorstore_path: str = str("./tmp")
retrieval_top_k: int = 1 retrieval_top_k: int = 1
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
@Node.decorate(depends_on="openai_api_key") llm: AzureOpenAI = AzureOpenAI.withx(
def llm(self): openai_api_base="https://bleh-dummy-2.openai.azure.com/",
return AzureOpenAI( openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_base="https://bleh-dummy-2.openai.azure.com/", openai_api_version="2023-03-15-preview",
openai_api_key=self.openai_api_key, deployment_name="dummy-q2-gpt35",
openai_api_version="2023-03-15-preview", temperature=0,
deployment_name="dummy-q2-gpt35", request_timeout=60,
temperature=0, )
request_timeout=60,
)
@Node.decorate(depends_on=["vectorstore_path", "openai_api_key"]) retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
def retrieving_pipeline(self): RetrieveDocumentFromVectorStorePipeline.withx(
vector_store = ChromaVectorStore(self.vectorstore_path) vector_store=_(ChromaVectorStore).withx(path="./tmp"),
embedding = AzureOpenAIEmbeddings( embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding", deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/", openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=self.openai_api_key, openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
) ),
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store,
embedding=embedding,
) )
)
def run_raw(self, text: str) -> str: def run_raw(self, text: str) -> str:
# reload the document store, in case it has been updated # reload the document store, in case it has been updated
@ -60,36 +54,27 @@ class QuestionAnsweringPipeline(BaseComponent):
prompt = f'Answer the following question: "{text}". The context is: \n{context}' prompt = f'Answer the following question: "{text}". The context is: \n{context}'
self.log_progress(".prompt", prompt=prompt) self.log_progress(".prompt", prompt=prompt)
return self.llm(prompt).text[0] return self.llm(prompt).text
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline): class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
# Expose variables for users to switch in prompt ui # Expose variables for users to switch in prompt ui
vectorstore_path: str = str("./tmp")
embedding_model: str = "text-embedding-ada-002" embedding_model: str = "text-embedding-ada-002"
deployment: str = "dummy-q2-text-embedding" vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
@Param.decorate(depends_on=["vectorstore_path"]) @Param.auto()
def vector_store(self): def doc_store(self) -> InMemoryDocumentStore:
return ChromaVectorStore(self.vectorstore_path)
@Param.decorate()
def doc_store(self):
doc_store = InMemoryDocumentStore() doc_store = InMemoryDocumentStore()
if os.path.isfile("docstore.json"): if os.path.isfile("docstore.json"):
doc_store.load("docstore.json") doc_store.load("docstore.json")
return doc_store return doc_store
@Node.decorate(depends_on=["vector_store"]) embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
def embedding(self): model="text-embedding-ada-002",
return AzureOpenAIEmbeddings( deployment="dummy-q2-text-embedding",
model="text-embedding-ada-002", openai_api_base="https://bleh-dummy-2.openai.azure.com/",
deployment=self.deployment, openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_base=self.openai_api_base, )
openai_api_key=self.openai_api_key,
)
def run_raw(self, text: str) -> int: # type: ignore def run_raw(self, text: str) -> int: # type: ignore
"""Normally, this indexing pipeline returns nothing. For demonstration, """Normally, this indexing pipeline returns nothing. For demonstration,
@ -100,7 +85,7 @@ class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
if self.doc_store is not None: if self.doc_store is not None:
# persist to local anytime an indexing is created # persist to local anytime an indexing is created
# this can be bypassed when we have a FileDocucmentStore # this can be bypassed when we have a FileDocumentStore
self.doc_store.save("docstore.json") self.doc_store.save("docstore.json")
return self.vector_store._collection.count() return self.vector_store._collection.count()

View File

@ -13,3 +13,17 @@ def mock_google_search(monkeypatch):
) )
monkeypatch.setattr(googlesearch, "search", result) monkeypatch.setattr(googlesearch, "search", result)
def if_haystack_not_installed():
try:
import haystack # noqa: F401
except ImportError:
return True
else:
return False
skip_when_haystack_not_installed = pytest.mark.skipif(
if_haystack_not_installed(), reason="Haystack is not installed"
)

View File

@ -1,7 +1,7 @@
import tempfile import tempfile
from typing import List from typing import List
from theflow import Node from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
@ -11,33 +11,27 @@ from kotaemon.vectorstores import ChromaVectorStore
class Pipeline(BaseComponent): class Pipeline(BaseComponent):
vectorstore_path: str = str(tempfile.mkdtemp()) llm: AzureOpenAI = AzureOpenAI.withx(
llm: Node[AzureOpenAI] = Node( openai_api_base="https://test.openai.azure.com/",
default=AzureOpenAI, openai_api_key="some-key",
default_kwargs={ openai_api_version="2023-03-15-preview",
"openai_api_base": "https://test.openai.azure.com/", deployment_name="gpt35turbo",
"openai_api_key": "some-key", temperature=0,
"openai_api_version": "2023-03-15-preview", request_timeout=60,
"deployment_name": "gpt35turbo",
"temperature": 0,
"request_timeout": 60,
},
) )
@Node.decorate(depends_on=["vectorstore_path"]) retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
def retrieving_pipeline(self): RetrieveDocumentFromVectorStorePipeline.withx(
vector_store = ChromaVectorStore(self.vectorstore_path) vector_store=_(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())),
embedding = AzureOpenAIEmbeddings( embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="embedding-deployment", deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/", openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key", openai_api_key="some-key",
) ),
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store, embedding=embedding
) )
)
def run_raw(self, text: str) -> str: def run_raw(self, text: str) -> str:
matched_texts: List[str] = self.retrieving_pipeline(text) matched_texts: List[str] = self.retrieving_pipeline(text)
return self.llm("\n".join(matched_texts)).text[0] return self.llm("\n".join(matched_texts)).text

View File

@ -1,7 +1,7 @@
from haystack.schema import Document as HaystackDocument
from kotaemon.documents.base import Document, RetrievedDocument from kotaemon.documents.base import Document, RetrievedDocument
from .conftest import skip_when_haystack_not_installed
def test_document_constructor_with_builtin_types(): def test_document_constructor_with_builtin_types():
for value in ["str", 1, {}, set(), [], tuple, None]: for value in ["str", 1, {}, set(), [], tuple, None]:
@ -19,7 +19,10 @@ def test_document_constructor_with_document():
assert doc2.content == doc1.content assert doc2.content == doc1.content
@skip_when_haystack_not_installed
def test_document_to_haystack_format(): def test_document_to_haystack_format():
from haystack.schema import Document as HaystackDocument
text = "Sample text" text = "Sample text"
metadata = {"filename": "sample.txt"} metadata = {"filename": "sample.txt"}
doc = Document(text, metadata=metadata) doc = Document(text, metadata=metadata)

View File

@ -16,7 +16,6 @@ class TestPromptConfig:
assert "text" in config["inputs"], "inputs should have config" assert "text" in config["inputs"], "inputs should have config"
assert "params" in config, "params should be in config" assert "params" in config, "params should be in config"
assert "vectorstore_path" in config["params"]
assert "llm.deployment_name" in config["params"] assert "llm.deployment_name" in config["params"]
assert "llm.openai_api_base" in config["params"] assert "llm.openai_api_base" in config["params"]
assert "llm.openai_api_key" in config["params"] assert "llm.openai_api_key" in config["params"]

View File

@ -42,8 +42,9 @@ def mock_openai_embedding(monkeypatch):
) )
def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path): def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
indexing_pipeline = ReaderIndexingPipeline( indexing_pipeline = ReaderIndexingPipeline(
storage=tmp_path, openai_api_key="some-key" storage_path=tmp_path,
) )
indexing_pipeline.embedding.openai_api_key = "some-key"
input_file_path = Path(__file__).parent / "resources/dummy.pdf" input_file_path = Path(__file__).parent / "resources/dummy.pdf"
# call ingestion pipeline # call ingestion pipeline

View File

@ -3,7 +3,7 @@ from pathlib import Path
from langchain.schema import Document as LangchainDocument from langchain.schema import Document as LangchainDocument
from llama_index.node_parser import SimpleNodeParser from llama_index.node_parser import SimpleNodeParser
from kotaemon.documents.base import Document, HaystackDocument from kotaemon.documents.base import Document
from kotaemon.loaders import AutoReader from kotaemon.loaders import AutoReader
@ -19,10 +19,6 @@ def test_pdf_reader():
assert isinstance(first_doc, Document) assert isinstance(first_doc, Document)
assert first_doc.text.lower().replace(" ", "") == "dummypdffile" assert first_doc.text.lower().replace(" ", "") == "dummypdffile"
# check conversion output
haystack_doc = first_doc.to_haystack_format()
assert isinstance(haystack_doc, HaystackDocument)
langchain_doc = first_doc.to_langchain_format() langchain_doc = first_doc.to_langchain_format()
assert isinstance(langchain_doc, LangchainDocument) assert isinstance(langchain_doc, LangchainDocument)

View File

@ -3,6 +3,8 @@ import sys
import pytest import pytest
from .conftest import skip_when_haystack_not_installed
@pytest.fixture @pytest.fixture
def clean_artifacts_for_telemetry(): def clean_artifacts_for_telemetry():
@ -26,6 +28,7 @@ def clean_artifacts_for_telemetry():
@pytest.mark.usefixtures("clean_artifacts_for_telemetry") @pytest.mark.usefixtures("clean_artifacts_for_telemetry")
@skip_when_haystack_not_installed
def test_disable_telemetry_import_haystack_first(): def test_disable_telemetry_import_haystack_first():
"""Test that telemetry is disabled when kotaemon lib is initiated after""" """Test that telemetry is disabled when kotaemon lib is initiated after"""
import os import os
@ -42,6 +45,7 @@ def test_disable_telemetry_import_haystack_first():
@pytest.mark.usefixtures("clean_artifacts_for_telemetry") @pytest.mark.usefixtures("clean_artifacts_for_telemetry")
@skip_when_haystack_not_installed
def test_disable_telemetry_import_haystack_after_kotaemon(): def test_disable_telemetry_import_haystack_after_kotaemon():
"""Test that telemetry is disabled when kotaemon lib is initiated before""" """Test that telemetry is disabled when kotaemon lib is initiated before"""
import os import os