feat: add graphrag modes (#574) #none
* feat: add support for retrieval modes in LightRAG & NanoGraphRAG * feat: expose custom prompts in LightRAG & NanoGraphRAG * fix: optimize setting UI * fix: update non local mode in LightRAG * fix: update graphRAG mode
This commit is contained in:
parent
c667bf9d0a
commit
1d3c4f4433
|
@ -302,7 +302,7 @@ GRAPHRAG_INDEX_TYPES = ["ktem.index.file.graph.GraphRAGIndex"]
|
|||
|
||||
if USE_NANO_GRAPHRAG:
|
||||
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex")
|
||||
elif USE_LIGHTRAG:
|
||||
if USE_LIGHTRAG:
|
||||
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.LightRAGIndex")
|
||||
|
||||
KH_INDEX_TYPES = [
|
||||
|
|
|
@ -204,6 +204,11 @@ mark {
|
|||
right: 15px;
|
||||
}
|
||||
|
||||
/* prevent overflow of html info panel */
|
||||
#html-info-panel {
|
||||
overflow-x: auto !important;
|
||||
}
|
||||
|
||||
#chat-expand-button {
|
||||
position: absolute;
|
||||
top: 6px;
|
||||
|
@ -211,6 +216,12 @@ mark {
|
|||
z-index: 1;
|
||||
}
|
||||
|
||||
#save-setting-btn {
|
||||
width: 150px;
|
||||
height: 30px;
|
||||
min-width: 100px !important;
|
||||
}
|
||||
|
||||
#quick-setting-labels {
|
||||
margin-top: 5px;
|
||||
margin-bottom: -10px;
|
||||
|
|
|
@ -21,6 +21,11 @@ function run() {
|
|||
let chat_column = document.getElementById("main-chat-bot");
|
||||
let conv_column = document.getElementById("conv-settings-panel");
|
||||
|
||||
// move setting close button
|
||||
let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav");
|
||||
let setting_close_button = document.getElementById("save-setting-btn");
|
||||
setting_tab_nav_bar.appendChild(setting_close_button);
|
||||
|
||||
let default_conv_column_min_width = "min(300px, 100%)";
|
||||
conv_column.style.minWidth = default_conv_column_min_width
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
from ..base import BaseFileIndexRetriever
|
||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
from .graph_index import GraphRAGIndex
|
||||
from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline
|
||||
|
||||
|
@ -12,14 +12,32 @@ class LightRAGIndex(GraphRAGIndex):
|
|||
def _setup_retriever_cls(self):
|
||||
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
|
||||
|
||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||
pipeline = super().get_indexing_pipeline(settings, user_id)
|
||||
# indexing settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
striped_settings = {
|
||||
key[len(prefix) :]: value
|
||||
for key, value in settings.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
# set the prompts
|
||||
pipeline.prompts = striped_settings
|
||||
return pipeline
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
_, file_ids, _ = selected
|
||||
# retrieval settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
search_type = settings.get(prefix + "search_type", "local")
|
||||
|
||||
retrievers = [
|
||||
LightRAGRetrieverPipeline(
|
||||
file_ids=file_ids,
|
||||
Index=self._resources["Index"],
|
||||
search_type=search_type,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ def get_llm_func(model):
|
|||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
output = model(input_messages).text
|
||||
output = (await model.ainvoke(input_messages)).text
|
||||
|
||||
print("-" * 50)
|
||||
print(output, "\n", "-" * 50)
|
||||
|
@ -220,7 +220,37 @@ def build_graphrag(working_dir, llm_func, embedding_func):
|
|||
class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||
"""GraphRAG specific indexing pipeline"""
|
||||
|
||||
prompts: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
try:
|
||||
from lightrag.prompt import PROMPTS
|
||||
|
||||
blacklist_keywords = ["default", "response", "process"]
|
||||
return {
|
||||
prompt_name: {
|
||||
"name": f"Prompt for '{prompt_name}'",
|
||||
"value": content,
|
||||
"component": "text",
|
||||
}
|
||||
for prompt_name, content in PROMPTS.items()
|
||||
if all(
|
||||
keyword not in prompt_name.lower() for keyword in blacklist_keywords
|
||||
)
|
||||
}
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
return {}
|
||||
|
||||
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
|
||||
from lightrag.prompt import PROMPTS
|
||||
|
||||
# modify the prompt if it is set in the settings
|
||||
for prompt_name, content in self.prompts.items():
|
||||
if prompt_name in PROMPTS:
|
||||
PROMPTS[prompt_name] = content
|
||||
|
||||
_, input_path = prepare_graph_index_path(graph_id)
|
||||
input_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -302,6 +332,19 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
|
||||
Index = Param(help="The SQLAlchemy Index table")
|
||||
file_ids: list[str] = []
|
||||
search_type: str = "local"
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
return {
|
||||
"search_type": {
|
||||
"name": "Search type",
|
||||
"value": "local",
|
||||
"choices": ["local", "global", "hybrid"],
|
||||
"component": "dropdown",
|
||||
"info": "Whether to use local or global search in the graph.",
|
||||
}
|
||||
}
|
||||
|
||||
def _build_graph_search(self):
|
||||
file_id = self.file_ids[0]
|
||||
|
@ -326,7 +369,8 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
llm_func=llm_func,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
query_params = QueryParam(mode="local", only_need_context=True)
|
||||
print("search_type", self.search_type)
|
||||
query_params = QueryParam(mode=self.search_type, only_need_context=True)
|
||||
|
||||
return graphrag_func, query_params
|
||||
|
||||
|
@ -381,14 +425,15 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
return []
|
||||
|
||||
graphrag_func, query_params = self._build_graph_search()
|
||||
|
||||
# only local mode support graph visualization
|
||||
if query_params.mode == "local":
|
||||
entities, relationships, sources = asyncio.run(
|
||||
lightrag_build_local_query_context(graphrag_func, text, query_params)
|
||||
)
|
||||
|
||||
documents = self.format_context_records(entities, relationships, sources)
|
||||
plot = self.plot_graph(relationships)
|
||||
|
||||
return documents + [
|
||||
documents += [
|
||||
RetrievedDocument(
|
||||
text="",
|
||||
metadata={
|
||||
|
@ -398,3 +443,22 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
},
|
||||
),
|
||||
]
|
||||
else:
|
||||
context = graphrag_func.query(text, query_params)
|
||||
|
||||
# account for missing ``` for closing code block
|
||||
context += "\n```"
|
||||
|
||||
documents = [
|
||||
RetrievedDocument(
|
||||
text=context,
|
||||
metadata={
|
||||
"file_name": "GraphRAG {} Search".format(
|
||||
query_params.mode.capitalize()
|
||||
),
|
||||
"type": "table",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
return documents
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
from ..base import BaseFileIndexRetriever
|
||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
from .graph_index import GraphRAGIndex
|
||||
from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline
|
||||
|
||||
|
@ -12,14 +12,32 @@ class NanoGraphRAGIndex(GraphRAGIndex):
|
|||
def _setup_retriever_cls(self):
|
||||
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
|
||||
|
||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||
pipeline = super().get_indexing_pipeline(settings, user_id)
|
||||
# indexing settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
striped_settings = {
|
||||
key[len(prefix) :]: value
|
||||
for key, value in settings.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
# set the prompts
|
||||
pipeline.prompts = striped_settings
|
||||
return pipeline
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
_, file_ids, _ = selected
|
||||
# retrieval settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
search_type = settings.get(prefix + "search_type", "local")
|
||||
|
||||
retrievers = [
|
||||
NanoGraphRAGRetrieverPipeline(
|
||||
file_ids=file_ids,
|
||||
Index=self._resources["Index"],
|
||||
search_type=search_type,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ def get_llm_func(model):
|
|||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
output = model(input_messages).text
|
||||
output = (await model.ainvoke(input_messages)).text
|
||||
|
||||
print("-" * 50)
|
||||
print(output, "\n", "-" * 50)
|
||||
|
@ -216,7 +216,37 @@ def build_graphrag(working_dir, llm_func, embedding_func):
|
|||
class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||
"""GraphRAG specific indexing pipeline"""
|
||||
|
||||
prompts: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
try:
|
||||
from nano_graphrag.prompt import PROMPTS
|
||||
|
||||
blacklist_keywords = ["default", "response", "process"]
|
||||
return {
|
||||
prompt_name: {
|
||||
"name": f"Prompt for '{prompt_name}'",
|
||||
"value": content,
|
||||
"component": "text",
|
||||
}
|
||||
for prompt_name, content in PROMPTS.items()
|
||||
if all(
|
||||
keyword not in prompt_name.lower() for keyword in blacklist_keywords
|
||||
)
|
||||
}
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
return {}
|
||||
|
||||
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
|
||||
from nano_graphrag.prompt import PROMPTS
|
||||
|
||||
# modify the prompt if it is set in the settings
|
||||
for prompt_name, content in self.prompts.items():
|
||||
if prompt_name in PROMPTS:
|
||||
PROMPTS[prompt_name] = content
|
||||
|
||||
_, input_path = prepare_graph_index_path(graph_id)
|
||||
input_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -297,6 +327,19 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
|
||||
Index = Param(help="The SQLAlchemy Index table")
|
||||
file_ids: list[str] = []
|
||||
search_type: str = "local"
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
return {
|
||||
"search_type": {
|
||||
"name": "Search type",
|
||||
"value": "local",
|
||||
"choices": ["local", "global"],
|
||||
"component": "dropdown",
|
||||
"info": "Whether to use local or global search in the graph.",
|
||||
}
|
||||
}
|
||||
|
||||
def _build_graph_search(self):
|
||||
file_id = self.file_ids[0]
|
||||
|
@ -321,7 +364,8 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
llm_func=llm_func,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
query_params = QueryParam(mode="local", only_need_context=True)
|
||||
print("search_type", self.search_type)
|
||||
query_params = QueryParam(mode=self.search_type, only_need_context=True)
|
||||
|
||||
return graphrag_func, query_params
|
||||
|
||||
|
@ -384,8 +428,13 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
return []
|
||||
|
||||
graphrag_func, query_params = self._build_graph_search()
|
||||
|
||||
# only local mode support graph visualization
|
||||
if query_params.mode == "local":
|
||||
entities, relationships, reports, sources = asyncio.run(
|
||||
nano_graph_rag_build_local_query_context(graphrag_func, text, query_params)
|
||||
nano_graph_rag_build_local_query_context(
|
||||
graphrag_func, text, query_params
|
||||
)
|
||||
)
|
||||
|
||||
documents = self.format_context_records(
|
||||
|
@ -393,7 +442,7 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
)
|
||||
plot = self.plot_graph(relationships)
|
||||
|
||||
return documents + [
|
||||
documents += [
|
||||
RetrievedDocument(
|
||||
text="",
|
||||
metadata={
|
||||
|
@ -403,3 +452,19 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
},
|
||||
),
|
||||
]
|
||||
else:
|
||||
context = graphrag_func.query(text, query_params)
|
||||
|
||||
documents = [
|
||||
RetrievedDocument(
|
||||
text=context,
|
||||
metadata={
|
||||
"file_name": "GraphRAG {} Search".format(
|
||||
query_params.mode.capitalize()
|
||||
),
|
||||
"type": "table",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
return documents
|
||||
|
|
|
@ -180,7 +180,7 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
|||
"search_type": {
|
||||
"name": "Search type",
|
||||
"value": "local",
|
||||
"choices": ["local", "global"],
|
||||
"choices": ["local"],
|
||||
"component": "dropdown",
|
||||
"info": "Whether to use local or global search in the graph.",
|
||||
}
|
||||
|
|
|
@ -106,6 +106,12 @@ class SettingsPage(BasePage):
|
|||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
self.setting_save_btn = gr.Button(
|
||||
"Save & Close",
|
||||
variant="primary",
|
||||
elem_classes=["right-button"],
|
||||
elem_id="save-setting-btn",
|
||||
)
|
||||
if self._app.f_user_management:
|
||||
with gr.Tab("User settings"):
|
||||
self.user_tab()
|
||||
|
@ -114,10 +120,6 @@ class SettingsPage(BasePage):
|
|||
self.index_tab()
|
||||
self.reasoning_tab()
|
||||
|
||||
self.setting_save_btn = gr.Button(
|
||||
"Save changes", variant="primary", scale=1, elem_classes=["right-button"]
|
||||
)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
"""
|
||||
Subscribes to public events related to user management.
|
||||
|
@ -177,6 +179,9 @@ class SettingsPage(BasePage):
|
|||
self.save_setting,
|
||||
inputs=[self._user_id] + self.components(),
|
||||
outputs=self._settings_state,
|
||||
).then(
|
||||
lambda: gr.Tabs(selected="chat-tab"),
|
||||
outputs=self._app.tabs,
|
||||
)
|
||||
self._components["reasoning.use"].change(
|
||||
self.change_reasoning_mode,
|
||||
|
|
|
@ -49,7 +49,13 @@ class Render:
|
|||
def table(text: str) -> str:
|
||||
"""Render table from markdown format into HTML"""
|
||||
text = replace_mardown_header(text)
|
||||
return markdown.markdown(text, extensions=["markdown.extensions.tables"])
|
||||
return markdown.markdown(
|
||||
text,
|
||||
extensions=[
|
||||
"markdown.extensions.tables",
|
||||
"markdown.extensions.fenced_code",
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preview(
|
||||
|
|
Loading…
Reference in New Issue
Block a user