fix: selecting search all does not work on LightRAG / NanoGraphRAG (#627) #none
* fix: base path * fix: select all doesn't work * fix: adding new documents should update the existing index within the file collection instead of creating new one #561 * fix linter issues * feat: update NanoGraphRAG with global collection search --------- Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
0b090896fd
commit
e3921f7704
|
@ -317,6 +317,7 @@ SETTINGS_REASONING = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
USE_GLOBAL_GRAPHRAG = config("USE_GLOBAL_GRAPHRAG", default=True, cast=bool)
|
||||||
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
|
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
|
||||||
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
|
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
|
||||||
USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool)
|
USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool)
|
||||||
|
|
|
@ -25,7 +25,7 @@ class GraphRAGIndex(FileIndex):
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
self, settings: dict, user_id: int, selected: Any = None
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
) -> list["BaseFileIndexRetriever"]:
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
_, file_ids, _ = selected
|
file_ids = self._selector_ui.get_selected_ids(selected)
|
||||||
retrievers = [
|
retrievers = [
|
||||||
GraphRAGRetrieverPipeline(
|
GraphRAGRetrieverPipeline(
|
||||||
file_ids=file_ids,
|
file_ids=file_ids,
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from ktem.db.engine import engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
from .graph_index import GraphRAGIndex
|
from .graph_index import GraphRAGIndex
|
||||||
|
@ -6,12 +10,35 @@ from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipel
|
||||||
|
|
||||||
|
|
||||||
class LightRAGIndex(GraphRAGIndex):
|
class LightRAGIndex(GraphRAGIndex):
|
||||||
|
def __init__(self, app, id: int, name: str, config: dict):
|
||||||
|
super().__init__(app, id, name, config)
|
||||||
|
self._collection_graph_id: Optional[str] = None
|
||||||
|
|
||||||
def _setup_indexing_cls(self):
|
def _setup_indexing_cls(self):
|
||||||
self._indexing_pipeline_cls = LightRAGIndexingPipeline
|
self._indexing_pipeline_cls = LightRAGIndexingPipeline
|
||||||
|
|
||||||
def _setup_retriever_cls(self):
|
def _setup_retriever_cls(self):
|
||||||
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
|
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
|
||||||
|
|
||||||
|
def _get_or_create_collection_graph_id(self):
|
||||||
|
if self._collection_graph_id:
|
||||||
|
return self._collection_graph_id
|
||||||
|
|
||||||
|
# Try to find existing graph ID for this collection
|
||||||
|
with Session(engine) as session:
|
||||||
|
result = (
|
||||||
|
session.query(self._resources["Index"].target_id) # type: ignore
|
||||||
|
.filter(
|
||||||
|
self._resources["Index"].relation_type == "graph" # type: ignore
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
self._collection_graph_id = result[0]
|
||||||
|
else:
|
||||||
|
self._collection_graph_id = str(uuid4())
|
||||||
|
return self._collection_graph_id
|
||||||
|
|
||||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||||
pipeline = super().get_indexing_pipeline(settings, user_id)
|
pipeline = super().get_indexing_pipeline(settings, user_id)
|
||||||
# indexing settings
|
# indexing settings
|
||||||
|
@ -23,12 +50,14 @@ class LightRAGIndex(GraphRAGIndex):
|
||||||
}
|
}
|
||||||
# set the prompts
|
# set the prompts
|
||||||
pipeline.prompts = striped_settings
|
pipeline.prompts = striped_settings
|
||||||
|
# set collection graph id
|
||||||
|
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
self, settings: dict, user_id: int, selected: Any = None
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
) -> list["BaseFileIndexRetriever"]:
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
_, file_ids, _ = selected
|
file_ids = self._selector_ui.get_selected_ids(selected)
|
||||||
# retrieval settings
|
# retrieval settings
|
||||||
prefix = f"index.options.{self.id}."
|
prefix = f"index.options.{self.id}."
|
||||||
search_type = settings.get(prefix + "search_type", "local")
|
search_type = settings.get(prefix + "search_type", "local")
|
||||||
|
|
|
@ -242,6 +242,40 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
"""GraphRAG specific indexing pipeline"""
|
"""GraphRAG specific indexing pipeline"""
|
||||||
|
|
||||||
prompts: dict[str, str] = {}
|
prompts: dict[str, str] = {}
|
||||||
|
collection_graph_id: str
|
||||||
|
|
||||||
|
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||||
|
if not settings.USE_GLOBAL_GRAPHRAG:
|
||||||
|
return super().store_file_id_with_graph_id(file_ids)
|
||||||
|
|
||||||
|
# Use the collection-wide graph ID for LightRAG
|
||||||
|
graph_id = self.collection_graph_id
|
||||||
|
|
||||||
|
# Record all files under this graph_id
|
||||||
|
with Session(engine) as session:
|
||||||
|
for file_id in file_ids:
|
||||||
|
if not file_id:
|
||||||
|
continue
|
||||||
|
# Check if mapping already exists
|
||||||
|
existing = (
|
||||||
|
session.query(self.Index)
|
||||||
|
.filter(
|
||||||
|
self.Index.source_id == file_id,
|
||||||
|
self.Index.target_id == graph_id,
|
||||||
|
self.Index.relation_type == "graph",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
node = self.Index(
|
||||||
|
source_id=file_id,
|
||||||
|
target_id=graph_id,
|
||||||
|
relation_type="graph",
|
||||||
|
)
|
||||||
|
session.add(node)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return graph_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
|
@ -295,46 +329,54 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
|
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text="[GraphRAG] Creating index... This can take a long time.",
|
text="[GraphRAG] Creating/Updating index... This can take a long time.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove all .json files in the input_path directory (previous cache)
|
# Check if graph already exists
|
||||||
|
graph_file = input_path / "graph_chunk_entity_relation.graphml"
|
||||||
|
is_incremental = graph_file.exists()
|
||||||
|
|
||||||
|
# Only clear cache if it's a new graph
|
||||||
|
if not is_incremental:
|
||||||
json_files = glob.glob(f"{input_path}/*.json")
|
json_files = glob.glob(f"{input_path}/*.json")
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
os.remove(json_file)
|
os.remove(json_file)
|
||||||
|
|
||||||
# indexing
|
# Initialize or load existing GraphRAG
|
||||||
graphrag_func = build_graphrag(
|
graphrag_func = build_graphrag(
|
||||||
input_path,
|
input_path,
|
||||||
llm_func=llm_func,
|
llm_func=llm_func,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
# output must be contain: Loaded graph from
|
|
||||||
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
|
|
||||||
total_docs = len(all_docs)
|
total_docs = len(all_docs)
|
||||||
process_doc_count = 0
|
process_doc_count = 0
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
|
text=(
|
||||||
|
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
|
||||||
|
f"{process_doc_count} / {total_docs} documents."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
||||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
||||||
combined_doc = "\n".join(cur_docs)
|
combined_doc = "\n".join(cur_docs)
|
||||||
|
|
||||||
|
# Use insert for incremental updates
|
||||||
graphrag_func.insert(combined_doc)
|
graphrag_func.insert(combined_doc)
|
||||||
process_doc_count += len(cur_docs)
|
process_doc_count += len(cur_docs)
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text=(
|
text=(
|
||||||
f"[GraphRAG] Indexed {process_doc_count} "
|
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
|
||||||
f"/ {total_docs} documents."
|
f"{process_doc_count} / {total_docs} documents."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text="[GraphRAG] Indexing finished.",
|
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from ktem.db.engine import engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
from .graph_index import GraphRAGIndex
|
from .graph_index import GraphRAGIndex
|
||||||
|
@ -6,12 +10,35 @@ from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverP
|
||||||
|
|
||||||
|
|
||||||
class NanoGraphRAGIndex(GraphRAGIndex):
|
class NanoGraphRAGIndex(GraphRAGIndex):
|
||||||
|
def __init__(self, app, id: int, name: str, config: dict):
|
||||||
|
super().__init__(app, id, name, config)
|
||||||
|
self._collection_graph_id: Optional[str] = None
|
||||||
|
|
||||||
def _setup_indexing_cls(self):
|
def _setup_indexing_cls(self):
|
||||||
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
|
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
|
||||||
|
|
||||||
def _setup_retriever_cls(self):
|
def _setup_retriever_cls(self):
|
||||||
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
|
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
|
||||||
|
|
||||||
|
def _get_or_create_collection_graph_id(self):
|
||||||
|
if self._collection_graph_id:
|
||||||
|
return self._collection_graph_id
|
||||||
|
|
||||||
|
# Try to find existing graph ID for this collection
|
||||||
|
with Session(engine) as session:
|
||||||
|
result = (
|
||||||
|
session.query(self._resources["Index"].target_id) # type: ignore
|
||||||
|
.filter(
|
||||||
|
self._resources["Index"].relation_type == "graph" # type: ignore
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
self._collection_graph_id = result[0]
|
||||||
|
else:
|
||||||
|
self._collection_graph_id = str(uuid4())
|
||||||
|
return self._collection_graph_id
|
||||||
|
|
||||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||||
pipeline = super().get_indexing_pipeline(settings, user_id)
|
pipeline = super().get_indexing_pipeline(settings, user_id)
|
||||||
# indexing settings
|
# indexing settings
|
||||||
|
@ -23,12 +50,14 @@ class NanoGraphRAGIndex(GraphRAGIndex):
|
||||||
}
|
}
|
||||||
# set the prompts
|
# set the prompts
|
||||||
pipeline.prompts = striped_settings
|
pipeline.prompts = striped_settings
|
||||||
|
# set collection graph id
|
||||||
|
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
self, settings: dict, user_id: int, selected: Any = None
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
) -> list["BaseFileIndexRetriever"]:
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
_, file_ids, _ = selected
|
file_ids = self._selector_ui.get_selected_ids(selected)
|
||||||
# retrieval settings
|
# retrieval settings
|
||||||
prefix = f"index.options.{self.id}."
|
prefix = f"index.options.{self.id}."
|
||||||
search_type = settings.get(prefix + "search_type", "local")
|
search_type = settings.get(prefix + "search_type", "local")
|
||||||
|
|
|
@ -238,6 +238,40 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
"""GraphRAG specific indexing pipeline"""
|
"""GraphRAG specific indexing pipeline"""
|
||||||
|
|
||||||
prompts: dict[str, str] = {}
|
prompts: dict[str, str] = {}
|
||||||
|
collection_graph_id: str
|
||||||
|
|
||||||
|
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||||
|
if not settings.USE_GLOBAL_GRAPHRAG:
|
||||||
|
return super().store_file_id_with_graph_id(file_ids)
|
||||||
|
|
||||||
|
# Use the collection-wide graph ID for LightRAG
|
||||||
|
graph_id = self.collection_graph_id
|
||||||
|
|
||||||
|
# Record all files under this graph_id
|
||||||
|
with Session(engine) as session:
|
||||||
|
for file_id in file_ids:
|
||||||
|
if not file_id:
|
||||||
|
continue
|
||||||
|
# Check if mapping already exists
|
||||||
|
existing = (
|
||||||
|
session.query(self.Index)
|
||||||
|
.filter(
|
||||||
|
self.Index.source_id == file_id,
|
||||||
|
self.Index.target_id == graph_id,
|
||||||
|
self.Index.relation_type == "graph",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
node = self.Index(
|
||||||
|
source_id=file_id,
|
||||||
|
target_id=graph_id,
|
||||||
|
relation_type="graph",
|
||||||
|
)
|
||||||
|
session.add(node)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return graph_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
|
@ -291,45 +325,54 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
|
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text="[GraphRAG] Creating index... This can take a long time.",
|
text="[GraphRAG] Creating/Updating index... This can take a long time.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove all .json files in the input_path directory (previous cache)
|
# Check if graph already exists
|
||||||
|
graph_file = input_path / "graph_chunk_entity_relation.graphml"
|
||||||
|
is_incremental = graph_file.exists()
|
||||||
|
|
||||||
|
# Only clear cache if it's a new graph
|
||||||
|
if not is_incremental:
|
||||||
json_files = glob.glob(f"{input_path}/*.json")
|
json_files = glob.glob(f"{input_path}/*.json")
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
os.remove(json_file)
|
os.remove(json_file)
|
||||||
|
|
||||||
# indexing
|
# Initialize or load existing GraphRAG
|
||||||
graphrag_func = build_graphrag(
|
graphrag_func = build_graphrag(
|
||||||
input_path,
|
input_path,
|
||||||
llm_func=llm_func,
|
llm_func=llm_func,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
# output must be contain: Loaded graph from
|
|
||||||
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
|
|
||||||
total_docs = len(all_docs)
|
total_docs = len(all_docs)
|
||||||
process_doc_count = 0
|
process_doc_count = 0
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
|
text=(
|
||||||
|
f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: "
|
||||||
|
f"{process_doc_count} / {total_docs} documents."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
||||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
||||||
combined_doc = "\n".join(cur_docs)
|
combined_doc = "\n".join(cur_docs)
|
||||||
|
|
||||||
|
# Use insert for incremental updates
|
||||||
graphrag_func.insert(combined_doc)
|
graphrag_func.insert(combined_doc)
|
||||||
process_doc_count += len(cur_docs)
|
process_doc_count += len(cur_docs)
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text=(
|
text=(
|
||||||
f"[GraphRAG] Indexed {process_doc_count} "
|
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
|
||||||
f"/ {total_docs} documents."
|
f"{process_doc_count} / {total_docs} documents."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield Document(
|
yield Document(
|
||||||
channel="debug",
|
channel="debug",
|
||||||
text="[GraphRAG] Indexing finished.",
|
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user