fix: selecting search all does not work on LightRAG / NanoGraphRAG (#627) #none

* fix: base path

* fix: select all doesn't work

* fix: adding new documents should update the existing index within the file collection instead of creating new one #561

* fix linter issues

* feat: update NanoGraphRAG with global collection search

---------

Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
Varun Sharma 2025-02-14 15:13:39 +01:00 committed by GitHub
parent 0b090896fd
commit e3921f7704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 173 additions and 29 deletions

View File

@ -317,6 +317,7 @@ SETTINGS_REASONING = {
}, },
} }
USE_GLOBAL_GRAPHRAG = config("USE_GLOBAL_GRAPHRAG", default=True, cast=bool)
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool) USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool) USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool) USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool)

View File

@ -25,7 +25,7 @@ class GraphRAGIndex(FileIndex):
def get_retriever_pipelines( def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]: ) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected file_ids = self._selector_ui.get_selected_ids(selected)
retrievers = [ retrievers = [
GraphRAGRetrieverPipeline( GraphRAGRetrieverPipeline(
file_ids=file_ids, file_ids=file_ids,

View File

@ -1,4 +1,8 @@
from typing import Any from typing import Any, Optional
from uuid import uuid4
from ktem.db.engine import engine
from sqlalchemy.orm import Session
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex from .graph_index import GraphRAGIndex
@ -6,12 +10,35 @@ from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipel
class LightRAGIndex(GraphRAGIndex): class LightRAGIndex(GraphRAGIndex):
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
self._collection_graph_id: Optional[str] = None
def _setup_indexing_cls(self): def _setup_indexing_cls(self):
self._indexing_pipeline_cls = LightRAGIndexingPipeline self._indexing_pipeline_cls = LightRAGIndexingPipeline
def _setup_retriever_cls(self): def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline] self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
def _get_or_create_collection_graph_id(self):
if self._collection_graph_id:
return self._collection_graph_id
# Try to find existing graph ID for this collection
with Session(engine) as session:
result = (
session.query(self._resources["Index"].target_id) # type: ignore
.filter(
self._resources["Index"].relation_type == "graph" # type: ignore
)
.first()
)
if result:
self._collection_graph_id = result[0]
else:
self._collection_graph_id = str(uuid4())
return self._collection_graph_id
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id) pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings # indexing settings
@ -23,12 +50,14 @@ class LightRAGIndex(GraphRAGIndex):
} }
# set the prompts # set the prompts
pipeline.prompts = striped_settings pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
return pipeline return pipeline
def get_retriever_pipelines( def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]: ) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected file_ids = self._selector_ui.get_selected_ids(selected)
# retrieval settings # retrieval settings
prefix = f"index.options.{self.id}." prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local") search_type = settings.get(prefix + "search_type", "local")

View File

