# -*- coding: utf-8 -*- """ Regression testing script for the ytdlp-ops system. This script orchestrates a regression test by: 1. Populating a Redis queue with video URLs from an input file. 2. Triggering the `ytdlp_ops_orchestrator` Airflow DAG to start processing. 3. Monitoring the progress of the processing for a specified duration. 4. Generating a report of any failures. 5. Optionally cleaning up the Redis queues after the test. """ import argparse import csv import json import logging import os import re import requests import subprocess import signal import sys import time from datetime import datetime, timedelta from pathlib import Path import redis from tabulate import tabulate # It's safe to import these as the script runs in the same container as Airflow # where the yt_ops_services package is installed. try: from yt_ops_services.client_utils import get_thrift_client, format_timestamp from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException except ImportError: logging.error("Could not import Thrift modules. Ensure this script is run in the 'airflow-regression-runner' container.") sys.exit(1) # --- Configuration --- logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) INTERRUPTED = False def signal_handler(sig, frame): """Handles Ctrl+C interruption.""" global INTERRUPTED if not INTERRUPTED: logging.warning("Ctrl+C detected. Initiating graceful shutdown...") INTERRUPTED = True else: logging.warning("Second Ctrl+C detected. Forcing exit.") sys.exit(1) # --- Helper Functions --- def _get_redis_client(redis_url: str): """Gets a Redis client from a URL.""" try: # from_url is the modern way to connect and handles password auth client = redis.from_url(redis_url, decode_responses=True) client.ping() logging.info(f"Successfully connected to Redis at {client.connection_pool.connection_kwargs.get('host')}:{client.connection_pool.connection_kwargs.get('port')}") return client except redis.exceptions.ConnectionError as e: logging.error(f"Failed to connect to Redis: {e}") sys.exit(1) except Exception as e: logging.error(f"An unexpected error occurred while connecting to Redis: {e}") sys.exit(1) def _get_webserver_url(): """ Determines the Airflow webserver URL, prioritizing MASTER_HOST_IP from .env. """ master_host_ip = os.getenv("MASTER_HOST_IP") if master_host_ip: url = f"http://{master_host_ip}:8080" logging.info(f"Using MASTER_HOST_IP for webserver URL: {url}") return url # Fallback to AIRFLOW_WEBSERVER_URL or the default service name url = os.getenv("AIRFLOW_WEBSERVER_URL", "http://airflow-webserver:8080") logging.info(f"Using default webserver URL: {url}") return url def _normalize_to_url(item: str) -> str | None: """ Validates if an item is a recognizable YouTube URL or video ID, and normalizes it to a standard watch URL format. """ if not item: return None video_id_pattern = r"^[a-zA-Z0-9_-]{11}$" if re.match(video_id_pattern, item): return f"https://www.youtube.com/watch?v={item}" url_patterns = [r"(?:v=|\/v\/|youtu\.be\/|embed\/|shorts\/)([a-zA-Z0-9_-]{11})"] for pattern in url_patterns: match = re.search(pattern, item) if match: return f"https://www.youtube.com/watch?v={match.group(1)}" logging.warning(f"Could not recognize '{item}' as a valid YouTube URL or video ID.") return None def _read_input_file(file_path: str) -> list[str]: """Reads video IDs/URLs from a file (CSV or JSON list).""" path = Path(file_path) if not path.is_file(): logging.error(f"Input file not found: {file_path}") sys.exit(1) content = path.read_text(encoding='utf-8') # Try parsing as JSON list first if content.strip().startswith('['): try: data = json.loads(content) if isinstance(data, list): logging.info(f"Successfully parsed {file_path} as a JSON list.") return [str(item) for item in data] except json.JSONDecodeError: logging.warning("File looks like JSON but failed to parse. Will try treating as CSV/text.") # Fallback to CSV/text (one item per line) items = [] # Use io.StringIO to handle the content as a file for the csv reader from io import StringIO # Sniff to see if it has a header try: has_header = csv.Sniffer().has_header(content) except csv.Error: has_header = False # Not a CSV, treat as plain text reader = csv.reader(StringIO(content)) if has_header: next(reader) # Skip header row for row in reader: if row: items.append(row[0].strip()) # Assume the ID/URL is in the first column logging.info(f"Successfully parsed {len(items)} items from {file_path} as CSV/text.") return items def _get_api_auth(): """Gets Airflow API credentials from environment variables.""" username = os.getenv("AIRFLOW_ADMIN_USERNAME", "admin") password = os.getenv("AIRFLOW_ADMIN_PASSWORD") if not password: logging.error("AIRFLOW_ADMIN_PASSWORD not found in environment. Cannot interact with API.") return None, None return username, password def _pause_dag(dag_id: str, is_paused: bool = True): """Pauses or unpauses an Airflow DAG via the REST API.""" logging.info(f"Attempting to {'pause' if is_paused else 'unpause'} DAG: {dag_id}...") username, password = _get_api_auth() if not username: return webserver_url = _get_webserver_url() endpoint = f"{webserver_url}/api/v1/dags/{dag_id}" payload = {"is_paused": is_paused} try: response = requests.patch(endpoint, auth=(username, password), json=payload, timeout=30) response.raise_for_status() logging.info(f"Successfully {'paused' if is_paused else 'unpaused'} DAG '{dag_id}'.") except requests.exceptions.RequestException as e: logging.error(f"Failed to {'pause' if is_paused else 'unpause'} DAG '{dag_id}': {e}") if e.response is not None: logging.error(f"Response: {e.response.text}") def _fail_running_dag_runs(dag_id: str): """Finds all running DAG runs for a given DAG and marks them as failed.""" logging.info(f"Attempting to fail all running instances of DAG '{dag_id}'...") username, password = _get_api_auth() if not username: return webserver_url = _get_webserver_url() list_endpoint = f"{webserver_url}/api/v1/dags/{dag_id}/dagRuns?state=running" try: # Get running DAGs response = requests.get(list_endpoint, auth=(username, password), timeout=30) response.raise_for_status() running_runs = response.json().get("dag_runs", []) if not running_runs: logging.info(f"No running DAG runs found for '{dag_id}'.") return logging.info(f"Found {len(running_runs)} running DAG run(s) to fail.") for run in running_runs: dag_run_id = run["dag_run_id"] update_endpoint = f"{webserver_url}/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}" payload = {"state": "failed"} try: update_response = requests.patch(update_endpoint, auth=(username, password), json=payload, timeout=30) update_response.raise_for_status() logging.info(f" - Successfully marked DAG run '{dag_run_id}' as failed.") except requests.exceptions.RequestException as e: logging.error(f" - Failed to mark DAG run '{dag_run_id}' as failed: {e}") except requests.exceptions.RequestException as e: logging.error(f"Failed to list running DAG runs for '{dag_id}': {e}") if e.response is not None: logging.error(f"Response: {e.response.text}") # --- Core Logic Functions --- def step_0_populate_queue(redis_client, queue_name: str, input_file: str): """Reads URLs from a file and populates the Redis inbox queue.""" logging.info("--- Step 0: Populating Redis Queue ---") raw_items = _read_input_file(input_file) if not raw_items: logging.error("No items found in the input file. Aborting.") sys.exit(1) valid_urls = [] for item in raw_items: url = _normalize_to_url(item) if url and url not in valid_urls: valid_urls.append(url) if not valid_urls: logging.error("No valid YouTube URLs or IDs were found in the input file. Aborting.") sys.exit(1) inbox_queue = f"{queue_name}_inbox" logging.info(f"Adding {len(valid_urls)} unique and valid URLs to Redis queue '{inbox_queue}'...") with redis_client.pipeline() as pipe: for url in valid_urls: pipe.rpush(inbox_queue, url) pipe.execute() logging.info(f"Successfully populated queue. Total items in '{inbox_queue}': {redis_client.llen(inbox_queue)}") return len(valid_urls) def step_1_trigger_orchestrator(args: argparse.Namespace): """Triggers the ytdlp_ops_orchestrator DAG using the Airflow REST API.""" logging.info("--- Step 1: Triggering Orchestrator DAG via REST API ---") # Get API details from environment variables webserver_url = _get_webserver_url() api_endpoint = f"{webserver_url}/api/v1/dags/ytdlp_ops_orchestrator/dagRuns" # Default admin user is 'admin' username = os.getenv("AIRFLOW_ADMIN_USERNAME", "admin") password = os.getenv("AIRFLOW_ADMIN_PASSWORD") if not password: logging.error("AIRFLOW_ADMIN_PASSWORD not found in environment. Please set it in your .env file.") sys.exit(1) # Construct the configuration for the DAG run conf = { "total_workers": args.workers, "workers_per_bunch": args.workers_per_bunch, "clients": args.client, } payload = { "conf": conf } logging.info(f"Triggering DAG at endpoint: {api_endpoint}") try: response = requests.post( api_endpoint, auth=(username, password), json=payload, timeout=30 # 30 second timeout ) response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx) logging.info("Successfully triggered the orchestrator DAG.") logging.debug(f"Airflow API response:\n{response.json()}") except requests.exceptions.RequestException as e: logging.error("Failed to trigger the orchestrator DAG via REST API.") logging.error(f"Error: {e}") if e.response is not None: logging.error(f"Response status code: {e.response.status_code}") logging.error(f"Response text: {e.response.text}") sys.exit(1) def step_2_monitor_progress(args: argparse.Namespace, redis_client, queue_name: str, total_urls: int, run_time_min: int, interval_min: int, show_status: bool): """Monitors the Redis queues for the duration of the test.""" logging.info("--- Step 2: Monitoring Progress ---") end_time = datetime.now() + timedelta(minutes=run_time_min) inbox_q = f"{queue_name}_inbox" progress_q = f"{queue_name}_progress" result_q = f"{queue_name}_result" fail_q = f"{queue_name}_fail" while datetime.now() < end_time and not INTERRUPTED: try: inbox_len = redis_client.llen(inbox_q) progress_len = redis_client.hlen(progress_q) result_len = redis_client.hlen(result_q) fail_len = redis_client.hlen(fail_q) processed = result_len + fail_len success_len = 0 if result_len > 0: # This is inefficient but gives a more accurate success count results = redis_client.hgetall(result_q) success_len = sum(1 for v in results.values() if '"status": "success"' in v) logging.info( f"Progress: {processed}/{total_urls} | " f"Success: {success_len} | Failed: {fail_len} | " f"In Progress: {progress_len} | Inbox: {inbox_len}" ) if show_status: # This function now connects directly to services to get status get_system_status(args, redis_client) except Exception as e: logging.error(f"Error while querying Redis for progress: {e}") # Wait for the interval, but check for interruption every second # for a more responsive shutdown. wait_until = time.time() + interval_min * 60 while time.time() < wait_until and not INTERRUPTED: # Check if we are past the main end_time if datetime.now() >= end_time: break time.sleep(1) if INTERRUPTED: logging.info("Monitoring interrupted.") else: logging.info("Monitoring period has ended.") # --- System Status Functions (Direct Connect) --- def _list_proxy_statuses(client, server_identity=None): """Lists proxy statuses by connecting directly to the Thrift service.""" logging.info(f"--- Proxy Statuses (Server: {server_identity or 'ALL'}) ---") try: statuses = client.getProxyStatus(server_identity) if not statuses: logging.info("No proxy statuses found.") return status_list = [] headers = ["Server", "Proxy URL", "Status", "Success", "Failures", "Last Success", "Last Failure"] for s in statuses: status_list.append({ "Server": s.serverIdentity, "Proxy URL": s.proxyUrl, "Status": s.status, "Success": s.successCount, "Failures": s.failureCount, "Last Success": format_timestamp(s.lastSuccessTimestamp), "Last Failure": format_timestamp(s.lastFailureTimestamp), }) logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid')) except (PBServiceException, PBUserException) as e: logging.error(f"Failed to get proxy statuses: {e.message}") except Exception as e: logging.error(f"An unexpected error occurred while getting proxy statuses: {e}", exc_info=True) def _list_account_statuses(client, redis_client, account_id=None): """Lists account statuses from Thrift, enriched with live Redis data.""" logging.info(f"--- Account Statuses (Account: {account_id or 'ALL'}) ---") try: statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None) if not statuses: logging.info("No account statuses found.") return status_list = [] for s in statuses: status_str = s.status if 'RESTING' in status_str: try: expiry_ts_bytes = redis_client.hget(f"account_status:{s.accountId}", "resting_until") if expiry_ts_bytes: expiry_ts = float(expiry_ts_bytes) now = datetime.now().timestamp() if now < expiry_ts: remaining_seconds = int(expiry_ts - now) status_str = f"RESTING ({remaining_seconds}s left)" except Exception: pass # Ignore if parsing fails last_success = float(s.lastSuccessTimestamp) if s.lastSuccessTimestamp else 0 last_failure = float(s.lastFailureTimestamp) if s.lastFailureTimestamp else 0 last_activity = max(last_success, last_failure) status_list.append({ "Account ID": s.accountId, "Status": status_str, "Success": s.successCount, "Failures": s.failureCount, "Last Success": format_timestamp(s.lastSuccessTimestamp), "Last Failure": format_timestamp(s.lastFailureTimestamp), "Last Proxy": s.lastUsedProxy or "N/A", "_last_activity": last_activity, }) status_list.sort(key=lambda item: item.get('_last_activity', 0), reverse=True) for item in status_list: del item['_last_activity'] logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid')) except (PBServiceException, PBUserException) as e: logging.error(f"Failed to get account statuses: {e.message}") except Exception as e: logging.error(f"An unexpected error occurred while getting account statuses: {e}", exc_info=True) def _list_client_statuses(redis_client): """Lists client statistics from Redis.""" logging.info("--- Client Statuses ---") try: stats_key = "client_stats" all_stats_raw = redis_client.hgetall(stats_key) if not all_stats_raw: logging.info("No client stats found in Redis.") return status_list = [] for client, stats_json in all_stats_raw.items(): try: stats = json.loads(stats_json) def format_latest(data): if not data: return "N/A" ts = format_timestamp(data.get('timestamp')) url = data.get('url', 'N/A') video_id_match = re.search(r'v=([a-zA-Z0-9_-]{11})', url) video_id = video_id_match.group(1) if video_id_match else 'N/A' return f"{ts} ({video_id})" status_list.append({ "Client": client, "Success": stats.get('success_count', 0), "Failures": stats.get('failure_count', 0), "Last Success": format_latest(stats.get('latest_success')), "Last Failure": format_latest(stats.get('latest_failure')), }) except (json.JSONDecodeError, AttributeError): status_list.append({"Client": client, "Success": "ERROR", "Failures": "ERROR", "Last Success": "Parse Error", "Last Failure": "Parse Error"}) status_list.sort(key=lambda item: item.get('Client', '')) logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid')) except Exception as e: logging.error(f"An unexpected error occurred while getting client statuses: {e}", exc_info=True) def get_system_status(args: argparse.Namespace, redis_client): """Connects to services and prints status tables.""" logging.info("--- Getting System Status ---") client, transport = None, None try: client, transport = get_thrift_client(args.management_host, args.management_port) _list_proxy_statuses(client) _list_account_statuses(client, redis_client) _list_client_statuses(redis_client) except Exception as e: logging.error(f"Could not get system status: {e}") finally: if transport and transport.isOpen(): transport.close() def step_3_generate_report(redis_client, queue_name: str, report_file: str | None): """Generates a CSV report of failed items.""" logging.info("--- Step 3: Generating Report ---") fail_q = f"{queue_name}_fail" failed_items = redis_client.hgetall(fail_q) if not failed_items: logging.info("No items found in the fail queue. No report will be generated.") return logging.info(f"Found {len(failed_items)} failed items. Writing to report...") report_data = [] for url, data_json in failed_items.items(): try: data = json.loads(data_json) error_details = data.get('error_details', {}) report_data.append({ 'url': url, 'video_id': _normalize_to_url(url).split('v=')[-1] if _normalize_to_url(url) else 'N/A', 'error_message': error_details.get('error_message', 'N/A'), 'error_code': error_details.get('error_code', 'N/A'), 'proxy_url': error_details.get('proxy_url', 'N/A'), 'timestamp': datetime.fromtimestamp(data.get('end_time', 0)).isoformat(), }) except (json.JSONDecodeError, AttributeError): report_data.append({'url': url, 'video_id': 'N/A', 'error_message': 'Could not parse error data', 'error_code': 'PARSE_ERROR', 'proxy_url': 'N/A', 'timestamp': 'N/A'}) if report_file: try: with open(report_file, 'w', newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=report_data[0].keys()) writer.writeheader() writer.writerows(report_data) logging.info(f"Successfully wrote report to {report_file}") except IOError as e: logging.error(f"Could not write report to file {report_file}: {e}") else: # Print to stdout if no file is specified logging.info("--- Failure Report (stdout) ---") for item in report_data: logging.info(f"URL: {item['url']}, Error: {item['error_code']} - {item['error_message']}") logging.info("--- End of Report ---") def handle_interruption(redis_client, queue_name, report_file): """Graceful shutdown logic for when the script is interrupted.""" logging.warning("--- Interruption Detected: Starting Shutdown Procedure ---") # 1. Pause DAGs _pause_dag("ytdlp_ops_orchestrator") _pause_dag("ytdlp_ops_dispatcher") # 2. Fail running per_url jobs _fail_running_dag_runs("ytdlp_ops_worker_per_url") # 3. Generate report logging.info("Generating final report due to interruption...") step_3_generate_report(redis_client, queue_name, report_file) # Also print to stdout if a file was specified, so user sees it immediately if report_file: logging.info("Printing report to stdout as well...") step_3_generate_report(redis_client, queue_name, None) def step_4_cleanup_queues(redis_client, queue_name: str): """Cleans up the Redis queues used by the test.""" logging.info("--- Step 4: Cleaning Up Queues ---") queues_to_delete = [ f"{queue_name}_inbox", f"{queue_name}_progress", f"{queue_name}_result", f"{queue_name}_fail", ] logging.warning(f"This will delete the following Redis keys: {queues_to_delete}") deleted_count = redis_client.delete(*queues_to_delete) logging.info(f"Cleanup complete. Deleted {deleted_count} key(s).") def main(): """Main function to parse arguments and run the regression test.""" # Register the signal handler for Ctrl+C signal.signal(signal.SIGINT, signal_handler) parser = argparse.ArgumentParser(description="Run a regression test for the ytdlp-ops system.") # Environment parser.add_argument("--redis-host", type=str, default="redis", help="Hostname or IP address of the Redis server. Defaults to 'redis' for in-container execution.") parser.add_argument("--management-host", type=str, default=os.getenv("MANAGEMENT_SERVICE_HOST", "envoy-thrift-lb"), help="Hostname of the management Thrift service.") parser.add_argument("--management-port", type=int, default=int(os.getenv("MANAGEMENT_SERVICE_PORT", 9080)), help="Port of the management Thrift service.") # Test Configuration parser.add_argument("--client", type=str, required=True, help="Client persona to test (e.g., 'mweb').") parser.add_argument("--workers", type=int, required=True, help="Total number of worker loops to start.") parser.add_argument("--workers-per-bunch", type=int, default=1, help="Number of workers per bunch.") parser.add_argument("--run-time-min", type=int, required=True, help="How long to let the test run, in minutes.") parser.add_argument("--input-file", type=str, help="Path to a file containing video IDs/URLs. If not provided, the existing queue will be used.") # Monitoring & Reporting parser.add_argument("--progress-interval-min", type=int, default=2, help="How often to query and print progress, in minutes.") parser.add_argument("--report-file", type=str, help="Path to a CSV file to write the list of failed URLs to.") parser.add_argument("--show-status", action="store_true", help="If set, show proxy and account statuses during progress monitoring.") # Actions parser.add_argument("--cleanup", action="store_true", help="If set, clear the Redis queues after the test completes.") parser.add_argument("--skip-populate", action="store_true", help="If set, skip populating the queue (assumes it's already populated).") parser.add_argument("--skip-trigger", action="store_true", help="If set, skip triggering the orchestrator (assumes it's already running).") args = parser.parse_args() # --- Setup --- redis_password = os.getenv("REDIS_PASSWORD") if not redis_password: logging.error("REDIS_PASSWORD not found in environment. Please set it in your .env file.") sys.exit(1) # Use the provided redis-host, defaulting to 'redis' for in-container execution redis_url = f"redis://:{redis_password}@{args.redis_host}:6379/0" redis_client = _get_redis_client(redis_url) queue_name = "video_queue" # Hardcoded for now, could be an arg total_urls = 0 # --- Execution --- if not args.skip_populate: if args.input_file: total_urls = step_0_populate_queue(redis_client, queue_name, args.input_file) else: logging.info("No input file provided, using existing queue.") total_urls = redis_client.llen(f"{queue_name}_inbox") if total_urls == 0: logging.warning("Queue is empty and no input file was provided. The test may not have any work to do.") else: total_urls = redis_client.llen(f"{queue_name}_inbox") logging.info(f"Skipping population. Found {total_urls} URLs in the inbox.") if not args.skip_trigger: step_1_trigger_orchestrator(args) else: logging.info("Skipping orchestrator trigger.") step_2_monitor_progress(args, redis_client, queue_name, total_urls, args.run_time_min, args.progress_interval_min, args.show_status) if INTERRUPTED: handle_interruption(redis_client, queue_name, args.report_file) else: step_3_generate_report(redis_client, queue_name, args.report_file) if args.cleanup: step_4_cleanup_queues(redis_client, queue_name) if INTERRUPTED: logging.warning("Regression test script finished due to user interruption.") sys.exit(130) # Standard exit code for Ctrl+C else: logging.info("Regression test script finished.") if __name__ == "__main__": main()