feat: integrate nano-graphrag (#433)

* add nano graph-rag

* ignore entities for relevant context reference

* refactor and add local model as default nano-graphrag

* feat: add kotaemon llm & embedding integration with nanographrag

* fix: add env var for nano GraphRAG

---------

Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
cin-klein 2024-10-30 15:32:30 +07:00 committed by GitHub
parent 19b386b51e
commit 66e565649e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 470 additions and 18 deletions

View File

@ -19,6 +19,8 @@ COHERE_API_KEY=<COHERE_API_KEY>
# settings for local models
LOCAL_MODEL=llama3.1:8b
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text
LOCAL_EMBEDDING_MODEL_DIM = 768
LOCAL_EMBEDDING_MODEL_MAX_TOKENS = 8192
# settings for GraphRAG
GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>

View File

@ -170,7 +170,22 @@ documents and developers who want to build their own RAG pipeline.
### Setup GraphRAG
> [!NOTE]
> Currently GraphRAG feature only works with OpenAI or Ollama API.
> Official MS GraphRAG indexing only works with OpenAI or Ollama API.
> We recommend most users to use NanoGraphRAG implementation for straightforward integration with Kotaemon.
<details>
<summary>Setup Nano GRAPHRAG</summary>
- Install nano-GraphRAG: `pip install nano-graphrag`
- Launch Kotaemon with `USE_NANO_GRAPHRAG=true` environment variable.
- Set your default LLM & Embedding models in Resources setting and it will be recognized automatically from NanoGraphRAG.
</details>
<details>
<summary>Setup MS GRAPHRAG</summary>
- **Non-Docker Installation**: If you are not using Docker, install GraphRAG with the following command:
@ -181,6 +196,8 @@ documents and developers who want to build their own RAG pipeline.
- **Setting Up API KEY**: To use the GraphRAG retriever feature, ensure you set the `GRAPHRAG_API_KEY` environment variable. You can do this directly in your environment or by adding it to a `.env` file.
- **Using Local Models and Custom Settings**: If you want to use GraphRAG with local models (like `Ollama`) or customize the default LLM and other configurations, set the `USE_CUSTOMIZED_GRAPHRAG_SETTING` environment variable to true. Then, adjust your settings in the `settings.yaml.example` file.
</details>
### Setup Local Models (for local/private RAG)
See [Local model setup](docs/local_model.md).

View File

@ -284,11 +284,43 @@ SETTINGS_REASONING = {
},
}
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
GRAPHRAG_INDEX_TYPE = (
"ktem.index.file.graph.GraphRAGIndex"
if not USE_NANO_GRAPHRAG
else "ktem.index.file.graph.NanoGraphRAGIndex"
)
KH_INDEX_TYPES = [
"ktem.index.file.FileIndex",
"ktem.index.file.graph.GraphRAGIndex",
GRAPHRAG_INDEX_TYPE,
]
GRAPHRAG_INDEX = (
{
"name": "GraphRAG",
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
),
"private": False,
},
"index_type": "ktem.index.file.graph.GraphRAGIndex",
}
if not USE_NANO_GRAPHRAG
else {
"name": "NanoGraphRAG",
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
),
"private": False,
},
"index_type": "ktem.index.file.graph.NanoGraphRAGIndex",
}
)
KH_INDICES = [
{
"name": "File",
@ -301,15 +333,5 @@ KH_INDICES = [
},
"index_type": "ktem.index.file.FileIndex",
},
{
"name": "GraphRAG",
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
),
"private": False,
},
"index_type": "ktem.index.file.graph.GraphRAGIndex",
},
GRAPHRAG_INDEX,
]

View File

@ -1,3 +1,4 @@
from .graph_index import GraphRAGIndex
from .nano_graph_index import NanoGraphRAGIndex
__all__ = ["GraphRAGIndex"]
__all__ = ["GraphRAGIndex", "NanoGraphRAGIndex"]

View File

@ -0,0 +1,26 @@
from typing import Any
from ..base import BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline
class NanoGraphRAGIndex(GraphRAGIndex):
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = NanoGraphRAGIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
retrievers = [
NanoGraphRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
)
]
return retrievers

