feat: add graphrag modes (#574) #none

* feat: add support for retrieval modes in LightRAG & NanoGraphRAG

* feat: expose custom prompts in LightRAG & NanoGraphRAG

* fix: optimize setting UI

* fix: update non local mode in LightRAG

* fix: update graphRAG mode
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-12-17 16:49:37 +07:00 committed by GitHub
parent c667bf9d0a
commit 1d3c4f4433
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 237 additions and 45 deletions

View File

@ -302,7 +302,7 @@ GRAPHRAG_INDEX_TYPES = ["ktem.index.file.graph.GraphRAGIndex"]
if USE_NANO_GRAPHRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex")
elif USE_LIGHTRAG:
if USE_LIGHTRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.LightRAGIndex")
KH_INDEX_TYPES = [

View File

@ -204,6 +204,11 @@ mark {
right: 15px;
}
/* prevent overflow of html info panel */
#html-info-panel {
overflow-x: auto !important;
}
#chat-expand-button {
position: absolute;
top: 6px;
@ -211,6 +216,12 @@ mark {
z-index: 1;
}
#save-setting-btn {
width: 150px;
height: 30px;
min-width: 100px !important;
}
#quick-setting-labels {
margin-top: 5px;
margin-bottom: -10px;

View File

@ -21,6 +21,11 @@ function run() {
let chat_column = document.getElementById("main-chat-bot");
let conv_column = document.getElementById("conv-settings-panel");
// move setting close button
let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav");
let setting_close_button = document.getElementById("save-setting-btn");
setting_tab_nav_bar.appendChild(setting_close_button);
let default_conv_column_min_width = "min(300px, 100%)";
conv_column.style.minWidth = default_conv_column_min_width

View File

@ -1,6 +1,6 @@
from typing import Any
from ..base import BaseFileIndexRetriever
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline
@ -12,14 +12,32 @@ class LightRAGIndex(GraphRAGIndex):
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
prefix = f"index.options.{self.id}."
striped_settings = {
key[len(prefix) :]: value
for key, value in settings.items()
if key.startswith(prefix)
}
# set the prompts
pipeline.prompts = striped_settings
return pipeline
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")
retrievers = [
LightRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
search_type=search_type,
)
]

View File

@ -70,7 +70,7 @@ def get_llm_func(model):
if if_cache_return is not None:
return if_cache_return["return"]
output = model(input_messages).text
output = (await model.ainvoke(input_messages)).text
print("-" * 50)
print(output, "\n", "-" * 50)
@ -220,7 +220,37 @@ def build_graphrag(working_dir, llm_func, embedding_func):
class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""
prompts: dict[str, str] = {}
@classmethod
def get_user_settings(cls) -> dict:
try:
from lightrag.prompt import PROMPTS
blacklist_keywords = ["default", "response", "process"]
return {
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower() for keyword in blacklist_keywords
)
}
except ImportError as e:
print(e)
return {}
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
from lightrag.prompt import PROMPTS
# modify the prompt if it is set in the settings
for prompt_name, content in self.prompts.items():
if prompt_name in PROMPTS:
PROMPTS[prompt_name] = content
_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
@ -302,6 +332,19 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
search_type: str = "local"
@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global", "hybrid"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}
def _build_graph_search(self):
file_id = self.file_ids[0]
@ -326,7 +369,8 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
llm_func=llm_func,
embedding_func=embedding_func,
)
query_params = QueryParam(mode="local", only_need_context=True)
print("search_type", self.search_type)
query_params = QueryParam(mode=self.search_type, only_need_context=True)
return graphrag_func, query_params
@ -381,20 +425,40 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
return []
graphrag_func, query_params = self._build_graph_search()
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)
# only local mode support graph visualization
if query_params.mode == "local":
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)
documents += [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
else:
context = graphrag_func.query(text, query_params)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
# account for missing ``` for closing code block
context += "\n```"
documents = [
RetrievedDocument(
text=context,
metadata={
"file_name": "GraphRAG {} Search".format(
query_params.mode.capitalize()
),
"type": "table",
},
)
]
return documents

View File

@ -1,6 +1,6 @@
from typing import Any
from ..base import BaseFileIndexRetriever
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline
@ -12,14 +12,32 @@ class NanoGraphRAGIndex(GraphRAGIndex):
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
prefix = f"index.options.{self.id}."
striped_settings = {
key[len(prefix) :]: value
for key, value in settings.items()
if key.startswith(prefix)
}
# set the prompts
pipeline.prompts = striped_settings
return pipeline
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")
retrievers = [
NanoGraphRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
search_type=search_type,
)
]

