feat: allow to use customized GraphRAG settings.yaml (#387) bump:patch

* allow to use customized GraphRAG settings.yaml

* adjust import style

* fix typo

* Added GraphRAG original documentation reference.

* feat: allow to use customized GraphRAG settings.yaml
(#387)

---------

Co-authored-by: Chen, Ron Gang <git@git.com>
This commit is contained in:
ronchengang
2024-10-14 22:18:34 +08:00
committed by GitHub
parent f0f3b4b23e
commit 8188760f32
3 changed files with 193 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
import os
import shutil
import subprocess
from pathlib import Path
from shutil import rmtree
@@ -7,6 +8,8 @@ from uuid import uuid4
import pandas as pd
import tiktoken
import yaml
from decouple import config
from ktem.db.models import engine
from sqlalchemy.orm import Session
from theflow.settings import settings
@@ -116,6 +119,16 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
print(result.stdout)
command = command[:-1]
# copy customized GraphRAG config file if it exists
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
setting_file_path = os.path.join(os.getcwd(), "settings.yaml.example")
destination_file_path = os.path.join(input_path, "settings.yaml")
try:
shutil.copy(setting_file_path, destination_file_path)
except shutil.Error:
# Handle the error if the file copy fails
print("failed to copy customized GraphRAG config file. ")
# Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout:
@@ -221,12 +234,28 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
# initialize default settings
embedding_model = os.getenv(
"GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"
)
embedding_api_key = os.getenv("GRAPHRAG_API_KEY")
embedding_api_base = None
# use customized GraphRAG settings if the flag is set
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
settings_yaml_path = Path(root_path) / "settings.yaml"
with open(settings_yaml_path, "r") as f:
settings = yaml.safe_load(f)
if settings["embeddings"]["llm"]["model"]:
embedding_model = settings["embeddings"]["llm"]["model"]
if settings["embeddings"]["llm"]["api_key"]:
embedding_api_key = settings["embeddings"]["llm"]["api_key"]
if settings["embeddings"]["llm"]["api_base"]:
embedding_api_base = settings["embeddings"]["llm"]["api_base"]
text_embedder = OpenAIEmbedding(
api_key=os.getenv("GRAPHRAG_API_KEY"),
api_base=None,
api_key=embedding_api_key,
api_base=embedding_api_base,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,