diff --git a/knowledgehub/embeddings/openai.py b/knowledgehub/embeddings/openai.py index 25de270..da02755 100644 --- a/knowledgehub/embeddings/openai.py +++ b/knowledgehub/embeddings/openai.py @@ -1,4 +1,4 @@ -from langchain.embeddings import OpenAIEmbeddings as LCOpenAIEmbeddings +from langchain import embeddings as lcembeddings from .base import LangchainEmbeddings @@ -9,23 +9,13 @@ class OpenAIEmbeddings(LangchainEmbeddings): This method is wrapped around the Langchain OpenAIEmbeddings class. """ - _lc_class = LCOpenAIEmbeddings + _lc_class = lcembeddings.OpenAIEmbeddings class AzureOpenAIEmbeddings(LangchainEmbeddings): """Azure OpenAI embeddings. - This method is wrapped around the Langchain OpenAIEmbeddings class. + This method is wrapped around the Langchain AzureOpenAIEmbeddings class. """ - _lc_class = LCOpenAIEmbeddings - - def __init__(self, **params): - params["openai_api_type"] = "azure" - - # openai.error.InvalidRequestError: Too many inputs. The max number of - # inputs is 16. We hope to increase the number of inputs per request - # soon. Please contact us through an Azure support request at: - # https://go.microsoft.com/fwlink/?linkid=2213926 for further questions. - params["chunk_size"] = 16 - super().__init__(**params) + _lc_class = lcembeddings.AzureOpenAIEmbeddings diff --git a/tests/test_embedding_models.py b/tests/test_embedding_models.py index 5353006..2b29538 100644 --- a/tests/test_embedding_models.py +++ b/tests/test_embedding_models.py @@ -21,7 +21,7 @@ def test_azureopenai_embeddings_raw(openai_embedding_call): model = AzureOpenAIEmbeddings( model="text-embedding-ada-002", deployment="embedding-deployment", - openai_api_base="https://test.openai.azure.com/", + azure_endpoint="https://test.openai.azure.com/", openai_api_key="some-key", ) output = model("Hello world") @@ -39,7 +39,7 @@ def test_azureopenai_embeddings_batch_raw(openai_embedding_call): model = AzureOpenAIEmbeddings( model="text-embedding-ada-002", deployment="embedding-deployment", - openai_api_base="https://test.openai.azure.com/", + azure_endpoint="https://test.openai.azure.com/", openai_api_key="some-key", ) output = model(["Hello world", "Goodbye world"])