66"""
77
88import time
9- from typing import Union , List
9+ from typing import List
1010from fastapi import APIRouter , Depends , HTTPException , status
1111from loguru import logger
1212
1818 RerankingDocumentError ,
1919 ValidationError ,
2020)
21-
2221from src .api .dependencies import get_model_manager
2322from src .utils .validators import extract_embedding_kwargs
2423
25- router = APIRouter (tags = ["rerank" ])
24+ router = APIRouter (tags = ["rerank" ])
2625
2726
2827@router .post (
29- "/rerank" , response_model = RerankResponse , summary = "Rerank documents" , description = "Reranks the provided documents based on the given query."
28+ "/rerank" ,
29+ response_model = RerankResponse ,
30+ summary = "Rerank documents" ,
31+ description = "Reranks the provided documents based on the given query." ,
3032)
3133async def rerank_documents (
3234 request : RerankRequest ,
@@ -35,54 +37,82 @@ async def rerank_documents(
3537 """
3638 Rerank documents based on a query.
3739
38- This endpoint processes a list of documents and returns them ranked according to their relevance to the query.
39-
40+ This endpoint processes a list of documents and returns them ranked
41+ according to their relevance to the query.
42+
4043 Args:
41- request (RerankRequest) : The request object containing the query and documents to rank.
42- manager (ModelManager) : The model manager dependency to access the model.
44+ request: The request object containing the query and documents to rank
45+ manager: The model manager dependency to access the model
4346
4447 Returns:
45- RerankResponse: The response containing the ranked documents and processing time.
48+ RerankResponse: The response containing the ranked documents and processing time
4649
4750 Raises:
48- HTTPException: If there are validation errors, model loading issues, or unexpected errors.
51+ HTTPException: If there are validation errors, model loading issues, or unexpected errors
4952 """
5053 # Filter out empty documents and keep original indices
5154 valid_docs = [
5255 (i , doc .strip ()) for i , doc in enumerate (request .documents ) if doc .strip ()
5356 ]
5457
5558 if not valid_docs :
56- raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "No valid documents provided." )
59+ raise HTTPException (
60+ status_code = status .HTTP_400_BAD_REQUEST ,
61+ detail = "No valid documents provided." ,
62+ )
5763
5864 try :
65+ # Extract kwargs but exclude rerank-specific fields
5966 kwargs = extract_embedding_kwargs (request )
67+
68+ # Remove fields that are already passed as positional arguments
69+ # to avoid "got multiple values for argument" error
70+ kwargs .pop ("query" , None )
71+ kwargs .pop ("documents" , None )
72+ kwargs .pop ("top_k" , None )
73+
6074 model = manager .get_model (request .model_id )
6175 config = manager .model_configs [request .model_id ]
6276
63- start = time . time ()
64- if config . type == "rerank" :
65- scores = model . rank_document (
66- request .query , request . documents , request . top_k , ** kwargs
77+ if config . type != "rerank" :
78+ raise HTTPException (
79+ status_code = status . HTTP_400_BAD_REQUEST ,
80+ detail = f"Model ' { request .model_id } ' is not a rerank model. Type: { config . type } " ,
6781 )
68- processing_time = time .time () - start
6982
70- original_indices , documents_list = zip (* valid_docs )
71- results : List [RerankResult ] = []
83+ start = time .time ()
84+
85+ # Call rank_document with clean kwargs
86+ scores = model .rank_document (
87+ query = request .query ,
88+ documents = [doc for _ , doc in valid_docs ], # Use filtered documents
89+ top_k = request .top_k ,
90+ ** kwargs ,
91+ )
92+
93+ processing_time = time .time () - start
7294
73- for i , (orig_idx , doc ) in enumerate (zip (original_indices , documents_list )):
74- results .append (RerankResult (text = doc , score = scores [i ], index = orig_idx ))
95+ # Build results with original indices
96+ original_indices , documents_list = zip (* valid_docs )
97+ results : List [RerankResult ] = []
7598
76- # Sort results by score in descending order
77- results .sort ( key = lambda x : x . score , reverse = True )
99+ for i , ( orig_idx , doc ) in enumerate ( zip ( original_indices , documents_list )):
100+ results .append ( RerankResult ( text = doc , score = scores [ i ], index = orig_idx ) )
78101
79- logger .info (f"Reranked documents in { processing_time :.3f} seconds" )
80- return RerankResponse (
81- model_id = request .model_id ,
82- processing_time = processing_time ,
83- query = request .query ,
84- results = results ,
85- )
102+ # Sort results by score in descending order
103+ results .sort (key = lambda x : x .score , reverse = True )
104+
105+ logger .info (
106+ f"Reranked { len (results )} documents in { processing_time :.3f} s "
107+ f"(model: { request .model_id } )"
108+ )
109+
110+ return RerankResponse (
111+ model_id = request .model_id ,
112+ processing_time = processing_time ,
113+ query = request .query ,
114+ results = results ,
115+ )
86116
87117 except (ValidationError , ModelNotFoundError ) as e :
88118 raise HTTPException (status_code = e .status_code , detail = e .message )
@@ -94,5 +124,5 @@ async def rerank_documents(
94124 logger .exception ("Unexpected error in rerank_documents" )
95125 raise HTTPException (
96126 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
97- detail = f"Failed to create query embedding : { str (e )} " ,
127+ detail = f"Failed to rerank documents : { str (e )} " ,
98128 )
0 commit comments