feat: add lightrag support (#474) bump:patch

* feat: add lightrag support

* docs: update README
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-11-07 17:22:41 +07:00 committed by GitHub
parent 908452cc18
commit aeace96c48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 477 additions and 34 deletions

View File

@ -187,6 +187,18 @@ documents and developers who want to build their own RAG pipeline.
<details> <details>
<summary>Setup LIGHTRAG</summary>
- Install LightRAG: `pip install git+https://github.com/HKUDS/LightRAG.git`
- `LightRAG` install might introduce version conflicts, see [this issue](https://github.com/Cinnamon/kotaemon/issues/440)
- To quickly fix: `pip uninstall hnswlib chroma-hnswlib && pip install chroma-hnswlib`
- Launch Kotaemon with `USE_LIGHTRAG=true` environment variable.
- Set your default LLM & Embedding models in Resources setting and it will be recognized automatically from LightRAG.
</details>
<details>
<summary>Setup MS GRAPHRAG</summary> <summary>Setup MS GRAPHRAG</summary>
- **Non-Docker Installation**: If you are not using Docker, install GraphRAG with the following command: - **Non-Docker Installation**: If you are not using Docker, install GraphRAG with the following command:

View File

@ -287,41 +287,31 @@ SETTINGS_REASONING = {
} }
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool) USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
GRAPHRAG_INDEX_TYPE = ( USE_LIGHTRAG = config("USE_LIGHTRAG", default=False, cast=bool)
"ktem.index.file.graph.GraphRAGIndex"
if not USE_NANO_GRAPHRAG if USE_NANO_GRAPHRAG:
else "ktem.index.file.graph.NanoGraphRAGIndex" GRAPHRAG_INDEX_TYPE = "ktem.index.file.graph.NanoGraphRAGIndex"
) elif USE_LIGHTRAG:
GRAPHRAG_INDEX_TYPE = "ktem.index.file.graph.LightRAGIndex"
else:
GRAPHRAG_INDEX_TYPE = "ktem.index.file.graph.GraphRAGIndex"
KH_INDEX_TYPES = [ KH_INDEX_TYPES = [
"ktem.index.file.FileIndex", "ktem.index.file.FileIndex",
GRAPHRAG_INDEX_TYPE, GRAPHRAG_INDEX_TYPE,
] ]
GRAPHRAG_INDEX = ( GRAPHRAG_INDEX = {
{ "name": GRAPHRAG_INDEX_TYPE.split(".")[-1].replace("Index", ""), # get last name
"name": "GraphRAG", "config": {
"config": { "supported_file_types": (
"supported_file_types": ( ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " ".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
".pptx, .csv, .html, .mhtml, .txt, .md, .zip" ),
), "private": False,
"private": False, },
}, "index_type": GRAPHRAG_INDEX_TYPE,
"index_type": "ktem.index.file.graph.GraphRAGIndex", }
}
if not USE_NANO_GRAPHRAG
else {
"name": "NanoGraphRAG",
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
),
"private": False,
},
"index_type": "ktem.index.file.graph.NanoGraphRAGIndex",
}
)
KH_INDICES = [ KH_INDICES = [
{ {

View File

@ -1,4 +1,5 @@
from .graph_index import GraphRAGIndex from .graph_index import GraphRAGIndex
from .light_graph_index import LightRAGIndex
from .nano_graph_index import NanoGraphRAGIndex from .nano_graph_index import NanoGraphRAGIndex
__all__ = ["GraphRAGIndex", "NanoGraphRAGIndex"] __all__ = ["GraphRAGIndex", "NanoGraphRAGIndex", "LightRAGIndex"]

View File

@ -0,0 +1,26 @@
from typing import Any
from ..base import BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline
class LightRAGIndex(GraphRAGIndex):
def _setup_indexing_cls(self):
self._indexing_pipeline_cls = LightRAGIndexingPipeline
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]
def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
retrievers = [
LightRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
)
]
return retrievers

View File

@ -0,0 +1,386 @@
import asyncio
import glob
import logging
import os
import re
from pathlib import Path
from typing import Generator
import numpy as np
import pandas as pd
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms
from sqlalchemy.orm import Session
from theflow.settings import settings
from kotaemon.base import Document, Param, RetrievedDocument
from kotaemon.base.schema import AIMessage, HumanMessage, SystemMessage
from ..pipelines import BaseFileIndexRetriever
from .pipelines import GraphRAGIndexingPipeline
from .visualize import create_knowledge_graph, visualize_graph
try:
from lightrag import LightRAG, QueryParam
from lightrag.operate import (
_find_most_related_edges_from_entities,
_find_most_related_text_unit_from_entities,
)
from lightrag.utils import EmbeddingFunc, compute_args_hash
except ImportError:
print(
(
"LightRAG dependencies not installed. "
"Try `pip install git+https://github.com/HKUDS/LightRAG.git` to install. "
"LighthRAG retriever pipeline will not work properly."
)
)
logging.getLogger("lightrag").setLevel(logging.INFO)
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "lightrag"
filestorage_path.mkdir(parents=True, exist_ok=True)
INDEX_BATCHSIZE = 2
def get_llm_func(model):
async def llm_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
input_messages = [SystemMessage(text=system_prompt)] if system_prompt else []
hashing_kv = kwargs.pop("hashing_kv", None)
if history_messages:
for msg in history_messages:
if msg.get("role") == "user":
input_messages.append(HumanMessage(text=msg["content"]))
else:
input_messages.append(AIMessage(text=msg["content"]))
input_messages.append(HumanMessage(text=prompt))
if hashing_kv is not None:
args_hash = compute_args_hash("model", input_messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
output = model(input_messages).text
print("-" * 50)
print(output, "\n", "-" * 50)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": output, "model": "model"}})
return output
return llm_func
def get_embedding_func(model):
async def embedding_func(texts: list[str]) -> np.ndarray:
outputs = model(texts)
embedding_outputs = np.array([doc.embedding for doc in outputs])
return embedding_outputs
return embedding_func
def get_default_models_wrapper():
# setup model functions
default_embedding = embeddings.get_default()
default_embedding_dim = len(default_embedding(["Hi"])[0].embedding)
embedding_func = EmbeddingFunc(
embedding_dim=default_embedding_dim,
max_token_size=8192,
func=get_embedding_func(default_embedding),
)
print("GraphRAG embedding dim", default_embedding_dim)
default_llm = llms.get_default()
llm_func = get_llm_func(default_llm)
return llm_func, embedding_func, default_llm, default_embedding
def prepare_graph_index_path(graph_id: str):
root_path = Path(filestorage_path) / graph_id
input_path = root_path / "input"
return root_path, input_path
def list_of_list_to_df(data: list[list]) -> pd.DataFrame:
df = pd.DataFrame(data[1:], columns=data[0])
return df
def clean_quote(input: str) -> str:
return re.sub(r"[\"']", "", input)
async def lightrag_build_local_query_context(
graph_func,
query,
query_param,
):
knowledge_graph_inst = graph_func.chunk_entity_relation_graph
entities_vdb = graph_func.entities_vdb
text_chunks_db = graph_func.text_chunks
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
raise ValueError("No results found")
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
)
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
)
logging.info(
f"Local query uses {len(node_datas)} entities, "
f"{len(use_relations)} relations, {len(use_text_units)} text units"
)
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
[
str(i),
clean_quote(n["entity_name"]),
n.get("entity_type", "UNKNOWN"),
clean_quote(n.get("description", "UNKNOWN")),
n["rank"],
]
)
entities_df = list_of_list_to_df(entites_section_list)
relations_section_list = [
["id", "source", "target", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(use_relations):
relations_section_list.append(
[
str(i),
clean_quote(e["src_tgt"][0]),
clean_quote(e["src_tgt"][1]),
clean_quote(e["description"]),
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_df = list_of_list_to_df(relations_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([str(i), t["content"]])
sources_df = list_of_list_to_df(text_units_section_list)
return entities_df, relations_df, sources_df
def build_graphrag(working_dir, llm_func, embedding_func):
graphrag_func = LightRAG(
working_dir=working_dir,
llm_model_func=llm_func,
embedding_func=embedding_func,
)
return graphrag_func
class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""
def call_graphrag_index(self, graph_id: str, docs: list[Document]):
_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
(
llm_func,
embedding_func,
default_llm,
default_embedding,
) = get_default_models_wrapper()
print(
f"Indexing GraphRAG with LLM {default_llm} "
f"and Embedding {default_embedding}..."
)
all_docs = [
doc.text for doc in docs if doc.metadata.get("type", "text") == "text"
]
yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
)
# remove all .json files in the input_path directory (previous cache)
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# indexing
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
total_docs = len(all_docs)
process_doc_count = 0
yield Document(
channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
)
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
graphrag_func.insert(cur_docs)
process_doc_count += len(cur_docs)
yield Document(
channel="debug",
text=(
f"[GraphRAG] Indexed {process_doc_count} "
f"/ {total_docs} documents."
),
)
yield Document(
channel="debug",
text="[GraphRAG] Indexing finished.",
)
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
file_ids, errors, all_docs = yield from super().stream(
file_paths, reindex=reindex, **kwargs
)
return file_ids, errors, all_docs
class LightRAGRetrieverPipeline(BaseFileIndexRetriever):
"""GraphRAG specific retriever pipeline"""
Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
def _build_graph_search(self):
file_id = self.file_ids[0]
# retrieve the graph_id from the index
with Session(engine) as session:
graph_id = (
session.query(self.Index.target_id)
.filter(self.Index.source_id == file_id)
.filter(self.Index.relation_type == "graph")
.first()
)
graph_id = graph_id[0] if graph_id else None
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)
llm_func, embedding_func, _, _ = get_default_models_wrapper()
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
query_params = QueryParam(mode="local", only_need_context=True)
return graphrag_func, query_params
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
return RetrievedDocument(
text=context_text,
metadata={
"file_name": header,
"type": "table",
"llm_trulens_score": 1.0,
},
score=1.0,
)
def format_context_records(
self, entities, relationships, sources
) -> list[RetrievedDocument]:
docs = []
context: str = ""
# entities current parsing error
header = "<b>Entities</b>\n"
context = entities[["entity", "description"]].to_markdown(index=False)
docs.append(self._to_document(header, context))
header = "\n<b>Relationships</b>\n"
context = relationships[["source", "target", "description"]].to_markdown(
index=False
)
docs.append(self._to_document(header, context))
header = "\n<b>Sources</b>\n"
context = ""
for _, row in sources.iterrows():
title, content = row["id"], row["content"]
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
context += content
docs.append(self._to_document(header, context))
return docs
def plot_graph(self, relationships):
G = create_knowledge_graph(relationships)
plot = visualize_graph(G)
return plot
def run(
self,
text: str,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []
graphrag_func, query_params = self._build_graph_search()
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)
return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]

View File

@ -28,7 +28,7 @@ try:
_find_most_related_edges_from_entities, _find_most_related_edges_from_entities,
_find_most_related_text_unit_from_entities, _find_most_related_text_unit_from_entities,
) )
from nano_graphrag._utils import EmbeddingFunc from nano_graphrag._utils import EmbeddingFunc, compute_args_hash
except ImportError: except ImportError:
print( print(
@ -46,6 +46,8 @@ logging.getLogger("nano-graphrag").setLevel(logging.INFO)
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "nano_graphrag" filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "nano_graphrag"
filestorage_path.mkdir(parents=True, exist_ok=True) filestorage_path.mkdir(parents=True, exist_ok=True)
INDEX_BATCHSIZE = 4
def get_llm_func(model): def get_llm_func(model):
async def llm_func( async def llm_func(
@ -53,6 +55,7 @@ def get_llm_func(model):
) -> str: ) -> str:
input_messages = [SystemMessage(text=system_prompt)] if system_prompt else [] input_messages = [SystemMessage(text=system_prompt)] if system_prompt else []
hashing_kv = kwargs.pop("hashing_kv", None)
if history_messages: if history_messages:
for msg in history_messages: for msg in history_messages:
if msg.get("role") == "user": if msg.get("role") == "user":
@ -61,11 +64,21 @@ def get_llm_func(model):
input_messages.append(AIMessage(text=msg["content"])) input_messages.append(AIMessage(text=msg["content"]))
input_messages.append(HumanMessage(text=prompt)) input_messages.append(HumanMessage(text=prompt))
if hashing_kv is not None:
args_hash = compute_args_hash("model", input_messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
output = model(input_messages).text output = model(input_messages).text
print("-" * 50) print("-" * 50)
print(output, "\n", "-" * 50) print(output, "\n", "-" * 50)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": output, "model": "model"}})
return output return output
return llm_func return llm_func
@ -196,7 +209,6 @@ def build_graphrag(working_dir, llm_func, embedding_func):
best_model_func=llm_func, best_model_func=llm_func,
cheap_model_func=llm_func, cheap_model_func=llm_func,
embedding_func=embedding_func, embedding_func=embedding_func,
embedding_func_max_async=4,
) )
return graphrag_func return graphrag_func
@ -241,7 +253,23 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
) )
# output must be contain: Loaded graph from # output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges # ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges
graphrag_func.insert(all_docs) total_docs = len(all_docs)
process_doc_count = 0
yield Document(
channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
)
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
graphrag_func.insert(cur_docs)
process_doc_count += len(cur_docs)
yield Document(
channel="debug",
text=(
f"[GraphRAG] Indexed {process_doc_count} "
f"/ {total_docs} documents."
),
)
yield Document( yield Document(
channel="debug", channel="debug",