feat: allow to use customized GraphRAG settings.yaml (#387) bump:patch

* allow to use customized GraphRAG settings.yaml

* adjust import style

* fix typo

* Added GraphRAG original documentation reference.

* feat: allow to use customized GraphRAG settings.yaml
(#387)

---------

Co-authored-by: Chen, Ron Gang <git@git.com>
This commit is contained in:
ronchengang 2024-10-14 22:18:34 +08:00 committed by GitHub
parent f0f3b4b23e
commit 8188760f32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 193 additions and 2 deletions

View File

@ -25,6 +25,9 @@ GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>
GRAPHRAG_LLM_MODEL=gpt-4o-mini GRAPHRAG_LLM_MODEL=gpt-4o-mini
GRAPHRAG_EMBEDDING_MODEL=text-embedding-3-small GRAPHRAG_EMBEDDING_MODEL=text-embedding-3-small
# set to true if you want to use customized GraphRAG config file
USE_CUSTOMIZED_GRAPHRAG_SETTING=false
# settings for Azure DI # settings for Azure DI
AZURE_DI_ENDPOINT= AZURE_DI_ENDPOINT=
AZURE_DI_CREDENTIAL= AZURE_DI_CREDENTIAL=

View File

@ -1,4 +1,5 @@
import os import os
import shutil
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
@ -7,6 +8,8 @@ from uuid import uuid4
import pandas as pd import pandas as pd
import tiktoken import tiktoken
import yaml
from decouple import config
from ktem.db.models import engine from ktem.db.models import engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from theflow.settings import settings from theflow.settings import settings
@ -116,6 +119,16 @@ class GraphRAGIndexingPipeline(IndexDocumentPipeline):
print(result.stdout) print(result.stdout)
command = command[:-1] command = command[:-1]
# copy customized GraphRAG config file if it exists
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
setting_file_path = os.path.join(os.getcwd(), "settings.yaml.example")
destination_file_path = os.path.join(input_path, "settings.yaml")
try:
shutil.copy(setting_file_path, destination_file_path)
except shutil.Error:
# Handle the error if the file copy fails
print("failed to copy customized GraphRAG config file. ")
# Run the command and stream stdout # Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process: with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout: if process.stdout:
@ -221,12 +234,28 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet") text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df) text_units = read_indexer_text_units(text_unit_df)
# initialize default settings
embedding_model = os.getenv( embedding_model = os.getenv(
"GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small" "GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"
) )
embedding_api_key = os.getenv("GRAPHRAG_API_KEY")
embedding_api_base = None
# use customized GraphRAG settings if the flag is set
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
settings_yaml_path = Path(root_path) / "settings.yaml"
with open(settings_yaml_path, "r") as f:
settings = yaml.safe_load(f)
if settings["embeddings"]["llm"]["model"]:
embedding_model = settings["embeddings"]["llm"]["model"]
if settings["embeddings"]["llm"]["api_key"]:
embedding_api_key = settings["embeddings"]["llm"]["api_key"]
if settings["embeddings"]["llm"]["api_base"]:
embedding_api_base = settings["embeddings"]["llm"]["api_base"]
text_embedder = OpenAIEmbedding( text_embedder = OpenAIEmbedding(
api_key=os.getenv("GRAPHRAG_API_KEY"), api_key=embedding_api_key,
api_base=None, api_base=embedding_api_base,
api_type=OpenaiApiType.OpenAI, api_type=OpenaiApiType.OpenAI,
model=embedding_model, model=embedding_model,
deployment_name=embedding_model, deployment_name=embedding_model,

159
settings.yaml.example Normal file
View File

@ -0,0 +1,159 @@
# This is a sample GraphRAG settings.yaml file that allows users to run the GraphRAG index process with their customized parameters.
# The parameters in this file will only take effect when the USE_CUSTOMIZED_GRAPHRAG_SETTING is true in .env file.
# For a comprehensive understanding of GraphRAG parameters, please refer to: https://microsoft.github.io/graphrag/config/json_yaml/.
encoding_model: cl100k_base
skip_workflows: []
llm:
api_key: ${GRAPHRAG_API_KEY}
type: openai_chat # or azure_openai_chat
api_base: http://127.0.0.1:11434/v1
model: qwen2
model_supports_json: true # recommended if this is available for your model.
# max_tokens: 4000
request_timeout: 1800.0
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
concurrent_requests: 5 # the number of parallel inflight requests that may be made
# temperature: 0 # temperature for sampling
# top_p: 1 # top-p sampling
# n: 1 # Number of completions to generate
parallelization:
stagger: 0.3
# num_threads: 50 # the number of threads to use for parallel processing
async_mode: threaded # or asyncio
embeddings:
## parallelization: override the global parallelization settings for embeddings
async_mode: threaded # or asyncio
# target: required # or all
# batch_size: 16 # the number of documents to send in a single request
# batch_max_tokens: 8191 # the maximum number of tokens to send in a single request
llm:
api_base: http://localhost:11434/v1
api_key: ${GRAPHRAG_API_KEY}
model: nomic-embed-text
type: openai_embedding
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
# concurrent_requests: 25 # the number of parallel inflight requests that may be made
chunks:
size: 1200
overlap: 100
group_by_columns: [id] # by default, we don't allow chunks to cross documents
input:
type: file # or blob
file_type: text # or csv
base_dir: "input"
file_encoding: utf-8
file_pattern: ".*\\.txt$"
cache:
type: file # or blob
base_dir: "cache"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
storage:
type: file # or blob
base_dir: "output"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
reporting:
type: file # or console, blob
base_dir: "output"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
entity_extraction:
## strategy: fully override the entity extraction strategy.
## type: one of graph_intelligence, graph_intelligence_json and nltk
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/entity_extraction.txt"
entity_types: [organization,person,geo,event]
max_gleanings: 1
summarize_descriptions:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/summarize_descriptions.txt"
max_length: 500
claim_extraction:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
# enabled: true
prompt: "prompts/claim_extraction.txt"
description: "Any claims or facts that could be relevant to information discovery."
max_gleanings: 1
community_reports:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000
cluster_graph:
max_cluster_size: 10
embed_graph:
enabled: false # if true, will generate node2vec embeddings for nodes
# num_walks: 10
# walk_length: 40
# window_size: 2
# iterations: 3
# random_seed: 597832
umap:
enabled: false # if true, will generate UMAP embeddings for nodes
snapshots:
graphml: false
raw_entities: false
top_level_nodes: false
local_search:
# text_unit_prop: 0.5
# community_prop: 0.1
# conversation_history_max_turns: 5
# top_k_mapped_entities: 10
# top_k_relationships: 10
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000
global_search:
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000
# data_max_tokens: 12000
# map_max_tokens: 1000
# reduce_max_tokens: 2000
# concurrency: 32