chore: added base_url parameter to CochereReranking (#743)
Co-authored-by: Mauro Gattari <mauro.gattari@infn.it>
This commit is contained in:
parent
833982ac81
commit
ffe766f24d
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from decouple import config
|
from decouple import config
|
||||||
|
|
||||||
from kotaemon.base import Document, Param
|
from kotaemon.base import Document, Param
|
||||||
|
@ -23,6 +24,9 @@ class CohereReranking(BaseReranking):
|
||||||
help="Cohere API key",
|
help="Cohere API key",
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
base_url: str = Param(
|
||||||
|
None, help="Rerank API base url. Default is https://api.cohere.com", required=False
|
||||||
|
)
|
||||||
|
|
||||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||||
"""Use Cohere Reranker model to re-order documents
|
"""Use Cohere Reranker model to re-order documents
|
||||||
|
@ -38,7 +42,7 @@ class CohereReranking(BaseReranking):
|
||||||
print("Cohere API key not found. Skipping rerankings.")
|
print("Cohere API key not found. Skipping rerankings.")
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
cohere_client = cohere.Client(self.cohere_api_key)
|
cohere_client = cohere.Client(self.cohere_api_key, base_url=self.base_url or os.getenv("CO_API_URL"))
|
||||||
compressed_docs: list[Document] = []
|
compressed_docs: list[Document] = []
|
||||||
|
|
||||||
if not documents: # to avoid empty api call
|
if not documents: # to avoid empty api call
|
||||||
|
|
Loading…
Reference in New Issue
Block a user