View File

@ -71,7 +71,7 @@ def get_llm_func(model):
if if_cache_return is not None:
return if_cache_return["return"]
output = model(input_messages).text
output = (await model.ainvoke(input_messages)).text
print("-" * 50)
print(output, "\n", "-" * 50)
@ -216,7 +216,37 @@ def build_graphrag(working_dir, llm_func, embedding_func):
class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""
prompts: dict[str, str] = {}
@classmethod
def get_user_settings(cls) -> dict:
try:
from nano_graphrag.prompt import PROMPTS
blacklist_keywords = ["default", "response", "process"]
return {
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower() for keyword in blacklist_keywords
)
}
except ImportError as e:
print(e)
return {}
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
from nano_graphrag.prompt import PROMPTS
# modify the prompt if it is set in the settings
for prompt_name, content in self.prompts.items():
if prompt_name in PROMPTS:
PROMPTS[prompt_name] = content
_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
@ -297,6 +327,19 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
search_type: str = "local"
@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}
def _build_graph_search(self):
file_id = self.file_ids[0]
@ -321,7 +364,8 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
llm_func=llm_func,
embedding_func=embedding_func,
)
query_params = QueryParam(mode="local", only_need_context=True)
print("search_type", self.search_type)
query_params = QueryParam(mode=self.search_type, only_need_context=True)
return graphrag_func, query_params
@ -384,22 +428,43 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
return []
graphrag_func, query_params = self._build_graph_search()
entities, relationships, reports, sources = asyncio.run(
nano_graph_rag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(
entities, relationships, reports, sources
)
plot = self.plot_graph(relationships)
# only local mode support graph visualization
if query_params.mode == "local":
entities, relationships, reports, sources = asyncio.run(
nano_graph_rag_build_local_query_context(
graphrag_func, text, query_params
)
)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
documents = self.format_context_records(
entities, relationships, reports, sources
)
plot = self.plot_graph(relationships)
documents += [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
else:
context = graphrag_func.query(text, query_params)
documents = [
RetrievedDocument(
text=context,
metadata={
"file_name": "GraphRAG {} Search".format(
query_params.mode.capitalize()
),
"type": "table",
},
)
]
return documents

View File

@ -180,7 +180,7 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global"],
"choices": ["local"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}

View File

@ -106,6 +106,12 @@ class SettingsPage(BasePage):
self.on_building_ui()
def on_building_ui(self):
self.setting_save_btn = gr.Button(
"Save & Close",
variant="primary",
elem_classes=["right-button"],
elem_id="save-setting-btn",
)
if self._app.f_user_management:
with gr.Tab("User settings"):
self.user_tab()
@ -114,10 +120,6 @@ class SettingsPage(BasePage):
self.index_tab()
self.reasoning_tab()
self.setting_save_btn = gr.Button(
"Save changes", variant="primary", scale=1, elem_classes=["right-button"]
)
def on_subscribe_public_events(self):
"""
Subscribes to public events related to user management.
@ -177,6 +179,9 @@ class SettingsPage(BasePage):
self.save_setting,
inputs=[self._user_id] + self.components(),
outputs=self._settings_state,
).then(
lambda: gr.Tabs(selected="chat-tab"),
outputs=self._app.tabs,
)
self._components["reasoning.use"].change(
self.change_reasoning_mode,

View File

@ -49,7 +49,13 @@ class Render:
def table(text: str) -> str:
"""Render table from markdown format into HTML"""
text = replace_mardown_header(text)
return markdown.markdown(text, extensions=["markdown.extensions.tables"])
return markdown.markdown(
text,
extensions=[
"markdown.extensions.tables",
"markdown.extensions.fenced_code",
],
)
@staticmethod
def preview(