""" Queue provider for stress policy tool. This module provides interfaces and implementations for queue operations, supporting both authentication and download workflows. """ import json import logging import time from abc import ABC, abstractmethod from typing import Dict, List, Optional, Any, Tuple, Union import redis logger = logging.getLogger(__name__) class QueueProvider(ABC): """Abstract base class for queue operations.""" @abstractmethod def get_task(self, queue_name: str) -> Optional[Dict]: """Get a task from the specified queue.""" pass @abstractmethod def get_tasks_batch(self, queue_name: str, batch_size: int) -> List[Dict]: """Get a batch of tasks from the specified queue.""" pass @abstractmethod def report_success(self, queue_name: str, task_id: str, result: Dict) -> bool: """Report a successful task completion.""" pass @abstractmethod def report_failure(self, queue_name: str, task_id: str, error: Dict) -> bool: """Report a task failure.""" pass @abstractmethod def report_skipped(self, queue_name: str, task_id: str, reason: Dict) -> bool: """Report a task that was skipped.""" pass @abstractmethod def mark_in_progress(self, queue_name: str, task_id: str, worker_id: str) -> bool: """Mark a task as in progress.""" pass @abstractmethod def remove_in_progress(self, queue_name: str, task_id: str) -> bool: """Remove a task from the in-progress tracking.""" pass @abstractmethod def get_queue_length(self, queue_name: str) -> int: """Get the current length of a queue.""" pass @abstractmethod def add_task(self, queue_name: str, task: Dict) -> bool: """Add a task to a queue.""" pass @abstractmethod def add_tasks_batch(self, queue_name: str, tasks: List[Dict]) -> int: """Add a batch of tasks to a queue. Returns number of tasks added.""" pass class RedisQueueProvider(QueueProvider): """Redis implementation of the QueueProvider interface.""" def __init__(self, redis_host: str = "localhost", redis_port: int = 6379, redis_password: Optional[str] = None, redis_db: int = 0, env_prefix: Optional[str] = None): """Initialize the Redis queue provider.""" prefix = f"{env_prefix}_" if env_prefix else "" # Queue name constants # Authentication stage self.AUTH_INBOX = f"{prefix}queue2_auth_inbox" self.AUTH_RESULT = f"{prefix}queue2_auth_result" self.AUTH_FAIL = f"{prefix}queue2_auth_fail" self.AUTH_SKIPPED = f"{prefix}queue2_auth_skipped" self.AUTH_PROGRESS = f"{prefix}queue2_auth_progress" # Download stage self.DL_TASKS = f"{prefix}queue2_dl_inbox" self.DL_RESULT = f"{prefix}queue2_dl_result" self.DL_FAIL = f"{prefix}queue2_dl_fail" self.DL_SKIPPED = f"{prefix}queue2_dl_skipped" self.DL_PROGRESS = f"{prefix}queue2_dl_progress" self.redis_client = redis.Redis( host=redis_host, port=redis_port, password=redis_password, db=redis_db, decode_responses=True ) self._validate_connection() def _validate_connection(self) -> None: """Validate the Redis connection.""" try: self.redis_client.ping() logger.info("Successfully connected to Redis") except redis.ConnectionError as e: logger.error(f"Failed to connect to Redis: {e}") raise def get_task(self, queue_name: str) -> Optional[Dict]: """Get a task from the specified queue. For LIST type queues, this pops an item. For HASH type queues, this just reads an item without removing it. """ try: queue_type = self._get_queue_type(queue_name) if queue_type == "list": # BRPOP with a timeout of 1 second result = self.redis_client.brpop(queue_name, timeout=1) if result: _, task_data = result try: # Assume it's a JSON object, which is the standard format. return json.loads(task_data) except json.JSONDecodeError: # If it fails, check if it's the auth inbox queue and a plain string. # This provides backward compatibility with queues populated with raw URLs. if queue_name == self.AUTH_INBOX and isinstance(task_data, str): logger.debug(f"Task from '{queue_name}' is a plain string. Wrapping it in a task dictionary.") return {"url": task_data} else: # If it's not the auth inbox or not a string, log and re-raise. logger.error(f"Failed to decode JSON task from queue '{queue_name}': {task_data}") raise return None elif queue_type == "hash": # For hash queues, we just get a random key keys = self.redis_client.hkeys(queue_name) if not keys: return None # Get a random key import random key = random.choice(keys) value = self.redis_client.hget(queue_name, key) if value: task = json.loads(value) task["id"] = key # Add the key as id return task return None else: logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}") return None except Exception as e: logger.error(f"Error getting task from {queue_name}: {e}") return None def get_tasks_batch(self, queue_name: str, batch_size: int) -> List[Dict]: """Get a batch of tasks from the specified queue.""" tasks = [] try: queue_type = self._get_queue_type(queue_name) if queue_type == "list": # Use pipeline for efficiency pipe = self.redis_client.pipeline() for _ in range(batch_size): pipe.rpop(queue_name) results = pipe.execute() for result in results: if result: try: tasks.append(json.loads(result)) except json.JSONDecodeError: if queue_name == self.AUTH_INBOX and isinstance(result, str): tasks.append({"url": result}) else: logger.error(f"Failed to decode JSON task from batch in queue '{queue_name}': {result}") # In batch mode, we skip the malformed item and continue. elif queue_type == "hash": # For hash queues, get multiple random keys keys = self.redis_client.hkeys(queue_name) if not keys: return [] # Get random keys up to batch_size import random selected_keys = random.sample(keys, min(batch_size, len(keys))) # Use pipeline for efficiency pipe = self.redis_client.pipeline() for key in selected_keys: pipe.hget(queue_name, key) results = pipe.execute() for i, result in enumerate(results): if result: task = json.loads(result) task["id"] = selected_keys[i] # Add the key as id tasks.append(task) else: logger.warning(f"Unsupported queue type for batch operations: {queue_name}") except Exception as e: logger.error(f"Error getting tasks batch from {queue_name}: {e}") return tasks def report_success(self, queue_name: str, task_id: str, result: Dict) -> bool: """Report a successful task completion.""" try: # Ensure task_id is included in the result result["task_id"] = task_id result["timestamp"] = time.time() # Store in the success hash self.redis_client.hset(queue_name, task_id, json.dumps(result)) return True except Exception as e: logger.error(f"Error reporting success to {queue_name}: {e}") return False def report_failure(self, queue_name: str, task_id: str, error: Dict) -> bool: """Report a task failure.""" try: # Ensure task_id is included in the error error["task_id"] = task_id error["timestamp"] = time.time() # Store in the failure hash self.redis_client.hset(queue_name, task_id, json.dumps(error)) return True except Exception as e: logger.error(f"Error reporting failure to {queue_name}: {e}") return False def report_skipped(self, queue_name: str, task_id: str, reason: Dict) -> bool: """Report a task that was skipped.""" try: # Ensure task_id is included in the reason reason["task_id"] = task_id reason["timestamp"] = time.time() # Store in the skipped hash self.redis_client.hset(queue_name, task_id, json.dumps(reason)) return True except Exception as e: logger.error(f"Error reporting skipped to {queue_name}: {e}") return False def mark_in_progress(self, queue_name: str, task_id: str, worker_id: str) -> bool: """Mark a task as in progress.""" try: progress_data = { "task_id": task_id, "worker_id": worker_id, "start_time": time.time() } # Store in the progress hash self.redis_client.hset(queue_name, task_id, json.dumps(progress_data)) return True except Exception as e: logger.error(f"Error marking task in progress in {queue_name}: {e}") return False def remove_in_progress(self, queue_name: str, task_id: str) -> bool: """Remove a task from the in-progress tracking.""" try: # Remove from the progress hash self.redis_client.hdel(queue_name, task_id) return True except Exception as e: logger.error(f"Error removing task from progress in {queue_name}: {e}") return False def get_queue_length(self, queue_name: str) -> int: """Get the current length of a queue.""" try: queue_type = self._get_queue_type(queue_name) if queue_type == "list": return self.redis_client.llen(queue_name) elif queue_type == "hash": return self.redis_client.hlen(queue_name) else: logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}") return 0 except Exception as e: logger.error(f"Error getting queue length for {queue_name}: {e}") return 0 def add_task(self, queue_name: str, task: Dict) -> bool: """Add a task to a queue.""" try: queue_type = self._get_queue_type(queue_name) if queue_type == "list": # For list queues, we push to the left (LPUSH) self.redis_client.lpush(queue_name, json.dumps(task)) return True elif queue_type == "hash": # For hash queues, we need a task_id task_id = task.get("id") or task.get("task_id") if not task_id: logger.error(f"Cannot add task to hash queue {queue_name} without an id") return False self.redis_client.hset(queue_name, task_id, json.dumps(task)) return True else: logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}") return False except Exception as e: logger.error(f"Error adding task to {queue_name}: {e}") return False def add_tasks_batch(self, queue_name: str, tasks: List[Dict]) -> int: """Add a batch of tasks to a queue. Returns number of tasks added.""" if not tasks: return 0 try: queue_type = self._get_queue_type(queue_name) if queue_type == "list": # Use pipeline for efficiency pipe = self.redis_client.pipeline() for task in tasks: pipe.lpush(queue_name, json.dumps(task)) results = pipe.execute() return len([r for r in results if r]) elif queue_type == "hash": # Use pipeline for efficiency pipe = self.redis_client.pipeline() added_count = 0 for task in tasks: task_id = task.get("id") or task.get("task_id") if task_id: pipe.hset(queue_name, task_id, json.dumps(task)) added_count += 1 else: logger.warning(f"Skipping task without id for hash queue {queue_name}") pipe.execute() return added_count else: logger.warning(f"Unsupported queue type for batch operations: {queue_name}") return 0 except Exception as e: logger.error(f"Error adding tasks batch to {queue_name}: {e}") return 0 def _get_queue_type(self, queue_name: str) -> str: """Determine the Redis data type of a queue.""" try: queue_type = self.redis_client.type(queue_name) if not queue_type or queue_type == "none": # Queue doesn't exist yet, infer type from name if queue_name.endswith(("_inbox", "_tasks")): return "list" else: return "hash" return queue_type except Exception as e: logger.error(f"Error determining queue type for {queue_name}: {e}") return "unknown" def requeue_failed_tasks(self, source_queue: str, target_queue: str, batch_size: int = 100) -> int: """Requeue failed tasks from a failure queue to an inbox queue.""" try: # Get failed tasks failed_tasks = self.get_tasks_batch(source_queue, batch_size) if not failed_tasks: return 0 # Prepare tasks for requeuing requeued_count = 0 requeue_tasks = [] for task in failed_tasks: # Extract the original URL or task data url = task.get("url") if url: # For auth failures, we just need the URL requeue_tasks.append({"url": url}) requeued_count += 1 else: # For download failures, we need the original task data original_task = task.get("original_task") if original_task: requeue_tasks.append(original_task) requeued_count += 1 # Add tasks to target queue if requeue_tasks: self.add_tasks_batch(target_queue, requeue_tasks) # Remove from source queue pipe = self.redis_client.pipeline() for task in failed_tasks: task_id = task.get("id") or task.get("task_id") if task_id: pipe.hdel(source_queue, task_id) pipe.execute() return requeued_count except Exception as e: logger.error(f"Error requeuing failed tasks from {source_queue} to {target_queue}: {e}") return 0 def get_queue_stats(self) -> Dict[str, Dict[str, int]]: """Get statistics for all queues.""" stats = {} # Authentication queues auth_queues = { "auth_inbox": self.AUTH_INBOX, "auth_result": self.AUTH_RESULT, "auth_fail": self.AUTH_FAIL, "auth_skipped": self.AUTH_SKIPPED, "auth_progress": self.AUTH_PROGRESS } # Download queues dl_queues = { "dl_tasks": self.DL_TASKS, "dl_result": self.DL_RESULT, "dl_fail": self.DL_FAIL, "dl_skipped": self.DL_SKIPPED, "dl_progress": self.DL_PROGRESS } # Get stats for auth queues auth_stats = {} for name, queue in auth_queues.items(): auth_stats[name] = self.get_queue_length(queue) stats["auth"] = auth_stats # Get stats for download queues dl_stats = {} for name, queue in dl_queues.items(): dl_stats[name] = self.get_queue_length(queue) stats["download"] = dl_stats return stats