637 lines
26 KiB
Python

# -*- coding: utf-8 -*-
"""
Regression testing script for the ytdlp-ops system.
This script orchestrates a regression test by:
1. Populating a Redis queue with video URLs from an input file.
2. Triggering the `ytdlp_ops_orchestrator` Airflow DAG to start processing.
3. Monitoring the progress of the processing for a specified duration.
4. Generating a report of any failures.
5. Optionally cleaning up the Redis queues after the test.
"""
import argparse
import csv
import json
import logging
import os
import re
import requests
import subprocess
import signal
import sys
import time
from datetime import datetime, timedelta
from pathlib import Path
import redis
from tabulate import tabulate
# It's safe to import these as the script runs in the same container as Airflow
# where the yt_ops_services package is installed.
try:
from yt_ops_services.client_utils import get_thrift_client, format_timestamp
from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException
except ImportError:
logging.error("Could not import Thrift modules. Ensure this script is run in the 'airflow-regression-runner' container.")
sys.exit(1)
# --- Configuration ---
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
INTERRUPTED = False
def signal_handler(sig, frame):
"""Handles Ctrl+C interruption."""
global INTERRUPTED
if not INTERRUPTED:
logging.warning("Ctrl+C detected. Initiating graceful shutdown...")
INTERRUPTED = True
else:
logging.warning("Second Ctrl+C detected. Forcing exit.")
sys.exit(1)
# --- Helper Functions ---
def _get_redis_client(redis_url: str):
"""Gets a Redis client from a URL."""
try:
# from_url is the modern way to connect and handles password auth
client = redis.from_url(redis_url, decode_responses=True)
client.ping()
logging.info(f"Successfully connected to Redis at {client.connection_pool.connection_kwargs.get('host')}:{client.connection_pool.connection_kwargs.get('port')}")
return client
except redis.exceptions.ConnectionError as e:
logging.error(f"Failed to connect to Redis: {e}")
sys.exit(1)
except Exception as e:
logging.error(f"An unexpected error occurred while connecting to Redis: {e}")
sys.exit(1)
def _get_webserver_url():
"""
Determines the Airflow webserver URL, prioritizing MASTER_HOST_IP from .env.
"""
master_host_ip = os.getenv("MASTER_HOST_IP")
if master_host_ip:
url = f"http://{master_host_ip}:8080"
logging.info(f"Using MASTER_HOST_IP for webserver URL: {url}")
return url
# Fallback to AIRFLOW_WEBSERVER_URL or the default service name
url = os.getenv("AIRFLOW_WEBSERVER_URL", "http://airflow-webserver:8080")
logging.info(f"Using default webserver URL: {url}")
return url
def _normalize_to_url(item: str) -> str | None:
"""
Validates if an item is a recognizable YouTube URL or video ID,
and normalizes it to a standard watch URL format.
"""
if not item:
return None
video_id_pattern = r"^[a-zA-Z0-9_-]{11}$"
if re.match(video_id_pattern, item):
return f"https://www.youtube.com/watch?v={item}"
url_patterns = [r"(?:v=|\/v\/|youtu\.be\/|embed\/|shorts\/)([a-zA-Z0-9_-]{11})"]
for pattern in url_patterns:
match = re.search(pattern, item)
if match:
return f"https://www.youtube.com/watch?v={match.group(1)}"
logging.warning(f"Could not recognize '{item}' as a valid YouTube URL or video ID.")
return None
def _read_input_file(file_path: str) -> list[str]:
"""Reads video IDs/URLs from a file (CSV or JSON list)."""
path = Path(file_path)
if not path.is_file():
logging.error(f"Input file not found: {file_path}")
sys.exit(1)
content = path.read_text(encoding='utf-8')
# Try parsing as JSON list first
if content.strip().startswith('['):
try:
data = json.loads(content)
if isinstance(data, list):
logging.info(f"Successfully parsed {file_path} as a JSON list.")
return [str(item) for item in data]
except json.JSONDecodeError:
logging.warning("File looks like JSON but failed to parse. Will try treating as CSV/text.")
# Fallback to CSV/text (one item per line)
items = []
# Use io.StringIO to handle the content as a file for the csv reader
from io import StringIO
# Sniff to see if it has a header
try:
has_header = csv.Sniffer().has_header(content)
except csv.Error:
has_header = False # Not a CSV, treat as plain text
reader = csv.reader(StringIO(content))
if has_header:
next(reader) # Skip header row
for row in reader:
if row:
items.append(row[0].strip()) # Assume the ID/URL is in the first column
logging.info(f"Successfully parsed {len(items)} items from {file_path} as CSV/text.")
return items
def _get_api_auth():
"""Gets Airflow API credentials from environment variables."""
username = os.getenv("AIRFLOW_ADMIN_USERNAME", "admin")
password = os.getenv("AIRFLOW_ADMIN_PASSWORD")
if not password:
logging.error("AIRFLOW_ADMIN_PASSWORD not found in environment. Cannot interact with API.")
return None, None
return username, password
def _pause_dag(dag_id: str, is_paused: bool = True):
"""Pauses or unpauses an Airflow DAG via the REST API."""
logging.info(f"Attempting to {'pause' if is_paused else 'unpause'} DAG: {dag_id}...")
username, password = _get_api_auth()
if not username:
return
webserver_url = _get_webserver_url()
endpoint = f"{webserver_url}/api/v1/dags/{dag_id}"
payload = {"is_paused": is_paused}
try:
response = requests.patch(endpoint, auth=(username, password), json=payload, timeout=30)
response.raise_for_status()
logging.info(f"Successfully {'paused' if is_paused else 'unpaused'} DAG '{dag_id}'.")
except requests.exceptions.RequestException as e:
logging.error(f"Failed to {'pause' if is_paused else 'unpause'} DAG '{dag_id}': {e}")
if e.response is not None:
logging.error(f"Response: {e.response.text}")
def _fail_running_dag_runs(dag_id: str):
"""Finds all running DAG runs for a given DAG and marks them as failed."""
logging.info(f"Attempting to fail all running instances of DAG '{dag_id}'...")
username, password = _get_api_auth()
if not username:
return
webserver_url = _get_webserver_url()
list_endpoint = f"{webserver_url}/api/v1/dags/{dag_id}/dagRuns?state=running"
try:
# Get running DAGs
response = requests.get(list_endpoint, auth=(username, password), timeout=30)
response.raise_for_status()
running_runs = response.json().get("dag_runs", [])
if not running_runs:
logging.info(f"No running DAG runs found for '{dag_id}'.")
return
logging.info(f"Found {len(running_runs)} running DAG run(s) to fail.")
for run in running_runs:
dag_run_id = run["dag_run_id"]
update_endpoint = f"{webserver_url}/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}"
payload = {"state": "failed"}
try:
update_response = requests.patch(update_endpoint, auth=(username, password), json=payload, timeout=30)
update_response.raise_for_status()
logging.info(f" - Successfully marked DAG run '{dag_run_id}' as failed.")
except requests.exceptions.RequestException as e:
logging.error(f" - Failed to mark DAG run '{dag_run_id}' as failed: {e}")
except requests.exceptions.RequestException as e:
logging.error(f"Failed to list running DAG runs for '{dag_id}': {e}")
if e.response is not None:
logging.error(f"Response: {e.response.text}")
# --- Core Logic Functions ---
def step_0_populate_queue(redis_client, queue_name: str, input_file: str):
"""Reads URLs from a file and populates the Redis inbox queue."""
logging.info("--- Step 0: Populating Redis Queue ---")
raw_items = _read_input_file(input_file)
if not raw_items:
logging.error("No items found in the input file. Aborting.")
sys.exit(1)
valid_urls = []
for item in raw_items:
url = _normalize_to_url(item)
if url and url not in valid_urls:
valid_urls.append(url)
if not valid_urls:
logging.error("No valid YouTube URLs or IDs were found in the input file. Aborting.")
sys.exit(1)
inbox_queue = f"{queue_name}_inbox"
logging.info(f"Adding {len(valid_urls)} unique and valid URLs to Redis queue '{inbox_queue}'...")
with redis_client.pipeline() as pipe:
for url in valid_urls:
pipe.rpush(inbox_queue, url)
pipe.execute()
logging.info(f"Successfully populated queue. Total items in '{inbox_queue}': {redis_client.llen(inbox_queue)}")
return len(valid_urls)
def step_1_trigger_orchestrator(args: argparse.Namespace):
"""Triggers the ytdlp_ops_orchestrator DAG using the Airflow REST API."""
logging.info("--- Step 1: Triggering Orchestrator DAG via REST API ---")
# Get API details from environment variables
webserver_url = _get_webserver_url()
api_endpoint = f"{webserver_url}/api/v1/dags/ytdlp_ops_orchestrator/dagRuns"
# Default admin user is 'admin'
username = os.getenv("AIRFLOW_ADMIN_USERNAME", "admin")
password = os.getenv("AIRFLOW_ADMIN_PASSWORD")
if not password:
logging.error("AIRFLOW_ADMIN_PASSWORD not found in environment. Please set it in your .env file.")
sys.exit(1)
# Construct the configuration for the DAG run
conf = {
"total_workers": args.workers,
"workers_per_bunch": args.workers_per_bunch,
"clients": args.client,
}
payload = {
"conf": conf
}
logging.info(f"Triggering DAG at endpoint: {api_endpoint}")
try:
response = requests.post(
api_endpoint,
auth=(username, password),
json=payload,
timeout=30 # 30 second timeout
)
response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx)
logging.info("Successfully triggered the orchestrator DAG.")
logging.debug(f"Airflow API response:\n{response.json()}")
except requests.exceptions.RequestException as e:
logging.error("Failed to trigger the orchestrator DAG via REST API.")
logging.error(f"Error: {e}")
if e.response is not None:
logging.error(f"Response status code: {e.response.status_code}")
logging.error(f"Response text: {e.response.text}")
sys.exit(1)
def step_2_monitor_progress(args: argparse.Namespace, redis_client, queue_name: str, total_urls: int, run_time_min: int, interval_min: int, show_status: bool):
"""Monitors the Redis queues for the duration of the test."""
logging.info("--- Step 2: Monitoring Progress ---")
end_time = datetime.now() + timedelta(minutes=run_time_min)
inbox_q = f"{queue_name}_inbox"
progress_q = f"{queue_name}_progress"
result_q = f"{queue_name}_result"
fail_q = f"{queue_name}_fail"
while datetime.now() < end_time and not INTERRUPTED:
try:
inbox_len = redis_client.llen(inbox_q)
progress_len = redis_client.hlen(progress_q)
result_len = redis_client.hlen(result_q)
fail_len = redis_client.hlen(fail_q)
processed = result_len + fail_len
success_len = 0
if result_len > 0:
# This is inefficient but gives a more accurate success count
results = redis_client.hgetall(result_q)
success_len = sum(1 for v in results.values() if '"status": "success"' in v)
logging.info(
f"Progress: {processed}/{total_urls} | "
f"Success: {success_len} | Failed: {fail_len} | "
f"In Progress: {progress_len} | Inbox: {inbox_len}"
)
if show_status:
# This function now connects directly to services to get status
get_system_status(args, redis_client)
except Exception as e:
logging.error(f"Error while querying Redis for progress: {e}")
# Wait for the interval, but check for interruption every second
# for a more responsive shutdown.
wait_until = time.time() + interval_min * 60
while time.time() < wait_until and not INTERRUPTED:
# Check if we are past the main end_time
if datetime.now() >= end_time:
break
time.sleep(1)
if INTERRUPTED:
logging.info("Monitoring interrupted.")
else:
logging.info("Monitoring period has ended.")
# --- System Status Functions (Direct Connect) ---
def _list_proxy_statuses(client, server_identity=None):
"""Lists proxy statuses by connecting directly to the Thrift service."""
logging.info(f"--- Proxy Statuses (Server: {server_identity or 'ALL'}) ---")
try:
statuses = client.getProxyStatus(server_identity)
if not statuses:
logging.info("No proxy statuses found.")
return
status_list = []
headers = ["Server", "Proxy URL", "Status", "Success", "Failures", "Last Success", "Last Failure"]
for s in statuses:
status_list.append({
"Server": s.serverIdentity, "Proxy URL": s.proxyUrl, "Status": s.status,
"Success": s.successCount, "Failures": s.failureCount,
"Last Success": format_timestamp(s.lastSuccessTimestamp),
"Last Failure": format_timestamp(s.lastFailureTimestamp),
})
logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid'))
except (PBServiceException, PBUserException) as e:
logging.error(f"Failed to get proxy statuses: {e.message}")
except Exception as e:
logging.error(f"An unexpected error occurred while getting proxy statuses: {e}", exc_info=True)
def _list_account_statuses(client, redis_client, account_id=None):
"""Lists account statuses from Thrift, enriched with live Redis data."""
logging.info(f"--- Account Statuses (Account: {account_id or 'ALL'}) ---")
try:
statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None)
if not statuses:
logging.info("No account statuses found.")
return
status_list = []
for s in statuses:
status_str = s.status
if 'RESTING' in status_str:
try:
expiry_ts_bytes = redis_client.hget(f"account_status:{s.accountId}", "resting_until")
if expiry_ts_bytes:
expiry_ts = float(expiry_ts_bytes)
now = datetime.now().timestamp()
if now < expiry_ts:
remaining_seconds = int(expiry_ts - now)
status_str = f"RESTING ({remaining_seconds}s left)"
except Exception:
pass # Ignore if parsing fails
last_success = float(s.lastSuccessTimestamp) if s.lastSuccessTimestamp else 0
last_failure = float(s.lastFailureTimestamp) if s.lastFailureTimestamp else 0
last_activity = max(last_success, last_failure)
status_list.append({
"Account ID": s.accountId, "Status": status_str, "Success": s.successCount,
"Failures": s.failureCount, "Last Success": format_timestamp(s.lastSuccessTimestamp),
"Last Failure": format_timestamp(s.lastFailureTimestamp), "Last Proxy": s.lastUsedProxy or "N/A",
"_last_activity": last_activity,
})
status_list.sort(key=lambda item: item.get('_last_activity', 0), reverse=True)
for item in status_list:
del item['_last_activity']
logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid'))
except (PBServiceException, PBUserException) as e:
logging.error(f"Failed to get account statuses: {e.message}")
except Exception as e:
logging.error(f"An unexpected error occurred while getting account statuses: {e}", exc_info=True)
def _list_client_statuses(redis_client):
"""Lists client statistics from Redis."""
logging.info("--- Client Statuses ---")
try:
stats_key = "client_stats"
all_stats_raw = redis_client.hgetall(stats_key)
if not all_stats_raw:
logging.info("No client stats found in Redis.")
return
status_list = []
for client, stats_json in all_stats_raw.items():
try:
stats = json.loads(stats_json)
def format_latest(data):
if not data: return "N/A"
ts = format_timestamp(data.get('timestamp'))
url = data.get('url', 'N/A')
video_id_match = re.search(r'v=([a-zA-Z0-9_-]{11})', url)
video_id = video_id_match.group(1) if video_id_match else 'N/A'
return f"{ts} ({video_id})"
status_list.append({
"Client": client, "Success": stats.get('success_count', 0),
"Failures": stats.get('failure_count', 0),
"Last Success": format_latest(stats.get('latest_success')),
"Last Failure": format_latest(stats.get('latest_failure')),
})
except (json.JSONDecodeError, AttributeError):
status_list.append({"Client": client, "Success": "ERROR", "Failures": "ERROR", "Last Success": "Parse Error", "Last Failure": "Parse Error"})
status_list.sort(key=lambda item: item.get('Client', ''))
logging.info("\n" + tabulate(status_list, headers='keys', tablefmt='grid'))
except Exception as e:
logging.error(f"An unexpected error occurred while getting client statuses: {e}", exc_info=True)
def get_system_status(args: argparse.Namespace, redis_client):
"""Connects to services and prints status tables."""
logging.info("--- Getting System Status ---")
client, transport = None, None
try:
client, transport = get_thrift_client(args.management_host, args.management_port)
_list_proxy_statuses(client)
_list_account_statuses(client, redis_client)
_list_client_statuses(redis_client)
except Exception as e:
logging.error(f"Could not get system status: {e}")
finally:
if transport and transport.isOpen():
transport.close()
def step_3_generate_report(redis_client, queue_name: str, report_file: str | None):
"""Generates a CSV report of failed items."""
logging.info("--- Step 3: Generating Report ---")
fail_q = f"{queue_name}_fail"
failed_items = redis_client.hgetall(fail_q)
if not failed_items:
logging.info("No items found in the fail queue. No report will be generated.")
return
logging.info(f"Found {len(failed_items)} failed items. Writing to report...")
report_data = []
for url, data_json in failed_items.items():
try:
data = json.loads(data_json)
error_details = data.get('error_details', {})
report_data.append({
'url': url,
'video_id': _normalize_to_url(url).split('v=')[-1] if _normalize_to_url(url) else 'N/A',
'error_message': error_details.get('error_message', 'N/A'),
'error_code': error_details.get('error_code', 'N/A'),
'proxy_url': error_details.get('proxy_url', 'N/A'),
'timestamp': datetime.fromtimestamp(data.get('end_time', 0)).isoformat(),
})
except (json.JSONDecodeError, AttributeError):
report_data.append({'url': url, 'video_id': 'N/A', 'error_message': 'Could not parse error data', 'error_code': 'PARSE_ERROR', 'proxy_url': 'N/A', 'timestamp': 'N/A'})
if report_file:
try:
with open(report_file, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=report_data[0].keys())
writer.writeheader()
writer.writerows(report_data)
logging.info(f"Successfully wrote report to {report_file}")
except IOError as e:
logging.error(f"Could not write report to file {report_file}: {e}")
else:
# Print to stdout if no file is specified
logging.info("--- Failure Report (stdout) ---")
for item in report_data:
logging.info(f"URL: {item['url']}, Error: {item['error_code']} - {item['error_message']}")
logging.info("--- End of Report ---")
def handle_interruption(redis_client, queue_name, report_file):
"""Graceful shutdown logic for when the script is interrupted."""
logging.warning("--- Interruption Detected: Starting Shutdown Procedure ---")
# 1. Pause DAGs
_pause_dag("ytdlp_ops_orchestrator")
_pause_dag("ytdlp_ops_dispatcher")
# 2. Fail running per_url jobs
_fail_running_dag_runs("ytdlp_ops_worker_per_url")
# 3. Generate report
logging.info("Generating final report due to interruption...")
step_3_generate_report(redis_client, queue_name, report_file)
# Also print to stdout if a file was specified, so user sees it immediately
if report_file:
logging.info("Printing report to stdout as well...")
step_3_generate_report(redis_client, queue_name, None)
def step_4_cleanup_queues(redis_client, queue_name: str):
"""Cleans up the Redis queues used by the test."""
logging.info("--- Step 4: Cleaning Up Queues ---")
queues_to_delete = [
f"{queue_name}_inbox",
f"{queue_name}_progress",
f"{queue_name}_result",
f"{queue_name}_fail",
]
logging.warning(f"This will delete the following Redis keys: {queues_to_delete}")
deleted_count = redis_client.delete(*queues_to_delete)
logging.info(f"Cleanup complete. Deleted {deleted_count} key(s).")
def main():
"""Main function to parse arguments and run the regression test."""
# Register the signal handler for Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="Run a regression test for the ytdlp-ops system.")
# Environment
parser.add_argument("--redis-host", type=str, default="redis", help="Hostname or IP address of the Redis server. Defaults to 'redis' for in-container execution.")
parser.add_argument("--management-host", type=str, default=os.getenv("MANAGEMENT_SERVICE_HOST", "envoy-thrift-lb"), help="Hostname of the management Thrift service.")
parser.add_argument("--management-port", type=int, default=int(os.getenv("MANAGEMENT_SERVICE_PORT", 9080)), help="Port of the management Thrift service.")
# Test Configuration
parser.add_argument("--client", type=str, required=True, help="Client persona to test (e.g., 'mweb').")
parser.add_argument("--workers", type=int, required=True, help="Total number of worker loops to start.")
parser.add_argument("--workers-per-bunch", type=int, default=1, help="Number of workers per bunch.")
parser.add_argument("--run-time-min", type=int, required=True, help="How long to let the test run, in minutes.")
parser.add_argument("--input-file", type=str, help="Path to a file containing video IDs/URLs. If not provided, the existing queue will be used.")
# Monitoring & Reporting
parser.add_argument("--progress-interval-min", type=int, default=2, help="How often to query and print progress, in minutes.")
parser.add_argument("--report-file", type=str, help="Path to a CSV file to write the list of failed URLs to.")
parser.add_argument("--show-status", action="store_true", help="If set, show proxy and account statuses during progress monitoring.")
# Actions
parser.add_argument("--cleanup", action="store_true", help="If set, clear the Redis queues after the test completes.")
parser.add_argument("--skip-populate", action="store_true", help="If set, skip populating the queue (assumes it's already populated).")
parser.add_argument("--skip-trigger", action="store_true", help="If set, skip triggering the orchestrator (assumes it's already running).")
args = parser.parse_args()
# --- Setup ---
redis_password = os.getenv("REDIS_PASSWORD")
if not redis_password:
logging.error("REDIS_PASSWORD not found in environment. Please set it in your .env file.")
sys.exit(1)
# Use the provided redis-host, defaulting to 'redis' for in-container execution
redis_url = f"redis://:{redis_password}@{args.redis_host}:6379/0"
redis_client = _get_redis_client(redis_url)
queue_name = "video_queue" # Hardcoded for now, could be an arg
total_urls = 0
# --- Execution ---
if not args.skip_populate:
if args.input_file:
total_urls = step_0_populate_queue(redis_client, queue_name, args.input_file)
else:
logging.info("No input file provided, using existing queue.")
total_urls = redis_client.llen(f"{queue_name}_inbox")
if total_urls == 0:
logging.warning("Queue is empty and no input file was provided. The test may not have any work to do.")
else:
total_urls = redis_client.llen(f"{queue_name}_inbox")
logging.info(f"Skipping population. Found {total_urls} URLs in the inbox.")
if not args.skip_trigger:
step_1_trigger_orchestrator(args)
else:
logging.info("Skipping orchestrator trigger.")
step_2_monitor_progress(args, redis_client, queue_name, total_urls, args.run_time_min, args.progress_interval_min, args.show_status)
if INTERRUPTED:
handle_interruption(redis_client, queue_name, args.report_file)
else:
step_3_generate_report(redis_client, queue_name, args.report_file)
if args.cleanup:
step_4_cleanup_queues(redis_client, queue_name)
if INTERRUPTED:
logging.warning("Regression test script finished due to user interruption.")
sys.exit(130) # Standard exit code for Ctrl+C
else:
logging.info("Regression test script finished.")
if __name__ == "__main__":
main()