[AUR-389] Add base interface and embedding model (#17)
This change provides the base interface of an embedding, and wrap the Langchain's OpenAI embedding. Usage as follow: ```python from kotaemon.embeddings import AzureOpenAIEmbeddings model = AzureOpenAIEmbeddings( model="text-embedding-ada-002", deployment="embedding-deployment", openai_api_base="https://test.openai.azure.com/", openai_api_key="some-key", ) output = model("Hello world") ```
This commit is contained in:
committed by
GitHub
parent
1061192731
commit
c339912312
1552
tests/resources/embedding_openai.json
Normal file
1552
tests/resources/embedding_openai.json
Normal file
File diff suppressed because it is too large
Load Diff
3094
tests/resources/embedding_openai_batch.json
Normal file
3094
tests/resources/embedding_openai_batch.json
Normal file
File diff suppressed because it is too large
Load Diff
46
tests/test_embedding_models.py
Normal file
46
tests/test_embedding_models.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f:
|
||||
openai_embedding_batch = json.load(f)
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||
openai_embedding = json.load(f)
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.api_resources.embedding.Embedding.create",
|
||||
side_effect=lambda *args, **kwargs: openai_embedding,
|
||||
)
|
||||
def test_azureopenai_embeddings_raw(openai_embedding_call):
|
||||
model = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
)
|
||||
output = model("Hello world")
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], float)
|
||||
openai_embedding_call.assert_called()
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.api_resources.embedding.Embedding.create",
|
||||
side_effect=lambda *args, **kwargs: openai_embedding_batch,
|
||||
)
|
||||
def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
|
||||
model = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
)
|
||||
output = model(["Hello world", "Goodbye world"])
|
||||
assert isinstance(output, list)
|
||||
assert isinstance(output[0], list)
|
||||
assert isinstance(output[0][0], float)
|
||||
openai_embedding_call.assert_called()
|
Reference in New Issue
Block a user