Move LLM-related components into LLM module (#74)
* Move splitter into indexing module * Rename post_processing module to parsers * Migrate LLM-specific composite pipelines into llms module This change moves the `splitters` module into `indexing` module. The `indexing` module will be created soon, to house `indexing`-related components. This change renames `post_processing` module into `parsers` module. Post-processing is a generic term which provides very little information. In the future, we will add other extractors into the `parser` module, like Metadata extractor... This change migrates the composite elements into `llms` module. These elements heavily assume that the internal nodes are llm-specific. As a result, migrating these elements into `llms` module will make them more discoverable, and simplify code base structure.
This commit is contained in:
committed by
GitHub
parent
9945afdf6f
commit
f8b8d86d4e
@@ -1,8 +1,9 @@
|
||||
from langchain.schema.messages import AIMessage, SystemMessage
|
||||
|
||||
from .chats import AzureChatOpenAI, ChatLLM
|
||||
from .chats.base import BaseMessage, HumanMessage
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, BaseMessage, ChatLLM, HumanMessage
|
||||
from .completions import LLM, AzureOpenAI, OpenAI
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
from .prompts import BasePromptComponent, PromptTemplate
|
||||
|
||||
__all__ = [
|
||||
@@ -20,4 +21,9 @@ __all__ = [
|
||||
# prompt-specific components
|
||||
"BasePromptComponent",
|
||||
"PromptTemplate",
|
||||
# strategies
|
||||
"SimpleLinearPipeline",
|
||||
"GatedLinearPipeline",
|
||||
"SimpleBranchingPipeline",
|
||||
"GatedBranchingPipeline",
|
||||
]
|
||||
|
185
knowledgehub/llms/branching.py
Normal file
185
knowledgehub/llms/branching.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from theflow import Param
|
||||
|
||||
from ..base import BaseComponent, Document
|
||||
from .linear import GatedLinearPipeline
|
||||
|
||||
|
||||
class SimpleBranchingPipeline(BaseComponent):
|
||||
"""
|
||||
A simple branching pipeline for executing multiple branches.
|
||||
|
||||
Attributes:
|
||||
branches (List[BaseComponent]): The list of branches to be executed.
|
||||
|
||||
Example Usage:
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedLinearPipeline,
|
||||
)
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
pipeline = SimpleBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
pipeline.add_branch(
|
||||
GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern=f"{i}"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
)
|
||||
print(pipeline(condition_text="1"))
|
||||
print(pipeline(condition_text="2"))
|
||||
print(pipeline(condition_text="12"))
|
||||
"""
|
||||
|
||||
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
|
||||
|
||||
def add_branch(self, component: BaseComponent):
|
||||
"""
|
||||
Add a new branch to the pipeline.
|
||||
|
||||
Args:
|
||||
component (BaseComponent): The branch component to be added.
|
||||
"""
|
||||
self.branches.append(component)
|
||||
|
||||
def run(self, **prompt_kwargs):
|
||||
"""
|
||||
Execute the pipeline by running each branch and return the outputs as a list.
|
||||
|
||||
Args:
|
||||
**prompt_kwargs: Keyword arguments for the branches.
|
||||
|
||||
Returns:
|
||||
List: The outputs of each branch as a list.
|
||||
"""
|
||||
output = []
|
||||
for i, branch in enumerate(self.branches):
|
||||
self._prepare_child(branch, name=f"branch-{i}")
|
||||
output.append(branch(**prompt_kwargs))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||
"""
|
||||
A simple gated branching pipeline for executing multiple branches based on a
|
||||
condition.
|
||||
|
||||
This class extends the SimpleBranchingPipeline class and adds the ability to execute
|
||||
the branches until a branch returns a non-empty output based on a condition.
|
||||
|
||||
Attributes:
|
||||
branches (List[BaseComponent]): The list of branches to be executed.
|
||||
|
||||
Example Usage:
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
BasePromptComponent,
|
||||
GatedLinearPipeline,
|
||||
)
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
pipeline = GatedBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
pipeline.add_branch(
|
||||
GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern=f"{i}"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
)
|
||||
print(pipeline(condition_text="1"))
|
||||
print(pipeline(condition_text="2"))
|
||||
"""
|
||||
|
||||
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
|
||||
"""
|
||||
Execute the pipeline by running each branch and return the output of the first
|
||||
branch that returns a non-empty output based on the provided condition.
|
||||
|
||||
Args:
|
||||
condition_text (str): The condition text to evaluate for each branch.
|
||||
Default to None.
|
||||
**prompt_kwargs: Keyword arguments for the branches.
|
||||
|
||||
Returns:
|
||||
Union[OutputType, None]: The output of the first branch that satisfies the
|
||||
condition, or None if no branch satisfies the condition.
|
||||
|
||||
Raise:
|
||||
ValueError: If condition_text is None
|
||||
"""
|
||||
if condition_text is None:
|
||||
raise ValueError("`condition_text` must be provided.")
|
||||
|
||||
for i, branch in enumerate(self.branches):
|
||||
self._prepare_child(branch, name=f"branch-{i}")
|
||||
output = branch(condition_text=condition_text, **prompt_kwargs)
|
||||
if output:
|
||||
return output
|
||||
|
||||
return Document(None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dotenv
|
||||
|
||||
from kotaemon.llms import BasePromptComponent
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
secrets = dotenv.dotenv_values(".env")
|
||||
|
||||
pipeline = GatedBranchingPipeline()
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
|
||||
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
pipeline.add_branch(
|
||||
GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern=f"{i}"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
)
|
||||
pipeline(condition_text="1")
|
@@ -1,4 +1,4 @@
|
||||
from .base import ChatLLM
|
||||
from .base import BaseMessage, ChatLLM, HumanMessage
|
||||
from .openai import AzureChatOpenAI
|
||||
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI"]
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "BaseMessage", "HumanMessage"]
|
||||
|
153
knowledgehub/llms/linear.py
Normal file
153
knowledgehub/llms/linear.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import Document, IO_Type
|
||||
from .chats import ChatLLM
|
||||
from .completions import LLM
|
||||
from .prompts import BasePromptComponent
|
||||
|
||||
|
||||
class SimpleLinearPipeline(BaseComponent):
|
||||
"""
|
||||
A simple pipeline for running a function with a prompt, a language model, and an
|
||||
optional post-processor.
|
||||
|
||||
Attributes:
|
||||
prompt (BasePromptComponent): The prompt component used to generate the initial
|
||||
input.
|
||||
llm (Union[ChatLLM, LLM]): The language model component used to generate the
|
||||
output.
|
||||
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
|
||||
post-processor component or function.
|
||||
|
||||
Example Usage:
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.llms import BasePromptComponent
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
pipeline = SimpleLinearPipeline(
|
||||
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
print(pipeline(word="lone"))
|
||||
"""
|
||||
|
||||
prompt: BasePromptComponent
|
||||
llm: Union[ChatLLM, LLM]
|
||||
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
llm_kwargs: Optional[dict] = {},
|
||||
post_processor_kwargs: Optional[dict] = {},
|
||||
**prompt_kwargs,
|
||||
):
|
||||
"""
|
||||
Run the function with the given arguments and return the final output as a
|
||||
Document object.
|
||||
|
||||
Args:
|
||||
llm_kwargs (dict): Keyword arguments for the llm call.
|
||||
post_processor_kwargs (dict): Keyword arguments for the post_processor.
|
||||
**prompt_kwargs: Keyword arguments for populating the prompt.
|
||||
|
||||
Returns:
|
||||
Document: The final output of the function as a Document object.
|
||||
"""
|
||||
prompt = self.prompt(**prompt_kwargs)
|
||||
llm_output = self.llm(prompt.text, **llm_kwargs)
|
||||
if self.post_processor is not None:
|
||||
final_output = self.post_processor(llm_output, **post_processor_kwargs)[0]
|
||||
else:
|
||||
final_output = llm_output
|
||||
|
||||
return Document(final_output)
|
||||
|
||||
|
||||
class GatedLinearPipeline(SimpleLinearPipeline):
|
||||
"""
|
||||
A pipeline that extends the SimpleLinearPipeline class and adds a condition
|
||||
attribute.
|
||||
|
||||
Attributes:
|
||||
condition (Callable[[IO_Type], Any]): A callable function that represents the
|
||||
condition.
|
||||
|
||||
Example Usage:
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.llms import BasePromptComponent
|
||||
from kotaemon.parsers import RegexExtractor
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="your openai api base",
|
||||
openai_api_key="your openai api key",
|
||||
openai_api_version="your openai api version",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
pipeline = GatedLinearPipeline(
|
||||
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
|
||||
condition=RegexExtractor(pattern="some pattern"),
|
||||
llm=llm,
|
||||
post_processor=identity,
|
||||
)
|
||||
print(pipeline(condition_text="some pattern", word="lone"))
|
||||
print(pipeline(condition_text="other pattern", word="lone"))
|
||||
"""
|
||||
|
||||
condition: Callable[[IO_Type], Any]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
condition_text: Optional[str] = None,
|
||||
llm_kwargs: Optional[dict] = {},
|
||||
post_processor_kwargs: Optional[dict] = {},
|
||||
**prompt_kwargs,
|
||||
) -> Document:
|
||||
"""
|
||||
Run the pipeline with the given arguments and return the final output as a
|
||||
Document object.
|
||||
|
||||
Args:
|
||||
condition_text (str): The condition text to evaluate. Default to None.
|
||||
llm_kwargs (dict): Additional keyword arguments for the language model call.
|
||||
post_processor_kwargs (dict): Additional keyword arguments for the
|
||||
post-processor.
|
||||
**prompt_kwargs: Keyword arguments for populating the prompt.
|
||||
|
||||
Returns:
|
||||
Document: The final output of the pipeline as a Document object.
|
||||
|
||||
Raises:
|
||||
ValueError: If condition_text is None
|
||||
"""
|
||||
if condition_text is None:
|
||||
raise ValueError("`condition_text` must be provided")
|
||||
|
||||
if self.condition(condition_text)[0]:
|
||||
return super().run(
|
||||
llm_kwargs=llm_kwargs,
|
||||
post_processor_kwargs=post_processor_kwargs,
|
||||
**prompt_kwargs,
|
||||
)
|
||||
|
||||
return Document(None)
|
Reference in New Issue
Block a user