[AUR-395] Adopt Example1 disclaimer pipeline (#42)
* Adopt Example1 disclaimer pipeline * Update Document class * Add composite components * Modify Extractor behaviours
This commit is contained in:
141
tests/test_composite.py
Normal file
141
tests/test_composite.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from kotaemon.composite import (
|
||||
GatedBranchingPipeline,
|
||||
GatedLinearPipeline,
|
||||
SimpleBranchingPipeline,
|
||||
SimpleLinearPipeline,
|
||||
)
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.post_processing.extractor import RegexExtractor
|
||||
from kotaemon.prompt.base import BasePromptComponent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
return AzureChatOpenAI(
|
||||
openai_api_base="OPENAI_API_BASE",
|
||||
openai_api_key="OPENAI_API_KEY",
|
||||
openai_api_version="OPENAI_API_VERSION",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=600,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_post_processor():
|
||||
return RegexExtractor(pattern=r"\d+")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt():
|
||||
return BasePromptComponent(template="Test prompt {value}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor):
|
||||
return SimpleLinearPipeline(
|
||||
prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor):
|
||||
return GatedLinearPipeline(
|
||||
prompt=mock_prompt,
|
||||
llm=mock_llm,
|
||||
post_processor=mock_post_processor,
|
||||
condition=RegexExtractor(pattern="positive"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor):
|
||||
return GatedLinearPipeline(
|
||||
prompt=mock_prompt,
|
||||
llm=mock_llm,
|
||||
post_processor=mock_post_processor,
|
||||
condition=RegexExtractor(pattern="negative"),
|
||||
)
|
||||
|
||||
|
||||
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
|
||||
openai_mocker = mocker.patch.object(
|
||||
AzureChatOpenAI, "run", return_value="This is a test 123"
|
||||
)
|
||||
|
||||
result = mock_simple_linear_pipeline.run(value="abc")
|
||||
|
||||
assert result.text == "123"
|
||||
assert openai_mocker.call_count == 1
|
||||
|
||||
|
||||
def test_gated_linear_pipeline_run_positive(
|
||||
mocker, mock_gated_linear_pipeline_positive
|
||||
):
|
||||
openai_mocker = mocker.patch.object(
|
||||
AzureChatOpenAI, "run", return_value="This is a test 123."
|
||||
)
|
||||
|
||||
result = mock_gated_linear_pipeline_positive.run(
|
||||
value="abc", condition_text="positive condition"
|
||||
)
|
||||
|
||||
assert result.text == "123"
|
||||
assert openai_mocker.call_count == 1
|
||||
|
||||
|
||||
def test_gated_linear_pipeline_run_negative(
|
||||
mocker, mock_gated_linear_pipeline_positive
|
||||
):
|
||||
openai_mocker = mocker.patch.object(
|
||||
AzureChatOpenAI, "run", return_value="This is a test 123."
|
||||
)
|
||||
|
||||
result = mock_gated_linear_pipeline_positive.run(
|
||||
value="abc", condition_text="negative condition"
|
||||
)
|
||||
|
||||
assert result.content is None
|
||||
assert openai_mocker.call_count == 0
|
||||
|
||||
|
||||
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
|
||||
openai_mocker = mocker.patch.object(
|
||||
AzureChatOpenAI,
|
||||
"run",
|
||||
side_effect=[
|
||||
"This is a test 123.",
|
||||
"a quick brown fox",
|
||||
"jumps over the lazy dog 456",
|
||||
],
|
||||
)
|
||||
pipeline = SimpleBranchingPipeline()
|
||||
for _ in range(3):
|
||||
pipeline.add_branch(mock_simple_linear_pipeline)
|
||||
|
||||
result = pipeline.run(value="abc")
|
||||
texts = [each.text for each in result]
|
||||
|
||||
assert len(result) == 3
|
||||
assert texts == ["123", "", "456"]
|
||||
assert openai_mocker.call_count == 3
|
||||
|
||||
|
||||
def test_simple_gated_branching_pipeline_run(
|
||||
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
|
||||
):
|
||||
openai_mocker = mocker.patch.object(
|
||||
AzureChatOpenAI, "run", return_value="a quick brown fox"
|
||||
)
|
||||
pipeline = GatedBranchingPipeline()
|
||||
|
||||
pipeline.add_branch(mock_gated_linear_pipeline_negative)
|
||||
pipeline.add_branch(mock_gated_linear_pipeline_positive)
|
||||
pipeline.add_branch(mock_gated_linear_pipeline_positive)
|
||||
|
||||
result = pipeline.run(value="abc", condition_text="positive condition")
|
||||
|
||||
assert result.text == ""
|
||||
assert openai_mocker.call_count == 2
|
49
tests/test_documents.py
Normal file
49
tests/test_documents.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
from kotaemon.documents.base import Document, RetrievedDocument
|
||||
|
||||
|
||||
def test_document_constructor_with_builtin_types():
|
||||
for value in ["str", 1, {}, set(), [], tuple, None]:
|
||||
doc = Document(value)
|
||||
assert doc.text == (str(value) if value else "")
|
||||
assert doc.content == value
|
||||
assert bool(doc) == bool(value)
|
||||
|
||||
|
||||
def test_document_constructor_with_document():
|
||||
text = "Sample text"
|
||||
doc1 = Document(text)
|
||||
doc2 = Document(doc1)
|
||||
assert doc2.text == doc1.text
|
||||
assert doc2.content == doc1.content
|
||||
|
||||
|
||||
def test_document_to_haystack_format():
|
||||
text = "Sample text"
|
||||
metadata = {"filename": "sample.txt"}
|
||||
doc = Document(text, metadata=metadata)
|
||||
haystack_doc = doc.to_haystack_format()
|
||||
assert isinstance(haystack_doc, HaystackDocument)
|
||||
assert haystack_doc.content == doc.text
|
||||
assert haystack_doc.meta == metadata
|
||||
|
||||
|
||||
def test_retrieved_document_default_values():
|
||||
sample_text = "text"
|
||||
retrieved_doc = RetrievedDocument(text=sample_text)
|
||||
assert retrieved_doc.text == sample_text
|
||||
assert retrieved_doc.score == 0.0
|
||||
assert retrieved_doc.retrieval_metadata == {}
|
||||
|
||||
|
||||
def test_retrieved_document_attributes():
|
||||
sample_text = "text"
|
||||
score = 0.8
|
||||
metadata = {"source": "retrieval_system"}
|
||||
retrieved_doc = RetrievedDocument(
|
||||
text=sample_text, score=score, retrieval_metadata=metadata
|
||||
)
|
||||
assert retrieved_doc.text == sample_text
|
||||
assert retrieved_doc.score == score
|
||||
assert retrieved_doc.retrieval_metadata == metadata
|
@@ -14,8 +14,8 @@ def regex_extractor():
|
||||
def test_run_document(regex_extractor):
|
||||
document = Document(text="This is a test. 1 2 3")
|
||||
extracted_document = regex_extractor(document)
|
||||
extracted_texts = [each.text for each in extracted_document]
|
||||
assert extracted_texts == ["One", "Two", "Three"]
|
||||
assert extracted_document.text == "One"
|
||||
assert extracted_document.matches == ["One", "Two", "Three"]
|
||||
|
||||
|
||||
def test_is_document(regex_extractor):
|
||||
@@ -30,11 +30,13 @@ def test_is_batch(regex_extractor):
|
||||
|
||||
def test_run_raw(regex_extractor):
|
||||
output = regex_extractor("This is a test. 123")
|
||||
output = [each.text for each in output]
|
||||
assert output == ["123"]
|
||||
assert output.text == "123"
|
||||
assert output.matches == ["123"]
|
||||
|
||||
|
||||
def test_run_batch_raw(regex_extractor):
|
||||
output = regex_extractor(["This is a test. 123", "456"])
|
||||
output = [[each.text for each in batch] for batch in output]
|
||||
assert output == [["123"], ["456"]]
|
||||
extracted_text = [each.text for each in output]
|
||||
extracted_matches = [each.matches for each in output]
|
||||
assert extracted_text == ["123", "456"]
|
||||
assert extracted_matches == [["123"], ["456"]]
|
||||
|
@@ -54,10 +54,7 @@ def test_run():
|
||||
|
||||
result = prompt()
|
||||
|
||||
assert (
|
||||
result.text
|
||||
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
|
||||
)
|
||||
assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One"
|
||||
|
||||
|
||||
def test_set_method():
|
||||
|
Reference in New Issue
Block a user