feat: add URL indexing directly from chat input (#571) bump:patch
* feat: enable lightrag by default and add graphrag key check * feat: add URL indexing from chatbox
This commit is contained in:
parent
a0c9a6e8de
commit
7a02cb72af
|
@ -296,7 +296,7 @@ SETTINGS_REASONING = {
|
||||||
}
|
}
|
||||||
|
|
||||||
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
|
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
|
||||||
USE_LIGHTRAG = config("USE_LIGHTRAG", default=False, cast=bool)
|
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
|
||||||
|
|
||||||
GRAPHRAG_INDEX_TYPES = ["ktem.index.file.graph.GraphRAGIndex"]
|
GRAPHRAG_INDEX_TYPES = ["ktem.index.file.graph.GraphRAGIndex"]
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ logging.getLogger("lightrag").setLevel(logging.INFO)
|
||||||
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "lightrag"
|
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "lightrag"
|
||||||
filestorage_path.mkdir(parents=True, exist_ok=True)
|
filestorage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
INDEX_BATCHSIZE = 2
|
INDEX_BATCHSIZE = 4
|
||||||
|
|
||||||
|
|
||||||
def get_llm_func(model):
|
def get_llm_func(model):
|
||||||
|
@ -268,7 +268,9 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
|
|
||||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
||||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
||||||
graphrag_func.insert(cur_docs)
|
combined_doc = "\n".join(cur_docs)
|
||||||
|
|
||||||
|
graphrag_func.insert(combined_doc)
|
||||||
process_doc_count += len(cur_docs)
|
process_doc_count += len(cur_docs)
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
|
|
|
@ -263,7 +263,9 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
)
|
)
|
||||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
||||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
||||||
graphrag_func.insert(cur_docs)
|
combined_doc = "\n".join(cur_docs)
|
||||||
|
|
||||||
|
graphrag_func.insert(combined_doc)
|
||||||
process_doc_count += len(cur_docs)
|
process_doc_count += len(cur_docs)
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
|
|
|
@ -47,6 +47,14 @@ except ImportError:
|
||||||
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "graphrag"
|
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "graphrag"
|
||||||
filestorage_path.mkdir(parents=True, exist_ok=True)
|
filestorage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
GRAPHRAG_KEY_MISSING_MESSAGE = (
|
||||||
|
"GRAPHRAG_API_KEY is not set. Please set it to use the GraphRAG retriever pipeline."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_graphrag_api_key():
|
||||||
|
return len(os.getenv("GRAPHRAG_API_KEY", "")) > 0
|
||||||
|
|
||||||
|
|
||||||
def prepare_graph_index_path(graph_id: str):
|
def prepare_graph_index_path(graph_id: str):
|
||||||
root_path = Path(filestorage_path) / graph_id
|
root_path = Path(filestorage_path) / graph_id
|
||||||
|
@ -99,6 +107,9 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
|
||||||
return root_path
|
return root_path
|
||||||
|
|
||||||
def call_graphrag_index(self, graph_id: str, all_docs: list[Document]):
|
def call_graphrag_index(self, graph_id: str, all_docs: list[Document]):
|
||||||
|
if not check_graphrag_api_key():
|
||||||
|
raise ValueError(GRAPHRAG_KEY_MISSING_MESSAGE)
|
||||||
|
|
||||||
# call GraphRAG index with docs and graph_id
|
# call GraphRAG index with docs and graph_id
|
||||||
input_path = self.write_docs_to_files(graph_id, all_docs)
|
input_path = self.write_docs_to_files(graph_id, all_docs)
|
||||||
input_path = str(input_path.absolute())
|
input_path = str(input_path.absolute())
|
||||||
|
@ -346,6 +357,10 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
||||||
) -> list[RetrievedDocument]:
|
) -> list[RetrievedDocument]:
|
||||||
if not self.file_ids:
|
if not self.file_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if not check_graphrag_api_key():
|
||||||
|
raise ValueError(GRAPHRAG_KEY_MISSING_MESSAGE)
|
||||||
|
|
||||||
context_builder = self._build_graph_search()
|
context_builder = self._build_graph_search()
|
||||||
|
|
||||||
local_context_params = {
|
local_context_params = {
|
||||||
|
|
|
@ -683,6 +683,11 @@ class FileIndexPage(BasePage):
|
||||||
if self._index.id == 1:
|
if self._index.id == 1:
|
||||||
self.quick_upload_state = gr.State(value=[])
|
self.quick_upload_state = gr.State(value=[])
|
||||||
print("Setting up quick upload event")
|
print("Setting up quick upload event")
|
||||||
|
|
||||||
|
# override indexing function from chat page
|
||||||
|
self._app.chat_page.first_indexing_url_fn = (
|
||||||
|
self.index_fn_url_with_default_loaders
|
||||||
|
)
|
||||||
quickUploadedEvent = (
|
quickUploadedEvent = (
|
||||||
self._app.chat_page.quick_file_upload.upload(
|
self._app.chat_page.quick_file_upload.upload(
|
||||||
fn=lambda: gr.update(
|
fn=lambda: gr.update(
|
||||||
|
|
|
@ -22,7 +22,7 @@ from theflow.settings import settings as flowsettings
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||||
|
|
||||||
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex
|
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
|
||||||
from .chat_panel import ChatPanel
|
from .chat_panel import ChatPanel
|
||||||
from .common import STATE
|
from .common import STATE
|
||||||
from .control import ConversationControl
|
from .control import ConversationControl
|
||||||
|
@ -140,6 +140,7 @@ class ChatPage(BasePage):
|
||||||
# get the file selector choices for the first index
|
# get the file selector choices for the first index
|
||||||
if index_id == 0:
|
if index_id == 0:
|
||||||
self.first_selector_choices = index_ui.selector_choices
|
self.first_selector_choices = index_ui.selector_choices
|
||||||
|
self.first_indexing_url_fn = None
|
||||||
|
|
||||||
if gr_index:
|
if gr_index:
|
||||||
if isinstance(gr_index, list):
|
if isinstance(gr_index, list):
|
||||||
|
@ -284,6 +285,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_panel.text_input,
|
self.chat_panel.text_input,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
|
self._app.settings_state,
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.first_selector_choices,
|
self.first_selector_choices,
|
||||||
|
@ -634,6 +636,7 @@ class ChatPage(BasePage):
|
||||||
chat_input,
|
chat_input,
|
||||||
chat_history,
|
chat_history,
|
||||||
user_id,
|
user_id,
|
||||||
|
settings,
|
||||||
conv_id,
|
conv_id,
|
||||||
conv_name,
|
conv_name,
|
||||||
first_selector_choices,
|
first_selector_choices,
|
||||||
|
@ -643,22 +646,44 @@ class ChatPage(BasePage):
|
||||||
raise ValueError("Input is empty")
|
raise ValueError("Input is empty")
|
||||||
|
|
||||||
chat_input_text = chat_input.get("text", "")
|
chat_input_text = chat_input.get("text", "")
|
||||||
|
file_ids = []
|
||||||
|
|
||||||
# get all file names with pattern @"filename" in input_str
|
|
||||||
file_names, chat_input_text = get_file_names_regex(chat_input_text)
|
|
||||||
first_selector_choices_map = {
|
first_selector_choices_map = {
|
||||||
item[0]: item[1] for item in first_selector_choices
|
item[0]: item[1] for item in first_selector_choices
|
||||||
}
|
}
|
||||||
file_ids = []
|
|
||||||
|
|
||||||
if file_names:
|
# get all file names with pattern @"filename" in input_str
|
||||||
|
file_names, chat_input_text = get_file_names_regex(chat_input_text)
|
||||||
|
# get all urls in input_str
|
||||||
|
urls, chat_input_text = get_urls(chat_input_text)
|
||||||
|
|
||||||
|
if urls and self.first_indexing_url_fn:
|
||||||
|
print("Detected URLs", urls)
|
||||||
|
file_ids = self.first_indexing_url_fn(
|
||||||
|
"\n".join(urls),
|
||||||
|
True,
|
||||||
|
settings,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
elif file_names:
|
||||||
for file_name in file_names:
|
for file_name in file_names:
|
||||||
file_id = first_selector_choices_map.get(file_name)
|
file_id = first_selector_choices_map.get(file_name)
|
||||||
if file_id:
|
if file_id:
|
||||||
file_ids.append(file_id)
|
file_ids.append(file_id)
|
||||||
|
|
||||||
|
# add new file ids to the first selector choices
|
||||||
|
first_selector_choices.extend(zip(urls, file_ids))
|
||||||
|
|
||||||
|
# if file_ids is not empty and chat_input_text is empty
|
||||||
|
# set the input to summary
|
||||||
|
if not chat_input_text and file_ids:
|
||||||
|
chat_input_text = "Summary"
|
||||||
|
|
||||||
if file_ids:
|
if file_ids:
|
||||||
selector_output = ["select", file_ids]
|
selector_output = [
|
||||||
|
"select",
|
||||||
|
gr.update(value=file_ids, choices=first_selector_choices),
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
selector_output = [gr.update(), gr.update()]
|
selector_output = [gr.update(), gr.update()]
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .conversation import get_file_names_regex
|
from .conversation import get_file_names_regex, get_urls
|
||||||
from .lang import SUPPORTED_LANGUAGE_MAP
|
from .lang import SUPPORTED_LANGUAGE_MAP
|
||||||
|
|
||||||
__all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex"]
|
__all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex", "get_urls"]
|
||||||
|
|
|
@ -29,5 +29,15 @@ def get_file_names_regex(input_str: str) -> tuple[list[str], str]:
|
||||||
return matches, input_str
|
return matches, input_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_urls(input_str: str) -> tuple[list[str], str]:
|
||||||
|
# get all urls in input_str
|
||||||
|
# also remove these urls from input_str
|
||||||
|
pattern = r"https?://[^\s]+"
|
||||||
|
matches = re.findall(pattern, input_str)
|
||||||
|
input_str = re.sub(pattern, "", input_str).strip()
|
||||||
|
|
||||||
|
return matches, input_str
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(sync_retrieval_n_message([[""], [""], [""]], []))
|
print(sync_retrieval_n_message([[""], [""], [""]], []))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user