feat: integrate nano-graphrag (#433)
* add nano graph-rag * ignore entities for relevant context reference * refactor and add local model as default nano-graphrag * feat: add kotaemon llm & embedding integration with nanographrag * fix: add env var for nano GraphRAG --------- Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
19b386b51e
commit
66e565649e
|
@ -19,6 +19,8 @@ COHERE_API_KEY=<COHERE_API_KEY>
|
||||||
# settings for local models
|
# settings for local models
|
||||||
LOCAL_MODEL=llama3.1:8b
|
LOCAL_MODEL=llama3.1:8b
|
||||||
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text
|
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text
|
||||||
|
LOCAL_EMBEDDING_MODEL_DIM = 768
|
||||||
|
LOCAL_EMBEDDING_MODEL_MAX_TOKENS = 8192
|
||||||
|
|
||||||
# settings for GraphRAG
|
# settings for GraphRAG
|
||||||
GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>
|
GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>
|
||||||
|
|
19
README.md
19
README.md
|
@ -170,7 +170,22 @@ documents and developers who want to build their own RAG pipeline.
|
||||||
### Setup GraphRAG
|
### Setup GraphRAG
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Currently GraphRAG feature only works with OpenAI or Ollama API.
|
> Official MS GraphRAG indexing only works with OpenAI or Ollama API.
|
||||||
|
> We recommend most users to use NanoGraphRAG implementation for straightforward integration with Kotaemon.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Setup Nano GRAPHRAG</summary>
|
||||||
|
|
||||||
|
- Install nano-GraphRAG: `pip install nano-graphrag`
|
||||||
|
- Launch Kotaemon with `USE_NANO_GRAPHRAG=true` environment variable.
|
||||||
|
- Set your default LLM & Embedding models in Resources setting and it will be recognized automatically from NanoGraphRAG.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Setup MS GRAPHRAG</summary>
|
||||||
|
|
||||||
- **Non-Docker Installation**: If you are not using Docker, install GraphRAG with the following command:
|
- **Non-Docker Installation**: If you are not using Docker, install GraphRAG with the following command:
|
||||||
|
|
||||||
|
@ -181,6 +196,8 @@ documents and developers who want to build their own RAG pipeline.
|
||||||
- **Setting Up API KEY**: To use the GraphRAG retriever feature, ensure you set the `GRAPHRAG_API_KEY` environment variable. You can do this directly in your environment or by adding it to a `.env` file.
|
- **Setting Up API KEY**: To use the GraphRAG retriever feature, ensure you set the `GRAPHRAG_API_KEY` environment variable. You can do this directly in your environment or by adding it to a `.env` file.
|
||||||
- **Using Local Models and Custom Settings**: If you want to use GraphRAG with local models (like `Ollama`) or customize the default LLM and other configurations, set the `USE_CUSTOMIZED_GRAPHRAG_SETTING` environment variable to true. Then, adjust your settings in the `settings.yaml.example` file.
|
- **Using Local Models and Custom Settings**: If you want to use GraphRAG with local models (like `Ollama`) or customize the default LLM and other configurations, set the `USE_CUSTOMIZED_GRAPHRAG_SETTING` environment variable to true. Then, adjust your settings in the `settings.yaml.example` file.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### Setup Local Models (for local/private RAG)
|
### Setup Local Models (for local/private RAG)
|
||||||
|
|
||||||
See [Local model setup](docs/local_model.md).
|
See [Local model setup](docs/local_model.md).
|
||||||
|
|
|
@ -284,11 +284,43 @@ SETTINGS_REASONING = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
|
||||||
|
GRAPHRAG_INDEX_TYPE = (
|
||||||
|
"ktem.index.file.graph.GraphRAGIndex"
|
||||||
|
if not USE_NANO_GRAPHRAG
|
||||||
|
else "ktem.index.file.graph.NanoGraphRAGIndex"
|
||||||
|
)
|
||||||
KH_INDEX_TYPES = [
|
KH_INDEX_TYPES = [
|
||||||
"ktem.index.file.FileIndex",
|
"ktem.index.file.FileIndex",
|
||||||
"ktem.index.file.graph.GraphRAGIndex",
|
GRAPHRAG_INDEX_TYPE,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
GRAPHRAG_INDEX = (
|
||||||
|
{
|
||||||
|
"name": "GraphRAG",
|
||||||
|
"config": {
|
||||||
|
"supported_file_types": (
|
||||||
|
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
||||||
|
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
|
||||||
|
),
|
||||||
|
"private": False,
|
||||||
|
},
|
||||||
|
"index_type": "ktem.index.file.graph.GraphRAGIndex",
|
||||||
|
}
|
||||||
|
if not USE_NANO_GRAPHRAG
|
||||||
|
else {
|
||||||
|
"name": "NanoGraphRAG",
|
||||||
|
"config": {
|
||||||
|
"supported_file_types": (
|
||||||
|
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
||||||
|
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
|
||||||
|
),
|
||||||
|
"private": False,
|
||||||
|
},
|
||||||
|
"index_type": "ktem.index.file.graph.NanoGraphRAGIndex",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
KH_INDICES = [
|
KH_INDICES = [
|
||||||
{
|
{
|
||||||
"name": "File",
|
"name": "File",
|
||||||
|
@ -301,15 +333,5 @@ KH_INDICES = [
|
||||||
},
|
},
|
||||||
"index_type": "ktem.index.file.FileIndex",
|
"index_type": "ktem.index.file.FileIndex",
|
||||||
},
|
},
|
||||||
{
|
GRAPHRAG_INDEX,
|
||||||
"name": "GraphRAG",
|
|
||||||
"config": {
|
|
||||||
"supported_file_types": (
|
|
||||||
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
|
||||||
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
|
|
||||||
),
|
|
||||||
"private": False,
|
|
||||||
},
|
|
||||||
"index_type": "ktem.index.file.graph.GraphRAGIndex",
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
from .graph_index import GraphRAGIndex
|
from .graph_index import GraphRAGIndex
|
||||||
|
from .nano_graph_index import NanoGraphRAGIndex
|
||||||
|
|
||||||
__all__ = ["GraphRAGIndex"]
|
__all__ = ["GraphRAGIndex", "NanoGraphRAGIndex"]
|
||||||
|
|
26
libs/ktem/ktem/index/file/graph/nano_graph_index.py
Normal file
26
libs/ktem/ktem/index/file/graph/nano_graph_index.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..base import BaseFileIndexRetriever
|
||||||
|
from .graph_index import GraphRAGIndex
|
||||||
|
from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class NanoGraphRAGIndex(GraphRAGIndex):
|
||||||
|
def _setup_indexing_cls(self):
|
||||||
|
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
|
||||||
|
|
||||||
|
def _setup_retriever_cls(self):
|
||||||
|
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
|
||||||
|
|
||||||
|
def get_retriever_pipelines(
|
||||||
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
|
_, file_ids, _ = selected
|
||||||
|
retrievers = [
|
||||||
|
NanoGraphRAGRetrieverPipeline(
|
||||||
|
file_ids=file_ids,
|
||||||
|
Index=self._resources["Index"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return retrievers
|
380
libs/ktem/ktem/index/file/graph/nano_pipelines.py
Normal file
380
libs/ktem/ktem/index/file/graph/nano_pipelines.py
Normal file
|
@ -0,0 +1,380 @@
|
||||||
|
import asyncio
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from ktem.db.models import engine
|
||||||
|
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from theflow.settings import settings
|
||||||
|
|
||||||
|
from kotaemon.base import Document, Param, RetrievedDocument
|
||||||
|
from kotaemon.base.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from ..pipelines import BaseFileIndexRetriever
|
||||||
|
from .pipelines import GraphRAGIndexingPipeline
|
||||||
|
from .visualize import create_knowledge_graph, visualize_graph
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nano_graphrag import GraphRAG, QueryParam
|
||||||
|
from nano_graphrag._op import (
|
||||||
|
_find_most_related_community_from_entities,
|
||||||
|
_find_most_related_edges_from_entities,
|
||||||
|
_find_most_related_text_unit_from_entities,
|
||||||
|
)
|
||||||
|
from nano_graphrag._utils import EmbeddingFunc
|
||||||
|
from nano_graphrag.base import (
|
||||||
|
BaseGraphStorage,
|
||||||
|
BaseKVStorage,
|
||||||
|
BaseVectorStorage,
|
||||||
|
CommunitySchema,
|
||||||
|
TextChunkSchema,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"Nano-GraphRAG dependencies not installed. "
|
||||||
|
"Try `pip install nano-graphrag` to install. "
|
||||||
|
"Nano-GraphRAG retriever pipeline will not work properly."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "nano_graphrag"
|
||||||
|
filestorage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_func(model):
|
||||||
|
async def llm_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
input_messages = [SystemMessage(text=system_prompt)] if system_prompt else []
|
||||||
|
|
||||||
|
if history_messages:
|
||||||
|
for msg in history_messages:
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
input_messages.append(HumanMessage(text=msg["content"]))
|
||||||
|
else:
|
||||||
|
input_messages.append(AIMessage(text=msg["content"]))
|
||||||
|
|
||||||
|
input_messages.append(HumanMessage(text=prompt))
|
||||||
|
output = model(input_messages).text
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
print(output, "\n", "-" * 50)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
return llm_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_func(model):
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
outputs = model(texts)
|
||||||
|
embedding_outputs = np.array([doc.embedding for doc in outputs])
|
||||||
|
|
||||||
|
return embedding_outputs
|
||||||
|
|
||||||
|
return embedding_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_models_wrapper():
|
||||||
|
# setup model functions
|
||||||
|
default_embedding = embeddings.get_default()
|
||||||
|
default_embedding_dim = len(default_embedding(["Hi"])[0].embedding)
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=default_embedding_dim,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=get_embedding_func(default_embedding),
|
||||||
|
)
|
||||||
|
print("GraphRAG embedding dim", default_embedding_dim)
|
||||||
|
|
||||||
|
default_llm = llms.get_default()
|
||||||
|
llm_func = get_llm_func(default_llm)
|
||||||
|
|
||||||
|
return llm_func, embedding_func, default_llm, default_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_graph_index_path(graph_id: str):
|
||||||
|
root_path = Path(filestorage_path) / graph_id
|
||||||
|
input_path = root_path / "input"
|
||||||
|
|
||||||
|
return root_path, input_path
|
||||||
|
|
||||||
|
|
||||||
|
def list_of_list_to_df(data: list[list]) -> pd.DataFrame:
|
||||||
|
df = pd.DataFrame(data[1:], columns=data[0])
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def clean_quote(input: str) -> str:
|
||||||
|
return re.sub(r"[\"']", "", input)
|
||||||
|
|
||||||
|
|
||||||
|
async def nano_graph_rag_build_local_query_context(
|
||||||
|
graph_func,
|
||||||
|
query,
|
||||||
|
query_param: QueryParam,
|
||||||
|
):
|
||||||
|
knowledge_graph_inst: BaseGraphStorage = graph_func.chunk_entity_relation_graph
|
||||||
|
entities_vdb: BaseVectorStorage = graph_func.entities_vdb
|
||||||
|
community_reports: BaseKVStorage[CommunitySchema] = graph_func.community_reports
|
||||||
|
text_chunks_db: BaseKVStorage[TextChunkSchema] = graph_func.text_chunks
|
||||||
|
|
||||||
|
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||||
|
if not len(results):
|
||||||
|
raise ValueError("No results found")
|
||||||
|
|
||||||
|
node_datas = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||||
|
)
|
||||||
|
node_degrees = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||||
|
)
|
||||||
|
node_datas = [
|
||||||
|
{**n, "entity_name": k["entity_name"], "rank": d}
|
||||||
|
for k, n, d in zip(results, node_datas, node_degrees)
|
||||||
|
if n is not None
|
||||||
|
]
|
||||||
|
use_communities = await _find_most_related_community_from_entities(
|
||||||
|
node_datas, query_param, community_reports
|
||||||
|
)
|
||||||
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||||
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||||
|
)
|
||||||
|
use_relations = await _find_most_related_edges_from_entities(
|
||||||
|
node_datas, query_param, knowledge_graph_inst
|
||||||
|
)
|
||||||
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||||
|
for i, n in enumerate(node_datas):
|
||||||
|
entites_section_list.append(
|
||||||
|
[
|
||||||
|
str(i),
|
||||||
|
clean_quote(n["entity_name"]),
|
||||||
|
n.get("entity_type", "UNKNOWN"),
|
||||||
|
clean_quote(n.get("description", "UNKNOWN")),
|
||||||
|
n["rank"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
entities_df = list_of_list_to_df(entites_section_list)
|
||||||
|
|
||||||
|
relations_section_list = [
|
||||||
|
["id", "source", "target", "description", "weight", "rank"]
|
||||||
|
]
|
||||||
|
for i, e in enumerate(use_relations):
|
||||||
|
relations_section_list.append(
|
||||||
|
[
|
||||||
|
str(i),
|
||||||
|
clean_quote(e["src_tgt"][0]),
|
||||||
|
clean_quote(e["src_tgt"][1]),
|
||||||
|
clean_quote(e["description"]),
|
||||||
|
e["weight"],
|
||||||
|
e["rank"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
relations_df = list_of_list_to_df(relations_section_list)
|
||||||
|
|
||||||
|
communities_section_list = [["id", "content"]]
|
||||||
|
for i, c in enumerate(use_communities):
|
||||||
|
communities_section_list.append([str(i), c["report_string"]])
|
||||||
|
communities_df = list_of_list_to_df(communities_section_list)
|
||||||
|
|
||||||
|
text_units_section_list = [["id", "content"]]
|
||||||
|
for i, t in enumerate(use_text_units):
|
||||||
|
text_units_section_list.append([str(i), t["content"]])
|
||||||
|
sources_df = list_of_list_to_df(text_units_section_list)
|
||||||
|
|
||||||
|
return entities_df, relations_df, communities_df, sources_df
|
||||||
|
|
||||||
|
|
||||||
|
def build_graphrag(working_dir, llm_func, embedding_func):
|
||||||
|
graphrag_func = GraphRAG(
|
||||||
|
working_dir=working_dir,
|
||||||
|
best_model_func=llm_func,
|
||||||
|
cheap_model_func=llm_func,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
embedding_func_max_async=4,
|
||||||
|
)
|
||||||
|
return graphrag_func
|
||||||
|
|
||||||
|
|
||||||
|
class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
|
"""GraphRAG specific indexing pipeline"""
|
||||||
|
|
||||||
|
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
|
||||||
|
_, input_path = prepare_graph_index_path(graph_id)
|
||||||
|
input_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
(
|
||||||
|
llm_func,
|
||||||
|
embedding_func,
|
||||||
|
default_llm,
|
||||||
|
default_embedding,
|
||||||
|
) = get_default_models_wrapper()
|
||||||
|
print(
|
||||||
|
f"Indexing GraphRAG with LLM {default_llm} "
|
||||||
|
f"and Embedding {default_embedding}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
all_docs = [
|
||||||
|
doc.text for doc in docs if doc.metadata.get("type", "text") == "text"
|
||||||
|
]
|
||||||
|
|
||||||
|
yield Document(
|
||||||
|
channel="debug",
|
||||||
|
text="[GraphRAG] Creating index... This can take a long time.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove all .json files in the input_path directory (previous cache)
|
||||||
|
json_files = glob.glob(f"{input_path}/*.json")
|
||||||
|
for json_file in json_files:
|
||||||
|
os.remove(json_file)
|
||||||
|
|
||||||
|
# indexing
|
||||||
|
graphrag_func = build_graphrag(
|
||||||
|
input_path,
|
||||||
|
llm_func=llm_func,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
)
|
||||||
|
# output must be contain: Loaded graph from
|
||||||
|
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
|
||||||
|
graphrag_func.insert(all_docs)
|
||||||
|
|
||||||
|
yield Document(
|
||||||
|
channel="debug",
|
||||||
|
text="[GraphRAG] Indexing finished.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||||
|
) -> Generator[
|
||||||
|
Document, None, tuple[list[str | None], list[str | None], list[Document]]
|
||||||
|
]:
|
||||||
|
file_ids, errors, all_docs = yield from super().stream(
|
||||||
|
file_paths, reindex=reindex, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return file_ids, errors, all_docs
|
||||||
|
|
||||||
|
|
||||||
|
class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
||||||
|
"""GraphRAG specific retriever pipeline"""
|
||||||
|
|
||||||
|
Index = Param(help="The SQLAlchemy Index table")
|
||||||
|
file_ids: list[str] = []
|
||||||
|
|
||||||
|
def _build_graph_search(self):
|
||||||
|
file_id = self.file_ids[0]
|
||||||
|
|
||||||
|
# retrieve the graph_id from the index
|
||||||
|
with Session(engine) as session:
|
||||||
|
graph_id = (
|
||||||
|
session.query(self.Index.target_id)
|
||||||
|
.filter(self.Index.source_id == file_id)
|
||||||
|
.filter(self.Index.relation_type == "graph")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
graph_id = graph_id[0] if graph_id else None
|
||||||
|
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
|
||||||
|
|
||||||
|
_, input_path = prepare_graph_index_path(graph_id)
|
||||||
|
input_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
llm_func, embedding_func, _, _ = get_default_models_wrapper()
|
||||||
|
graphrag_func = build_graphrag(
|
||||||
|
input_path,
|
||||||
|
llm_func=llm_func,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
)
|
||||||
|
query_params = QueryParam(mode="local", only_need_context=True)
|
||||||
|
|
||||||
|
return graphrag_func, query_params
|
||||||
|
|
||||||
|
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
|
||||||
|
return RetrievedDocument(
|
||||||
|
text=context_text,
|
||||||
|
metadata={
|
||||||
|
"file_name": header,
|
||||||
|
"type": "table",
|
||||||
|
"llm_trulens_score": 1.0,
|
||||||
|
},
|
||||||
|
score=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_context_records(
|
||||||
|
self, entities, relationships, reports, sources
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
docs = []
|
||||||
|
context: str = ""
|
||||||
|
|
||||||
|
# entities current parsing error
|
||||||
|
header = "<b>Entities</b>\n"
|
||||||
|
context = entities[["entity", "description"]].to_markdown(index=False)
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
header = "\n<b>Relationships</b>\n"
|
||||||
|
context = relationships[["source", "target", "description"]].to_markdown(
|
||||||
|
index=False
|
||||||
|
)
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
header = "\n<b>Reports</b>\n"
|
||||||
|
context = ""
|
||||||
|
for _, row in reports.iterrows():
|
||||||
|
title, content = row["id"], row["content"] # not contain title
|
||||||
|
context += f"\n\n<h5>Report <b>{title}</b></h5>\n"
|
||||||
|
context += content
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
header = "\n<b>Sources</b>\n"
|
||||||
|
context = ""
|
||||||
|
for _, row in sources.iterrows():
|
||||||
|
title, content = row["id"], row["content"]
|
||||||
|
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
|
||||||
|
context += content
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def plot_graph(self, relationships):
|
||||||
|
G = create_knowledge_graph(relationships)
|
||||||
|
plot = visualize_graph(G)
|
||||||
|
return plot
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
if not self.file_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
graphrag_func, query_params = self._build_graph_search()
|
||||||
|
entities, relationships, reports, sources = asyncio.run(
|
||||||
|
nano_graph_rag_build_local_query_context(graphrag_func, text, query_params)
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = self.format_context_records(
|
||||||
|
entities, relationships, reports, sources
|
||||||
|
)
|
||||||
|
plot = self.plot_graph(relationships)
|
||||||
|
|
||||||
|
return documents + [
|
||||||
|
RetrievedDocument(
|
||||||
|
text="",
|
||||||
|
metadata={
|
||||||
|
"file_name": "GraphRAG",
|
||||||
|
"type": "plot",
|
||||||
|
"data": plot,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
|
@ -38,6 +38,7 @@ except ImportError:
|
||||||
print(
|
print(
|
||||||
(
|
(
|
||||||
"GraphRAG dependencies not installed. "
|
"GraphRAG dependencies not installed. "
|
||||||
|
"Try `pip install graphrag future` to install. "
|
||||||
"GraphRAG retriever pipeline will not work properly."
|
"GraphRAG retriever pipeline will not work properly."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -97,7 +98,11 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
|
||||||
|
|
||||||
return root_path
|
return root_path
|
||||||
|
|
||||||
def call_graphrag_index(self, input_path: str):
|
def call_graphrag_index(self, graph_id: str, all_docs: list[Document]):
|
||||||
|
# call GraphRAG index with docs and graph_id
|
||||||
|
input_path = self.write_docs_to_files(graph_id, all_docs)
|
||||||
|
input_path = str(input_path.absolute())
|
||||||
|
|
||||||
# Construct the command
|
# Construct the command
|
||||||
command = [
|
command = [
|
||||||
"python",
|
"python",
|
||||||
|
@ -147,8 +152,7 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
|
||||||
# assign graph_id to file_ids
|
# assign graph_id to file_ids
|
||||||
graph_id = self.store_file_id_with_graph_id(file_ids)
|
graph_id = self.store_file_id_with_graph_id(file_ids)
|
||||||
# call GraphRAG index with docs and graph_id
|
# call GraphRAG index with docs and graph_id
|
||||||
graph_index_path = self.write_docs_to_files(graph_id, all_docs)
|
yield from self.call_graphrag_index(graph_id, all_docs)
|
||||||
yield from self.call_graphrag_index(str(graph_index_path.absolute()))
|
|
||||||
|
|
||||||
return file_ids, errors, all_docs
|
return file_ids, errors, all_docs
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user