From 95191f53d9ee54d40bc78885a6bf706b29ea4a9a Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Mon, 23 Dec 2024 09:28:24 +0700 Subject: [PATCH] feat: add web search (#580) bump:patch * feat: add web search * feat: update requirements --- flowsettings.py | 4 + .../kotaemon/indices/retrievers/__init__.py | 0 .../indices/retrievers/jina_web_search.py | 60 +++++++++++++ .../indices/retrievers/tavily_web_search.py | 57 ++++++++++++ libs/kotaemon/pyproject.toml | 1 + libs/ktem/ktem/index/file/ui.py | 13 ++- libs/ktem/ktem/pages/chat/__init__.py | 87 +++++++++++++------ libs/ktem/ktem/pages/chat/chat_panel.py | 4 +- libs/ktem/ktem/utils/commands.py | 1 + libs/ktem/ktem/utils/render.py | 18 ++++ 10 files changed, 218 insertions(+), 27 deletions(-) create mode 100644 libs/kotaemon/kotaemon/indices/retrievers/__init__.py create mode 100644 libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py create mode 100644 libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py create mode 100644 libs/ktem/ktem/utils/commands.py diff --git a/flowsettings.py b/flowsettings.py index 0647b71..0962eef 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -81,6 +81,10 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str( KH_ENABLE_ALEMBIC = False KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}" KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files") +KH_WEB_SEARCH_BACKEND = ( + "kotaemon.indices.retrievers.tavily_web_search.WebSearch" + # "kotaemon.indices.retrievers.jina_web_search.WebSearch" +) KH_DOCSTORE = { # "__type__": "kotaemon.storages.ElasticsearchDocumentStore", diff --git a/libs/kotaemon/kotaemon/indices/retrievers/__init__.py b/libs/kotaemon/kotaemon/indices/retrievers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py b/libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py new file mode 100644 index 0000000..48fa60f --- /dev/null +++ b/libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py @@ -0,0 +1,60 @@ +import requests +from decouple import config + +from kotaemon.base import BaseComponent, RetrievedDocument + +JINA_API_KEY = config("JINA_API_KEY", default="") +JINA_URL = config("JINA_URL", default="https://r.jina.ai/") + + +class WebSearch(BaseComponent): + """WebSearch component for fetching data from the web + using Jina API + """ + + def run( + self, + text: str, + *args, + **kwargs, + ) -> list[RetrievedDocument]: + if JINA_API_KEY == "": + raise ValueError( + "This feature requires JINA_API_KEY " + "(get free one from https://jina.ai/reader)" + ) + + # setup the request + api_url = f"https://s.jina.ai/{text}" + headers = {"X-With-Generated-Alt": "true", "Accept": "application/json"} + if JINA_API_KEY: + headers["Authorization"] = f"Bearer {JINA_API_KEY}" + + response = requests.get(api_url, headers=headers) + response.raise_for_status() + response_dict = response.json() + + return [ + RetrievedDocument( + text=( + "###URL: [{url}]({url})\n\n" + "####{title}\n\n" + "{description}\n" + "{content}" + ).format( + url=item["url"], + title=item["title"], + description=item["description"], + content=item["content"], + ), + metadata={ + "file_name": "Web search", + "type": "table", + "llm_trulens_score": 1.0, + }, + ) + for item in response_dict["data"] + ] + + def generate_relevant_scores(self, text, documents: list[RetrievedDocument]): + return documents diff --git a/libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py b/libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py new file mode 100644 index 0000000..f6087d5 --- /dev/null +++ b/libs/kotaemon/kotaemon/indices/retrievers/tavily_web_search.py @@ -0,0 +1,57 @@ +from decouple import config + +from kotaemon.base import BaseComponent, RetrievedDocument + +TAVILY_API_KEY = config("TAVILY_API_KEY", default="") + + +class WebSearch(BaseComponent): + """WebSearch component for fetching data from the web + using Jina API + """ + + def run( + self, + text: str, + *args, + **kwargs, + ) -> list[RetrievedDocument]: + if TAVILY_API_KEY == "": + raise ValueError( + "This feature requires TAVILY_API_KEY " + "(get free one from https://app.tavily.com/)" + ) + + try: + from tavily import TavilyClient + except ImportError: + raise ImportError( + "Please install `pip install tavily-python` to use this feature" + ) + + tavily_client = TavilyClient(api_key=TAVILY_API_KEY) + results = tavily_client.search( + query=text, + search_depth="advanced", + )["results"] + context = "\n\n".join( + "###URL: [{url}]({url})\n\n{content}".format( + url=result["url"], + content=result["content"], + ) + for result in results + ) + + return [ + RetrievedDocument( + text=context, + metadata={ + "file_name": "Web search", + "type": "table", + "llm_trulens_score": 1.0, + }, + ) + ] + + def generate_relevant_scores(self, text, documents: list[RetrievedDocument]): + return documents diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index ee5c468..2e60886 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "theflow>=0.8.6,<0.9.0", "trogon>=0.5.0,<0.6", "umap-learn==0.5.5", + "tavily-python>=0.4.0", ] readme = "README.md" authors = [ diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 417391c..aa2da36 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -19,6 +19,8 @@ from sqlalchemy import select from sqlalchemy.orm import Session from theflow.settings import settings as flowsettings +from ...utils.commands import WEB_SEARCH_COMMAND + DOWNLOAD_MESSAGE = "Press again to download" MAX_FILENAME_LENGTH = 20 @@ -38,6 +40,13 @@ function(file_list) { value: '"' + file_list[i][0] + '"', }); } + + // manually push web search tag + values.push({ + key: "web_search", + value: '"web_search"', + }); + var tribute = new Tribute({ values: values, noMatchTemplate: "", @@ -46,7 +55,9 @@ function(file_list) { input_box = document.querySelector('#chat-input textarea'); tribute.attach(input_box); } -""" +""".replace( + "web_search", WEB_SEARCH_COMMAND +) class File(gr.File): diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 8aec594..7696027 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -1,4 +1,5 @@ import asyncio +import importlib import json import re from copy import deepcopy @@ -23,11 +24,22 @@ from kotaemon.base import Document from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls +from ...utils.commands import WEB_SEARCH_COMMAND from .chat_panel import ChatPanel from .common import STATE from .control import ConversationControl from .report import ReportIssue +KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None) +WebSearch = None +if KH_WEB_SEARCH_BACKEND: + try: + module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1) + module = importlib.import_module(module_name) + WebSearch = getattr(module, class_name) + except (ImportError, AttributeError) as e: + print(f"Error importing {class_name} from {module_name}: {e}") + DEFAULT_SETTING = "(default)" INFO_PANEL_SCALES = {True: 8, False: 4} @@ -113,6 +125,7 @@ class ChatPage(BasePage): value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False) ) self._info_panel_expanded = gr.State(value=True) + self._command_state = gr.State(value=None) def on_building_ui(self): with gr.Row(): @@ -299,6 +312,7 @@ class ChatPage(BasePage): # file selector from the first index self._indices_input[0], self._indices_input[1], + self._command_state, ], concurrency_limit=20, show_progress="hidden", @@ -315,6 +329,7 @@ class ChatPage(BasePage): self.citation, self.language, self.state_chat, + self._command_state, self._app.user_id, ] + self._indices_input, @@ -647,6 +662,7 @@ class ChatPage(BasePage): chat_input_text = chat_input.get("text", "") file_ids = [] + used_command = None first_selector_choices_map = { item[0]: item[1] for item in first_selector_choices @@ -654,6 +670,11 @@ class ChatPage(BasePage): # get all file names with pattern @"filename" in input_str file_names, chat_input_text = get_file_names_regex(chat_input_text) + + # check if web search command is in file_names + if WEB_SEARCH_COMMAND in file_names: + used_command = WEB_SEARCH_COMMAND + # get all urls in input_str urls, chat_input_text = get_urls(chat_input_text) @@ -707,13 +728,17 @@ class ChatPage(BasePage): conv_update = gr.update() new_conv_name = conv_name - return [ - {}, - chat_history, - new_conv_id, - conv_update, - new_conv_name, - ] + selector_output + return ( + [ + {}, + chat_history, + new_conv_id, + conv_update, + new_conv_name, + ] + + selector_output + + [used_command] + ) def toggle_delete(self, conv_id): if conv_id: @@ -877,6 +902,7 @@ class ChatPage(BasePage): session_use_citation: str, session_language: str, state: dict, + command_state: str | None, user_id: int, *selecteds, ): @@ -934,17 +960,26 @@ class ChatPage(BasePage): # get retrievers retrievers = [] - for index in self._app.index_manager.indices: - index_selected = [] - if isinstance(index.selector, int): - index_selected = selecteds[index.selector] - if isinstance(index.selector, tuple): - for i in index.selector: - index_selected.append(selecteds[i]) - iretrievers = index.get_retriever_pipelines( - settings, user_id, index_selected - ) - retrievers += iretrievers + + if command_state == WEB_SEARCH_COMMAND: + # set retriever for web search + if not WebSearch: + raise ValueError("Web search back-end is not available.") + + web_search = WebSearch() + retrievers.append(web_search) + else: + for index in self._app.index_manager.indices: + index_selected = [] + if isinstance(index.selector, int): + index_selected = selecteds[index.selector] + if isinstance(index.selector, tuple): + for i in index.selector: + index_selected.append(selecteds[i]) + iretrievers = index.get_retriever_pipelines( + settings, user_id, index_selected + ) + retrievers += iretrievers # prepare states reasoning_state = { @@ -966,7 +1001,8 @@ class ChatPage(BasePage): use_mind_map, use_citation, language, - state, + chat_state, + command_state, user_id, *selecteds, ): @@ -976,7 +1012,7 @@ class ChatPage(BasePage): # if chat_input is empty, assume regen mode if chat_output: - state["app"]["regen"] = True + chat_state["app"]["regen"] = True queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() @@ -988,7 +1024,8 @@ class ChatPage(BasePage): use_mind_map, use_citation, language, - state, + chat_state, + command_state, user_id, *selecteds, ) @@ -1005,7 +1042,7 @@ class ChatPage(BasePage): refs, plot_gr, plot, - state, + chat_state, ) for response in pipeline.stream(chat_input, conversation_id, chat_history): @@ -1032,14 +1069,14 @@ class ChatPage(BasePage): plot = response.content plot_gr = self._json_to_plot(plot) - state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + chat_state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] yield ( chat_history + [(chat_input, text or msg_placeholder)], refs, plot_gr, plot, - state, + chat_state, ) if not text: @@ -1052,7 +1089,7 @@ class ChatPage(BasePage): refs, plot_gr, plot, - state, + chat_state, ) def check_and_suggest_name_conv(self, chat_history): diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index 2adc52f..3db13ed 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -25,7 +25,9 @@ class ChatPanel(BasePage): interactive=True, scale=20, file_count="multiple", - placeholder="Type a message (or tag a file with @filename)", + placeholder=( + "Type a message, or search the @web, " "tag a file with @filename" + ), container=False, show_label=False, elem_id="chat-input", diff --git a/libs/ktem/ktem/utils/commands.py b/libs/ktem/ktem/utils/commands.py new file mode 100644 index 0000000..48f7a70 --- /dev/null +++ b/libs/ktem/ktem/utils/commands.py @@ -0,0 +1 @@ +WEB_SEARCH_COMMAND = "web" diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py index 63e1ab5..49c2f79 100644 --- a/libs/ktem/ktem/utils/render.py +++ b/libs/ktem/ktem/utils/render.py @@ -59,6 +59,17 @@ class Render: ], ) + @staticmethod + def table_preserve_linebreaks(text: str) -> str: + """Render table from markdown format into HTML""" + return markdown.markdown( + text, + extensions=[ + "markdown.extensions.tables", + "markdown.extensions.fenced_code", + ], + ).replace("\n", "
") + @staticmethod def preview( html_content: str, @@ -134,6 +145,8 @@ class Render: header = f"{get_header(doc)}" if doc.metadata.get("type", "") == "image": doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text) + elif doc.metadata.get("type", "") == "table_raw": + doc_content = Render.table_preserve_linebreaks(doc.text) else: doc_content = Render.table(doc.text) @@ -174,6 +187,9 @@ class Render: if item_type_prefix: item_type_prefix += " from " + if "raw" in item_type_prefix: + item_type_prefix = "" + if llm_reranking_score > 0: relevant_score = llm_reranking_score elif reranking_score > 0: @@ -198,6 +214,8 @@ class Render: url=doc.metadata["image_origin"], text=text, ) + elif doc.metadata.get("type", "") == "table_raw": + rendered_doc_content = Render.table_preserve_linebreaks(doc.text) else: rendered_doc_content = Render.table(text)