feat: add index batch size setting for lightrag (#720) #none
This commit is contained in:
parent
79a5f064a2
commit
2ffe374c2f
|
@ -86,7 +86,7 @@ RUN --mount=type=ssh \
|
||||||
ENV USE_LIGHTRAG=true
|
ENV USE_LIGHTRAG=true
|
||||||
RUN --mount=type=ssh \
|
RUN --mount=type=ssh \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install aioboto3 nano-vectordb ollama xxhash "lightrag-hku<=0.0.8"
|
pip install aioboto3 nano-vectordb ollama xxhash "lightrag-hku<=1.3.0"
|
||||||
|
|
||||||
RUN --mount=type=ssh \
|
RUN --mount=type=ssh \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
|
|
|
@ -52,6 +52,10 @@ class LightRAGIndex(GraphRAGIndex):
|
||||||
pipeline.prompts = striped_settings
|
pipeline.prompts = striped_settings
|
||||||
# set collection graph id
|
# set collection graph id
|
||||||
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
||||||
|
# set index batch size
|
||||||
|
pipeline.index_batch_size = striped_settings.get(
|
||||||
|
"batch_size", pipeline.index_batch_size
|
||||||
|
)
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
|
|
|
@ -243,6 +243,7 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
|
|
||||||
prompts: dict[str, str] = {}
|
prompts: dict[str, str] = {}
|
||||||
collection_graph_id: str
|
collection_graph_id: str
|
||||||
|
index_batch_size: int = INDEX_BATCHSIZE
|
||||||
|
|
||||||
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||||
if not settings.USE_GLOBAL_GRAPHRAG:
|
if not settings.USE_GLOBAL_GRAPHRAG:
|
||||||
|
@ -283,18 +284,31 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
from lightrag.prompt import PROMPTS
|
from lightrag.prompt import PROMPTS
|
||||||
|
|
||||||
blacklist_keywords = ["default", "response", "process"]
|
blacklist_keywords = ["default", "response", "process"]
|
||||||
return {
|
settings_dict = {
|
||||||
prompt_name: {
|
"batch_size": {
|
||||||
"name": f"Prompt for '{prompt_name}'",
|
"name": (
|
||||||
"value": content,
|
"Index batch size " "(reduce if you have rate limit issues)"
|
||||||
"component": "text",
|
),
|
||||||
|
"value": INDEX_BATCHSIZE,
|
||||||
|
"component": "number",
|
||||||
}
|
}
|
||||||
for prompt_name, content in PROMPTS.items()
|
|
||||||
if all(
|
|
||||||
keyword not in prompt_name.lower() for keyword in blacklist_keywords
|
|
||||||
)
|
|
||||||
and isinstance(content, str)
|
|
||||||
}
|
}
|
||||||
|
settings_dict.update(
|
||||||
|
{
|
||||||
|
prompt_name: {
|
||||||
|
"name": f"Prompt for '{prompt_name}'",
|
||||||
|
"value": content,
|
||||||
|
"component": "text",
|
||||||
|
}
|
||||||
|
for prompt_name, content in PROMPTS.items()
|
||||||
|
if all(
|
||||||
|
keyword not in prompt_name.lower()
|
||||||
|
for keyword in blacklist_keywords
|
||||||
|
)
|
||||||
|
and isinstance(content, str)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return settings_dict
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(e)
|
print(e)
|
||||||
return {}
|
return {}
|
||||||
|
@ -359,8 +373,8 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
for doc_id in range(0, len(all_docs), self.index_batch_size):
|
||||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
cur_docs = all_docs[doc_id : doc_id + self.index_batch_size]
|
||||||
combined_doc = "\n".join(cur_docs)
|
combined_doc = "\n".join(cur_docs)
|
||||||
|
|
||||||
# Use insert for incremental updates
|
# Use insert for incremental updates
|
||||||
|
|
|
@ -52,6 +52,10 @@ class NanoGraphRAGIndex(GraphRAGIndex):
|
||||||
pipeline.prompts = striped_settings
|
pipeline.prompts = striped_settings
|
||||||
# set collection graph id
|
# set collection graph id
|
||||||
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
|
||||||
|
# set index batch size
|
||||||
|
pipeline.index_batch_size = striped_settings.get(
|
||||||
|
"batch_size", pipeline.index_batch_size
|
||||||
|
)
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
|
|
|
@ -239,6 +239,7 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
|
|
||||||
prompts: dict[str, str] = {}
|
prompts: dict[str, str] = {}
|
||||||
collection_graph_id: str
|
collection_graph_id: str
|
||||||
|
index_batch_size: int = INDEX_BATCHSIZE
|
||||||
|
|
||||||
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||||
if not settings.USE_GLOBAL_GRAPHRAG:
|
if not settings.USE_GLOBAL_GRAPHRAG:
|
||||||
|
@ -279,18 +280,31 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
from nano_graphrag.prompt import PROMPTS
|
from nano_graphrag.prompt import PROMPTS
|
||||||
|
|
||||||
blacklist_keywords = ["default", "response", "process"]
|
blacklist_keywords = ["default", "response", "process"]
|
||||||
return {
|
settings_dict = {
|
||||||
prompt_name: {
|
"batch_size": {
|
||||||
"name": f"Prompt for '{prompt_name}'",
|
"name": (
|
||||||
"value": content,
|
"Index batch size " "(reduce if you have rate limit issues)"
|
||||||
"component": "text",
|
),
|
||||||
|
"value": INDEX_BATCHSIZE,
|
||||||
|
"component": "number",
|
||||||
}
|
}
|
||||||
for prompt_name, content in PROMPTS.items()
|
|
||||||
if all(
|
|
||||||
keyword not in prompt_name.lower() for keyword in blacklist_keywords
|
|
||||||
)
|
|
||||||
and isinstance(content, str)
|
|
||||||
}
|
}
|
||||||
|
settings_dict.update(
|
||||||
|
{
|
||||||
|
prompt_name: {
|
||||||
|
"name": f"Prompt for '{prompt_name}'",
|
||||||
|
"value": content,
|
||||||
|
"component": "text",
|
||||||
|
}
|
||||||
|
for prompt_name, content in PROMPTS.items()
|
||||||
|
if all(
|
||||||
|
keyword not in prompt_name.lower()
|
||||||
|
for keyword in blacklist_keywords
|
||||||
|
)
|
||||||
|
and isinstance(content, str)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return settings_dict
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(e)
|
print(e)
|
||||||
return {}
|
return {}
|
||||||
|
@ -355,8 +369,8 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
|
for doc_id in range(0, len(all_docs), self.index_batch_size):
|
||||||
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
|
cur_docs = all_docs[doc_id : doc_id + self.index_batch_size]
|
||||||
combined_doc = "\n".join(cur_docs)
|
combined_doc = "\n".join(cur_docs)
|
||||||
|
|
||||||
# Use insert for incremental updates
|
# Use insert for incremental updates
|
||||||
|
|
Loading…
Reference in New Issue
Block a user