2025-12-26 10:05:00 +03:00

471 lines
18 KiB
Python

"""
Queue provider for stress policy tool.
This module provides interfaces and implementations for queue operations,
supporting both authentication and download workflows.
"""
import json
import logging
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Tuple, Union
import redis
logger = logging.getLogger(__name__)
class QueueProvider(ABC):
"""Abstract base class for queue operations."""
@abstractmethod
def get_task(self, queue_name: str) -> Optional[Dict]:
"""Get a task from the specified queue."""
pass
@abstractmethod
def get_tasks_batch(self, queue_name: str, batch_size: int) -> List[Dict]:
"""Get a batch of tasks from the specified queue."""
pass
@abstractmethod
def report_success(self, queue_name: str, task_id: str, result: Dict) -> bool:
"""Report a successful task completion."""
pass
@abstractmethod
def report_failure(self, queue_name: str, task_id: str, error: Dict) -> bool:
"""Report a task failure."""
pass
@abstractmethod
def report_skipped(self, queue_name: str, task_id: str, reason: Dict) -> bool:
"""Report a task that was skipped."""
pass
@abstractmethod
def mark_in_progress(self, queue_name: str, task_id: str, worker_id: str) -> bool:
"""Mark a task as in progress."""
pass
@abstractmethod
def remove_in_progress(self, queue_name: str, task_id: str) -> bool:
"""Remove a task from the in-progress tracking."""
pass
@abstractmethod
def get_queue_length(self, queue_name: str) -> int:
"""Get the current length of a queue."""
pass
@abstractmethod
def add_task(self, queue_name: str, task: Dict) -> bool:
"""Add a task to a queue."""
pass
@abstractmethod
def add_tasks_batch(self, queue_name: str, tasks: List[Dict]) -> int:
"""Add a batch of tasks to a queue. Returns number of tasks added."""
pass
class RedisQueueProvider(QueueProvider):
"""Redis implementation of the QueueProvider interface."""
def __init__(self, redis_host: str = "localhost", redis_port: int = 6379,
redis_password: Optional[str] = None, redis_db: int = 0,
env_prefix: Optional[str] = None):
"""Initialize the Redis queue provider."""
prefix = f"{env_prefix}_" if env_prefix else ""
# Queue name constants
# Authentication stage
self.AUTH_INBOX = f"{prefix}queue2_auth_inbox"
self.AUTH_RESULT = f"{prefix}queue2_auth_result"
self.AUTH_FAIL = f"{prefix}queue2_auth_fail"
self.AUTH_SKIPPED = f"{prefix}queue2_auth_skipped"
self.AUTH_PROGRESS = f"{prefix}queue2_auth_progress"
# Download stage
self.DL_TASKS = f"{prefix}queue2_dl_inbox"
self.DL_RESULT = f"{prefix}queue2_dl_result"
self.DL_FAIL = f"{prefix}queue2_dl_fail"
self.DL_SKIPPED = f"{prefix}queue2_dl_skipped"
self.DL_PROGRESS = f"{prefix}queue2_dl_progress"
self.redis_client = redis.Redis(
host=redis_host,
port=redis_port,
password=redis_password,
db=redis_db,
decode_responses=True
)
self._validate_connection()
def _validate_connection(self) -> None:
"""Validate the Redis connection."""
try:
self.redis_client.ping()
logger.info("Successfully connected to Redis")
except redis.ConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
def get_task(self, queue_name: str) -> Optional[Dict]:
"""Get a task from the specified queue.
For LIST type queues, this pops an item.
For HASH type queues, this just reads an item without removing it.
"""
try:
queue_type = self._get_queue_type(queue_name)
if queue_type == "list":
# BRPOP with a timeout of 1 second
result = self.redis_client.brpop(queue_name, timeout=1)
if result:
_, task_data = result
try:
# Assume it's a JSON object, which is the standard format.
return json.loads(task_data)
except json.JSONDecodeError:
# If it fails, check if it's the auth inbox queue and a plain string.
# This provides backward compatibility with queues populated with raw URLs.
if queue_name == self.AUTH_INBOX and isinstance(task_data, str):
logger.debug(f"Task from '{queue_name}' is a plain string. Wrapping it in a task dictionary.")
return {"url": task_data}
else:
# If it's not the auth inbox or not a string, log and re-raise.
logger.error(f"Failed to decode JSON task from queue '{queue_name}': {task_data}")
raise
return None
elif queue_type == "hash":
# For hash queues, we just get a random key
keys = self.redis_client.hkeys(queue_name)
if not keys:
return None
# Get a random key
import random
key = random.choice(keys)
value = self.redis_client.hget(queue_name, key)
if value:
task = json.loads(value)
task["id"] = key # Add the key as id
return task
return None
else:
logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}")
return None
except Exception as e:
logger.error(f"Error getting task from {queue_name}: {e}")
return None
def get_tasks_batch(self, queue_name: str, batch_size: int) -> List[Dict]:
"""Get a batch of tasks from the specified queue."""
tasks = []
try:
queue_type = self._get_queue_type(queue_name)
if queue_type == "list":
# Use pipeline for efficiency
pipe = self.redis_client.pipeline()
for _ in range(batch_size):
pipe.rpop(queue_name)
results = pipe.execute()
for result in results:
if result:
try:
tasks.append(json.loads(result))
except json.JSONDecodeError:
if queue_name == self.AUTH_INBOX and isinstance(result, str):
tasks.append({"url": result})
else:
logger.error(f"Failed to decode JSON task from batch in queue '{queue_name}': {result}")
# In batch mode, we skip the malformed item and continue.
elif queue_type == "hash":
# For hash queues, get multiple random keys
keys = self.redis_client.hkeys(queue_name)
if not keys:
return []
# Get random keys up to batch_size
import random
selected_keys = random.sample(keys, min(batch_size, len(keys)))
# Use pipeline for efficiency
pipe = self.redis_client.pipeline()
for key in selected_keys:
pipe.hget(queue_name, key)
results = pipe.execute()
for i, result in enumerate(results):
if result:
task = json.loads(result)
task["id"] = selected_keys[i] # Add the key as id
tasks.append(task)
else:
logger.warning(f"Unsupported queue type for batch operations: {queue_name}")
except Exception as e:
logger.error(f"Error getting tasks batch from {queue_name}: {e}")
return tasks
def report_success(self, queue_name: str, task_id: str, result: Dict) -> bool:
"""Report a successful task completion."""
try:
# Ensure task_id is included in the result
result["task_id"] = task_id
result["timestamp"] = time.time()
# Store in the success hash
self.redis_client.hset(queue_name, task_id, json.dumps(result))
return True
except Exception as e:
logger.error(f"Error reporting success to {queue_name}: {e}")
return False
def report_failure(self, queue_name: str, task_id: str, error: Dict) -> bool:
"""Report a task failure."""
try:
# Ensure task_id is included in the error
error["task_id"] = task_id
error["timestamp"] = time.time()
# Store in the failure hash
self.redis_client.hset(queue_name, task_id, json.dumps(error))
return True
except Exception as e:
logger.error(f"Error reporting failure to {queue_name}: {e}")
return False
def report_skipped(self, queue_name: str, task_id: str, reason: Dict) -> bool:
"""Report a task that was skipped."""
try:
# Ensure task_id is included in the reason
reason["task_id"] = task_id
reason["timestamp"] = time.time()
# Store in the skipped hash
self.redis_client.hset(queue_name, task_id, json.dumps(reason))
return True
except Exception as e:
logger.error(f"Error reporting skipped to {queue_name}: {e}")
return False
def mark_in_progress(self, queue_name: str, task_id: str, worker_id: str) -> bool:
"""Mark a task as in progress."""
try:
progress_data = {
"task_id": task_id,
"worker_id": worker_id,
"start_time": time.time()
}
# Store in the progress hash
self.redis_client.hset(queue_name, task_id, json.dumps(progress_data))
return True
except Exception as e:
logger.error(f"Error marking task in progress in {queue_name}: {e}")
return False
def remove_in_progress(self, queue_name: str, task_id: str) -> bool:
"""Remove a task from the in-progress tracking."""
try:
# Remove from the progress hash
self.redis_client.hdel(queue_name, task_id)
return True
except Exception as e:
logger.error(f"Error removing task from progress in {queue_name}: {e}")
return False
def get_queue_length(self, queue_name: str) -> int:
"""Get the current length of a queue."""
try:
queue_type = self._get_queue_type(queue_name)
if queue_type == "list":
return self.redis_client.llen(queue_name)
elif queue_type == "hash":
return self.redis_client.hlen(queue_name)
else:
logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}")
return 0
except Exception as e:
logger.error(f"Error getting queue length for {queue_name}: {e}")
return 0
def add_task(self, queue_name: str, task: Dict) -> bool:
"""Add a task to a queue."""
try:
queue_type = self._get_queue_type(queue_name)
if queue_type == "list":
# For list queues, we push to the left (LPUSH)
self.redis_client.lpush(queue_name, json.dumps(task))
return True
elif queue_type == "hash":
# For hash queues, we need a task_id
task_id = task.get("id") or task.get("task_id")
if not task_id:
logger.error(f"Cannot add task to hash queue {queue_name} without an id")
return False
self.redis_client.hset(queue_name, task_id, json.dumps(task))
return True
else:
logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}")
return False
except Exception as e:
logger.error(f"Error adding task to {queue_name}: {e}")
return False
def add_tasks_batch(self, queue_name: str, tasks: List[Dict]) -> int:
"""Add a batch of tasks to a queue. Returns number of tasks added."""
if not tasks:
return 0
try:
queue_type = self._get_queue_type(queue_name)
if queue_type == "list":
# Use pipeline for efficiency
pipe = self.redis_client.pipeline()
for task in tasks:
pipe.lpush(queue_name, json.dumps(task))
results = pipe.execute()
return len([r for r in results if r])
elif queue_type == "hash":
# Use pipeline for efficiency
pipe = self.redis_client.pipeline()
added_count = 0
for task in tasks:
task_id = task.get("id") or task.get("task_id")
if task_id:
pipe.hset(queue_name, task_id, json.dumps(task))
added_count += 1
else:
logger.warning(f"Skipping task without id for hash queue {queue_name}")
pipe.execute()
return added_count
else:
logger.warning(f"Unsupported queue type for batch operations: {queue_name}")
return 0
except Exception as e:
logger.error(f"Error adding tasks batch to {queue_name}: {e}")
return 0
def _get_queue_type(self, queue_name: str) -> str:
"""Determine the Redis data type of a queue."""
try:
queue_type = self.redis_client.type(queue_name)
if not queue_type or queue_type == "none":
# Queue doesn't exist yet, infer type from name
if queue_name.endswith(("_inbox", "_tasks")):
return "list"
else:
return "hash"
return queue_type
except Exception as e:
logger.error(f"Error determining queue type for {queue_name}: {e}")
return "unknown"
def requeue_failed_tasks(self, source_queue: str, target_queue: str,
batch_size: int = 100) -> int:
"""Requeue failed tasks from a failure queue to an inbox queue."""
try:
# Get failed tasks
failed_tasks = self.get_tasks_batch(source_queue, batch_size)
if not failed_tasks:
return 0
# Prepare tasks for requeuing
requeued_count = 0
requeue_tasks = []
for task in failed_tasks:
# Extract the original URL or task data
url = task.get("url")
if url:
# For auth failures, we just need the URL
requeue_tasks.append({"url": url})
requeued_count += 1
else:
# For download failures, we need the original task data
original_task = task.get("original_task")
if original_task:
requeue_tasks.append(original_task)
requeued_count += 1
# Add tasks to target queue
if requeue_tasks:
self.add_tasks_batch(target_queue, requeue_tasks)
# Remove from source queue
pipe = self.redis_client.pipeline()
for task in failed_tasks:
task_id = task.get("id") or task.get("task_id")
if task_id:
pipe.hdel(source_queue, task_id)
pipe.execute()
return requeued_count
except Exception as e:
logger.error(f"Error requeuing failed tasks from {source_queue} to {target_queue}: {e}")
return 0
def get_queue_stats(self) -> Dict[str, Dict[str, int]]:
"""Get statistics for all queues."""
stats = {}
# Authentication queues
auth_queues = {
"auth_inbox": self.AUTH_INBOX,
"auth_result": self.AUTH_RESULT,
"auth_fail": self.AUTH_FAIL,
"auth_skipped": self.AUTH_SKIPPED,
"auth_progress": self.AUTH_PROGRESS
}
# Download queues
dl_queues = {
"dl_tasks": self.DL_TASKS,
"dl_result": self.DL_RESULT,
"dl_fail": self.DL_FAIL,
"dl_skipped": self.DL_SKIPPED,
"dl_progress": self.DL_PROGRESS
}
# Get stats for auth queues
auth_stats = {}
for name, queue in auth_queues.items():
auth_stats[name] = self.get_queue_length(queue)
stats["auth"] = auth_stats
# Get stats for download queues
dl_stats = {}
for name, queue in dl_queues.items():
dl_stats[name] = self.get_queue_length(queue)
stats["download"] = dl_stats
return stats