Skip to content

Commit 2e5859f

Browse files
committed
[FIX ERROR]: multiple args 'query'
1 parent 64b5451 commit 2e5859f

File tree

3 files changed

+91
-57
lines changed

3 files changed

+91
-57
lines changed

src/api/routers/rerank.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import time
9-
from typing import Union, List
9+
from typing import List
1010
from fastapi import APIRouter, Depends, HTTPException, status
1111
from loguru import logger
1212

@@ -18,15 +18,17 @@
1818
RerankingDocumentError,
1919
ValidationError,
2020
)
21-
2221
from src.api.dependencies import get_model_manager
2322
from src.utils.validators import extract_embedding_kwargs
2423

25-
router = APIRouter(tags=["rerank"])
24+
router = APIRouter(tags=["rerank"])
2625

2726

2827
@router.post(
29-
"/rerank", response_model=RerankResponse, summary="Rerank documents", description="Reranks the provided documents based on the given query."
28+
"/rerank",
29+
response_model=RerankResponse,
30+
summary="Rerank documents",
31+
description="Reranks the provided documents based on the given query.",
3032
)
3133
async def rerank_documents(
3234
request: RerankRequest,
@@ -35,54 +37,82 @@ async def rerank_documents(
3537
"""
3638
Rerank documents based on a query.
3739
38-
This endpoint processes a list of documents and returns them ranked according to their relevance to the query.
39-
40+
This endpoint processes a list of documents and returns them ranked
41+
according to their relevance to the query.
42+
4043
Args:
41-
request (RerankRequest): The request object containing the query and documents to rank.
42-
manager (ModelManager): The model manager dependency to access the model.
44+
request: The request object containing the query and documents to rank
45+
manager: The model manager dependency to access the model
4346
4447
Returns:
45-
RerankResponse: The response containing the ranked documents and processing time.
48+
RerankResponse: The response containing the ranked documents and processing time
4649
4750
Raises:
48-
HTTPException: If there are validation errors, model loading issues, or unexpected errors.
51+
HTTPException: If there are validation errors, model loading issues, or unexpected errors
4952
"""
5053
# Filter out empty documents and keep original indices
5154
valid_docs = [
5255
(i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()
5356
]
5457

5558
if not valid_docs:
56-
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid documents provided.")
59+
raise HTTPException(
60+
status_code=status.HTTP_400_BAD_REQUEST,
61+
detail="No valid documents provided.",
62+
)
5763

5864
try:
65+
# Extract kwargs but exclude rerank-specific fields
5966
kwargs = extract_embedding_kwargs(request)
67+
68+
# Remove fields that are already passed as positional arguments
69+
# to avoid "got multiple values for argument" error
70+
kwargs.pop("query", None)
71+
kwargs.pop("documents", None)
72+
kwargs.pop("top_k", None)
73+
6074
model = manager.get_model(request.model_id)
6175
config = manager.model_configs[request.model_id]
6276

63-
start = time.time()
64-
if config.type == "rerank":
65-
scores = model.rank_document(
66-
request.query, request.documents, request.top_k, **kwargs
77+
if config.type != "rerank":
78+
raise HTTPException(
79+
status_code=status.HTTP_400_BAD_REQUEST,
80+
detail=f"Model '{request.model_id}' is not a rerank model. Type: {config.type}",
6781
)
68-
processing_time = time.time() - start
6982

70-
original_indices, documents_list = zip(*valid_docs)
71-
results: List[RerankResult] = []
83+
start = time.time()
84+
85+
# Call rank_document with clean kwargs
86+
scores = model.rank_document(
87+
query=request.query,
88+
documents=[doc for _, doc in valid_docs], # Use filtered documents
89+
top_k=request.top_k,
90+
**kwargs,
91+
)
92+
93+
processing_time = time.time() - start
7294

73-
for i, (orig_idx, doc) in enumerate(zip(original_indices, documents_list)):
74-
results.append(RerankResult(text=doc, score=scores[i], index=orig_idx))
95+
# Build results with original indices
96+
original_indices, documents_list = zip(*valid_docs)
97+
results: List[RerankResult] = []
7598

76-
# Sort results by score in descending order
77-
results.sort(key=lambda x: x.score, reverse=True)
99+
for i, (orig_idx, doc) in enumerate(zip(original_indices, documents_list)):
100+
results.append(RerankResult(text=doc, score=scores[i], index=orig_idx))
78101

79-
logger.info(f"Reranked documents in {processing_time:.3f} seconds")
80-
return RerankResponse(
81-
model_id=request.model_id,
82-
processing_time=processing_time,
83-
query=request.query,
84-
results=results,
85-
)
102+
# Sort results by score in descending order
103+
results.sort(key=lambda x: x.score, reverse=True)
104+
105+
logger.info(
106+
f"Reranked {len(results)} documents in {processing_time:.3f}s "
107+
f"(model: {request.model_id})"
108+
)
109+
110+
return RerankResponse(
111+
model_id=request.model_id,
112+
processing_time=processing_time,
113+
query=request.query,
114+
results=results,
115+
)
86116

87117
except (ValidationError, ModelNotFoundError) as e:
88118
raise HTTPException(status_code=e.status_code, detail=e.message)
@@ -94,5 +124,5 @@ async def rerank_documents(
94124
logger.exception("Unexpected error in rerank_documents")
95125
raise HTTPException(
96126
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
97-
detail=f"Failed to create query embedding: {str(e)}",
127+
detail=f"Failed to rerank documents: {str(e)}",
98128
)

src/models/schemas/requests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ class Config:
126126
"model_id": "jina-reranker-v3",
127127
"query": "Python best programming languages for data science",
128128
"top_k": 4,
129-
"prompt": "Rerank document based user query",
130129
"documents": [
131130
"Python is a popular language for data science due to its extensive libraries.",
132131
"R is widely used in statistical computing and data analysis.",

src/utils/validators.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@
1010
from src.core.exceptions import TextTooLongError, BatchTooLargeError, ValidationError
1111

1212

13+
def validate_text(text: str, max_length: int = 8192, allow_empty: bool = False) -> None:
14+
"""
15+
Validate a single text input.
16+
17+
Args:
18+
text: Input text to validate
19+
max_length: Maximum allowed text length
20+
allow_empty: Whether to allow empty strings
21+
22+
Raises:
23+
ValidationError: If text is empty and not allowed
24+
TextTooLongError: If text exceeds max_length
25+
"""
26+
if not allow_empty and not text.strip():
27+
raise ValidationError("text", "Text cannot be empty")
28+
29+
if len(text) > max_length:
30+
raise TextTooLongError(len(text), max_length)
31+
32+
1333
def validate_texts(
1434
texts: List[str],
1535
max_length: int = 8192,
@@ -71,30 +91,6 @@ def validate_model_id(model_id: str, available_models: List[str]) -> None:
7191
)
7292

7393

74-
def sanitize_text(text: str, max_length: int = 8192) -> str:
75-
"""
76-
Sanitize text input by removing excessive whitespace and truncating.
77-
78-
Args:
79-
text: Input text to sanitize
80-
max_length: Maximum length to truncate to
81-
82-
Returns:
83-
Sanitized text
84-
"""
85-
# Remove leading/trailing whitespace
86-
text = text.strip()
87-
88-
# Replace multiple whitespaces with single space
89-
text = " ".join(text.split())
90-
91-
# Truncate if too long
92-
if len(text) > max_length:
93-
text = text[:max_length]
94-
95-
return text
96-
97-
9894
def extract_embedding_kwargs(request: BaseModel) -> Dict[str, Any]:
9995
"""
10096
Extract embedding kwargs from a request object.
@@ -110,7 +106,7 @@ def extract_embedding_kwargs(request: BaseModel) -> Dict[str, Any]:
110106
111107
Example:
112108
>>> request = EmbedRequest(
113-
... text="hello",
109+
... texts=["hello"],
114110
... model_id="qwen3-0.6b",
115111
... options=EmbeddingOptions(normalize_embeddings=True),
116112
... batch_size=32 # Extra field
@@ -125,7 +121,16 @@ def extract_embedding_kwargs(request: BaseModel) -> Dict[str, Any]:
125121
kwargs.update(request.options.to_kwargs())
126122

127123
# Extract extra fields (excluding standard fields)
128-
standard_fields = {"text", "texts", "model_id", "prompt", "options"}
124+
standard_fields = {
125+
"text",
126+
"texts",
127+
"model_id",
128+
"prompt",
129+
"options",
130+
"query",
131+
"documents",
132+
"top_k",
133+
}
129134
request_dict = request.model_dump()
130135

131136
for key, value in request_dict.items():

0 commit comments

Comments
 (0)