471 lines
18 KiB
Python
471 lines
18 KiB
Python
"""
|
|
Queue provider for stress policy tool.
|
|
|
|
This module provides interfaces and implementations for queue operations,
|
|
supporting both authentication and download workflows.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Any, Tuple, Union
|
|
|
|
import redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class QueueProvider(ABC):
|
|
"""Abstract base class for queue operations."""
|
|
|
|
@abstractmethod
|
|
def get_task(self, queue_name: str) -> Optional[Dict]:
|
|
"""Get a task from the specified queue."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_tasks_batch(self, queue_name: str, batch_size: int) -> List[Dict]:
|
|
"""Get a batch of tasks from the specified queue."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def report_success(self, queue_name: str, task_id: str, result: Dict) -> bool:
|
|
"""Report a successful task completion."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def report_failure(self, queue_name: str, task_id: str, error: Dict) -> bool:
|
|
"""Report a task failure."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def report_skipped(self, queue_name: str, task_id: str, reason: Dict) -> bool:
|
|
"""Report a task that was skipped."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def mark_in_progress(self, queue_name: str, task_id: str, worker_id: str) -> bool:
|
|
"""Mark a task as in progress."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def remove_in_progress(self, queue_name: str, task_id: str) -> bool:
|
|
"""Remove a task from the in-progress tracking."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_queue_length(self, queue_name: str) -> int:
|
|
"""Get the current length of a queue."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def add_task(self, queue_name: str, task: Dict) -> bool:
|
|
"""Add a task to a queue."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def add_tasks_batch(self, queue_name: str, tasks: List[Dict]) -> int:
|
|
"""Add a batch of tasks to a queue. Returns number of tasks added."""
|
|
pass
|
|
|
|
|
|
class RedisQueueProvider(QueueProvider):
|
|
"""Redis implementation of the QueueProvider interface."""
|
|
|
|
def __init__(self, redis_host: str = "localhost", redis_port: int = 6379,
|
|
redis_password: Optional[str] = None, redis_db: int = 0,
|
|
env_prefix: Optional[str] = None):
|
|
"""Initialize the Redis queue provider."""
|
|
prefix = f"{env_prefix}_" if env_prefix else ""
|
|
|
|
# Queue name constants
|
|
# Authentication stage
|
|
self.AUTH_INBOX = f"{prefix}queue2_auth_inbox"
|
|
self.AUTH_RESULT = f"{prefix}queue2_auth_result"
|
|
self.AUTH_FAIL = f"{prefix}queue2_auth_fail"
|
|
self.AUTH_SKIPPED = f"{prefix}queue2_auth_skipped"
|
|
self.AUTH_PROGRESS = f"{prefix}queue2_auth_progress"
|
|
|
|
# Download stage
|
|
self.DL_TASKS = f"{prefix}queue2_dl_inbox"
|
|
self.DL_RESULT = f"{prefix}queue2_dl_result"
|
|
self.DL_FAIL = f"{prefix}queue2_dl_fail"
|
|
self.DL_SKIPPED = f"{prefix}queue2_dl_skipped"
|
|
self.DL_PROGRESS = f"{prefix}queue2_dl_progress"
|
|
|
|
self.redis_client = redis.Redis(
|
|
host=redis_host,
|
|
port=redis_port,
|
|
password=redis_password,
|
|
db=redis_db,
|
|
decode_responses=True
|
|
)
|
|
self._validate_connection()
|
|
|
|
def _validate_connection(self) -> None:
|
|
"""Validate the Redis connection."""
|
|
try:
|
|
self.redis_client.ping()
|
|
logger.info("Successfully connected to Redis")
|
|
except redis.ConnectionError as e:
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
|
raise
|
|
|
|
def get_task(self, queue_name: str) -> Optional[Dict]:
|
|
"""Get a task from the specified queue.
|
|
|
|
For LIST type queues, this pops an item.
|
|
For HASH type queues, this just reads an item without removing it.
|
|
"""
|
|
try:
|
|
queue_type = self._get_queue_type(queue_name)
|
|
|
|
if queue_type == "list":
|
|
# BRPOP with a timeout of 1 second
|
|
result = self.redis_client.brpop(queue_name, timeout=1)
|
|
if result:
|
|
_, task_data = result
|
|
try:
|
|
# Assume it's a JSON object, which is the standard format.
|
|
return json.loads(task_data)
|
|
except json.JSONDecodeError:
|
|
# If it fails, check if it's the auth inbox queue and a plain string.
|
|
# This provides backward compatibility with queues populated with raw URLs.
|
|
if queue_name == self.AUTH_INBOX and isinstance(task_data, str):
|
|
logger.debug(f"Task from '{queue_name}' is a plain string. Wrapping it in a task dictionary.")
|
|
return {"url": task_data}
|
|
else:
|
|
# If it's not the auth inbox or not a string, log and re-raise.
|
|
logger.error(f"Failed to decode JSON task from queue '{queue_name}': {task_data}")
|
|
raise
|
|
return None
|
|
|
|
elif queue_type == "hash":
|
|
# For hash queues, we just get a random key
|
|
keys = self.redis_client.hkeys(queue_name)
|
|
if not keys:
|
|
return None
|
|
|
|
# Get a random key
|
|
import random
|
|
key = random.choice(keys)
|
|
value = self.redis_client.hget(queue_name, key)
|
|
|
|
if value:
|
|
task = json.loads(value)
|
|
task["id"] = key # Add the key as id
|
|
return task
|
|
|
|
return None
|
|
|
|
else:
|
|
logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting task from {queue_name}: {e}")
|
|
return None
|
|
|
|
def get_tasks_batch(self, queue_name: str, batch_size: int) -> List[Dict]:
|
|
"""Get a batch of tasks from the specified queue."""
|
|
tasks = []
|
|
try:
|
|
queue_type = self._get_queue_type(queue_name)
|
|
|
|
if queue_type == "list":
|
|
# Use pipeline for efficiency
|
|
pipe = self.redis_client.pipeline()
|
|
for _ in range(batch_size):
|
|
pipe.rpop(queue_name)
|
|
results = pipe.execute()
|
|
|
|
for result in results:
|
|
if result:
|
|
try:
|
|
tasks.append(json.loads(result))
|
|
except json.JSONDecodeError:
|
|
if queue_name == self.AUTH_INBOX and isinstance(result, str):
|
|
tasks.append({"url": result})
|
|
else:
|
|
logger.error(f"Failed to decode JSON task from batch in queue '{queue_name}': {result}")
|
|
# In batch mode, we skip the malformed item and continue.
|
|
|
|
elif queue_type == "hash":
|
|
# For hash queues, get multiple random keys
|
|
keys = self.redis_client.hkeys(queue_name)
|
|
if not keys:
|
|
return []
|
|
|
|
# Get random keys up to batch_size
|
|
import random
|
|
selected_keys = random.sample(keys, min(batch_size, len(keys)))
|
|
|
|
# Use pipeline for efficiency
|
|
pipe = self.redis_client.pipeline()
|
|
for key in selected_keys:
|
|
pipe.hget(queue_name, key)
|
|
results = pipe.execute()
|
|
|
|
for i, result in enumerate(results):
|
|
if result:
|
|
task = json.loads(result)
|
|
task["id"] = selected_keys[i] # Add the key as id
|
|
tasks.append(task)
|
|
|
|
else:
|
|
logger.warning(f"Unsupported queue type for batch operations: {queue_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting tasks batch from {queue_name}: {e}")
|
|
|
|
return tasks
|
|
|
|
def report_success(self, queue_name: str, task_id: str, result: Dict) -> bool:
|
|
"""Report a successful task completion."""
|
|
try:
|
|
# Ensure task_id is included in the result
|
|
result["task_id"] = task_id
|
|
result["timestamp"] = time.time()
|
|
|
|
# Store in the success hash
|
|
self.redis_client.hset(queue_name, task_id, json.dumps(result))
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error reporting success to {queue_name}: {e}")
|
|
return False
|
|
|
|
def report_failure(self, queue_name: str, task_id: str, error: Dict) -> bool:
|
|
"""Report a task failure."""
|
|
try:
|
|
# Ensure task_id is included in the error
|
|
error["task_id"] = task_id
|
|
error["timestamp"] = time.time()
|
|
|
|
# Store in the failure hash
|
|
self.redis_client.hset(queue_name, task_id, json.dumps(error))
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error reporting failure to {queue_name}: {e}")
|
|
return False
|
|
|
|
def report_skipped(self, queue_name: str, task_id: str, reason: Dict) -> bool:
|
|
"""Report a task that was skipped."""
|
|
try:
|
|
# Ensure task_id is included in the reason
|
|
reason["task_id"] = task_id
|
|
reason["timestamp"] = time.time()
|
|
|
|
# Store in the skipped hash
|
|
self.redis_client.hset(queue_name, task_id, json.dumps(reason))
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error reporting skipped to {queue_name}: {e}")
|
|
return False
|
|
|
|
def mark_in_progress(self, queue_name: str, task_id: str, worker_id: str) -> bool:
|
|
"""Mark a task as in progress."""
|
|
try:
|
|
progress_data = {
|
|
"task_id": task_id,
|
|
"worker_id": worker_id,
|
|
"start_time": time.time()
|
|
}
|
|
|
|
# Store in the progress hash
|
|
self.redis_client.hset(queue_name, task_id, json.dumps(progress_data))
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error marking task in progress in {queue_name}: {e}")
|
|
return False
|
|
|
|
def remove_in_progress(self, queue_name: str, task_id: str) -> bool:
|
|
"""Remove a task from the in-progress tracking."""
|
|
try:
|
|
# Remove from the progress hash
|
|
self.redis_client.hdel(queue_name, task_id)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error removing task from progress in {queue_name}: {e}")
|
|
return False
|
|
|
|
def get_queue_length(self, queue_name: str) -> int:
|
|
"""Get the current length of a queue."""
|
|
try:
|
|
queue_type = self._get_queue_type(queue_name)
|
|
|
|
if queue_type == "list":
|
|
return self.redis_client.llen(queue_name)
|
|
elif queue_type == "hash":
|
|
return self.redis_client.hlen(queue_name)
|
|
else:
|
|
logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}")
|
|
return 0
|
|
except Exception as e:
|
|
logger.error(f"Error getting queue length for {queue_name}: {e}")
|
|
return 0
|
|
|
|
def add_task(self, queue_name: str, task: Dict) -> bool:
|
|
"""Add a task to a queue."""
|
|
try:
|
|
queue_type = self._get_queue_type(queue_name)
|
|
|
|
if queue_type == "list":
|
|
# For list queues, we push to the left (LPUSH)
|
|
self.redis_client.lpush(queue_name, json.dumps(task))
|
|
return True
|
|
|
|
elif queue_type == "hash":
|
|
# For hash queues, we need a task_id
|
|
task_id = task.get("id") or task.get("task_id")
|
|
if not task_id:
|
|
logger.error(f"Cannot add task to hash queue {queue_name} without an id")
|
|
return False
|
|
|
|
self.redis_client.hset(queue_name, task_id, json.dumps(task))
|
|
return True
|
|
|
|
else:
|
|
logger.warning(f"Unsupported queue type for {queue_name}: {queue_type}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding task to {queue_name}: {e}")
|
|
return False
|
|
|
|
def add_tasks_batch(self, queue_name: str, tasks: List[Dict]) -> int:
|
|
"""Add a batch of tasks to a queue. Returns number of tasks added."""
|
|
if not tasks:
|
|
return 0
|
|
|
|
try:
|
|
queue_type = self._get_queue_type(queue_name)
|
|
|
|
if queue_type == "list":
|
|
# Use pipeline for efficiency
|
|
pipe = self.redis_client.pipeline()
|
|
for task in tasks:
|
|
pipe.lpush(queue_name, json.dumps(task))
|
|
results = pipe.execute()
|
|
return len([r for r in results if r])
|
|
|
|
elif queue_type == "hash":
|
|
# Use pipeline for efficiency
|
|
pipe = self.redis_client.pipeline()
|
|
added_count = 0
|
|
|
|
for task in tasks:
|
|
task_id = task.get("id") or task.get("task_id")
|
|
if task_id:
|
|
pipe.hset(queue_name, task_id, json.dumps(task))
|
|
added_count += 1
|
|
else:
|
|
logger.warning(f"Skipping task without id for hash queue {queue_name}")
|
|
|
|
pipe.execute()
|
|
return added_count
|
|
|
|
else:
|
|
logger.warning(f"Unsupported queue type for batch operations: {queue_name}")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding tasks batch to {queue_name}: {e}")
|
|
return 0
|
|
|
|
def _get_queue_type(self, queue_name: str) -> str:
|
|
"""Determine the Redis data type of a queue."""
|
|
try:
|
|
queue_type = self.redis_client.type(queue_name)
|
|
if not queue_type or queue_type == "none":
|
|
# Queue doesn't exist yet, infer type from name
|
|
if queue_name.endswith(("_inbox", "_tasks")):
|
|
return "list"
|
|
else:
|
|
return "hash"
|
|
return queue_type
|
|
except Exception as e:
|
|
logger.error(f"Error determining queue type for {queue_name}: {e}")
|
|
return "unknown"
|
|
|
|
def requeue_failed_tasks(self, source_queue: str, target_queue: str,
|
|
batch_size: int = 100) -> int:
|
|
"""Requeue failed tasks from a failure queue to an inbox queue."""
|
|
try:
|
|
# Get failed tasks
|
|
failed_tasks = self.get_tasks_batch(source_queue, batch_size)
|
|
if not failed_tasks:
|
|
return 0
|
|
|
|
# Prepare tasks for requeuing
|
|
requeued_count = 0
|
|
requeue_tasks = []
|
|
|
|
for task in failed_tasks:
|
|
# Extract the original URL or task data
|
|
url = task.get("url")
|
|
if url:
|
|
# For auth failures, we just need the URL
|
|
requeue_tasks.append({"url": url})
|
|
requeued_count += 1
|
|
else:
|
|
# For download failures, we need the original task data
|
|
original_task = task.get("original_task")
|
|
if original_task:
|
|
requeue_tasks.append(original_task)
|
|
requeued_count += 1
|
|
|
|
# Add tasks to target queue
|
|
if requeue_tasks:
|
|
self.add_tasks_batch(target_queue, requeue_tasks)
|
|
|
|
# Remove from source queue
|
|
pipe = self.redis_client.pipeline()
|
|
for task in failed_tasks:
|
|
task_id = task.get("id") or task.get("task_id")
|
|
if task_id:
|
|
pipe.hdel(source_queue, task_id)
|
|
pipe.execute()
|
|
|
|
return requeued_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error requeuing failed tasks from {source_queue} to {target_queue}: {e}")
|
|
return 0
|
|
|
|
def get_queue_stats(self) -> Dict[str, Dict[str, int]]:
|
|
"""Get statistics for all queues."""
|
|
stats = {}
|
|
|
|
# Authentication queues
|
|
auth_queues = {
|
|
"auth_inbox": self.AUTH_INBOX,
|
|
"auth_result": self.AUTH_RESULT,
|
|
"auth_fail": self.AUTH_FAIL,
|
|
"auth_skipped": self.AUTH_SKIPPED,
|
|
"auth_progress": self.AUTH_PROGRESS
|
|
}
|
|
|
|
# Download queues
|
|
dl_queues = {
|
|
"dl_tasks": self.DL_TASKS,
|
|
"dl_result": self.DL_RESULT,
|
|
"dl_fail": self.DL_FAIL,
|
|
"dl_skipped": self.DL_SKIPPED,
|
|
"dl_progress": self.DL_PROGRESS
|
|
}
|
|
|
|
# Get stats for auth queues
|
|
auth_stats = {}
|
|
for name, queue in auth_queues.items():
|
|
auth_stats[name] = self.get_queue_length(queue)
|
|
stats["auth"] = auth_stats
|
|
|
|
# Get stats for download queues
|
|
dl_stats = {}
|
|
for name, queue in dl_queues.items():
|
|
dl_stats[name] = self.get_queue_length(queue)
|
|
stats["download"] = dl_stats
|
|
|
|
return stats
|