View File

@ -0,0 +1,380 @@
import asyncio
import glob
import logging
import os
import re
from pathlib import Path
from typing import Generator
import numpy as np
import pandas as pd
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms
from sqlalchemy.orm import Session
from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument
from kotaemon.base.schema import AIMessage, HumanMessage, SystemMessage
from ..pipelines import BaseFileIndexRetriever
from .pipelines import GraphRAGIndexingPipeline
from .visualize import create_knowledge_graph, visualize_graph
try:
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag._op import (
_find_most_related_community_from_entities,
_find_most_related_edges_from_entities,
_find_most_related_text_unit_from_entities,
)
from nano_graphrag._utils import EmbeddingFunc
from nano_graphrag.base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
CommunitySchema,
TextChunkSchema,
)
except ImportError:
print(
(
"Nano-GraphRAG dependencies not installed. "
"Try `pip install nano-graphrag` to install. "
"Nano-GraphRAG retriever pipeline will not work properly."
)
)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "nano_graphrag"
filestorage_path.mkdir(parents=True, exist_ok=True)
def get_llm_func(model):
async def llm_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
input_messages = [SystemMessage(text=system_prompt)] if system_prompt else []
if history_messages:
for msg in history_messages:
if msg.get("role") == "user":
input_messages.append(HumanMessage(text=msg["content"]))
else:
input_messages.append(AIMessage(text=msg["content"]))
input_messages.append(HumanMessage(text=prompt))
output = model(input_messages).text
print("-" * 50)
print(output, "\n", "-" * 50)
return output
return llm_func
def get_embedding_func(model):
async def embedding_func(texts: list[str]) -> np.ndarray:
outputs = model(texts)
embedding_outputs = np.array([doc.embedding for doc in outputs])
return embedding_outputs
return embedding_func
def get_default_models_wrapper():
# setup model functions
default_embedding = embeddings.get_default()
default_embedding_dim = len(default_embedding(["Hi"])[0].embedding)
embedding_func = EmbeddingFunc(
embedding_dim=default_embedding_dim,
max_token_size=8192,
func=get_embedding_func(default_embedding),
)
print("GraphRAG embedding dim", default_embedding_dim)
default_llm = llms.get_default()
llm_func = get_llm_func(default_llm)
return llm_func, embedding_func, default_llm, default_embedding
def prepare_graph_index_path(graph_id: str):
root_path = Path(filestorage_path) / graph_id
input_path = root_path / "input"
return root_path, input_path
def list_of_list_to_df(data: list[list]) -> pd.DataFrame:
df = pd.DataFrame(data[1:], columns=data[0])
return df
def clean_quote(input: str) -> str:
return re.sub(r"[\"']", "", input)
async def nano_graph_rag_build_local_query_context(
graph_func,
query,
query_param: QueryParam,
):
knowledge_graph_inst: BaseGraphStorage = graph_func.chunk_entity_relation_graph
entities_vdb: BaseVectorStorage = graph_func.entities_vdb
community_reports: BaseKVStorage[CommunitySchema] = graph_func.community_reports
text_chunks_db: BaseKVStorage[TextChunkSchema] = graph_func.text_chunks
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
raise ValueError("No results found")
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
use_communities = await _find_most_related_community_from_entities(
node_datas, query_param, community_reports
)
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
[
str(i),
clean_quote(n["entity_name"]),
n.get("entity_type", "UNKNOWN"),
clean_quote(n.get("description", "UNKNOWN")),
n["rank"],
]
)
entities_df = list_of_list_to_df(entites_section_list)
relations_section_list = [
["id", "source", "target", "description", "weight", "rank"]
]
for i, e in enumerate(use_relations):
relations_section_list.append(
[
str(i),
clean_quote(e["src_tgt"][0]),
clean_quote(e["src_tgt"][1]),
clean_quote(e["description"]),
e["weight"],
e["rank"],
]
)
relations_df = list_of_list_to_df(relations_section_list)
communities_section_list = [["id", "content"]]
for i, c in enumerate(use_communities):
communities_section_list.append([str(i), c["report_string"]])
communities_df = list_of_list_to_df(communities_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([str(i), t["content"]])
sources_df = list_of_list_to_df(text_units_section_list)
return entities_df, relations_df, communities_df, sources_df
def build_graphrag(working_dir, llm_func, embedding_func):
graphrag_func = GraphRAG(
working_dir=working_dir,
best_model_func=llm_func,
cheap_model_func=llm_func,
embedding_func=embedding_func,
embedding_func_max_async=4,
)
return graphrag_func
class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
(
llm_func,
embedding_func,
default_llm,
default_embedding,
) = get_default_models_wrapper()
print(
f"Indexing GraphRAG with LLM {default_llm} "
f"and Embedding {default_embedding}..."
)
all_docs = [
doc.text for doc in docs if doc.metadata.get("type", "text") == "text"
]
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
)
# remove all .json files in the input_path directory (previous cache)
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# indexing
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
graphrag_func.insert(all_docs)
yield Document(
channel="debug",
text="[GraphRAG] Indexing finished.",
)
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
file_ids, errors, all_docs = yield from super().stream(
file_paths, reindex=reindex, **kwargs
)
return file_ids, errors, all_docs
class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever):
"""GraphRAG specific retriever pipeline"""
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
def _build_graph_search(self):
file_id = self.file_ids[0]
# retrieve the graph_id from the index
with Session(engine) as session:
graph_id = (
session.query(self.Index.target_id)
.filter(self.Index.source_id == file_id)
.filter(self.Index.relation_type == "graph")
.first()
)
graph_id = graph_id[0] if graph_id else None
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
llm_func, embedding_func, _, _ = get_default_models_wrapper()
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
query_params = QueryParam(mode="local", only_need_context=True)
return graphrag_func, query_params
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
return RetrievedDocument(
text=context_text,
metadata={
"file_name": header,
"type": "table",
"llm_trulens_score": 1.0,
},
score=1.0,
)
def format_context_records(
self, entities, relationships, reports, sources
) -> list[RetrievedDocument]:
docs = []
context: str = ""
# entities current parsing error
header = "<b>Entities</b>\n"
context = entities[["entity", "description"]].to_markdown(index=False)
docs.append(self._to_document(header, context))
header = "\n<b>Relationships</b>\n"
context = relationships[["source", "target", "description"]].to_markdown(
index=False
)
docs.append(self._to_document(header, context))
header = "\n<b>Reports</b>\n"
context = ""
for _, row in reports.iterrows():
title, content = row["id"], row["content"] # not contain title
context += f"\n\n<h5>Report <b>{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
header = "\n<b>Sources</b>\n"
context = ""
for _, row in sources.iterrows():
title, content = row["id"], row["content"]
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
return docs
def plot_graph(self, relationships):
G = create_knowledge_graph(relationships)
plot = visualize_graph(G)
return plot
def run(
self,
text: str,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []
graphrag_func, query_params = self._build_graph_search()
entities, relationships, reports, sources = asyncio.run(
nano_graph_rag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(
entities, relationships, reports, sources
)
plot = self.plot_graph(relationships)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]

View File

@ -38,6 +38,7 @@ except ImportError:
print(
(
"GraphRAG dependencies not installed. "
"Try `pip install graphrag future` to install. "
"GraphRAG retriever pipeline will not work properly."
)
)
@ -97,7 +98,11 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
return root_path
def call_graphrag_index(self, input_path: str):
def call_graphrag_index(self, graph_id: str, all_docs: list[Document]):
# call GraphRAG index with docs and graph_id
input_path = self.write_docs_to_files(graph_id, all_docs)
input_path = str(input_path.absolute())
# Construct the command
command = [
"python",
@ -147,8 +152,7 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
# assign graph_id to file_ids
graph_id = self.store_file_id_with_graph_id(file_ids)
# call GraphRAG index with docs and graph_id
graph_index_path = self.write_docs_to_files(graph_id, all_docs)
yield from self.call_graphrag_index(str(graph_index_path.absolute()))
yield from self.call_graphrag_index(graph_id, all_docs)
return file_ids, errors, all_docs