@ -242,6 +242,40 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline""" """GraphRAG specific indexing pipeline"""
prompts: dict[str, str] = {} prompts: dict[str, str] = {}
collection_graph_id: str
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
if not settings.USE_GLOBAL_GRAPHRAG:
return super().store_file_id_with_graph_id(file_ids)
# Use the collection-wide graph ID for LightRAG
graph_id = self.collection_graph_id
# Record all files under this graph_id
with Session(engine) as session:
for file_id in file_ids:
if not file_id:
continue
# Check if mapping already exists
existing = (
session.query(self.Index)
.filter(
self.Index.source_id == file_id,
self.Index.target_id == graph_id,
self.Index.relation_type == "graph",
)
.first()
)
if not existing:
node = self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
session.add(node)
session.commit()
return graph_id
@classmethod @classmethod
def get_user_settings(cls) -> dict: def get_user_settings(cls) -> dict:
@ -295,46 +329,54 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
yield Document( yield Document(
channel="debug", channel="debug",
text="[GraphRAG] Creating index... This can take a long time.", text="[GraphRAG] Creating/Updating index... This can take a long time.",
) )
# remove all .json files in the input_path directory (previous cache) # Check if graph already exists
json_files = glob.glob(f"{input_path}/*.json") graph_file = input_path / "graph_chunk_entity_relation.graphml"
for json_file in json_files: is_incremental = graph_file.exists()
os.remove(json_file)
# indexing # Only clear cache if it's a new graph
if not is_incremental:
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# Initialize or load existing GraphRAG
graphrag_func = build_graphrag( graphrag_func = build_graphrag(
input_path, input_path,
llm_func=llm_func, llm_func=llm_func,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
total_docs = len(all_docs) total_docs = len(all_docs)
process_doc_count = 0 process_doc_count = 0
yield Document( yield Document(
channel="debug", channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.", text=(
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
f"{process_doc_count} / {total_docs} documents."
),
) )
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE): for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE] cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
combined_doc = "\n".join(cur_docs) combined_doc = "\n".join(cur_docs)
# Use insert for incremental updates
graphrag_func.insert(combined_doc) graphrag_func.insert(combined_doc)
process_doc_count += len(cur_docs) process_doc_count += len(cur_docs)
yield Document( yield Document(
channel="debug", channel="debug",
text=( text=(
f"[GraphRAG] Indexed {process_doc_count} " f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
f"/ {total_docs} documents." f"{process_doc_count} / {total_docs} documents."
), ),
) )
yield Document( yield Document(
channel="debug", channel="debug",
text="[GraphRAG] Indexing finished.", text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
) )
def stream( def stream(

View File

@ -1,4 +1,8 @@
from typing import Any from typing import Any, Optional
from uuid import uuid4
from ktem.db.engine import engine
from sqlalchemy.orm import Session
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex from .graph_index import GraphRAGIndex
@ -6,12 +10,35 @@ from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverP
class NanoGraphRAGIndex(GraphRAGIndex): class NanoGraphRAGIndex(GraphRAGIndex):
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
self._collection_graph_id: Optional[str] = None
def _setup_indexing_cls(self): def _setup_indexing_cls(self):
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
def _setup_retriever_cls(self): def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline] self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
def _get_or_create_collection_graph_id(self):
if self._collection_graph_id:
return self._collection_graph_id
# Try to find existing graph ID for this collection
with Session(engine) as session:
result = (
session.query(self._resources["Index"].target_id) # type: ignore
.filter(
self._resources["Index"].relation_type == "graph" # type: ignore
)
.first()
)
if result:
self._collection_graph_id = result[0]
else:
self._collection_graph_id = str(uuid4())
return self._collection_graph_id
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id) pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings # indexing settings
@ -23,12 +50,14 @@ class NanoGraphRAGIndex(GraphRAGIndex):
} }
# set the prompts # set the prompts
pipeline.prompts = striped_settings pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
return pipeline return pipeline
def get_retriever_pipelines( def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]: ) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected file_ids = self._selector_ui.get_selected_ids(selected)
# retrieval settings # retrieval settings
prefix = f"index.options.{self.id}." prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local") search_type = settings.get(prefix + "search_type", "local")

View File

@ -238,6 +238,40 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline""" """GraphRAG specific indexing pipeline"""
prompts: dict[str, str] = {} prompts: dict[str, str] = {}
collection_graph_id: str
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
if not settings.USE_GLOBAL_GRAPHRAG:
return super().store_file_id_with_graph_id(file_ids)
# Use the collection-wide graph ID for LightRAG
graph_id = self.collection_graph_id
# Record all files under this graph_id
with Session(engine) as session:
for file_id in file_ids:
if not file_id:
continue
# Check if mapping already exists
existing = (
session.query(self.Index)
.filter(
self.Index.source_id == file_id,
self.Index.target_id == graph_id,
self.Index.relation_type == "graph",
)
.first()
)
if not existing:
node = self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
session.add(node)
session.commit()
return graph_id
@classmethod @classmethod
def get_user_settings(cls) -> dict: def get_user_settings(cls) -> dict:
@ -291,45 +325,54 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
yield Document( yield Document(
channel="debug", channel="debug",
text="[GraphRAG] Creating index... This can take a long time.", text="[GraphRAG] Creating/Updating index... This can take a long time.",
) )
# remove all .json files in the input_path directory (previous cache) # Check if graph already exists
json_files = glob.glob(f"{input_path}/*.json") graph_file = input_path / "graph_chunk_entity_relation.graphml"
for json_file in json_files: is_incremental = graph_file.exists()
os.remove(json_file)
# indexing # Only clear cache if it's a new graph
if not is_incremental:
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# Initialize or load existing GraphRAG
graphrag_func = build_graphrag( graphrag_func = build_graphrag(
input_path, input_path,
llm_func=llm_func, llm_func=llm_func,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
total_docs = len(all_docs) total_docs = len(all_docs)
process_doc_count = 0 process_doc_count = 0
yield Document( yield Document(
channel="debug", channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.", text=(
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
f"{process_doc_count} / {total_docs} documents."
),
) )
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE): for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE] cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
combined_doc = "\n".join(cur_docs) combined_doc = "\n".join(cur_docs)
# Use insert for incremental updates
graphrag_func.insert(combined_doc) graphrag_func.insert(combined_doc)
process_doc_count += len(cur_docs) process_doc_count += len(cur_docs)
yield Document( yield Document(
channel="debug", channel="debug",
text=( text=(
f"[GraphRAG] Indexed {process_doc_count} " f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
f"/ {total_docs} documents." f"{process_doc_count} / {total_docs} documents."
), ),
) )
yield Document( yield Document(
channel="debug", channel="debug",
text="[GraphRAG] Indexing finished.", text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
) )
def stream( def stream(