feat: add web search (#580) bump:patch
* feat: add web search * feat: update requirements
This commit is contained in:
parent
4fe080737a
commit
95191f53d9
|
@ -81,6 +81,10 @@ KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
|
|||
KH_ENABLE_ALEMBIC = False
|
||||
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
|
||||
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
|
||||
KH_WEB_SEARCH_BACKEND = (
|
||||
"kotaemon.indices.retrievers.tavily_web_search.WebSearch"
|
||||
# "kotaemon.indices.retrievers.jina_web_search.WebSearch"
|
||||
)
|
||||
|
||||
KH_DOCSTORE = {
|
||||
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore",
|
||||
|
|
60
libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py
Normal file
60
libs/kotaemon/kotaemon/indices/retrievers/jina_web_search.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import requests
|
||||
from decouple import config
|
||||
|
||||
from kotaemon.base import BaseComponent, RetrievedDocument
|
||||
|
||||
JINA_API_KEY = config("JINA_API_KEY", default="")
|
||||
JINA_URL = config("JINA_URL", default="https://r.jina.ai/")
|
||||
|
||||
|
||||
class WebSearch(BaseComponent):
|
||||
"""WebSearch component for fetching data from the web
|
||||
using Jina API
|
||||
"""
|
||||
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[RetrievedDocument]:
|
||||
if JINA_API_KEY == "":
|
||||
raise ValueError(
|
||||
"This feature requires JINA_API_KEY "
|
||||
"(get free one from https://jina.ai/reader)"
|
||||
)
|
||||
|
||||
# setup the request
|
||||
api_url = f"https://s.jina.ai/{text}"
|
||||
headers = {"X-With-Generated-Alt": "true", "Accept": "application/json"}
|
||||
if JINA_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
response_dict = response.json()
|
||||
|
||||
return [
|
||||
RetrievedDocument(
|
||||
text=(
|
||||
"###URL: [{url}]({url})\n\n"
|
||||
"####{title}\n\n"
|
||||
"{description}\n"
|
||||
"{content}"
|
||||
).format(
|
||||
url=item["url"],
|
||||
title=item["title"],
|
||||
description=item["description"],
|
||||
content=item["content"],
|
||||
),
|
||||
metadata={
|
||||
"file_name": "Web search",
|
||||
"type": "table",
|
||||
"llm_trulens_score": 1.0,
|
||||
},
|
||||
)
|
||||
for item in response_dict["data"]
|
||||
]
|
||||
|
||||
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
|
||||
return documents
|
|
@ -0,0 +1,57 @@
|
|||
from decouple import config
|
||||
|
||||
from kotaemon.base import BaseComponent, RetrievedDocument
|
||||
|
||||
TAVILY_API_KEY = config("TAVILY_API_KEY", default="")
|
||||
|
||||
|
||||
class WebSearch(BaseComponent):
|
||||
"""WebSearch component for fetching data from the web
|
||||
using Jina API
|
||||
"""
|
||||
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[RetrievedDocument]:
|
||||
if TAVILY_API_KEY == "":
|
||||
raise ValueError(
|
||||
"This feature requires TAVILY_API_KEY "
|
||||
"(get free one from https://app.tavily.com/)"
|
||||
)
|
||||
|
||||
try:
|
||||
from tavily import TavilyClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install `pip install tavily-python` to use this feature"
|
||||
)
|
||||
|
||||
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||
results = tavily_client.search(
|
||||
query=text,
|
||||
search_depth="advanced",
|
||||
)["results"]
|
||||
context = "\n\n".join(
|
||||
"###URL: [{url}]({url})\n\n{content}".format(
|
||||
url=result["url"],
|
||||
content=result["content"],
|
||||
)
|
||||
for result in results
|
||||
)
|
||||
|
||||
return [
|
||||
RetrievedDocument(
|
||||
text=context,
|
||||
metadata={
|
||||
"file_name": "Web search",
|
||||
"type": "table",
|
||||
"llm_trulens_score": 1.0,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
|
||||
return documents
|
|
@ -55,6 +55,7 @@ dependencies = [
|
|||
"theflow>=0.8.6,<0.9.0",
|
||||
"trogon>=0.5.0,<0.6",
|
||||
"umap-learn==0.5.5",
|
||||
"tavily-python>=0.4.0",
|
||||
]
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
|
|
@ -19,6 +19,8 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from ...utils.commands import WEB_SEARCH_COMMAND
|
||||
|
||||
DOWNLOAD_MESSAGE = "Press again to download"
|
||||
MAX_FILENAME_LENGTH = 20
|
||||
|
||||
|
@ -38,6 +40,13 @@ function(file_list) {
|
|||
value: '"' + file_list[i][0] + '"',
|
||||
});
|
||||
}
|
||||
|
||||
// manually push web search tag
|
||||
values.push({
|
||||
key: "web_search",
|
||||
value: '"web_search"',
|
||||
});
|
||||
|
||||
var tribute = new Tribute({
|
||||
values: values,
|
||||
noMatchTemplate: "",
|
||||
|
@ -46,7 +55,9 @@ function(file_list) {
|
|||
input_box = document.querySelector('#chat-input textarea');
|
||||
tribute.attach(input_box);
|
||||
}
|
||||
"""
|
||||
""".replace(
|
||||
"web_search", WEB_SEARCH_COMMAND
|
||||
)
|
||||
|
||||
|
||||
class File(gr.File):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
@ -23,11 +24,22 @@ from kotaemon.base import Document
|
|||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||
|
||||
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
|
||||
from ...utils.commands import WEB_SEARCH_COMMAND
|
||||
from .chat_panel import ChatPanel
|
||||
from .common import STATE
|
||||
from .control import ConversationControl
|
||||
from .report import ReportIssue
|
||||
|
||||
KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None)
|
||||
WebSearch = None
|
||||
if KH_WEB_SEARCH_BACKEND:
|
||||
try:
|
||||
module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
WebSearch = getattr(module, class_name)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Error importing {class_name} from {module_name}: {e}")
|
||||
|
||||
DEFAULT_SETTING = "(default)"
|
||||
INFO_PANEL_SCALES = {True: 8, False: 4}
|
||||
|
||||
|
@ -113,6 +125,7 @@ class ChatPage(BasePage):
|
|||
value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False)
|
||||
)
|
||||
self._info_panel_expanded = gr.State(value=True)
|
||||
self._command_state = gr.State(value=None)
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Row():
|
||||
|
@ -299,6 +312,7 @@ class ChatPage(BasePage):
|
|||
# file selector from the first index
|
||||
self._indices_input[0],
|
||||
self._indices_input[1],
|
||||
self._command_state,
|
||||
],
|
||||
concurrency_limit=20,
|
||||
show_progress="hidden",
|
||||
|
@ -315,6 +329,7 @@ class ChatPage(BasePage):
|
|||
self.citation,
|
||||
self.language,
|
||||
self.state_chat,
|
||||
self._command_state,
|
||||
self._app.user_id,
|
||||
]
|
||||
+ self._indices_input,
|
||||
|
@ -647,6 +662,7 @@ class ChatPage(BasePage):
|
|||
|
||||
chat_input_text = chat_input.get("text", "")
|
||||
file_ids = []
|
||||
used_command = None
|
||||
|
||||
first_selector_choices_map = {
|
||||
item[0]: item[1] for item in first_selector_choices
|
||||
|
@ -654,6 +670,11 @@ class ChatPage(BasePage):
|
|||
|
||||
# get all file names with pattern @"filename" in input_str
|
||||
file_names, chat_input_text = get_file_names_regex(chat_input_text)
|
||||
|
||||
# check if web search command is in file_names
|
||||
if WEB_SEARCH_COMMAND in file_names:
|
||||
used_command = WEB_SEARCH_COMMAND
|
||||
|
||||
# get all urls in input_str
|
||||
urls, chat_input_text = get_urls(chat_input_text)
|
||||
|
||||
|
@ -707,13 +728,17 @@ class ChatPage(BasePage):
|
|||
conv_update = gr.update()
|
||||
new_conv_name = conv_name
|
||||
|
||||
return [
|
||||
return (
|
||||
[
|
||||
{},
|
||||
chat_history,
|
||||
new_conv_id,
|
||||
conv_update,
|
||||
new_conv_name,
|
||||
] + selector_output
|
||||
]
|
||||
+ selector_output
|
||||
+ [used_command]
|
||||
)
|
||||
|
||||
def toggle_delete(self, conv_id):
|
||||
if conv_id:
|
||||
|
@ -877,6 +902,7 @@ class ChatPage(BasePage):
|
|||
session_use_citation: str,
|
||||
session_language: str,
|
||||
state: dict,
|
||||
command_state: str | None,
|
||||
user_id: int,
|
||||
*selecteds,
|
||||
):
|
||||
|
@ -934,6 +960,15 @@ class ChatPage(BasePage):
|
|||
|
||||
# get retrievers
|
||||
retrievers = []
|
||||
|
||||
if command_state == WEB_SEARCH_COMMAND:
|
||||
# set retriever for web search
|
||||
if not WebSearch:
|
||||
raise ValueError("Web search back-end is not available.")
|
||||
|
||||
web_search = WebSearch()
|
||||
retrievers.append(web_search)
|
||||
else:
|
||||
for index in self._app.index_manager.indices:
|
||||
index_selected = []
|
||||
if isinstance(index.selector, int):
|
||||
|
@ -966,7 +1001,8 @@ class ChatPage(BasePage):
|
|||
use_mind_map,
|
||||
use_citation,
|
||||
language,
|
||||
state,
|
||||
chat_state,
|
||||
command_state,
|
||||
user_id,
|
||||
*selecteds,
|
||||
):
|
||||
|
@ -976,7 +1012,7 @@ class ChatPage(BasePage):
|
|||
|
||||
# if chat_input is empty, assume regen mode
|
||||
if chat_output:
|
||||
state["app"]["regen"] = True
|
||||
chat_state["app"]["regen"] = True
|
||||
|
||||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||
|
||||
|
@ -988,7 +1024,8 @@ class ChatPage(BasePage):
|
|||
use_mind_map,
|
||||
use_citation,
|
||||
language,
|
||||
state,
|
||||
chat_state,
|
||||
command_state,
|
||||
user_id,
|
||||
*selecteds,
|
||||
)
|
||||
|
@ -1005,7 +1042,7 @@ class ChatPage(BasePage):
|
|||
refs,
|
||||
plot_gr,
|
||||
plot,
|
||||
state,
|
||||
chat_state,
|
||||
)
|
||||
|
||||
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
||||
|
@ -1032,14 +1069,14 @@ class ChatPage(BasePage):
|
|||
plot = response.content
|
||||
plot_gr = self._json_to_plot(plot)
|
||||
|
||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
chat_state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
|
||||
yield (
|
||||
chat_history + [(chat_input, text or msg_placeholder)],
|
||||
refs,
|
||||
plot_gr,
|
||||
plot,
|
||||
state,
|
||||
chat_state,
|
||||
)
|
||||
|
||||
if not text:
|
||||
|
@ -1052,7 +1089,7 @@ class ChatPage(BasePage):
|
|||
refs,
|
||||
plot_gr,
|
||||
plot,
|
||||
state,
|
||||
chat_state,
|
||||
)
|
||||
|
||||
def check_and_suggest_name_conv(self, chat_history):
|
||||
|
|
|
@ -25,7 +25,9 @@ class ChatPanel(BasePage):
|
|||
interactive=True,
|
||||
scale=20,
|
||||
file_count="multiple",
|
||||
placeholder="Type a message (or tag a file with @filename)",
|
||||
placeholder=(
|
||||
"Type a message, or search the @web, " "tag a file with @filename"
|
||||
),
|
||||
container=False,
|
||||
show_label=False,
|
||||
elem_id="chat-input",
|
||||
|
|
1
libs/ktem/ktem/utils/commands.py
Normal file
1
libs/ktem/ktem/utils/commands.py
Normal file
|
@ -0,0 +1 @@
|
|||
WEB_SEARCH_COMMAND = "web"
|
|
@ -59,6 +59,17 @@ class Render:
|
|||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def table_preserve_linebreaks(text: str) -> str:
|
||||
"""Render table from markdown format into HTML"""
|
||||
return markdown.markdown(
|
||||
text,
|
||||
extensions=[
|
||||
"markdown.extensions.tables",
|
||||
"markdown.extensions.fenced_code",
|
||||
],
|
||||
).replace("\n", "<br>")
|
||||
|
||||
@staticmethod
|
||||
def preview(
|
||||
html_content: str,
|
||||
|
@ -134,6 +145,8 @@ class Render:
|
|||
header = f"<i>{get_header(doc)}</i>"
|
||||
if doc.metadata.get("type", "") == "image":
|
||||
doc_content = Render.image(url=doc.metadata["image_origin"], text=doc.text)
|
||||
elif doc.metadata.get("type", "") == "table_raw":
|
||||
doc_content = Render.table_preserve_linebreaks(doc.text)
|
||||
else:
|
||||
doc_content = Render.table(doc.text)
|
||||
|
||||
|
@ -174,6 +187,9 @@ class Render:
|
|||
if item_type_prefix:
|
||||
item_type_prefix += " from "
|
||||
|
||||
if "raw" in item_type_prefix:
|
||||
item_type_prefix = ""
|
||||
|
||||
if llm_reranking_score > 0:
|
||||
relevant_score = llm_reranking_score
|
||||
elif reranking_score > 0:
|
||||
|
@ -198,6 +214,8 @@ class Render:
|
|||
url=doc.metadata["image_origin"],
|
||||
text=text,
|
||||
)
|
||||
elif doc.metadata.get("type", "") == "table_raw":
|
||||
rendered_doc_content = Render.table_preserve_linebreaks(doc.text)
|
||||
else:
|
||||
rendered_doc_content = Render.table(text)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user