@@ -238,32 +238,62 @@ def _max_total_attempts(self, max_retries: int | None) -> int | None:
238238 """Convert max retries to max total attempts (includes initial attempt)."""
239239 return max_retries + 1 if max_retries is not None else None
240240
241- def _get_total_attempts (self ) -> int :
242- """Get total attempts including re-deliveries from visibility timeout expiration ."""
241+ def _get_retry_count (self ) -> int :
242+ """Get retry count safely, handling missing request or retries attribute ."""
243243 if not hasattr (self , "request" ):
244244 return 0
245+ return self .request .retries if hasattr (self .request , "retries" ) else 0
246+
247+ def _get_request_headers (self ) -> dict :
248+ """Get request headers safely, handling missing request or get method."""
249+ if not hasattr (self , "request" ) or self .request is None :
250+ return {}
251+ request = self .request
252+ try :
253+ if hasattr (request , "get" ) and callable (getattr (request , "get" , None )):
254+ return request .get ("headers" , {})
255+ except (AttributeError , TypeError ):
256+ pass
257+ return getattr (request , "headers" , {})
258+
259+ def _expected_total_attempts (self ) -> int :
260+ """Get expected total attempts based on retry count (retries + initial attempt)."""
261+ return self ._get_retry_count () + 1
262+
263+ def _get_total_attempts (self ) -> int :
264+ """Get total attempts including re-deliveries from visibility timeout expiration.
265+
266+ Returns:
267+ - Header value if present and valid (most accurate)
268+ - retry_count + 1 if header missing/invalid (best guess based on retry count)
269+ - 0 if request unavailable (rare, safe default for comparisons)
270+
271+ Returns int (not None) to be safe for comparisons and logging without null checks.
272+ """
273+ if not hasattr (self , "request" ) or self .request is None :
274+ return 0
245275
246- retry_count = self .request .retries if hasattr (self .request , "retries" ) else 0
247- headers = self .request .get ("headers" , {}) or {}
276+ headers = self ._get_request_headers ()
248277 total_attempts_header = headers .get ("total_attempts" , None )
249278
250279 if total_attempts_header is not None :
251280 try :
252281 return int (total_attempts_header )
253282 except (ValueError , TypeError ):
283+ retry_count = self ._get_retry_count ()
254284 log .warning (
255285 "Invalid total_attempts header value" ,
256286 extra = {"value" : total_attempts_header , "retry_count" : retry_count },
257287 )
258- return retry_count + 1
259- return retry_count + 1
288+ return self . _expected_total_attempts ()
289+ return self . _expected_total_attempts ()
260290
261291 def _has_exceeded_max_attempts (self , max_retries : int | None ) -> bool :
262292 """Check if task has exceeded max attempts (including re-deliveries)."""
263293 if max_retries is None :
264294 return False
265295
266- current_retries = self .request . retries if hasattr ( self , "request" ) else 0
296+ current_retries = self ._get_retry_count ()
267297 total_attempts = self ._get_total_attempts ()
268298 max_total_attempts = self ._max_total_attempts (max_retries )
269299
@@ -279,31 +309,32 @@ def safe_retry(self, max_retries=None, countdown=None, exc=None, **kwargs):
279309 task_max_retries = max_retries if max_retries is not None else self .max_retries
280310
281311 if self ._has_exceeded_max_attempts (task_max_retries ):
282- current_retries = self .request . retries if hasattr ( self , "request" ) else 0
312+ current_retries = self ._get_retry_count ()
283313 total_attempts = self ._get_total_attempts ()
284314 max_total_attempts = self ._max_total_attempts (task_max_retries )
285315 log .error (
286316 f"Task { self .name } exceeded max retries" ,
287317 extra = {
288- "task_name" : self .name ,
289318 "current_retries" : current_retries ,
290- "total_attempts" : total_attempts ,
291319 "max_retries" : task_max_retries ,
292320 "max_total_attempts" : max_total_attempts ,
321+ "task_name" : self .name ,
322+ "total_attempts" : total_attempts ,
293323 },
294324 )
295325 TASK_MAX_RETRIES_EXCEEDED_COUNTER .labels (task = self .name ).inc ()
296326 return False
297327
298- current_retries = self .request . retries if hasattr ( self , "request" ) else 0
328+ current_retries = self ._get_retry_count ()
299329 if countdown is None :
300330 countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2 ** current_retries )
301331
302332 try :
303333 total_attempts = self ._get_total_attempts ()
304334 headers = {}
305- if hasattr (self , "request" ) and hasattr (self .request , "headers" ):
306- headers .update (self .request .headers or {})
335+ request_headers = self ._get_request_headers ()
336+ if request_headers :
337+ headers .update (request_headers )
307338 headers .update (kwargs .get ("headers" , {}) or {})
308339 headers ["total_attempts" ] = total_attempts + 1
309340 kwargs ["headers" ] = headers
@@ -351,13 +382,23 @@ def _analyse_error(self, exception: SQLAlchemyError, *args, **kwargs):
351382 )
352383
353384 def _emit_queue_metrics (self ):
354- created_timestamp = self .request .get ("created_timestamp" , None )
385+ if not hasattr (self , "request" ) or self .request is None :
386+ return
387+ request = self .request
388+ if not hasattr (request , "get" ):
389+ return
390+ created_timestamp = request .get ("created_timestamp" , None )
355391 if created_timestamp :
356392 enqueued_time = datetime .fromisoformat (created_timestamp )
357393 now = datetime .now ()
358394 delta = now - enqueued_time
359395
360- queue_name = self .request .get ("delivery_info" , {}).get ("routing_key" , None )
396+ delivery_info = request .get ("delivery_info" , {})
397+ queue_name = (
398+ delivery_info .get ("routing_key" , None )
399+ if isinstance (delivery_info , dict )
400+ else None
401+ )
361402 time_in_queue_timer = TASK_TIME_IN_QUEUE .labels (
362403 task = self .name , queue = queue_name
363404 ) # TODO is None a valid label value
@@ -377,36 +418,41 @@ def run(self, *args, **kwargs):
377418 task = get_current_task ()
378419 if task and task .request :
379420 log_context .task_name = task .name
380- log_context .task_id = task .request .id
381-
382- # Track total attempts including re-deliveries from visibility timeout
383- # If this is a re-delivery (visibility timeout expired), log it
384- # Note: We can't modify headers here (they're read-only), but we can detect
385- # re-deliveries by checking if total_attempts is ahead of retry_count + 1
386- # (which indicates Redis re-queued the task without incrementing retry_count)
387- headers = task .request .get ("headers" , {}) or {}
388- total_attempts = headers .get ("total_attempts" , None )
389- retry_count = (
390- task .request .retries if hasattr (task .request , "retries" ) else 0
391- )
392-
393- if total_attempts is not None :
394- try :
395- total_attempts_int = int (total_attempts )
396- if total_attempts_int > retry_count + 1 :
397- log .warning (
398- f"Task { task .name } re-delivered (visibility timeout expired)" ,
421+ task_id = getattr (task .request , "id" , None )
422+ if task_id :
423+ log_context .task_id = task_id
424+
425+ try :
426+ headers = self ._get_request_headers ()
427+ header_total_attempts = headers .get ("total_attempts" )
428+ if header_total_attempts is not None :
429+ try :
430+ total_attempts = int (header_total_attempts )
431+ retry_count = (
432+ getattr (task .request , "retries" , 0 )
433+ if hasattr (task .request , "retries" )
434+ else 0
435+ )
436+ if total_attempts > retry_count + 1 :
437+ log .warning (
438+ f"Task { task .name } re-delivered (visibility timeout expired)" ,
439+ extra = {
440+ "retry_count" : retry_count ,
441+ "task_id" : task_id ,
442+ "total_attempts" : total_attempts ,
443+ },
444+ )
445+ except (ValueError , TypeError ):
446+ log .debug (
447+ "Invalid total_attempts header, skipping re-delivery detection" ,
399448 extra = {
400- "task_id" : task .request .id ,
401- "retry_count" : retry_count ,
402- "total_attempts" : total_attempts_int ,
449+ "task_id" : task_id ,
450+ "value" : header_total_attempts ,
403451 },
404452 )
405- except (ValueError , TypeError ):
406- log .debug (
407- "Invalid total_attempts header, skipping re-delivery detection" ,
408- extra = {"task_id" : task .request .id , "value" : total_attempts },
409- )
453+ except Exception :
454+ # Silently ignore errors accessing request headers to avoid breaking task execution
455+ pass
410456
411457 log_context .populate_from_sqlalchemy (db_session )
412458 set_log_context (log_context )
@@ -485,10 +531,8 @@ def wrap_up_dbsession(self, db_session):
485531 def on_retry (self , exc , task_id , args , kwargs , einfo ):
486532 res = super ().on_retry (exc , task_id , args , kwargs , einfo )
487533 self .task_retry_counter .inc ()
488- # Track retry count for better observability
489- retry_count = self .request .retries if hasattr (self , "request" ) else 0
490534 TASK_RETRY_WITH_COUNT_COUNTER .labels (
491- task = self .name , retry_count = str (retry_count )
535+ task = self .name , retry_count = str (self . _get_retry_count () )
492536 ).inc ()
493537 return res
494538
0 commit comments