637 lines
26 KiB
Python
637 lines
26 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Regression testing script for the ytdlp-ops system.
|
|
|
|
This script orchestrates a regression test by:
|
|
1. Populating a Redis queue with video URLs from an input file.
|
|
2. Triggering the `ytdlp_ops_orchestrator` Airflow DAG to start processing.
|
|
3. Monitoring the progress of the processing for a specified duration.
|
|
4. Generating a report of any failures.
|
|
5. Optionally cleaning up the Redis queues after the test.
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import requests
|
|
import subprocess
|
|
import signal
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
|
|
import redis
|
|
from tabulate import tabulate
|
|
|
|
# It's safe to import these as the script runs in the same container as Airflow
|
|
# where the yt_ops_services package is installed.
|
|
try:
|
|
from yt_ops_services.client_utils import get_thrift_client, format_timestamp
|
|
from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException
|
|
except ImportError:
|
|
logging.error("Could not import Thrift modules. Ensure this script is run in the 'airflow-regression-runner' container.")
|
|
sys.exit(1)
|
|
|
|
# --- Configuration ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[%(asctime)s] [%(levelname)s] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
INTERRUPTED = False
|
|
|
|
def signal_handler(sig, frame):
|
|
"""Handles Ctrl+C interruption."""
|
|
global INTERRUPTED
|
|
if not INTERRUPTED:
|
|
logging.warning("Ctrl+C detected. Initiating graceful shutdown...")
|
|
INTERRUPTED = True
|
|
else:
|
|
logging.warning("Second Ctrl+C detected. Forcing exit.")
|
|
sys.exit(1)
|
|
|
|
|
|
# --- Helper Functions ---
|
|
|
|
def _get_redis_client(redis_url: str):
|
|
"""Gets a Redis client from a URL."""
|
|
try:
|
|
# from_url is the modern way to connect and handles password auth
|
|
client = redis.from_url(redis_url, decode_responses=True)
|
|
client.ping()
|
|
logging.info(f"Successfully connected to Redis at {client.connection_pool.connection_kwargs.get('host')}:{client.connection_pool.connection_kwargs.get('port')}")
|
|
return client
|
|
except redis.exceptions.ConnectionError as e:
|
|
logging.error(f"Failed to connect to Redis: {e}")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred while connecting to Redis: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
def _get_webserver_url():
|
|
"""
|
|
Determines the Airflow webserver URL, prioritizing MASTER_HOST_IP from .env.
|
|
"""
|
|
master_host_ip = os.getenv("MASTER_HOST_IP")
|
|
if master_host_ip:
|
|
url = f"http://{master_host_ip}:8080"
|
|
logging.info(f"Using MASTER_HOST_IP for webserver URL: {url}")
|
|
return url
|
|
|
|
# Fallback to AIRFLOW_WEBSERVER_URL or the default service name
|
|
url = os.getenv("AIRFLOW_WEBSERVER_URL", "http://airflow-webserver:8080")
|
|
logging.info(f"Using default webserver URL: {url}")
|
|
return url
|
|
|
|
def _normalize_to_url(item: str) -> str | None:
|
|
"""
|
|
Validates if an item is a recognizable YouTube URL or video ID,
|
|
and normalizes it to a standard watch URL format.
|
|
"""
|
|
if not item:
|
|
return None
|
|
|
|
video_id_pattern = r"^[a-zA-Z0-9_-]{11}$"
|
|
if re.match(video_id_pattern, item):
|
|
return f"https://www.youtube.com/watch?v={item}"
|
|
|
|
url_patterns = [r"(?:v=|\/v\/|youtu\.be\/|embed\/|shorts\/)([a-zA-Z0-9_-]{11})"]
|
|
for pattern in url_patterns:
|
|
match = re.search(pattern, item)
|
|
if match:
|
|
return f"https://www.youtube.com/watch?v={match.group(1)}"
|
|
|
|
logging.warning(f"Could not recognize '{item}' as a valid YouTube URL or video ID.")
|
|
return None
|
|
|
|
def _read_input_file(file_path: str) -> list[str]:
|
|
"""Reads video IDs/URLs from a file (CSV or JSON list)."""
|
|
path = Path(file_path)
|
|
if not path.is_file():
|
|
logging.error(f"Input file not found: {file_path}")
|
|
sys.exit(1)
|
|
|
|
content = path.read_text(encoding='utf-8')
|
|
|
|
# Try parsing as JSON list first
|
|
if content.strip().startswith('['):
|
|
try:
|
|
data = json.loads(content)
|
|
if isinstance(data, list):
|
|
logging.info(f"Successfully parsed {file_path} as a JSON list.")
|
|
return [str(item) for item in data]
|
|
except json.JSONDecodeError:
|
|
logging.warning("File looks like JSON but failed to parse. Will try treating as CSV/text.")
|
|
|
|
# Fallback to CSV/text (one item per line)
|
|
items = []
|
|
# Use io.StringIO to handle the content as a file for the csv reader
|
|
from io import StringIO
|
|
# Sniff to see if it has a header
|
|
try:
|
|
has_header = csv.Sniffer().has_header(content)
|
|
except csv.Error:
|
|
has_header = False # Not a CSV, treat as plain text
|
|
|
|
reader = csv.reader(StringIO(content))
|
|
if has_header:
|
|
next(reader) # Skip header row
|
|
|
|
for row in reader:
|
|
if row:
|
|
items.append(row[0].strip()) # Assume the ID/URL is in the first column
|
|
|
|
logging.info(f"Successfully parsed {len(items)} items from {file_path} as CSV/text.")
|
|
return items
|
|
|
|
|
|
def _get_api_auth():
|
|
"""Gets Airflow API credentials from environment variables."""
|
|
username = os.getenv("AIRFLOW_ADMIN_USERNAME", "admin")
|
|
password = os.getenv("AIRFLOW_ADMIN_PASSWORD")
|
|
if not password:
|
|
logging.error("AIRFLOW_ADMIN_PASSWORD not found in environment. Cannot interact with API.")
|
|
return None, None
|
|
return username, password
|
|
|
|
def _pause_dag(dag_id: str, is_paused: bool = True):
|
|
"""Pauses or unpauses an Airflow DAG via the REST API."""
|
|
logging.info(f"Attempting to {'pause' if is_paused else 'unpause'} DAG: {dag_id}...")
|
|
username, password = _get_api_auth()
|
|
if not username:
|
|
return
|
|
|
|
webserver_url = _get_webserver_url()
|
|
endpoint = f"{webserver_url}/api/v1/dags/{dag_id}"
|
|
payload = {"is_paused": is_paused}
|
|
|
|
try:
|
|
response = requests.patch(endpoint, auth=(username, password), json=payload, timeout=30)
|
|
response.raise_for_status()
|
|
logging.info(f"Successfully {'paused' if is_paused else 'unpaused'} DAG '{dag_id}'.")
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"Failed to {'pause' if is_paused else 'unpause'} DAG '{dag_id}': {e}")
|
|
if e.response is not None:
|
|
logging.error(f"Response: {e.response.text}")
|
|
|
|
def _fail_running_dag_runs(dag_id: str):
|
|
"""Finds all running DAG runs for a given DAG and marks them as failed."""
|
|
logging.info(f"Attempting to fail all running instances of DAG '{dag_id}'...")
|
|
username, password = _get_api_auth()
|
|
if not username:
|
|
return
|
|
|
|
webserver_url = _get_webserver_url()
|
|
list_endpoint = f"{webserver_url}/api/v1/dags/{dag_id}/dagRuns?state=running"
|
|
|
|
try:
|
|
# Get running DAGs
|
|
response = requests.get(list_endpoint, auth=(username, password), timeout=30)
|
|
response.raise_for_status()
|
|
running_runs = response.json().get("dag_runs", [])
|
|
|
|
if not running_runs:
|
|
logging.info(f"No running DAG runs found for '{dag_id}'.")
|
|
return
|
|
|
|
logging.info(f"Found {len(running_runs)} running DAG run(s) to fail.")
|
|
|
|
for run in running_runs:
|
|
dag_run_id = run["dag_run_id"]
|
|
update_endpoint = f"{webserver_url}/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}"
|
|
payload = {"state": "failed"}
|
|
try:
|
|
update_response = requests.patch(update_endpoint, auth=(username, password), json=payload, timeout=30)
|
|
update_response.raise_for_status()
|
|
logging.info(f" - Successfully marked DAG run '{dag_run_id}' as failed.")
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f" - Failed to mark DAG run '{dag_run_id}' as failed: {e}")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"Failed to list running DAG runs for '{dag_id}': {e}")
|
|
if e.response is not None:
|
|
logging.error(f"Response: {e.response.text}")
|
|
|
|
|
|
# --- Core Logic Functions ---
|
|
|
|
def step_0_populate_queue(redis_client, queue_name: str, input_file: str):
|
|
"""Reads URLs from a file and populates the Redis inbox queue."""
|
|
logging.info("--- Step 0: Populating Redis Queue ---")
|
|
raw_items = _read_input_file(input_file)
|
|
if not raw_items:
|
|
logging.error("No items found in the input file. Aborting.")
|
|
sys.exit(1)
|
|
|
|
valid_urls = []
|
|
for item in raw_items:
|
|
url = _normalize_to_url(item)
|
|
if url and url not in valid_urls:
|
|
valid_urls.append(url)
|
|
|
|
if not valid_urls:
|
|
logging.error("No valid YouTube URLs or IDs were found in the input file. Aborting.")
|
|
sys.exit(1)
|
|
|
|
inbox_queue = f"{queue_name}_inbox"
|
|
logging.info(f"Adding {len(valid_urls)} unique and valid URLs to Redis queue '{inbox_queue}'...")
|
|
|
|
with redis_client.pipeline() as pipe:
|
|
for url in valid_urls:
|
|
pipe.rpush(inbox_queue, url)
|
|
pipe.execute()
|
|
|
|
logging.info(f"Successfully populated queue. Total items in '{inbox_queue}': {redis_client.llen(inbox_queue)}")
|
|
return len(valid_urls)
|
|
|
|
|
|
def step_1_trigger_orchestrator(args: argparse.Namespace):
|
|
"""Triggers the ytdlp_ops_orchestrator DAG using the Airflow REST API."""
|
|
logging.info("--- Step 1: Triggering Orchestrator DAG via REST API ---")
|
|
|
|
# Get API details from environment variables
|
|
webserver_url = _get_webserver_url()
|
|
api_endpoint = f"{webserver_url}/api/v1/dags/ytdlp_ops_orchestrator/dagRuns"
|
|
|
|
# Default admin user is 'admin'
|
|
username = os.getenv("AIRFLOW_ADMIN_USERNAME", "admin")
|
|
password = os.getenv("AIRFLOW_ADMIN_PASSWORD")
|
|
|
|
if not password:
|
|
logging.error("AIRFLOW_ADMIN_PASSWORD not found in environment. Please set it in your .env file.")
|
|
sys.exit(1)
|
|
|
|
# Construct the configuration for the DAG run
|
|
conf = {
|
|
"total_workers": args.workers,
|
|
"workers_per_bunch": args.workers_per_bunch,
|
|
"clients": args.client,
|
|
}
|
|
|
|
payload = {
|
|
"conf": conf
|
|
}
|
|
|
|
logging.info(f"Triggering DAG at endpoint: {api_endpoint}")
|
|
|
|
try:
|
|
response = requests.post(
|
|
api_endpoint,
|
|
auth=(username, password),
|
|
json=payload,
|
|
timeout=30 # 30 second timeout
|
|
)
|
|
response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx)
|
|
|
|
logging.info("Successfully triggered the orchestrator DAG.")
|
|
logging.debug(f"Airflow API response:\n{response.json()}")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error("Failed to trigger the orchestrator DAG via REST API.")
|
|
logging.error(f"Error: {e}")
|
|
if e.response is not None:
|
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
logging.error(f"Response text: {e.response.text}")
|
|
sys.exit(1)
|
|
|
|
|
|
def step_2_monitor_progress(args: argparse.Namespace, redis_client, queue_name: str, total_urls: int, run_time_min: int, interval_min: int, show_status: bool):
|
|
"""Monitors the Redis queues for the duration of the test."""
|
|
logging.info("--- Step 2: Monitoring Progress ---")
|
|
|
|
end_time = datetime.now() + timedelta(minutes=run_time_min)
|
|
inbox_q = f"{queue_name}_inbox"
|
|
progress_q = f"{queue_name}_progress"
|
|
result_q = f"{queue_name}_result"
|
|
fail_q = f"{queue_name}_fail"
|
|
|
|
while datetime.now() < end_time and not INTERRUPTED:
|
|
try:
|
|
inbox_len = redis_client.llen(inbox_q)
|
|
progress_len = redis_client.hlen(progress_q)
|
|
result_len = redis_client.hlen(result_q)
|
|
fail_len = redis_client.hlen(fail_q)
|
|
|
|
processed = result_len + fail_len
|
|
success_len = 0
|
|
if result_len > 0:
|
|
# This is inefficient but gives a more accurate success count
|
|
results = redis_client.hgetall(result_q)
|
|
success_len = sum(1 for v in results.values() if '"status": "success"' in v)
|
|
|
|
logging.info(
|
|
f"Progress: {processed}/{total_urls} | "
|
|
f"Success: {success_len} | Failed: {fail_len} | "
|
|
f"In Progress: {progress_len} | Inbox: {inbox_len}"
|
|
)
|
|
if show_status:
|
|
# This function now connects directly to services to get status
|
|
get_system_status(args, redis_client)
|
|
except Exception as e:
|
|
logging.error(f"Error while querying Redis for progress: {e}")
|
|
|
|
# Wait for the interval, but check for interruption every second
|
|
# for a more responsive shutdown.
|
|
wait_until = time.time() + interval_min * 60
|
|
while time.time() < wait_until and not INTERRUPTED:
|
|
# Check if we are past the main end_time
|
|
if datetime.now() >= end_time:
|
|
break
|
|
time.sleep(1)
|
|
|
|
if INTERRUPTED:
|
|
logging.info("Monitoring interrupted.")
|
|
else:
|
|
logging.info("Monitoring period has ended.")
|
|
|
|
|
|
# --- System Status Functions (Direct Connect) ---
|
|
|
|
def _list_proxy_statuses(client, server_identity=None):
|
|
"""Lists proxy statuses by connecting directly to the Thrift service."""
|
|
logging.info(f"--- Proxy Statuses (Server: {server_identity or 'ALL'}) ---")
|
|
try:
|
|
statuses = client.getProxyStatus(server_identity)
|
|
if not statuses:
|
|
logging.info("No proxy statuses found.")
|
|
return
|
|
|
|
status_list = []
|
|
headers = ["Server", "Proxy URL", "Status", "Success", "Failures", "Last Success", "Last Failure"]
|
|
for s in statuses:
|
|
status_list.append({
|
|
"Server": s.serverIdentity, "Proxy URL": s.proxyUrl, "Status": s.status,
|
|
"Success": s.successCount, "Failures": s.failureCount,
|
|
"Last Success": format_timestamp(s.lastSuccessTimestamp),
|
|
"Last Failure": format_timestamp(s.lastFailureTimestamp),
|
|
})
|
|
logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid'))
|
|
except (PBServiceException, PBUserException) as e:
|
|
logging.error(f"Failed to get proxy statuses: {e.message}")
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred while getting proxy statuses: {e}", exc_info=True)
|
|
|
|
def _list_account_statuses(client, redis_client, account_id=None):
|
|
"""Lists account statuses from Thrift, enriched with live Redis data."""
|
|
logging.info(f"--- Account Statuses (Account: {account_id or 'ALL'}) ---")
|
|
try:
|
|
statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None)
|
|
if not statuses:
|
|
logging.info("No account statuses found.")
|
|
return
|
|
|
|
status_list = []
|
|
for s in statuses:
|
|
status_str = s.status
|
|
if 'RESTING' in status_str:
|
|
try:
|
|
expiry_ts_bytes = redis_client.hget(f"account_status:{s.accountId}", "resting_until")
|
|
if expiry_ts_bytes:
|
|
expiry_ts = float(expiry_ts_bytes)
|
|
now = datetime.now().timestamp()
|
|
if now < expiry_ts:
|
|
remaining_seconds = int(expiry_ts - now)
|
|
status_str = f"RESTING ({remaining_seconds}s left)"
|
|
except Exception:
|
|
pass # Ignore if parsing fails
|
|
|
|
last_success = float(s.lastSuccessTimestamp) if s.lastSuccessTimestamp else 0
|
|
last_failure = float(s.lastFailureTimestamp) if s.lastFailureTimestamp else 0
|
|
last_activity = max(last_success, last_failure)
|
|
|
|
status_list.append({
|
|
"Account ID": s.accountId, "Status": status_str, "Success": s.successCount,
|
|
"Failures": s.failureCount, "Last Success": format_timestamp(s.lastSuccessTimestamp),
|
|
"Last Failure": format_timestamp(s.lastFailureTimestamp), "Last Proxy": s.lastUsedProxy or "N/A",
|
|
"_last_activity": last_activity,
|
|
})
|
|
|
|
status_list.sort(key=lambda item: item.get('_last_activity', 0), reverse=True)
|
|
for item in status_list:
|
|
del item['_last_activity']
|
|
|
|
logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid'))
|
|
except (PBServiceException, PBUserException) as e:
|
|
logging.error(f"Failed to get account statuses: {e.message}")
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred while getting account statuses: {e}", exc_info=True)
|
|
|
|
def _list_client_statuses(redis_client):
|
|
"""Lists client statistics from Redis."""
|
|
logging.info("--- Client Statuses ---")
|
|
try:
|
|
stats_key = "client_stats"
|
|
all_stats_raw = redis_client.hgetall(stats_key)
|
|
if not all_stats_raw:
|
|
logging.info("No client stats found in Redis.")
|
|
return
|
|
|
|
status_list = []
|
|
for client, stats_json in all_stats_raw.items():
|
|
try:
|
|
stats = json.loads(stats_json)
|
|
def format_latest(data):
|
|
if not data: return "N/A"
|
|
ts = format_timestamp(data.get('timestamp'))
|
|
url = data.get('url', 'N/A')
|
|
video_id_match = re.search(r'v=([a-zA-Z0-9_-]{11})', url)
|
|
video_id = video_id_match.group(1) if video_id_match else 'N/A'
|
|
return f"{ts} ({video_id})"
|
|
|
|
status_list.append({
|
|
"Client": client, "Success": stats.get('success_count', 0),
|
|
"Failures": stats.get('failure_count', 0),
|
|
"Last Success": format_latest(stats.get('latest_success')),
|
|
"Last Failure": format_latest(stats.get('latest_failure')),
|
|
})
|
|
except (json.JSONDecodeError, AttributeError):
|
|
status_list.append({"Client": client, "Success": "ERROR", "Failures": "ERROR", "Last Success": "Parse Error", "Last Failure": "Parse Error"})
|
|
|
|
status_list.sort(key=lambda item: item.get('Client', ''))
|
|
logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid'))
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred while getting client statuses: {e}", exc_info=True)
|
|
|
|
def get_system_status(args: argparse.Namespace, redis_client):
|
|
"""Connects to services and prints status tables."""
|
|
logging.info("--- Getting System Status ---")
|
|
client, transport = None, None
|
|
try:
|
|
client, transport = get_thrift_client(args.management_host, args.management_port)
|
|
_list_proxy_statuses(client)
|
|
_list_account_statuses(client, redis_client)
|
|
_list_client_statuses(redis_client)
|
|
except Exception as e:
|
|
logging.error(f"Could not get system status: {e}")
|
|
finally:
|
|
if transport and transport.isOpen():
|
|
transport.close()
|
|
|
|
|
|
def step_3_generate_report(redis_client, queue_name: str, report_file: str | None):
|
|
"""Generates a CSV report of failed items."""
|
|
logging.info("--- Step 3: Generating Report ---")
|
|
fail_q = f"{queue_name}_fail"
|
|
|
|
failed_items = redis_client.hgetall(fail_q)
|
|
if not failed_items:
|
|
logging.info("No items found in the fail queue. No report will be generated.")
|
|
return
|
|
|
|
logging.info(f"Found {len(failed_items)} failed items. Writing to report...")
|
|
|
|
report_data = []
|
|
for url, data_json in failed_items.items():
|
|
try:
|
|
data = json.loads(data_json)
|
|
error_details = data.get('error_details', {})
|
|
report_data.append({
|
|
'url': url,
|
|
'video_id': _normalize_to_url(url).split('v=')[-1] if _normalize_to_url(url) else 'N/A',
|
|
'error_message': error_details.get('error_message', 'N/A'),
|
|
'error_code': error_details.get('error_code', 'N/A'),
|
|
'proxy_url': error_details.get('proxy_url', 'N/A'),
|
|
'timestamp': datetime.fromtimestamp(data.get('end_time', 0)).isoformat(),
|
|
})
|
|
except (json.JSONDecodeError, AttributeError):
|
|
report_data.append({'url': url, 'video_id': 'N/A', 'error_message': 'Could not parse error data', 'error_code': 'PARSE_ERROR', 'proxy_url': 'N/A', 'timestamp': 'N/A'})
|
|
|
|
if report_file:
|
|
try:
|
|
with open(report_file, 'w', newline='', encoding='utf-8') as f:
|
|
writer = csv.DictWriter(f, fieldnames=report_data[0].keys())
|
|
writer.writeheader()
|
|
writer.writerows(report_data)
|
|
logging.info(f"Successfully wrote report to {report_file}")
|
|
except IOError as e:
|
|
logging.error(f"Could not write report to file {report_file}: {e}")
|
|
else:
|
|
# Print to stdout if no file is specified
|
|
logging.info("--- Failure Report (stdout) ---")
|
|
for item in report_data:
|
|
logging.info(f"URL: {item['url']}, Error: {item['error_code']} - {item['error_message']}")
|
|
logging.info("--- End of Report ---")
|
|
|
|
|
|
def handle_interruption(redis_client, queue_name, report_file):
|
|
"""Graceful shutdown logic for when the script is interrupted."""
|
|
logging.warning("--- Interruption Detected: Starting Shutdown Procedure ---")
|
|
|
|
# 1. Pause DAGs
|
|
_pause_dag("ytdlp_ops_orchestrator")
|
|
_pause_dag("ytdlp_ops_dispatcher")
|
|
|
|
# 2. Fail running per_url jobs
|
|
_fail_running_dag_runs("ytdlp_ops_worker_per_url")
|
|
|
|
# 3. Generate report
|
|
logging.info("Generating final report due to interruption...")
|
|
step_3_generate_report(redis_client, queue_name, report_file)
|
|
# Also print to stdout if a file was specified, so user sees it immediately
|
|
if report_file:
|
|
logging.info("Printing report to stdout as well...")
|
|
step_3_generate_report(redis_client, queue_name, None)
|
|
|
|
|
|
def step_4_cleanup_queues(redis_client, queue_name: str):
|
|
"""Cleans up the Redis queues used by the test."""
|
|
logging.info("--- Step 4: Cleaning Up Queues ---")
|
|
queues_to_delete = [
|
|
f"{queue_name}_inbox",
|
|
f"{queue_name}_progress",
|
|
f"{queue_name}_result",
|
|
f"{queue_name}_fail",
|
|
]
|
|
logging.warning(f"This will delete the following Redis keys: {queues_to_delete}")
|
|
|
|
deleted_count = redis_client.delete(*queues_to_delete)
|
|
logging.info(f"Cleanup complete. Deleted {deleted_count} key(s).")
|
|
|
|
|
|
def main():
|
|
"""Main function to parse arguments and run the regression test."""
|
|
# Register the signal handler for Ctrl+C
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
parser = argparse.ArgumentParser(description="Run a regression test for the ytdlp-ops system.")
|
|
|
|
# Environment
|
|
parser.add_argument("--redis-host", type=str, default="redis", help="Hostname or IP address of the Redis server. Defaults to 'redis' for in-container execution.")
|
|
parser.add_argument("--management-host", type=str, default=os.getenv("MANAGEMENT_SERVICE_HOST", "envoy-thrift-lb"), help="Hostname of the management Thrift service.")
|
|
parser.add_argument("--management-port", type=int, default=int(os.getenv("MANAGEMENT_SERVICE_PORT", 9080)), help="Port of the management Thrift service.")
|
|
|
|
# Test Configuration
|
|
parser.add_argument("--client", type=str, required=True, help="Client persona to test (e.g., 'mweb').")
|
|
parser.add_argument("--workers", type=int, required=True, help="Total number of worker loops to start.")
|
|
parser.add_argument("--workers-per-bunch", type=int, default=1, help="Number of workers per bunch.")
|
|
parser.add_argument("--run-time-min", type=int, required=True, help="How long to let the test run, in minutes.")
|
|
parser.add_argument("--input-file", type=str, help="Path to a file containing video IDs/URLs. If not provided, the existing queue will be used.")
|
|
|
|
# Monitoring & Reporting
|
|
parser.add_argument("--progress-interval-min", type=int, default=2, help="How often to query and print progress, in minutes.")
|
|
parser.add_argument("--report-file", type=str, help="Path to a CSV file to write the list of failed URLs to.")
|
|
parser.add_argument("--show-status", action="store_true", help="If set, show proxy and account statuses during progress monitoring.")
|
|
|
|
# Actions
|
|
parser.add_argument("--cleanup", action="store_true", help="If set, clear the Redis queues after the test completes.")
|
|
parser.add_argument("--skip-populate", action="store_true", help="If set, skip populating the queue (assumes it's already populated).")
|
|
parser.add_argument("--skip-trigger", action="store_true", help="If set, skip triggering the orchestrator (assumes it's already running).")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# --- Setup ---
|
|
redis_password = os.getenv("REDIS_PASSWORD")
|
|
if not redis_password:
|
|
logging.error("REDIS_PASSWORD not found in environment. Please set it in your .env file.")
|
|
sys.exit(1)
|
|
|
|
# Use the provided redis-host, defaulting to 'redis' for in-container execution
|
|
redis_url = f"redis://:{redis_password}@{args.redis_host}:6379/0"
|
|
redis_client = _get_redis_client(redis_url)
|
|
|
|
queue_name = "video_queue" # Hardcoded for now, could be an arg
|
|
total_urls = 0
|
|
|
|
# --- Execution ---
|
|
if not args.skip_populate:
|
|
if args.input_file:
|
|
total_urls = step_0_populate_queue(redis_client, queue_name, args.input_file)
|
|
else:
|
|
logging.info("No input file provided, using existing queue.")
|
|
total_urls = redis_client.llen(f"{queue_name}_inbox")
|
|
if total_urls == 0:
|
|
logging.warning("Queue is empty and no input file was provided. The test may not have any work to do.")
|
|
else:
|
|
total_urls = redis_client.llen(f"{queue_name}_inbox")
|
|
logging.info(f"Skipping population. Found {total_urls} URLs in the inbox.")
|
|
|
|
if not args.skip_trigger:
|
|
step_1_trigger_orchestrator(args)
|
|
else:
|
|
logging.info("Skipping orchestrator trigger.")
|
|
|
|
step_2_monitor_progress(args, redis_client, queue_name, total_urls, args.run_time_min, args.progress_interval_min, args.show_status)
|
|
|
|
if INTERRUPTED:
|
|
handle_interruption(redis_client, queue_name, args.report_file)
|
|
else:
|
|
step_3_generate_report(redis_client, queue_name, args.report_file)
|
|
|
|
if args.cleanup:
|
|
step_4_cleanup_queues(redis_client, queue_name)
|
|
|
|
if INTERRUPTED:
|
|
logging.warning("Regression test script finished due to user interruption.")
|
|
sys.exit(130) # Standard exit code for Ctrl+C
|
|
else:
|
|
logging.info("Regression test script finished.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|