diff --git a/VERSION.client b/VERSION.client new file mode 100644 index 0000000..6d7de6e --- /dev/null +++ b/VERSION.client @@ -0,0 +1 @@ +1.0.2 diff --git a/airflow/Dockerfile b/airflow/Dockerfile index a48a7cb..87a13bd 100644 --- a/airflow/Dockerfile +++ b/airflow/Dockerfile @@ -132,13 +132,18 @@ COPY --chown=airflow:airflow bin/ytops-client /app/bin/ytops-client RUN chmod +x /app/bin/ytops-client ENV PATH="/app/bin:${PATH}" -# Install the package in editable mode. This runs setup.py and installs all dependencies -# listed in `install_requires`, making the `yt_ops_services` module available everywhere. +# Install dependencies for the ytops_client package, then install the package itself +# in editable mode. This makes the `yt_ops_services` and `ytops_client` modules +# available everywhere. # Bypass the pip root check again. RUN mv /usr/local/bin/pip /usr/local/bin/pip.orig && \ + python3 -m pip install --no-cache-dir -r ytops_client/requirements.txt && \ python3 -m pip install --no-cache-dir -e . && \ mv /usr/local/bin/pip.orig /usr/local/bin/pip +# Ensure all files in /app, including the generated .egg-info directory, are owned by the airflow user. +RUN chown -R airflow:airflow /app + # Copy token generator scripts and utils with correct permissions # COPY --chown=airflow:airflow generate_tokens_direct.mjs ./ # COPY --chown=airflow:airflow utils ./utils/ diff --git a/airflow/configs/docker-compose-ytdlp-ops.yaml.j2 b/airflow/configs/docker-compose-ytdlp-ops.yaml.j2 index ed6b329..2032582 100644 --- a/airflow/configs/docker-compose-ytdlp-ops.yaml.j2 +++ b/airflow/configs/docker-compose-ytdlp-ops.yaml.j2 @@ -127,13 +127,11 @@ services: - "${CAMOUFOX_PROXIES}" - "--camoufox-endpoints-file" - "/app/config/camoufox_endpoints.json" - - "--print-tokens" - "--stop-if-no-proxy" - "--comms-log-root-dir" - "/app/logs/yt-dlp-ops/communication_logs" - - "--bgutils-no-innertube" - - "--visitor-rotation-threshold" - - "250" + #- "--visitor-rotation-threshold" + #- "250" {% endif %} restart: unless-stopped pull_policy: always diff --git a/airflow/dags/QUEUE.md b/airflow/dags/QUEUE.md new file mode 100644 index 0000000..d34b42a --- /dev/null +++ b/airflow/dags/QUEUE.md @@ -0,0 +1,76 @@ +V2 System: Separated Auth & Download Flow + +The v2 system splits the process into two distinct stages, each with its own set of queues. The base names for these queues are queue2_auth and queue2_dl. + +───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +1. Authentication Stage (ytdlp_ops_v02_worker_per_url_auth) + +This stage is responsible for taking a raw YouTube URL, authenticating with the yt-ops-server to get an info.json, and creating granular download tasks. + + • Getting Data (Input): + • Queue: queue2_auth_inbox + • Redis Type: LIST + • Purpose: This is the main entry point for the entire v2 system. Raw YouTube URLs or video IDs are pushed here. The ytdlp_ops_v02_dispatcher_auth DAG pulls URLs from this list to start the process. + • Reporting Results: + • Success: + • Queue: queue2_auth_result (Redis HASH) - A success record for the authentication step is stored here. + • Queue: queue_dl_format_tasks (Redis LIST) - This is the critical handoff queue. Upon successful authentication, the auth worker resolves the desired formats (e.g., bestvideo+bestaudio) into specific format IDs (e.g., 299, 140) and pushes one JSON job payload for each format into this list. This queue + feeds the download stage. + • Failure: + • Queue: queue2_auth_fail (Redis HASH) - If the authentication fails due to a system error (like bot detection or a proxy failure), the error details are stored here. + • Skipped: + • Queue: queue2_auth_skipped (Redis HASH) - If the video is unavailable for a non-system reason (e.g., it's private, deleted, or geo-restricted), the URL is logged here. This is not considered a system failure. + • Tracking Tasks: + • Queue: queue2_auth_progress + • Redis Type: HASH + • Purpose: When an auth worker picks up a URL, it adds an entry to this hash to show that the URL is actively being processed. The entry is removed upon completion (success, failure, or skip). + +───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +2. Download Stage (ytdlp_ops_v02_worker_per_url_dl) + +This stage is responsible for executing the download and probing of a single media format, based on the job created by the auth worker. + + • Getting Data (Input): + • Queue: queue_dl_format_tasks + • Redis Type: LIST + • Purpose: The ytdlp_ops_v02_worker_per_url_dl DAG pulls granular job payloads from this list. Each payload contains everything needed to download a single format (the path to the info.json, the format ID, etc.). + • Reporting Results: + • Success: + • Queue: queue2_dl_result (Redis HASH) - A success record for the download of a specific format is stored here. + • Failure: + • Queue: queue2_dl_fail (Redis HASH) - If the download or probe fails, the error is logged here. As seen in ytdlp_mgmt_queues.py, these failed items can be requeued, which sends them back to queue2_auth_inbox to start the process over. + • Skipped: + • Queue: queue2_dl_skipped (Redis HASH) - Used for unrecoverable download errors (e.g., HTTP 403 Forbidden), similar to the auth stage. + • Tracking Tasks: + • Queue: queue2_dl_progress + • Redis Type: HASH + • Purpose: Tracks download tasks that are actively in progress. + +Summary Table (V2) + + + Queue Name Pattern Redis Type Purpose + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + queue2_auth_inbox LIST Input for Auth: Holds raw YouTube URLs to be authenticated. + queue2_auth_progress HASH Tracks URLs currently being authenticated. + queue2_auth_result HASH Stores successful authentication results. + queue2_auth_fail HASH Stores failed authentication attempts. + queue2_auth_skipped HASH Stores URLs skipped due to content issues (private, deleted, etc.). + queue_dl_format_tasks LIST Input for Download: Holds granular download jobs (one per format) created by the auth worker. + queue2_dl_progress HASH Tracks download jobs currently in progress. + queue2_dl_result HASH Stores successful download results. + queue2_dl_fail HASH Stores failed download attempts. + queue2_dl_skipped HASH Stores downloads skipped due to unrecoverable errors (e.g., 403 Forbidden). + + +───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +V1 System (Monolithic) for Contrast + +For completeness, the older v1 system (ytdlp_ops_v01_worker_per_url) uses a simpler, monolithic set of queues, typically with the base name video_queue. + + • Input: video_queue_inbox (Redis LIST) + • Results: video_queue_result, video_queue_fail, video_queue_skipped (all Redis HASHes) + • In-Progress: video_queue_progress (Redis HASH) + +In this model, there is no handoff between stages; a single worker handles both authentication and download for all requested formats of a URL. + diff --git a/airflow/dags/ytdlp_mgmt_proxy_account.py b/airflow/dags/ytdlp_mgmt_proxy_account.py index b189ba0..fe8fa7d 100644 --- a/airflow/dags/ytdlp_mgmt_proxy_account.py +++ b/airflow/dags/ytdlp_mgmt_proxy_account.py @@ -3,6 +3,12 @@ DAG to manage the state of proxies and accounts used by the ytdlp-ops-server. """ from __future__ import annotations +# --- Add project root to path to allow for yt-ops-client imports --- +import sys +# The yt-ops-client package is installed in editable mode in /app +if '/app' not in sys.path: + sys.path.insert(0, '/app') + import logging import json import re @@ -17,6 +23,7 @@ from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.models.taskinstance import TaskInstance from airflow.operators.python import PythonOperator +from airflow.decorators import task from airflow.utils.dates import days_ago from airflow.models.variable import Variable from airflow.providers.redis.hooks.redis import RedisHook @@ -35,12 +42,13 @@ except ImportError: except Exception as e: logger.error(f"Error applying Thrift exceptions patch: {e}") -# Thrift imports +# Thrift imports (kept for DEPRECATED proxy management) try: + from ytops_client.profile_manager_tool import ProfileManager, format_duration, format_timestamp from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException - from yt_ops_services.client_utils import get_thrift_client, format_timestamp + from yt_ops_services.client_utils import get_thrift_client except ImportError as e: - logger.critical(f"Could not import Thrift modules: {e}. Ensure yt_ops_services package is installed correctly.") + logger.critical(f"Could not import project modules: {e}. Ensure yt-ops-client and services are installed correctly.") # Fail DAG parsing if thrift modules are not available raise @@ -70,6 +78,7 @@ def _get_redis_client(redis_conn_id: str): def _list_proxy_statuses(client, server_identity): """Lists the status of proxies.""" logger.info(f"Listing proxy statuses for server: {server_identity or 'ALL'}") + logger.warning("DEPRECATED: Proxy management is now handled by the standalone policy-enforcer.") logger.info("NOTE: Proxy statuses are read from server's internal state via Thrift service") try: statuses = client.getProxyStatus(server_identity) @@ -126,88 +135,54 @@ def _list_proxy_statuses(client, server_identity): print("NOTE: To see Recent Accounts/Machines, the server's `getProxyStatus` method must be updated to return these fields.") -def _list_account_statuses(client, account_id, redis_conn_id): - """Lists the status of accounts, enriching with live data from Redis.""" - logger.info(f"Listing account statuses for account: {account_id or 'ALL'}") - logger.info("NOTE: Account statuses are read from the Thrift service and enriched with live data from Redis.") +def _list_account_statuses(pm: ProfileManager, account_id_prefix: str | None): + """Lists the status of profiles from Redis using ProfileManager.""" + logger.info(f"Listing v2 profile statuses from Redis for prefix: {account_id_prefix or 'ALL'}") - redis_client = None try: - redis_client = _get_redis_client(redis_conn_id) - logger.info("Successfully connected to Redis to fetch detailed account status.") - except Exception as e: - logger.warning(f"Could not connect to Redis to get detailed status. Will show basic status. Error: {e}") - redis_client = None - - try: - # The thrift method takes accountId (specific) or accountPrefix. - # If account_id is provided, we use it. If not, we get all by leaving both params as None. - statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None) - if not statuses: - print("\n--- Account Statuses ---\nNo account statuses found.\n------------------------\n") + profiles = pm.list_profiles() + if not profiles: + print("\n--- V2 Profile Statuses ---\nNo profiles found.\n---------------------------\n") return from tabulate import tabulate status_list = [] + now = time.time() - for s in statuses: - status_str = s.status - # If an account is resting, get the live countdown from Redis for accuracy. - if redis_client and 'RESTING' in status_str: - try: - status_key = f"account_status:{s.accountId}" - # The server stores resting expiry time in 'resting_until'. - expiry_ts_bytes = redis_client.hget(status_key, "resting_until") - if expiry_ts_bytes: - expiry_ts = float(expiry_ts_bytes) - now = datetime.now().timestamp() - if now >= expiry_ts: - status_str = "ACTIVE (was RESTING)" - else: - remaining_seconds = int(expiry_ts - now) - if remaining_seconds > 3600: - status_str = f"RESTING (active in {remaining_seconds // 3600}h {remaining_seconds % 3600 // 60}m)" - elif remaining_seconds > 60: - status_str = f"RESTING (active in {remaining_seconds // 60}m {remaining_seconds % 60}s)" - else: - status_str = f"RESTING (active in {remaining_seconds}s)" - except Exception as e: - logger.warning(f"Could not parse resting time for {s.accountId} from Redis: {e}. Using server status.") + for p in profiles: + if account_id_prefix and not p['name'].startswith(account_id_prefix): + continue + + status = p.get('state', 'UNKNOWN') + if status == 'RESTING': + rest_until = p.get('rest_until', 0) + if rest_until > now: + status += f" ({format_duration(rest_until - now)} left)" + elif status == 'COOLDOWN': + cooldown_until = p.get('cooldown_until', 0) + if cooldown_until > now: + status += f" ({format_duration(cooldown_until - now)} left)" - # Determine the last activity timestamp for sorting - last_success = float(s.lastSuccessTimestamp) if s.lastSuccessTimestamp else 0 - last_failure = float(s.lastFailureTimestamp) if s.lastFailureTimestamp else 0 - last_activity = max(last_success, last_failure) status_item = { - "Account ID": s.accountId, - "Status": status_str, - "Success": s.successCount, - "Failures": s.failureCount, - "Last Success": format_timestamp(s.lastSuccessTimestamp), - "Last Failure": format_timestamp(s.lastFailureTimestamp), - "Last Proxy": s.lastUsedProxy or "N/A", - "Last Machine": s.lastUsedMachine or "N/A", - "_last_activity": last_activity, # Add a temporary key for sorting + "Name": p.get('name'), + "Status": status, + "Proxy": p.get('proxy', 'N/A'), + "Success": p.get('success', 0), + "Failures": p.get('failure', 0), + "Last Activity": format_timestamp(p.get('last_activity_ts', 0)), + "Owner": p.get('owner', 'None'), + "Lock Time": format_duration(now - p.get('lock_ts', 0)) if p.get('state') == 'LOCKED' else 'N/A', } status_list.append(status_item) - # Sort the list by the last activity timestamp in descending order - status_list.sort(key=lambda item: item.get('_last_activity', 0), reverse=True) + status_list.sort(key=lambda item: item.get('Name', '')) - # Remove the temporary sort key before printing - for item in status_list: - del item['_last_activity'] - - print("\n--- Account Statuses ---") - # The f-string with a newline ensures the table starts on a new line in the logs. + print("\n--- V2 Profile Statuses ---") print(f"\n{tabulate(status_list, headers='keys', tablefmt='grid')}") - print("------------------------\n") - except (PBServiceException, PBUserException) as e: - logger.error(f"Failed to get account statuses: {e.message}", exc_info=True) - print(f"\nERROR: Could not retrieve account statuses. Server returned: {e.message}\n") + print("---------------------------\n") except Exception as e: - logger.error(f"An unexpected error occurred while getting account statuses: {e}", exc_info=True) + logger.error(f"An unexpected error occurred while getting v2 profile statuses: {e}", exc_info=True) print(f"\nERROR: An unexpected error occurred: {e}\n") @@ -317,6 +292,49 @@ def _list_activity_counters(redis_conn_id: str): print(f"\nERROR: An unexpected error occurred: {e}\n") +def _create_profiles_from_json(**context): + """Creates profiles by calling the yt-ops-client setup-profiles tool.""" + import subprocess + import tempfile + import yaml + + params = context['params'] + json_payload_str = params.get('create_profiles_json') + if not json_payload_str: + raise AirflowException("Parameter 'create_profiles_json' is empty.") + + try: + # We accept JSON but the setup tool uses YAML, so we parse and dump. + # This also serves as validation. + json_payload = json.loads(json_payload_str) + yaml_payload = yaml.dump(json_payload) + except (json.JSONDecodeError, yaml.YAMLError) as e: + raise AirflowException(f"Invalid JSON/YAML in 'create_profiles_json': {e}") + + with tempfile.NamedTemporaryFile(mode='w+', delete=True, suffix='.yaml', prefix='airflow-profile-setup-') as temp_policy_file: + temp_policy_file.write(yaml_payload) + temp_policy_file.flush() + logger.info(f"Created temporary policy file for profile setup: {temp_policy_file.name}") + + cmd = [ + 'ytops-client', 'setup-profiles', + '--policy', temp_policy_file.name, + ] + # Pass through Redis connection params if provided + if params.get('redis_conn_id') != DEFAULT_REDIS_CONN_ID: + logger.warning("Custom Redis connection is not supported for `create_profiles` yet. It will use the default from .env or localhost.") + + logger.info(f"Running command: {' '.join(cmd)}") + process = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + if process.stdout: + print(f"\n--- yt-ops-client setup-profiles STDOUT ---\n{process.stdout}\n----------------------------------------\n") + if process.stderr: + print(f"\n--- yt-ops-client setup-profiles STDERR ---\n{process.stderr}\n----------------------------------------\n") + + if process.returncode != 0: + raise AirflowException(f"Profile creation failed with exit code {process.returncode}.") + def manage_system_callable(**context): """Main callable to interact with the system management endpoints.""" # Log version for debugging @@ -327,7 +345,7 @@ def manage_system_callable(**context): action = params["action"] # For Thrift actions, use the new management host/port - if entity not in ["activity_counters"]: + if entity not in ["activity_counters", "account"]: host = params["management_host"] port = params["management_port"] else: @@ -335,12 +353,13 @@ def manage_system_callable(**context): server_identity = params.get("server_identity") proxy_url = params.get("proxy_url") - account_id = params.get("account_id") + account_id = params.get("account_id") # Used as prefix for v2 profiles + redis_env = params.get("redis_env") # --- Validate Action/Entity Combination and Parameters --- valid_actions = { "proxy": ["list_with_status", "ban", "unban", "ban_all", "unban_all", "delete_from_redis"], - "account": ["list_with_status", "ban", "unban", "unban_all", "delete_from_redis"], + "account": ["list_with_status", "create_profiles", "ban", "unban", "activate", "pause", "delete", "delete_all"], "client": ["list_with_status", "delete_from_redis"], "accounts_and_proxies": ["list_with_status", "ban", "unban", "ban_all", "unban_all", "delete_from_redis"], "activity_counters": ["list_with_status"], @@ -360,9 +379,23 @@ def manage_system_callable(**context): raise ValueError(f"A 'proxy_url' is required for proxy action '{action}'.") if entity == "account": - if action in ["ban", "unban"] and not account_id: - raise ValueError(f"An 'account_id' is required for account action '{action}'.") + if action in ["ban", "unban", "pause", "activate", "delete"] and not account_id: + raise ValueError(f"An 'account_id' (profile name) is required for account action '{action}'.") + # --- ProfileManager setup for v2 account actions --- + pm = None + if entity == "account": + try: + redis_hook = RedisHook(redis_conn_id=params["redis_conn_id"]) + if redis_env: + key_prefix = f"{redis_env}_profile_mgmt_" + else: + raise ValueError("A 'redis_env' (e.g., 'sim_auth') must be provided for v2 profile actions.") + + pm = ProfileManager(redis_hook=redis_hook, key_prefix=key_prefix) + logger.info(f"Initialized ProfileManager for env '{redis_env}' (Redis key prefix: '{key_prefix}')") + except Exception as e: + raise AirflowException(f"Failed to initialize ProfileManager: {e}") # --- Handle Activity Counter action --- if entity == "activity_counters": @@ -372,13 +405,25 @@ def manage_system_callable(**context): else: raise ValueError(f"Action '{action}' is not valid for entity 'activity_counters'. Only 'list_with_status' is supported.") - # Handle Thrift-based deletion actions + # Handle direct Redis deletion actions if action == "delete_from_redis": + if entity == "client": + logger.info("Deleting all client stats from Redis...") + redis_client = _get_redis_client(params["redis_conn_id"]) + result = redis_client.delete("client_stats") + if result > 0: + print(f"\nSuccessfully deleted 'client_stats' key from Redis.\n") + else: + print(f"\nKey 'client_stats' not found in Redis. Nothing to delete.\n") + return + + # All other delete actions are handled by Thrift for now. client, transport = None, None try: client, transport = get_thrift_client(host, port) if entity == "proxy": + logger.warning("DEPRECATED: Proxy management is now handled by the standalone policy-enforcer.") proxy_url = params.get("proxy_url") server_identity = params.get("server_identity") @@ -391,63 +436,12 @@ def manage_system_callable(**context): print(f"\nFailed to delete proxy '{proxy_url}' for server '{server_identity}' from Redis.\n") else: logger.info("Deleting all proxies from Redis via Thrift service...") - # If server_identity is provided, delete all proxies for that server - # If server_identity is None, delete all proxies for ALL servers result = client.deleteAllProxiesFromRedis(server_identity) if server_identity: print(f"\nSuccessfully deleted all proxies for server '{server_identity}' from Redis. Count: {result}\n") else: print(f"\nSuccessfully deleted all proxies from Redis across ALL servers. Count: {result}\n") - elif entity == "account": - account_id = params.get("account_id") - - if account_id: - logger.info(f"Deleting account '{account_id}' from Redis via Thrift service...") - result = client.deleteAccountFromRedis(account_id) - if result: - print(f"\nSuccessfully deleted account '{account_id}' from Redis.\n") - else: - print(f"\nFailed to delete account '{account_id}' from Redis.\n") - else: - logger.info("Deleting all accounts from Redis via Thrift service...") - # If account_id is provided as prefix, delete all accounts with that prefix - # If account_id is None, delete all accounts - account_prefix = params.get("account_id") - result = client.deleteAllAccountsFromRedis(account_prefix) - if account_prefix: - print(f"\nSuccessfully deleted all accounts with prefix '{account_prefix}' from Redis. Count: {result}\n") - else: - print(f"\nSuccessfully deleted all accounts from Redis. Count: {result}\n") - - elif entity == "accounts_and_proxies": - # Delete accounts - account_prefix = params.get("account_id") # Repurpose account_id param as an optional prefix - logger.info("Deleting accounts from Redis via Thrift service...") - account_result = client.deleteAllAccountsFromRedis(account_prefix) - if account_prefix: - print(f"\nSuccessfully deleted {account_result} account keys with prefix '{account_prefix}' from Redis.\n") - else: - print(f"\nSuccessfully deleted {account_result} account keys from Redis.\n") - - # Delete proxies - server_identity = params.get("server_identity") - logger.info("Deleting proxies from Redis via Thrift service...") - proxy_result = client.deleteAllProxiesFromRedis(server_identity) - if server_identity: - print(f"\nSuccessfully deleted {proxy_result} proxy keys for server '{server_identity}' from Redis.\n") - else: - print(f"\nSuccessfully deleted {proxy_result} proxy keys from Redis across ALL servers.\n") - - elif entity == "client": - logger.info("Deleting all client stats from Redis...") - redis_client = _get_redis_client(params["redis_conn_id"]) - result = redis_client.delete("client_stats") - if result > 0: - print(f"\nSuccessfully deleted 'client_stats' key from Redis.\n") - else: - print(f"\nKey 'client_stats' not found in Redis. Nothing to delete.\n") - except (PBServiceException, PBUserException) as e: logger.error(f"Thrift error performing delete action: {e.message}", exc_info=True) print(f"\nERROR: Thrift service error: {e.message}\n") @@ -460,16 +454,21 @@ def manage_system_callable(**context): if transport and transport.isOpen(): transport.close() logger.info("Thrift connection closed.") - return # End execution for this action + return + # --- Main Action Handler --- client, transport = None, None try: - client, transport = get_thrift_client(host, port) + # Connect to Thrift only if needed + if entity == "proxy": + client, transport = get_thrift_client(host, port) if entity == "client": if action == "list_with_status": _list_client_statuses(params["redis_conn_id"]) + elif entity == "proxy": + logger.warning("DEPRECATED: Proxy management is now handled by the standalone policy-enforcer. These actions are for legacy support.") if action == "list_with_status": _list_proxy_statuses(client, server_identity) elif action == "ban": @@ -483,300 +482,60 @@ def manage_system_callable(**context): client.unbanProxy(proxy_url, server_identity) print(f"Successfully sent request to unban proxy '{proxy_url}'.") elif action == "ban_all": - if server_identity: + if server_identity: logger.info(f"Banning all proxies for server '{server_identity}'...") client.banAllProxies(server_identity) print(f"Successfully sent request to ban all proxies for '{server_identity}'.") - else: - logger.info("No server_identity provided. Banning all proxies for ALL servers...") - all_statuses = client.getProxyStatus(None) - if not all_statuses: - print("\nNo proxy statuses found for any server. Nothing to ban.\n") - return - - all_server_identities = sorted(list(set(s.serverIdentity for s in all_statuses))) - logger.info(f"Found {len(all_server_identities)} server identities: {all_server_identities}") - print(f"Found {len(all_server_identities)} server identities. Sending ban request for each...") - - success_count = 0 - fail_count = 0 - for identity in all_server_identities: - try: - client.banAllProxies(identity) - logger.info(f" - Sent ban_all for '{identity}'.") - success_count += 1 - except Exception as e: - logger.error(f" - Failed to ban all proxies for '{identity}': {e}") - fail_count += 1 - - print(f"\nSuccessfully sent ban_all requests for {success_count} server identities.") - if fail_count > 0: - print(f"Failed to send ban_all requests for {fail_count} server identities. See logs for details.") + else: + raise ValueError("A 'server_identity' is required for 'ban_all' on proxies.") elif action == "unban_all": if server_identity: logger.info(f"Unbanning all proxy statuses for server '{server_identity}'...") client.resetAllProxyStatuses(server_identity) print(f"Successfully sent request to unban all proxy statuses for '{server_identity}'.") else: - logger.info("No server_identity provided. Unbanning all proxies for ALL servers...") - all_statuses = client.getProxyStatus(None) - if not all_statuses: - print("\nNo proxy statuses found for any server. Nothing to unban.\n") - return - - all_server_identities = sorted(list(set(s.serverIdentity for s in all_statuses))) - logger.info(f"Found {len(all_server_identities)} server identities: {all_server_identities}") - print(f"Found {len(all_server_identities)} server identities. Sending unban request for each...") - - success_count = 0 - fail_count = 0 - for identity in all_server_identities: - try: - client.resetAllProxyStatuses(identity) - logger.info(f" - Sent unban_all for '{identity}'.") - success_count += 1 - except Exception as e: - logger.error(f" - Failed to unban all proxies for '{identity}': {e}") - fail_count += 1 - - print(f"\nSuccessfully sent unban_all requests for {success_count} server identities.") - if fail_count > 0: - print(f"Failed to send unban_all requests for {fail_count} server identities. See logs for details.") + raise ValueError("A 'server_identity' is required for 'unban_all' on proxies.") elif entity == "account": if action == "list_with_status": - _list_account_statuses(client, account_id, params["redis_conn_id"]) + _list_account_statuses(pm, account_id) + elif action == "create_profiles": + # This action is handled by a separate PythonOperator + pass elif action == "ban": - if not account_id: raise ValueError("An 'account_id' is required.") - reason = f"Manual ban from Airflow mgmt DAG by {socket.gethostname()}" - logger.info(f"Banning account '{account_id}'...") - client.banAccount(accountId=account_id, reason=reason) - print(f"Successfully sent request to ban account '{account_id}'.") - elif action == "unban": - if not account_id: raise ValueError("An 'account_id' is required.") - reason = f"Manual un-ban from Airflow mgmt DAG by {socket.gethostname()}" - logger.info(f"Unbanning account '{account_id}'...") - - # Fetch status to get current success count before unbanning - statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None) - if not statuses: - raise AirflowException(f"Account '{account_id}' not found.") - current_success_count = statuses[0].successCount or 0 - - client.unbanAccount(accountId=account_id, reason=reason) - print(f"Successfully sent request to unban account '{account_id}'.") - - # Set the success_count_at_activation to baseline the account - redis_client = _get_redis_client(params["redis_conn_id"]) - redis_client.hset(f"account_status:{account_id}", "success_count_at_activation", current_success_count) - logger.info(f"Set 'success_count_at_activation' for '{account_id}' to {current_success_count}.") - elif action == "unban_all": - account_prefix = account_id # Repurpose account_id param as an optional prefix - logger.info(f"Unbanning all account statuses to ACTIVE (prefix: '{account_prefix or 'ALL'}')...") + logger.info(f"Banning profile '{account_id}' in env '{redis_env}'...") + pm.update_profile_state(account_id, "BANNED", f"Manual ban from Airflow mgmt DAG") + print(f"Successfully set state of profile '{account_id}' to BANNED.") + elif action == "unban" or action == "activate": + logger.info(f"Activating profile '{account_id}' in env '{redis_env}'...") + pm.update_profile_state(account_id, "ACTIVE", f"Manual activation from Airflow mgmt DAG") + print(f"Successfully set state of profile '{account_id}' to ACTIVE.") + elif action == "pause": + logger.info(f"Pausing (resting) profile '{account_id}' in env '{redis_env}'...") + pm.update_profile_state(account_id, "RESTING", f"Manual pause from Airflow mgmt DAG") + print(f"Successfully set state of profile '{account_id}' to RESTING.") + elif action == "delete": + logger.info(f"Deleting profile '{account_id}' in env '{redis_env}'...") + pm.delete_profile(account_id) + print(f"Successfully deleted profile '{account_id}'.") + elif action == "delete_all": + logger.warning(f"DESTRUCTIVE: Deleting all profiles with prefix '{account_id}' in env '{redis_env}'...") + profiles = pm.list_profiles() + deleted_count = 0 + for p in profiles: + if not account_id or p['name'].startswith(account_id): + pm.delete_profile(p['name']) + deleted_count += 1 + print(f"Successfully deleted {deleted_count} profile(s).") - all_statuses = client.getAccountStatus(accountId=None, accountPrefix=account_prefix) - if not all_statuses: - print(f"No accounts found with prefix '{account_prefix or 'ALL'}' to unban.") - return - - accounts_to_unban = [s.accountId for s in all_statuses] - account_map = {s.accountId: s for s in all_statuses} - redis_client = _get_redis_client(params["redis_conn_id"]) - - logger.info(f"Found {len(accounts_to_unban)} accounts to unban.") - print(f"Found {len(accounts_to_unban)} accounts. Sending unban request for each...") - - unban_count = 0 - fail_count = 0 - for acc_id in accounts_to_unban: - try: - reason = f"Manual unban_all from Airflow mgmt DAG by {socket.gethostname()}" - client.unbanAccount(accountId=acc_id, reason=reason) - logger.info(f" - Sent unban for '{acc_id}'.") - - # Also set the success_count_at_activation to baseline the account - current_success_count = account_map[acc_id].successCount or 0 - redis_client.hset(f"account_status:{acc_id}", "success_count_at_activation", current_success_count) - logger.info(f" - Set 'success_count_at_activation' for '{acc_id}' to {current_success_count}.") - - unban_count += 1 - except Exception as e: - logger.error(f" - Failed to unban account '{acc_id}': {e}") - fail_count += 1 - - print(f"\nSuccessfully sent unban requests for {unban_count} accounts.") - if fail_count > 0: - print(f"Failed to send unban requests for {fail_count} accounts. See logs for details.") - - # Optionally, list statuses again to confirm - print("\n--- Listing statuses after unban_all ---") - _list_account_statuses(client, account_prefix, params["redis_conn_id"]) - elif entity == "accounts_and_proxies": + logger.warning("DEPRECATED: Combined 'accounts_and_proxies' actions are no longer supported in v2. Please manage accounts and proxies separately.") if action == "list_with_status": - print("\n--- Listing statuses for Proxies, Accounts, and Clients ---") + print("\n--- Listing statuses for Proxies, V2 Profiles, and Clients ---") _list_proxy_statuses(client, server_identity) - _list_account_statuses(client, account_id, params["redis_conn_id"]) + _list_account_statuses(pm, account_id) _list_client_statuses(params["redis_conn_id"]) - return # End execution for list_with_status - - print(f"\n--- Performing action '{action}' on BOTH Proxies and Accounts ---") - - # --- Proxy Action --- - try: - print("\n-- Running Proxy Action --") - if action == "list_with_status": - _list_proxy_statuses(client, server_identity) - elif action == "ban": - if not proxy_url: raise ValueError("A 'proxy_url' is required.") - logger.info(f"Banning proxy '{proxy_url}' for server '{server_identity}'...") - client.banProxy(proxy_url, server_identity) - print(f"Successfully sent request to ban proxy '{proxy_url}'.") - elif action == "unban": - if not proxy_url: raise ValueError("A 'proxy_url' is required.") - logger.info(f"Unbanning proxy '{proxy_url}' for server '{server_identity}'...") - client.unbanProxy(proxy_url, server_identity) - print(f"Successfully sent request to unban proxy '{proxy_url}'.") - elif action == "ban_all": - if server_identity: - logger.info(f"Banning all proxies for server '{server_identity}'...") - client.banAllProxies(server_identity) - print(f"Successfully sent request to ban all proxies for '{server_identity}'.") - else: - logger.info("No server_identity provided. Banning all proxies for ALL servers...") - all_statuses = client.getProxyStatus(None) - if not all_statuses: - print("\nNo proxy statuses found for any server. Nothing to ban.\n") - else: - all_server_identities = sorted(list(set(s.serverIdentity for s in all_statuses))) - logger.info(f"Found {len(all_server_identities)} server identities: {all_server_identities}") - print(f"Found {len(all_server_identities)} server identities. Sending ban request for each...") - - success_count = 0 - fail_count = 0 - for identity in all_server_identities: - try: - client.banAllProxies(identity) - logger.info(f" - Sent ban_all for '{identity}'.") - success_count += 1 - except Exception as e: - logger.error(f" - Failed to ban all proxies for '{identity}': {e}") - fail_count += 1 - - print(f"\nSuccessfully sent ban_all requests for {success_count} server identities.") - if fail_count > 0: - print(f"Failed to send ban_all requests for {fail_count} server identities. See logs for details.") - elif action == "unban_all": - if server_identity: - logger.info(f"Unbanning all proxy statuses for server '{server_identity}'...") - client.resetAllProxyStatuses(server_identity) - print(f"Successfully sent request to unban all proxy statuses for '{server_identity}'.") - else: - logger.info("No server_identity provided. Unbanning all proxies for ALL servers...") - all_statuses = client.getProxyStatus(None) - if not all_statuses: - print("\nNo proxy statuses found for any server. Nothing to unban.\n") - else: - all_server_identities = sorted(list(set(s.serverIdentity for s in all_statuses))) - logger.info(f"Found {len(all_server_identities)} server identities: {all_server_identities}") - print(f"Found {len(all_server_identities)} server identities. Sending unban request for each...") - - success_count = 0 - fail_count = 0 - for identity in all_server_identities: - try: - client.resetAllProxyStatuses(identity) - logger.info(f" - Sent unban_all for '{identity}'.") - success_count += 1 - except Exception as e: - logger.error(f" - Failed to unban all proxies for '{identity}': {e}") - fail_count += 1 - - print(f"\nSuccessfully sent unban_all requests for {success_count} server identities.") - if fail_count > 0: - print(f"Failed to send unban_all requests for {fail_count} server identities. See logs for details.") - except Exception as proxy_e: - logger.error(f"Error during proxy action '{action}': {proxy_e}", exc_info=True) - print(f"\nERROR during proxy action: {proxy_e}") - - # --- Account Action --- - try: - print("\n-- Running Account Action --") - if action == "list_with_status": - _list_account_statuses(client, account_id, params["redis_conn_id"]) - elif action == "ban": - if not account_id: raise ValueError("An 'account_id' is required.") - reason = f"Manual ban from Airflow mgmt DAG by {socket.gethostname()}" - logger.info(f"Banning account '{account_id}'...") - client.banAccount(accountId=account_id, reason=reason) - print(f"Successfully sent request to ban account '{account_id}'.") - elif action == "unban": - if not account_id: raise ValueError("An 'account_id' is required.") - reason = f"Manual un-ban from Airflow mgmt DAG by {socket.gethostname()}" - logger.info(f"Unbanning account '{account_id}'...") - - # Fetch status to get current success count before unbanning - statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None) - if not statuses: - logger.warning(f"Account '{account_id}' not found. Skipping account unban.") - else: - current_success_count = statuses[0].successCount or 0 - client.unbanAccount(accountId=account_id, reason=reason) - print(f"Successfully sent request to unban account '{account_id}'.") - - # Set the success_count_at_activation to baseline the account - redis_client = _get_redis_client(params["redis_conn_id"]) - redis_client.hset(f"account_status:{account_id}", "success_count_at_activation", current_success_count) - logger.info(f"Set 'success_count_at_activation' for '{account_id}' to {current_success_count}.") - elif action == "unban_all": - account_prefix = account_id # Repurpose account_id param as an optional prefix - logger.info(f"Unbanning all account statuses to ACTIVE (prefix: '{account_prefix or 'ALL'}')...") - - all_statuses = client.getAccountStatus(accountId=None, accountPrefix=account_prefix) - if not all_statuses: - print(f"No accounts found with prefix '{account_prefix or 'ALL'}' to unban.") - else: - accounts_to_unban = [s.accountId for s in all_statuses] - account_map = {s.accountId: s for s in all_statuses} - redis_client = _get_redis_client(params["redis_conn_id"]) - - logger.info(f"Found {len(accounts_to_unban)} accounts to unban.") - print(f"Found {len(accounts_to_unban)} accounts. Sending unban request for each...") - - unban_count = 0 - fail_count = 0 - for acc_id in accounts_to_unban: - try: - reason = f"Manual unban_all from Airflow mgmt DAG by {socket.gethostname()}" - client.unbanAccount(accountId=acc_id, reason=reason) - logger.info(f" - Sent unban for '{acc_id}'.") - - # Also set the success_count_at_activation to baseline the account - current_success_count = account_map[acc_id].successCount or 0 - redis_client.hset(f"account_status:{acc_id}", "success_count_at_activation", current_success_count) - logger.info(f" - Set 'success_count_at_activation' for '{acc_id}' to {current_success_count}.") - - unban_count += 1 - except Exception as e: - logger.error(f" - Failed to unban account '{acc_id}': {e}") - fail_count += 1 - - print(f"\nSuccessfully sent unban requests for {unban_count} accounts.") - if fail_count > 0: - print(f"Failed to send unban requests for {fail_count} accounts. See logs for details.") - - # Optionally, list statuses again to confirm - print("\n--- Listing statuses after unban_all ---") - _list_account_statuses(client, account_prefix, params["redis_conn_id"]) - except Exception as account_e: - logger.error(f"Error during account action '{action}': {account_e}", exc_info=True) - print(f"\nERROR during account action: {account_e}") - - elif entity == "all": - if action == "list_with_status": - print("\nListing all entities...") - _list_proxy_statuses(client, server_identity) - _list_account_statuses(client, account_id, params["redis_conn_id"]) + return except (PBServiceException, PBUserException) as e: logger.error(f"Thrift error performing action '{action}': {e.message}", exc_info=True) @@ -800,91 +559,120 @@ with DAG( catchup=False, tags=["ytdlp", "mgmt", "master"], doc_md=""" - ### YT-DLP Proxy and Account Manager DAG - This DAG provides tools to manage the state of proxies and accounts used by the `ytdlp-ops-server`. + ### YT-DLP v2 Profile and System Manager + + This DAG provides tools to manage the state of **v2 profiles** (formerly accounts) and other system components. Select an `entity` and an `action` to perform. - - **IMPORTANT NOTE ABOUT DATA SOURCES:** - - **Proxy Statuses**: Read from the server's internal state via Thrift service calls. - - **Account Statuses**: Read from the Thrift service, and then enriched with live cooldown data directly from Redis. - - **IMPORTANT NOTE ABOUT PROXY MANAGEMENT:** - - Proxies are managed by the server's internal state through Thrift methods - - There is NO direct Redis manipulation for proxies - they are managed entirely by the server - - To properly manage proxies, use the Thrift service methods (ban, unban, etc.) + + **V2 Profile Management (`entity: account`):** + - All account/profile actions are now performed directly on Redis using the `ProfileManager`. + - A `redis_env` (e.g., `sim_auth` or `sim_download`) is **required** to target the correct set of profiles. + - Actions include `list`, `create`, `ban`, `activate`, `pause`, and `delete`. + + **Legacy Proxy Management (`entity: proxy`):** + - **DEPRECATED**: Proxy state is now managed automatically by the standalone `policy-enforcer` service. + - These actions are provided for legacy support and interact with the old Thrift service. They may be removed in the future. """, params={ - "management_host": Param(DEFAULT_MANAGEMENT_SERVICE_IP, type="string", title="Management Service Host", description="The hostname or IP of the management service. Can be a Docker container name (e.g., 'envoy-thrift-lb') if on the same network."), - "management_port": Param(DEFAULT_MANAGEMENT_SERVICE_PORT, type="integer", title="Management Service Port", description="The port of the dedicated management service."), + "management_host": Param(DEFAULT_MANAGEMENT_SERVICE_IP, type="string", title="Management Service Host (DEPRECATED)", description="The hostname or IP of the management service. Used only for legacy proxy actions."), + "management_port": Param(DEFAULT_MANAGEMENT_SERVICE_PORT, type="integer", title="Management Service Port (DEPRECATED)", description="The port of the dedicated management service."), "entity": Param( - "accounts_and_proxies", + "account", type="string", - enum=["account", "proxy", "client", "accounts_and_proxies", "activity_counters"], + enum=["account", "proxy", "client", "activity_counters", "accounts_and_proxies"], description="The type of entity to manage.", ), "action": Param( "list_with_status", type="string", - enum=["list_with_status", "ban", "unban", "ban_all", "unban_all", "delete_from_redis"], + enum=["list_with_status", "create_profiles", "ban", "unban", "activate", "pause", "delete", "delete_all", "ban_all", "unban_all", "delete_from_redis"], description="""The management action to perform. --- - #### Actions for `entity: proxy` - - `list_with_status`: View status of all proxies, optionally filtered by `server_identity`. - - `ban`: Ban a specific proxy for a given `server_identity`. Requires `proxy_url`. - - `unban`: Un-ban a specific proxy. Requires `proxy_url`. - - `ban_all`: Sets the status of all proxies for a given `server_identity` (or all servers) to `BANNED`. - - `unban_all`: Resets the status of all proxies for a given `server_identity` (or all servers) to `ACTIVE`. - - `delete_from_redis`: **(Destructive)** Deletes proxy status from Redis via Thrift service. This permanently removes the proxy from being tracked by the system. If `proxy_url` and `server_identity` are provided, it deletes a single proxy. If only `server_identity` is provided, it deletes all proxies for that server. If neither is provided, it deletes ALL proxies across all servers. + #### Actions for `entity: account` (V2 Profiles) + - `list_with_status`: View status of all profiles, optionally filtered by `account_id` as a prefix. + - `create_profiles`: Creates new profiles from a JSON payload. See `create_profiles_json` param. + - `ban`: Sets a profile's state to BANNED. Requires `account_id`. + - `unban`/`activate`: Sets a profile's state to ACTIVE. Requires `account_id`. + - `pause`: Sets a profile's state to RESTING. Requires `account_id`. + - `delete`: Deletes a single profile. Requires `account_id`. + - `delete_all`: **(Destructive)** Deletes all profiles, or those matching the `account_id` as a prefix. - #### Actions for `entity: account` - - `list_with_status`: View status of all accounts, optionally filtered by `account_id` (as a prefix). - - `ban`: Ban a specific account. Requires `account_id`. - - `unban`: Un-ban a specific account. Requires `account_id`. - - `unban_all`: Sets the status of all accounts (or those matching a prefix in `account_id`) to `ACTIVE`. - - `delete_from_redis`: **(Destructive)** Deletes account status from Redis via Thrift service. This permanently removes the account from being tracked by the system. If `account_id` is provided, it deletes that specific account. If `account_id` is provided as a prefix, it deletes all accounts matching that prefix. If `account_id` is empty, it deletes ALL accounts. + #### Actions for `entity: proxy` (DEPRECATED) + - `list_with_status`, `ban`, `unban`, `ban_all`, `unban_all`, `delete_from_redis`. #### Actions for `entity: client` - `list_with_status`: View success/failure statistics for each client type. - `delete_from_redis`: **(Destructive)** Deletes all client stats from Redis. #### Actions for `entity: activity_counters` - - `list_with_status`: View current activity rates (ops/min, ops/hr) for proxies and accounts. - - #### Actions for `entity: accounts_and_proxies` - - This entity performs the selected action on **both** proxies and accounts where applicable. - - `list_with_status`: View statuses for both proxies and accounts. - - `ban`: Ban a specific proxy AND a specific account. Requires `proxy_url`, `server_identity`, and `account_id`. - - `unban`: Un-ban a specific proxy AND a specific account. Requires `proxy_url`, `server_identity`, and `account_id`. - - `ban_all`: Ban all proxies for a `server_identity` (or all servers). Does not affect accounts. - - `unban_all`: Un-ban all proxies for a `server_identity` (or all servers) AND all accounts (optionally filtered by `account_id` as a prefix). - - `delete_from_redis`: Deletes both account and proxy status from Redis via Thrift service. For accounts, if `account_id` is provided as a prefix, it deletes all accounts matching that prefix. If `account_id` is empty, it deletes ALL accounts. For proxies, if `server_identity` is provided, it deletes all proxies for that server. If `server_identity` is empty, it deletes ALL proxies across all servers. - + - `list_with_status`: View current activity rates for proxies and accounts. """, ), - "server_identity": Param( - None, - type=["null", "string"], - description="The identity of the server instance (for proxy management). Leave blank to list all or delete all proxies.", - ), - "proxy_url": Param( - None, - type=["null", "string"], - description="The proxy URL to act upon (e.g., 'socks5://host:port').", + "redis_env": Param( + "sim_auth", + type="string", + enum=["sim_auth", "sim_download"], + title="[V2 Profiles] Redis Environment", + description="The environment for v2 profile management (e.g., 'sim_auth'). Determines the Redis key prefix.", ), "account_id": Param( None, type=["null", "string"], - description="The account ID to act upon. For `unban_all` or `delete_from_redis` on accounts, this can be an optional prefix. Leave blank to delete all accounts.", + description="For v2 profiles: The profile name (e.g., 'auth_user_0') or a prefix for `list` and `delete_all`.", + ), + "create_profiles_json": Param( + """{ + "auth_profile_setup": { + "env": "sim_auth", + "cleanup_before_run": false, + "pools": [ + { + "prefix": "auth_user", + "proxy": "sslocal-rust-1090:1090", + "count": 2 + } + ] + } +}""", + type="string", + title="[V2 Profiles] Create Profiles JSON", + description="For action `create_profiles`. A JSON payload defining the profiles to create. This is passed to `yt-ops-client setup-profiles`.", + **{'ui_widget': 'json', 'multi_line': True} + ), + "server_identity": Param( + None, + type=["null", "string"], + description="[DEPRECATED] The server identity for proxy management.", + ), + "proxy_url": Param( + None, + type=["null", "string"], + description="[DEPRECATED] The proxy URL to act upon.", ), "redis_conn_id": Param( DEFAULT_REDIS_CONN_ID, type="string", title="Redis Connection ID", - description="The Airflow connection ID for the Redis server (used for 'delete_from_redis' and for fetching detailed account status).", + description="The Airflow connection ID for the Redis server.", ), }, ) as dag: + + @task.branch(task_id="branch_on_action") + def branch_on_action(**context): + action = context["params"]["action"] + if action == "create_profiles": + return "create_profiles_task" + return "system_management_task" + + create_profiles_task = PythonOperator( + task_id="create_profiles_task", + python_callable=_create_profiles_from_json, + ) + system_management_task = PythonOperator( task_id="system_management_task", python_callable=manage_system_callable, ) + + branch_on_action() >> [create_profiles_task, system_management_task] diff --git a/airflow/dags/ytdlp_mgmt_queues.py b/airflow/dags/ytdlp_mgmt_queues.py index eaedfdb..5c979c5 100644 --- a/airflow/dags/ytdlp_mgmt_queues.py +++ b/airflow/dags/ytdlp_mgmt_queues.py @@ -322,7 +322,14 @@ def clear_queue_callable(**context): dump_redis_data_to_csv(redis_client, dump_dir, dump_patterns) all_suffixes = ['_inbox', '_fail', '_result', '_progress', '_skipped'] + special_queues = ['queue_dl_format_tasks'] keys_to_delete = set() + + # Handle special queues first + for q in special_queues: + if q in queues_to_clear_options: + keys_to_delete.add(q) + for queue_base_name in queue_base_names_to_clear: if '_all' in queues_to_clear_options: logger.info(f"'_all' option selected. Clearing all standard queues for base '{queue_base_name}'.") @@ -446,6 +453,7 @@ def check_status_callable(**context): raise ValueError(f"Invalid queue_system: {queue_system}") queue_suffixes = ['_inbox', '_progress', '_result', '_fail', '_skipped'] + special_queues = ['queue_dl_format_tasks'] logger.info(f"--- Checking Status for Queue System: '{queue_system}' ---") @@ -468,6 +476,18 @@ def check_status_callable(**context): else: logger.info(f" - Queue '{queue_to_check}': Does not exist.") + logger.info(f"--- Special Queues ---") + for queue_name in special_queues: + key_type = redis_client.type(queue_name).decode('utf-8') + size = 0 + if key_type == 'list': + size = redis_client.llen(queue_name) + + if key_type != 'none': + logger.info(f" - Queue '{queue_name}': Type='{key_type.upper()}', Size={size}") + else: + logger.info(f" - Queue '{queue_name}': Does not exist.") + logger.info(f"--- End of Status Check ---") except Exception as e: @@ -794,10 +814,10 @@ with DAG( None, type=["null", "array"], title="[clear_queue] Queues to Clear", - description="Select which standard queues to clear. '_all' clears all four. If left empty, it defaults to '_all'.", + description="Select which standard queues to clear. '_all' clears all standard queues. 'queue_dl_format_tasks' is the new granular download task queue.", items={ "type": "string", - "enum": ["_inbox", "_fail", "_result", "_progress", "_skipped", "_all"], + "enum": ["_inbox", "_fail", "_result", "_progress", "_skipped", "_all", "queue_dl_format_tasks"], } ), "confirm_clear": Param( @@ -826,7 +846,7 @@ with DAG( ), # --- Params for 'list_contents' --- "queue_to_list": Param( - 'video_queue_inbox,queue2_auth_inbox,queue2_dl_inbox,queue2_dl_result', + 'queue2_auth_inbox,queue_dl_format_tasks,queue2_dl_inbox', type="string", title="[list_contents] Queues to List", description="Comma-separated list of exact Redis key names to list.", diff --git a/airflow/dags/ytdlp_ops_account_maintenance.py b/airflow/dags/ytdlp_ops_account_maintenance.py index 0ae7b52..bf3054b 100644 --- a/airflow/dags/ytdlp_ops_account_maintenance.py +++ b/airflow/dags/ytdlp_ops_account_maintenance.py @@ -4,255 +4,44 @@ # # Distributed under terms of the MIT license. +# -*- coding: utf-8 -*- +# +# Copyright © 2024 rl +# +# Distributed under terms of the MIT license. + """ -Maintenance DAG for managing the lifecycle of ytdlp-ops accounts. -This DAG is responsible for: -- Un-banning accounts whose ban duration has expired. -- Transitioning accounts from RESTING to ACTIVE after their cooldown period. -- Transitioning accounts from ACTIVE to RESTING after their active duration. -This logic was previously handled inside the ytdlp-ops-server and has been -moved here to give the orchestrator full control over account state. +DEPRECATED: Maintenance DAG for managing the lifecycle of ytdlp-ops accounts. """ from __future__ import annotations -import logging -import time -from datetime import datetime, timedelta - -from airflow.decorators import task -from airflow.models import Variable from airflow.models.dag import DAG -from airflow.models.param import Param from airflow.utils.dates import days_ago -# Import utility functions and Thrift modules -from utils.redis_utils import _get_redis_client -from pangramia.yt.management import YTManagementService -from thrift.protocol import TBinaryProtocol -from thrift.transport import TSocket, TTransport - -# Configure logging -logger = logging.getLogger(__name__) - -# Default settings from Airflow Variables or hardcoded fallbacks -DEFAULT_REDIS_CONN_ID = 'redis_default' -DEFAULT_MANAGEMENT_SERVICE_IP = Variable.get("MANAGEMENT_SERVICE_HOST", default_var="172.17.0.1") -DEFAULT_MANAGEMENT_SERVICE_PORT = Variable.get("MANAGEMENT_SERVICE_PORT", default_var=9080) - DEFAULT_ARGS = { 'owner': 'airflow', - 'retries': 1, - 'retry_delay': 30, + 'retries': 0, 'queue': 'queue-mgmt', } - -# --- Helper Functions --- - -def _get_thrift_client(host, port, timeout=60): - """Helper to create and connect a Thrift client.""" - transport = TSocket.TSocket(host, port) - transport.setTimeout(timeout * 1000) - transport = TTransport.TFramedTransport(transport) - protocol = TBinaryProtocol.TBinaryProtocol(transport) - client = YTManagementService.Client(protocol) - transport.open() - logger.info(f"Connected to Thrift server at {host}:{port}") - return client, transport - - -@task -def manage_account_states(**context): - """ - Fetches all account statuses and performs necessary state transitions - based on time durations configured in the DAG parameters. - """ - params = context['params'] - requests_limit = params['account_requests_limit'] - cooldown_duration_s = params['account_cooldown_duration_min'] * 60 - ban_duration_s = params['account_ban_duration_hours'] * 3600 - - host = DEFAULT_MANAGEMENT_SERVICE_IP - port = int(DEFAULT_MANAGEMENT_SERVICE_PORT) - redis_conn_id = DEFAULT_REDIS_CONN_ID - logger.info(f"Starting account maintenance. Service: {host}:{port}, Redis: {redis_conn_id}") - logger.info(f"Using limits: Requests={requests_limit}, Cooldown={params['account_cooldown_duration_min']}m, Ban={params['account_ban_duration_hours']}h") - - client, transport = None, None - try: - client, transport = _get_thrift_client(host, port) - redis_client = _get_redis_client(redis_conn_id) - - logger.info(f"--- Step 1: Fetching all account statuses from the ytdlp-ops-server at {host}:{port}... ---") - all_accounts = client.getAccountStatus(accountId=None, accountPrefix=None) - logger.info(f"Found {len(all_accounts)} total accounts to process.") - - accounts_to_unban = [] - accounts_to_activate = [] - accounts_to_rest = [] - - now_ts = int(time.time()) - - for acc in all_accounts: - # Thrift can return 0 for unset integer fields. - # The AccountStatus thrift object is missing status_changed_timestamp and active_since_timestamp. - # We use available timestamps as proxies. - last_failure_ts = int(acc.lastFailureTimestamp or 0) - last_success_ts = int(acc.lastSuccessTimestamp or 0) - last_usage_ts = max(last_failure_ts, last_success_ts) - - if acc.status == "BANNED" and last_failure_ts > 0: - time_since_ban = now_ts - last_failure_ts - if time_since_ban >= ban_duration_s: - accounts_to_unban.append(acc.accountId) - else: - remaining_s = ban_duration_s - time_since_ban - logger.info(f"Account {acc.accountId} is BANNED. Time until unban: {timedelta(seconds=remaining_s)}") - elif acc.status == "RESTING" and last_usage_ts > 0: - time_since_rest = now_ts - last_usage_ts - if time_since_rest >= cooldown_duration_s: - accounts_to_activate.append(acc.accountId) - else: - remaining_s = cooldown_duration_s - time_since_rest - logger.info(f"Account {acc.accountId} is RESTING. Time until active: {timedelta(seconds=remaining_s)}") - elif acc.status == "ACTIVE": - # For ACTIVE -> RESTING, check how many requests have been made since activation. - count_at_activation_raw = redis_client.hget(f"account_status:{acc.accountId}", "success_count_at_activation") - - if count_at_activation_raw is not None: - count_at_activation = int(count_at_activation_raw) - current_success_count = acc.successCount or 0 - requests_made = current_success_count - count_at_activation - - if requests_made >= requests_limit: - logger.info(f"Account {acc.accountId} reached request limit ({requests_made}/{requests_limit}). Moving to RESTING.") - accounts_to_rest.append(acc.accountId) - else: - requests_remaining = requests_limit - requests_made - logger.info(f"Account {acc.accountId} is ACTIVE. Requests until rest: {requests_remaining}/{requests_limit}") - else: - # This is a fallback for accounts that were activated before this logic was deployed. - # We can activate them "fresh" by setting their baseline count now. - logger.info(f"Account {acc.accountId} is ACTIVE but has no 'success_count_at_activation'. Setting it now.") - redis_client.hset(f"account_status:{acc.accountId}", "success_count_at_activation", acc.successCount or 0) - - logger.info("--- Step 2: Analyzing accounts for state transitions ---") - logger.info(f"Found {len(accounts_to_unban)} accounts with expired bans to un-ban.") - logger.info(f"Found {len(accounts_to_activate)} accounts with expired rest periods to activate.") - logger.info(f"Found {len(accounts_to_rest)} accounts with expired active periods to put to rest.") - - # --- Perform State Transitions --- - - # 1. Un-ban accounts via Thrift call - logger.info("--- Step 3: Processing un-bans ---") - if accounts_to_unban: - logger.info(f"Un-banning {len(accounts_to_unban)} accounts: {accounts_to_unban}") - account_map = {acc.accountId: acc for acc in all_accounts} - for acc_id in accounts_to_unban: - try: - client.unbanAccount(acc_id, "Automatic un-ban by Airflow maintenance DAG.") - logger.info(f"Successfully un-banned account '{acc_id}'.") - - # Set the activation count to baseline the account immediately after un-banning. - key = f"account_status:{acc_id}" - current_success_count = account_map[acc_id].successCount or 0 - redis_client.hset(key, "success_count_at_activation", current_success_count) - logger.info(f"Set 'success_count_at_activation' for un-banned account '{acc_id}' to {current_success_count}.") - except Exception as e: - logger.error(f"Failed to un-ban account '{acc_id}': {e}") - else: - logger.info("No accounts to un-ban.") - - # 2. Activate resting accounts via direct Redis write - logger.info("--- Step 4: Processing activations ---") - if accounts_to_activate: - logger.info(f"Activating {len(accounts_to_activate)} accounts: {accounts_to_activate}") - now_ts = int(time.time()) - account_map = {acc.accountId: acc for acc in all_accounts} - with redis_client.pipeline() as pipe: - for acc_id in accounts_to_activate: - key = f"account_status:{acc_id}" - current_success_count = account_map[acc_id].successCount or 0 - pipe.hset(key, "status", "ACTIVE") - pipe.hset(key, "active_since_timestamp", now_ts) - pipe.hset(key, "status_changed_timestamp", now_ts) - pipe.hset(key, "success_count_at_activation", current_success_count) - pipe.execute() - logger.info("Finished activating accounts.") - else: - logger.info("No accounts to activate.") - - # 3. Rest active accounts via direct Redis write - logger.info("--- Step 5: Processing rests ---") - if accounts_to_rest: - logger.info(f"Putting {len(accounts_to_rest)} accounts to rest: {accounts_to_rest}") - now_ts = int(time.time()) - with redis_client.pipeline() as pipe: - for acc_id in accounts_to_rest: - key = f"account_status:{acc_id}" - pipe.hset(key, "status", "RESTING") - pipe.hset(key, "status_changed_timestamp", now_ts) - pipe.hdel(key, "success_count_at_activation") - pipe.execute() - logger.info("Finished putting accounts to rest.") - else: - logger.info("No accounts to put to rest.") - - logger.info("--- Account maintenance run complete. ---") - - finally: - if transport and transport.isOpen(): - transport.close() - - with DAG( dag_id='ytdlp_ops_account_maintenance', default_args=DEFAULT_ARGS, - schedule='*/5 * * * *', # Run every 5 minutes + schedule=None, # Disabled start_date=days_ago(1), catchup=False, - tags=['ytdlp', 'maintenance'], + is_paused_upon_creation=True, + tags=['ytdlp', 'maintenance', 'deprecated'], doc_md=""" - ### YT-DLP Account Maintenance: Time-Based State Transitions + ### DEPRECATED: YT-DLP Account Maintenance - This DAG is the central authority for automated, **time-based** state management for ytdlp-ops accounts. - It runs periodically to fetch the status of all accounts and applies its own logic to determine if an account's state should change based on configurable time durations. + This DAG is **DEPRECATED** and should not be used. Its functionality has been replaced + by a standalone, continuously running `policy-enforcer` service. - The thresholds are defined as DAG parameters and can be configured via the Airflow UI: - - **Requests Limit**: How many successful requests an account can perform before it needs to rest. - - **Cooldown Duration**: How long an account must rest before it can be used again. - - **Ban Duration**: How long a ban lasts before the account is automatically un-banned. + To run the new enforcer, use the following command on a management node: + `bin/ytops-client policy-enforcer --policy policies/8_unified_simulation_enforcer.yaml --live` - --- - - #### Separation of Concerns: Time vs. Errors - - It is critical to understand that this DAG primarily handles time-based state changes. Error-based banning may be handled by worker DAGs during URL processing. This separation ensures that maintenance is predictable and based on timers, while acute, error-driven actions are handled immediately by the workers that encounter them. - - --- - - #### State Transitions Performed by This DAG: - - On each run, this DAG fetches the raw status and timestamps for all accounts and performs the following checks: - - 1. **Un-banning (`BANNED` -> `ACTIVE`)**: - - **Condition**: An account has been in the `BANNED` state for longer than the configured `account_ban_duration_hours`. - - **Action**: The DAG calls the `unbanAccount` service endpoint to lift the ban. - - 2. **Activation (`RESTING` -> `ACTIVE`)**: - - **Condition**: An account has been in the `RESTING` state for longer than the configured `account_cooldown_duration_min`. - - **Action**: The DAG updates the account's status to `ACTIVE` directly in Redis. - - 3. **Resting (`ACTIVE` -> `RESTING`)**: - - **Condition**: An account has performed more successful requests than the configured `account_requests_limit` since it was last activated. - - **Action**: The DAG updates the account's status to `RESTING` directly in Redis. - - This process gives full control over time-based account lifecycle management to the Airflow orchestrator. + This DAG is paused by default and will be removed in a future version. """, - params={ - 'account_requests_limit': Param(250, type="integer", description="Number of successful requests an account can make before it is rested. Default is 250."), - 'account_cooldown_duration_min': Param(60, type="integer", description="Duration in minutes an account must rest ('pause') before being activated again. Default is 60 minutes (1 hour)."), - 'account_ban_duration_hours': Param(24, type="integer", description="Duration in hours an account stays banned before it can be un-banned."), - } ) as dag: - manage_account_states() + pass diff --git a/airflow/dags/ytdlp_ops_v01_orchestrator.py b/airflow/dags/ytdlp_ops_v01_orchestrator.py index 6d0dc64..9602d85 100644 --- a/airflow/dags/ytdlp_ops_v01_orchestrator.py +++ b/airflow/dags/ytdlp_ops_v01_orchestrator.py @@ -48,6 +48,65 @@ DEFAULT_BUNCH_DELAY_S = 1 DEFAULT_YT_AUTH_SERVICE_IP = Variable.get("YT_AUTH_SERVICE_IP", default_var="172.17.0.1") DEFAULT_YT_AUTH_SERVICE_PORT = Variable.get("YT_AUTH_SERVICE_PORT", default_var=9080) +# Default ytdlp.json content for the unified config parameter +DEFAULT_YTDLP_CONFIG = { + "ytops": { + "force_renew": [], + "session_params": { + # "visitor_rotation_threshold": 250 + } + }, + "ytdlp_params": { + "debug_printtraffic": True, + "write_pages": True, + "verbose": True, + "no_color": True, + "ignoreerrors": True, + "noresizebuffer": True, + "buffersize": "4M", + "concurrent_fragments": 8, + "socket_timeout": 60, + "outtmpl": { + "default": "%(id)s.f%(format_id)s.%(ext)s" + }, + "restrictfilenames": True, + "updatetime": False, + "noplaylist": True, + "match_filter": "!is_live", + "writeinfojson": True, + "skip_download": True, + "allow_playlist_files": False, + "clean_infojson": True, + "getcomments": False, + "writesubtitles": False, + "writethumbnail": False, + "sleep_interval_requests": 0.75, + "parse_metadata": [ + ":(?P)" + ], + "extractor_args": { + "youtube": { + "player_client": ["tv_simply"], + "formats": ["duplicate"], + "jsc_trace": ["true"], + "pot_trace": ["true"], + "skip": ["translated_subs", "hls"] + }, + "youtubepot-bgutilhttp": { + "base_url": ["http://172.17.0.1:4416"] + } + }, + "noprogress": True, + "format_sort": [ + "res", + "ext:mp4:m4a" + ], + "remuxvideo": "mp4", + "nooverwrites": True, + "continuedl": True + } +} + # --- Helper Functions --- def _check_application_queue(redis_client, queue_base_name: str) -> int: @@ -159,26 +218,43 @@ def orchestrate_workers_ignition_callable(**context): # --- Generate a consistent timestamped prefix for this orchestrator run --- # This ensures all workers spawned from this run use the same set of accounts. final_account_pool_prefix = params['account_pool'] + + # --- Unified JSON Config Handling --- + # Start with the JSON config from params, then merge legacy params into it. + try: + ytdlp_config = json.loads(params.get('ytdlp_config_json', '{}')) + except json.JSONDecodeError as e: + logger.error(f"Invalid ytdlp_config_json parameter. Must be valid JSON. Error: {e}") + raise AirflowException("Invalid ytdlp_config_json parameter.") + if params.get('prepend_client_to_account') and params.get('account_pool_size') is not None: - clients_str = params.get('clients', '') + try: + clients_str = ','.join(ytdlp_config['ytdlp_params']['extractor_args']['youtube']['player_client']) + except KeyError: + clients_str = '' + primary_client = clients_str.split(',')[0].strip() if clients_str else 'unknown' - # Use a timestamp from the orchestrator's run for consistency timestamp = datetime.now().strftime('%Y%m%d%H%M%S') final_account_pool_prefix = f"{params['account_pool']}_{timestamp}_{primary_client}" logger.info(f"Generated consistent account prefix for this run: '{final_account_pool_prefix}'") + final_ytdlp_config_str = json.dumps(ytdlp_config) + # --- End of JSON Config Handling --- + for i, bunch in enumerate(bunches): logger.info(f"--- Triggering Bunch {i+1}/{len(bunches)} (contains {len(bunch)} dispatcher(s)) ---") - for j, _ in enumerate(bunch): + for j, worker_index in enumerate(bunch): # Create a unique run_id for each dispatcher run run_id = f"dispatched_{dag_run_id}_{total_triggered}" # Pass all orchestrator params to the dispatcher, which will then pass them to the worker. conf_to_pass = {p: params[p] for p in params} - # Override account_pool with the generated prefix + # Override account_pool with the generated prefix and set the unified JSON config conf_to_pass['account_pool'] = final_account_pool_prefix + conf_to_pass['worker_index'] = worker_index + conf_to_pass['ytdlp_config_json'] = final_ytdlp_config_str - logger.info(f"Triggering dispatcher {j+1}/{len(bunch)} in bunch {i+1} (run {total_triggered + 1}/{total_workers}) (Run ID: {run_id})") + logger.info(f"Triggering dispatcher {j+1}/{len(bunch)} in bunch {i+1} (run {total_triggered + 1}/{total_workers}, worker_index: {worker_index}) (Run ID: {run_id})") logger.debug(f"Full conf for dispatcher run {run_id}: {conf_to_pass}") trigger_dag( @@ -299,73 +375,22 @@ with DAG( 'delay_between_bunches_s': Param(DEFAULT_BUNCH_DELAY_S, type="integer", description="Delay in seconds between starting each bunch."), 'skip_if_queue_empty': Param(False, type="boolean", title="[Ignition Control] Skip if Queue Empty", description="If True, the orchestrator will not start any dispatchers if the application's work queue is empty."), + # --- Unified Worker Configuration --- + 'ytdlp_config_json': Param( + json.dumps(DEFAULT_YTDLP_CONFIG, indent=2), + type="string", + title="[Worker Param] Unified yt-dlp JSON Config", + description="A JSON string containing all parameters for both yt-ops-server and the yt-dlp downloaders. This is the primary way to configure workers.", + **{'ui_widget': 'json', 'multi_line': True} + ), + # --- Worker Passthrough Parameters --- - 'on_auth_failure': Param( - 'proceed_loop_under_manual_inspection', - type="string", - enum=['stop_loop', 'retry_with_new_account', 'retry_without_ban', 'proceed_loop_under_manual_inspection'], - title="[Worker Param] On Authentication Failure Policy", - description="Policy for a worker when a bannable authentication error occurs. " - "'stop_loop': Ban the account, mark URL as failed, and stop the worker's loop. " - "'retry_with_new_account': (Default) Ban the failed account, retry ONCE with a new account. If retry fails, ban the second account and stop." - "'retry_without_ban': If a connection error (e.g. SOCKS timeout) occurs, retry with a new account but do NOT ban the first account/proxy. If retry fails, stop the loop without banning." - "'proceed_loop_under_manual_inspection': **BEWARE: MANUAL SUPERVISION REQUIRED.** Marks the URL as failed but continues the processing loop. Use this only when you can manually intervene." - ), - 'on_download_failure': Param( - 'proceed_loop', - type="string", - enum=['stop_loop', 'proceed_loop', 'retry_with_new_token'], - title="[Worker Param] On Download Failure Policy", - description="Policy for a worker when a download or probe error occurs. " - "'stop_loop': Mark URL as failed and stop the worker's loop. " - "'proceed_loop': (Default) Mark URL as failed but continue the processing loop with a new URL. " - "'retry_with_new_token': Attempt to get a new token with a new account and retry the download once. If it fails again, proceed loop." - ), - 'request_params_json': Param('{}', type="string", title="[Worker Param] Request Params JSON", description="JSON string with per-request parameters to override server defaults. Can be a full JSON object or comma-separated key=value pairs (e.g., 'session_params.location=DE,ytdlp_params.skip_cache=true')."), - 'language_code': Param('en-US', type="string", title="[Worker Param] Language Code", description="The language code (e.g., 'en-US', 'de-DE') to use for the YouTube request headers."), + # These are used by the orchestrator itself and are also passed to workers. 'queue_name': Param(DEFAULT_QUEUE_NAME, type="string", description="[Worker Param] Base name for Redis queues."), 'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string", description="[Worker Param] Airflow Redis connection ID."), - 'clients': Param( - 'tv_simply', - type="string", - title="[Worker Param] Clients", - description="[Worker Param] Comma-separated list of clients for token generation. Full list: web, web_safari, web_embedded, web_music, web_creator, mweb, web_camoufox, web_safari_camoufox, web_embedded_camoufox, web_music_camoufox, web_creator_camoufox, mweb_camoufox, android, android_music, android_creator, android_vr, ios, ios_music, ios_creator, tv, tv_simply, tv_embedded. See DAG documentation for details." - ), 'account_pool': Param('ytdlp_account', type="string", description="[Worker Param] Account pool prefix or comma-separated list."), 'account_pool_size': Param(10, type=["integer", "null"], description="[Worker Param] If using a prefix for 'account_pool', this specifies the number of accounts to generate (e.g., 10 for 'prefix_01' through 'prefix_10'). Required when using a prefix."), 'prepend_client_to_account': Param(True, type="boolean", title="[Worker Param] Prepend Client to Account", description="If True, prepends client and timestamp to account names in prefix mode. Format: prefix_YYYYMMDDHHMMSS_client_XX."), - 'service_ip': Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string", description="[Worker Param] IP of the ytdlp-ops-server. Default is from Airflow variable YT_AUTH_SERVICE_IP or hardcoded."), - 'service_port': Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer", description="[Worker Param] Port of the Envoy load balancer. Default is from Airflow variable YT_AUTH_SERVICE_PORT or hardcoded."), - 'machine_id': Param("ytdlp-ops-airflow-service", type="string", description="[Worker Param] Identifier for the client machine."), - 'assigned_proxy_url': Param(None, type=["string", "null"], title="[Worker Param] Assigned Proxy URL", description="A specific proxy URL to use for the request, overriding the server's proxy pool logic."), - 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean", description="[Worker Param] If True and all accounts in a prefix-based pool are exhausted, create a new one automatically."), - # --- Download Control Parameters --- - 'delay_between_formats_s': Param(15, type="integer", title="[Worker Param] Delay Between Formats (s)", description="Delay in seconds between downloading each format when multiple formats are specified. A 22s wait may be effective for batch downloads, while 6-12s may suffice if cookies are refreshed regularly."), - 'yt_dlp_test_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Test Mode", description="If True, runs yt-dlp with --test flag (dry run without downloading)."), - 'skip_probe': Param(True, type="boolean", title="[Worker Param] Skip Probe", description="If True, skips the ffmpeg probe of downloaded files."), - 'yt_dlp_cleanup_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Cleanup Mode", description="If True, creates a .empty file and deletes the original media file after successful download and probe."), - 'socket_timeout': Param(15, type="integer", title="[Worker Param] Socket Timeout", description="Timeout in seconds for socket operations."), - 'download_format': Param( - 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', - type="string", - title="[Worker Param] Download Format", - description="Custom yt-dlp format string. Common presets: [1] 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' (Default, best quality MP4). [2] '18-dashy/18,140-dashy/140,133-dashy/134-dashy/136-dashy/137-dashy/250-dashy/298-dashy/299-dashy' (Legacy formats). [3] '299-dashy/298-dashy/250-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy' (High-framerate formats)." - ), - 'downloader': Param( - 'cli', - type="string", - enum=['py', 'aria-rpc', 'cli'], - title="[Worker Param] Download Tool", - description="Choose the download tool to use: 'py' (native python, recommended), 'aria-rpc' (send to aria2c daemon), 'cli' (legacy yt-dlp wrapper)." - ), - 'aria_host': Param('172.17.0.1', type="string", title="[Worker Param] Aria2c Host", description="For 'aria-rpc' downloader: Host of the aria2c RPC server. Can be set via Airflow Variable 'YTDLP_ARIA_HOST'."), - 'aria_port': Param(6800, type="integer", title="[Worker Param] Aria2c Port", description="For 'aria-rpc' downloader: Port of the aria2c RPC server. Can be set via Airflow Variable 'YTDLP_ARIA_PORT'."), - 'aria_secret': Param('SQGCQPLVFQIASMPNPOJYLVGJYLMIDIXDXAIXOTX', type="string", title="[Worker Param] Aria2c Secret", description="For 'aria-rpc' downloader: Secret token. Can be set via Airflow Variable 'YTDLP_ARIA_SECRET'."), - 'yt_dlp_extra_args': Param( - '', - type=["string", "null"], - title="[Worker Param] Extra yt-dlp arguments", - ), } ) as dag: diff --git a/airflow/dags/ytdlp_ops_v01_worker_per_url.py b/airflow/dags/ytdlp_ops_v01_worker_per_url.py index c42b38e..815b340 100644 --- a/airflow/dags/ytdlp_ops_v01_worker_per_url.py +++ b/airflow/dags/ytdlp_ops_v01_worker_per_url.py @@ -215,6 +215,15 @@ def _get_account_pool(params: dict) -> list: # TASK DEFINITIONS (TaskFlow API) # ============================================================================= +def _get_worker_params(params: dict) -> dict: + """Loads and returns the worker_params dict from the unified JSON config.""" + try: + ytdlp_config = json.loads(params.get('ytdlp_config_json', '{}')) + return ytdlp_config.get('ytops', {}).get('worker_params', {}) + except json.JSONDecodeError: + logger.error("Could not parse ytdlp_config_json. Using empty worker_params.") + return {} + @task def get_url_and_assign_account(**context): """ @@ -223,6 +232,15 @@ def get_url_and_assign_account(**context): """ params = context['params'] ti = context['task_instance'] + worker_params = _get_worker_params(params) + + # Log the active policies + auth_policy = worker_params.get('on_auth_failure', 'not_set') + download_policy = worker_params.get('on_download_failure', 'not_set') + logger.info(f"--- Worker Policies ---") + logger.info(f" Auth Failure Policy: {auth_policy}") + logger.info(f" Download Failure Policy: {download_policy}") + logger.info(f"-----------------------") # --- Worker Pinning Verification --- # This is a safeguard against a known Airflow issue where clearing a task @@ -293,9 +311,20 @@ def get_url_and_assign_account(**context): except Exception as e: logger.error(f"Could not mark URL as in-progress in Redis: {e}", exc_info=True) - # Account assignment logic is the same as before. - account_id = random.choice(_get_account_pool(params)) - logger.info(f"Selected account '{account_id}' for this run.") + # Account assignment logic + account_id = params.get('account_id') + if account_id: + logger.info(f"Using sticky account '{account_id}' passed from previous run.") + else: + account_pool = _get_account_pool(params) + worker_index = params.get('worker_index') + if worker_index is not None: + account_id = account_pool[worker_index % len(account_pool)] + logger.info(f"Selected account '{account_id}' deterministically using worker_index {worker_index}.") + else: + # Fallback to random choice if no worker_index is provided (e.g., for manual runs) + account_id = random.choice(account_pool) + logger.warning(f"No worker_index provided. Selected account '{account_id}' randomly as a fallback.") return { 'url_to_process': url_to_process, @@ -305,10 +334,7 @@ def get_url_and_assign_account(**context): @task def get_token(initial_data: dict, **context): - """Makes a single attempt to get a token by calling the ytops-client get-info tool.""" - import subprocess - import shlex - + """Makes a single attempt to get a token by calling the Thrift service directly.""" ti = context['task_instance'] params = context['params'] @@ -318,26 +344,13 @@ def get_token(initial_data: dict, **context): host, port = params['service_ip'], int(params['service_port']) machine_id = params.get('machine_id') or socket.gethostname() - clients = params.get('clients') - request_params_json = params.get('request_params_json') - language_code = params.get('language_code') + + # For sticky proxy assigned_proxy_url = params.get('assigned_proxy_url') - - if language_code: - try: - params_dict = json.loads(request_params_json) - logger.info(f"Setting language for request: {language_code}") - if 'session_params' not in params_dict: - params_dict['session_params'] = {} - params_dict['session_params']['lang'] = language_code - request_params_json = json.dumps(params_dict) - except (json.JSONDecodeError, TypeError): - logger.warning("Could not parse request_params_json as JSON. Treating as key=value pairs and appending language code.") - lang_kv = f"session_params.lang={language_code}" - if request_params_json: - request_params_json += f",{lang_kv}" - else: - request_params_json = lang_kv + + # The unified JSON config is now the primary source of parameters. + request_params_json = params.get('ytdlp_config_json', '{}') + clients = None # This will be read from the JSON config on the server side. video_id = _extract_video_id(url) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -346,75 +359,111 @@ def get_token(initial_data: dict, **context): os.makedirs(job_dir_path, exist_ok=True) info_json_path = os.path.join(job_dir_path, f"info_{video_id or 'unknown'}_{account_id}_{timestamp}.json") - cmd = [ - 'ytops-client', 'get-info', - '--host', host, - '--port', str(port), - '--profile', account_id, - '--output', info_json_path, - '--print-proxy', - '--verbose', - '--log-return', - ] + # Save the received JSON config to the job directory for the download tool. + ytdlp_config_path = os.path.join(job_dir_path, 'ytdlp.json') + try: + with open(ytdlp_config_path, 'w', encoding='utf-8') as f: + # Pretty-print the JSON for readability + config_data = json.loads(request_params_json) + json.dump(config_data, f, indent=2) + logger.info(f"Saved ytdlp config to {ytdlp_config_path}") + except (IOError, json.JSONDecodeError) as e: + logger.error(f"Failed to save ytdlp.json config: {e}") + # Continue anyway, but download may fail. + ytdlp_config_path = None - if clients: - cmd.extend(['--client', clients]) - if machine_id: - cmd.extend(['--machine-id', machine_id]) - if request_params_json and request_params_json != '{}': - cmd.extend(['--request-params-json', request_params_json]) - if assigned_proxy_url: - cmd.extend(['--assigned-proxy-url', assigned_proxy_url]) - - cmd.append(url) + client, transport = None, None + try: + timeout = int(params.get('timeout', DEFAULT_TIMEOUT)) + client, transport = _get_thrift_client(host, port, timeout) - logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' (Clients: {clients}) ---") - copy_paste_cmd = ' '.join(shlex.quote(arg) for arg in cmd) - logger.info(f"Executing command: {copy_paste_cmd}") + airflow_log_context = AirflowLogContext( + taskId=ti.task_id, + runId=ti.run_id, + tryNumber=ti.try_number + ) - process = subprocess.run(cmd, capture_output=True, text=True, timeout=int(params.get('timeout', DEFAULT_TIMEOUT))) - - if process.stdout: - logger.info(f"ytops-client STDOUT:\n{process.stdout}") - if process.stderr: - logger.info(f"ytops-client STDERR:\n{process.stderr}") - - if process.returncode != 0: - error_message = "ytops-client failed. See logs for details." - # Try to find a more specific error message from the Thrift client's output - thrift_error_match = re.search(r'A Thrift error occurred: (.*)', process.stderr) - if thrift_error_match: - error_message = thrift_error_match.group(1).strip() - else: # Fallback to old line-by-line parsing - for line in reversed(process.stderr.strip().split('\n')): - if 'ERROR' in line or 'Thrift error' in line or 'Connection to server failed' in line: - error_message = line.strip() - break + logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' (Clients: {clients}, Proxy: {assigned_proxy_url or 'any'}) ---") - # Determine error code for branching logic - error_code = 'GET_INFO_CLIENT_FAIL' - stderr_lower = process.stderr.lower() + token_data = client.getOrRefreshToken( + accountId=account_id, + updateType=TokenUpdateMode.AUTO, + url=url, + clients=clients, + machineId=machine_id, + airflowLogContext=airflow_log_context, + requestParamsJson=request_params_json, + assignedProxyUrl=assigned_proxy_url + ) + + # --- Log server-side details for debugging --- + if hasattr(token_data, 'serverVersionInfo') and token_data.serverVersionInfo: + logger.info(f"--- Server Version Info ---\n{token_data.serverVersionInfo}") - # These patterns should match the error codes from PBUserException and others - error_patterns = { - "BOT_DETECTED": ["bot_detected"], - "BOT_DETECTION_SIGN_IN_REQUIRED": ["bot_detection_sign_in_required"], - "TRANSPORT_ERROR": ["connection to server failed"], - "PRIVATE_VIDEO": ["private video"], - "COPYRIGHT_REMOVAL": ["copyright"], - "GEO_RESTRICTED": ["in your country"], - "VIDEO_REMOVED": ["video has been removed"], - "VIDEO_UNAVAILABLE": ["video unavailable"], - "MEMBERS_ONLY": ["members-only"], - "AGE_GATED_SIGN_IN": ["sign in to confirm your age"], - "VIDEO_PROCESSING": ["processing this video"], + if hasattr(token_data, 'requestSummary') and token_data.requestSummary: + try: + summary_data = json.loads(token_data.requestSummary) + summary_text = summary_data.get('summary', 'Not available.') + prefetch_log = summary_data.get('prefetch_log', 'Not available.') + nodejs_log = summary_data.get('nodejs_log', 'Not available.') + ytdlp_log = summary_data.get('ytdlp_log', 'Not available.') + + logger.info(f"--- Request Summary ---\n{summary_text}") + logger.info(f"--- Prefetch Log ---\n{prefetch_log}") + logger.info(f"--- Node.js Log ---\n{nodejs_log}") + logger.info(f"--- yt-dlp Log ---\n{ytdlp_log}") + except (json.JSONDecodeError, AttributeError): + logger.info(f"--- Raw Request Summary (could not parse JSON) ---\n{token_data.requestSummary}") + + if hasattr(token_data, 'communicationLogPaths') and token_data.communicationLogPaths: + logger.info("--- Communication Log Paths on Server ---") + for log_path in token_data.communicationLogPaths: + logger.info(f" - {log_path}") + # --- End of server-side logging --- + + if not token_data or not token_data.infoJson: + raise AirflowException("Thrift service did not return valid info.json data.") + + # Save info.json to file + with open(info_json_path, 'w', encoding='utf-8') as f: + f.write(token_data.infoJson) + + proxy = token_data.socks + + # Rename file with proxy + final_info_json_path = info_json_path + if proxy: + sanitized_proxy = proxy.replace('://', '---') + new_filename = f"info_{video_id or 'unknown'}_{account_id}_{timestamp}_proxy_{sanitized_proxy}.json" + new_path = os.path.join(job_dir_path, new_filename) + try: + os.rename(info_json_path, new_path) + final_info_json_path = new_path + logger.info(f"Renamed info.json to include proxy: {new_path}") + except OSError as e: + logger.error(f"Failed to rename info.json to include proxy: {e}. Using original path.") + + return { + 'info_json_path': final_info_json_path, + 'job_dir_path': job_dir_path, + 'socks_proxy': proxy, + 'ytdlp_command': None, + 'successful_account_id': account_id, + 'original_url': url, + 'ytdlp_config_path': ytdlp_config_path, } - for code, patterns in error_patterns.items(): - if any(p in stderr_lower for p in patterns): - error_code = code - break # Found a match, stop searching + except (PBServiceException, PBUserException) as e: + error_message = e.message or "Unknown Thrift error" + error_code = getattr(e, 'errorCode', 'THRIFT_ERROR') + # If a "Video unavailable" error mentions rate-limiting, it's a form of bot detection. + if error_code == 'VIDEO_UNAVAILABLE' and 'rate-limited' in error_message.lower(): + logger.warning("Re-classifying rate-limit-related 'VIDEO_UNAVAILABLE' error as 'BOT_DETECTED'.") + error_code = 'BOT_DETECTED' + + logger.error(f"Thrift error getting token: {error_code} - {error_message}") + error_details = { 'error_message': error_message, 'error_code': error_code, @@ -422,35 +471,18 @@ def get_token(initial_data: dict, **context): } ti.xcom_push(key='error_details', value=error_details) raise AirflowException(f"ytops-client get-info failed: {error_message}") - - proxy = None - proxy_match = re.search(r"Proxy used: (.*)", process.stderr) - if proxy_match: - proxy = proxy_match.group(1).strip() - - # Rename the info.json to include the proxy for the download worker - final_info_json_path = info_json_path - if proxy: - # Sanitize for filename: replace '://' which is invalid in paths. Colons are usually fine. - sanitized_proxy = proxy.replace('://', '---') - - new_filename = f"info_{video_id or 'unknown'}_{account_id}_{timestamp}_proxy_{sanitized_proxy}.json" - new_path = os.path.join(job_dir_path, new_filename) - try: - os.rename(info_json_path, new_path) - final_info_json_path = new_path - logger.info(f"Renamed info.json to include proxy: {new_path}") - except OSError as e: - logger.error(f"Failed to rename info.json to include proxy: {e}. Using original path.") - - return { - 'info_json_path': final_info_json_path, - 'job_dir_path': job_dir_path, - 'socks_proxy': proxy, - 'ytdlp_command': None, - 'successful_account_id': account_id, - 'original_url': url, - } + except TTransportException as e: + logger.error(f"Thrift transport error: {e}", exc_info=True) + error_details = { + 'error_message': f"Thrift transport error: {e}", + 'error_code': 'TRANSPORT_ERROR', + 'proxy_url': None + } + ti.xcom_push(key='error_details', value=error_details) + raise AirflowException(f"Thrift transport error: {e}") + finally: + if transport and transport.isOpen(): + transport.close() @task.branch def handle_bannable_error_branch(task_id_to_check: str, **context): @@ -460,7 +492,31 @@ def handle_bannable_error_branch(task_id_to_check: str, **context): """ ti = context['task_instance'] params = context['params'] - error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details') + + # Try to get error details from the specified task + error_details = None + try: + error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details') + except Exception as e: + logger.warning(f"Could not pull error details from task '{task_id_to_check}': {e}") + + # If not found, try to get from any task in the DAG run + if not error_details: + # Look for error details in any task that may have pushed them + # This is a fallback mechanism + dag_run = ti.get_dagrun() + task_instances = dag_run.get_task_instances() + for task_instance in task_instances: + if task_instance.task_id != ti.task_id: + try: + details = task_instance.xcom_pull(key='error_details') + if details: + error_details = details + logger.info(f"Found error details in task '{task_instance.task_id}'") + break + except Exception: + pass + if not error_details: logger.error(f"Task {task_id_to_check} failed without error details. Marking as fatal.") return 'handle_fatal_error' @@ -577,7 +633,7 @@ def ban_and_retry_logic(initial_data: dict): @task(task_id='ban_account_task') def ban_account_task(data: dict, **context): """Wrapper task to call the main ban_account function.""" - ban_account(initial_data=data, reason="Banned by Airflow worker after sliding window check", **context) + _ban_account(initial_data=data, reason="Banned by Airflow worker after sliding window check", context=context) @task(task_id='skip_ban_task') def skip_ban_task(): @@ -591,8 +647,7 @@ def ban_and_retry_logic(initial_data: dict): check_task >> [ban_task_in_group, skip_task] -@task -def ban_account(initial_data: dict, reason: str, **context): +def _ban_account(initial_data: dict, reason: str, context: dict): """Bans a single account via the Thrift service.""" params = context['params'] account_id = initial_data['account_id'] @@ -602,7 +657,8 @@ def ban_account(initial_data: dict, reason: str, **context): client, transport = _get_thrift_client(host, port, timeout) logger.warning(f"Banning account '{account_id}'. Reason: {reason}") client.banAccount(accountId=account_id, reason=reason) - except Exception as e: + except BaseException as e: + # Catch BaseException to include SystemExit, which may be raised by the Thrift client logger.error(f"Failed to issue ban for account '{account_id}': {e}", exc_info=True) finally: if transport and transport.isOpen(): @@ -650,13 +706,29 @@ def assign_new_account_after_ban_check(initial_data: dict, **context): 'accounts_tried': accounts_tried, } -@task -def ban_and_report_immediately(initial_data: dict, reason: str, **context): - """Bans an account and prepares for failure reporting and continuing the loop.""" - ban_account(initial_data, reason, **context) - logger.info(f"Account '{initial_data.get('account_id')}' banned. Proceeding to report failure.") - # This task is a leaf in its path and is followed by the failure reporting task. - return initial_data # Pass data along if needed by reporting +@task(retries=0) +def ban_and_report_immediately(**context): + """Bans an account and prepares for failure reporting and stopping the loop.""" + ti = context['task_instance'] + # Manually pull initial_data. This is more robust if the upstream task was skipped. + initial_data = ti.xcom_pull(task_ids='get_url_and_assign_account') + if not initial_data: + logger.error("Could not retrieve initial_data to ban account.") + # Return a default dict to allow downstream reporting to proceed. + return {'account_id': 'unknown', 'url_to_process': context['params'].get('url_to_process', 'unknown')} + + try: + reason = "Banned by Airflow worker (policy is stop_loop)" + _ban_account(initial_data, reason, context) + logger.info(f"Account '{initial_data.get('account_id')}' banned. Proceeding to report failure.") + except BaseException as e: + # Catch BaseException to include SystemExit, which may be raised by the Thrift client + logger.error(f"Error during ban_and_report_immediately: {e}", exc_info=True) + # Swallow the exception to ensure this task succeeds. The loop will be stopped by downstream tasks. + + # Always return the initial data, even if banning failed + # Make a copy to ensure we're not returning a reference that might be modified elsewhere + return dict(initial_data) if initial_data else {} @task def list_available_formats(token_data: dict, **context): @@ -787,6 +859,14 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context params = context['params'] info_json_path = token_data.get('info_json_path') original_url = token_data.get('original_url') + ytdlp_config_path = token_data.get('ytdlp_config_path') + ytdlp_config = {} + if ytdlp_config_path and os.path.exists(ytdlp_config_path): + try: + with open(ytdlp_config_path, 'r', encoding='utf-8') as f: + ytdlp_config = json.load(f) + except (IOError, json.JSONDecodeError) as e: + logger.warning(f"Could not load ytdlp config from {ytdlp_config_path}: {e}") # Extract proxy from filename, with fallback to token_data for backward compatibility proxy = None @@ -839,6 +919,10 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context downloader = params.get('downloader', 'py') cmd = ['ytops-client', 'download', downloader, '--load-info-json', info_json_path, '-f', format_selector] + # Pass the unified config file to the download tool + if ytdlp_config_path: + cmd.extend(['--config', ytdlp_config_path]) + if downloader == 'py': if proxy: cmd.extend(['--proxy', proxy]) @@ -846,15 +930,13 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context # The 'py' tool maps many yt-dlp flags via --extra-ytdlp-args # The 'py' tool maps many yt-dlp flags via --extra-ytdlp-args - py_extra_args = ['--output', output_template, '--no-resize-buffer', '--buffer-size', '4M'] - if params.get('fragment_retries'): - py_extra_args.extend(['--fragment-retries', str(params['fragment_retries'])]) - if params.get('socket_timeout'): - py_extra_args.extend(['--socket-timeout', str(params['socket_timeout'])]) + py_extra_args = ['--output', output_template] if params.get('yt_dlp_test_mode'): py_extra_args.append('--test') - existing_extra = shlex.split(params.get('yt_dlp_extra_args') or '') + # Get extra args from the config file now + existing_extra_str = ytdlp_config.get('ytops', {}).get('worker_params', {}).get('yt_dlp_extra_args', '') + existing_extra = shlex.split(existing_extra_str or '') final_extra_args_list = existing_extra + py_extra_args if final_extra_args_list: final_extra_args_str = shlex.join(final_extra_args_list) @@ -877,10 +959,13 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context # The remote-dir is the path relative to aria2c's working directory on the host. # The output-dir is the container's local path to the same shared volume. remote_dir = os.path.relpath(download_dir, '/opt/airflow/downloadfiles/videos') + + # Get aria params from config file + worker_params = ytdlp_config.get('ytops', {}).get('worker_params', {}) cmd.extend([ - '--aria-host', params.get('aria_host', '172.17.0.1'), - '--aria-port', str(params.get('aria_port', 6800)), - '--aria-secret', params.get('aria_secret'), + '--aria-host', worker_params.get('aria_host', '172.17.0.1'), + '--aria-port', str(worker_params.get('aria_port', 6800)), + '--aria-secret', worker_params.get('aria_secret'), '--wait', '--output-dir', download_dir, '--remote-dir', remote_dir, @@ -900,11 +985,7 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context cmd.extend(['--proxy', proxy]) # The 'cli' tool is the old yt-dlp wrapper, so it takes similar arguments. - cli_extra_args = ['--output', full_output_path, '--no-resize-buffer', '--buffer-size', '4M'] - if params.get('fragment_retries'): - cli_extra_args.extend(['--fragment-retries', str(params['fragment_retries'])]) - if params.get('socket_timeout'): - cli_extra_args.extend(['--socket-timeout', str(params['socket_timeout'])]) + cli_extra_args = ['--output', full_output_path, '--verbose'] if params.get('yt_dlp_test_mode'): cli_extra_args.append('--test') @@ -1030,71 +1111,84 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context with open(info_json_path, 'r', encoding='utf-8') as f: info = json.load(f) - # Split the format string by commas to get a list of individual format selectors. - # This enables parallel downloads of different formats or format groups. - # For example, '18,140,299/298' becomes ['18', '140', '299/298'], - # and each item will be downloaded in a separate yt-dlp process. - if download_format and isinstance(download_format, str): - formats_to_download_initial = [selector.strip() for selector in download_format.split(',') if selector.strip()] - else: - # Fallback for safety, though download_format should always be a string. - formats_to_download_initial = [] + ytdlp_params = ytdlp_config.get('ytdlp_params', {}) + download_format = ytdlp_params.get('format') - if not formats_to_download_initial: - raise AirflowException("No valid download format selectors were found after parsing.") - - # --- Filter and resolve requested formats --- final_formats_to_download = [] - if not available_formats: - logger.warning("List of available formats is empty. Cannot validate numeric selectors, but will attempt to resolve generic selectors.") + downloader = params.get('downloader', 'cli') + pass_without_splitting = params.get('pass_without_formats_splitting', False) - for selector in formats_to_download_initial: - # A selector is considered generic if it contains keywords like 'best' or filter brackets '[]'. - is_generic = bool(re.search(r'(best|\[|\])', selector)) - - if is_generic: - resolved_selector = _resolve_generic_selector(selector, info_json_path, logger) - if resolved_selector: - # The resolver returns a list for '+' selectors, or a string for others. - resolved_formats = resolved_selector if isinstance(resolved_selector, list) else [resolved_selector] - - for res_format in resolved_formats: - # Prefer -dashy version if available and the format is a simple numeric ID - if res_format.isdigit() and f"{res_format}-dashy" in available_formats: - final_format = f"{res_format}-dashy" - logger.info(f"Resolved format '{res_format}' from selector '{selector}'. Preferred '-dashy' version: '{final_format}'.") - else: - final_format = res_format - - # Validate the chosen format against available formats - if available_formats: - individual_ids = re.split(r'[/+]', final_format) - is_available = any(fid in available_formats for fid in individual_ids) - - if is_available: - final_formats_to_download.append(final_format) - else: - logger.warning(f"Resolved format '{final_format}' (from '{selector}') contains no available formats. Skipping.") - else: - # Cannot validate, so we trust the resolver's output. - final_formats_to_download.append(final_format) - else: - logger.warning(f"Could not resolve generic selector '{selector}' using yt-dlp. Skipping.") + if pass_without_splitting and downloader != 'aria-rpc': + logger.info("'pass_without_formats_splitting' is True. Passing download format string directly to the download tool.") + final_formats_to_download = download_format + else: + if pass_without_splitting and downloader == 'aria-rpc': + logger.warning("'pass_without_formats_splitting' is True but is not compatible with 'aria-rpc' downloader. Splitting formats as normal.") + + # Split the format string by commas to get a list of individual format selectors. + # This enables parallel downloads of different formats or format groups. + # For example, '18,140,299/298' becomes ['18', '140', '299/298'], + # and each item will be downloaded in a separate yt-dlp process. + if download_format and isinstance(download_format, str): + formats_to_download_initial = [selector.strip() for selector in download_format.split(',') if selector.strip()] else: - # This is a numeric-based selector (e.g., '140' or '299/298' or '140-dashy'). - # Validate it against the available formats. - if not available_formats: - logger.warning(f"Cannot validate numeric selector '{selector}' because available formats list is empty. Assuming it's valid.") - final_formats_to_download.append(selector) - continue + # Fallback for safety, though download_format should always be a string. + formats_to_download_initial = [] - individual_ids = re.split(r'[/+]', selector) - is_available = any(fid in available_formats for fid in individual_ids) - - if is_available: - final_formats_to_download.append(selector) + if not formats_to_download_initial: + raise AirflowException("No valid download format selectors were found after parsing.") + + # --- Filter and resolve requested formats --- + if not available_formats: + logger.warning("List of available formats is empty. Cannot validate numeric selectors, but will attempt to resolve generic selectors.") + + for selector in formats_to_download_initial: + # A selector is considered generic if it contains keywords like 'best' or filter brackets '[]'. + is_generic = bool(re.search(r'(best|\[|\])', selector)) + + if is_generic: + resolved_selector = _resolve_generic_selector(selector, info_json_path, logger) + if resolved_selector: + # The resolver returns a list for '+' selectors, or a string for others. + resolved_formats = resolved_selector if isinstance(resolved_selector, list) else [resolved_selector] + + for res_format in resolved_formats: + # Prefer -dashy version if available and the format is a simple numeric ID + if res_format.isdigit() and f"{res_format}-dashy" in available_formats: + final_format = f"{res_format}-dashy" + logger.info(f"Resolved format '{res_format}' from selector '{selector}'. Preferred '-dashy' version: '{final_format}'.") + else: + final_format = res_format + + # Validate the chosen format against available formats + if available_formats: + individual_ids = re.split(r'[/+]', final_format) + is_available = any(fid in available_formats for fid in individual_ids) + + if is_available: + final_formats_to_download.append(final_format) + else: + logger.warning(f"Resolved format '{final_format}' (from '{selector}') contains no available formats. Skipping.") + else: + # Cannot validate, so we trust the resolver's output. + final_formats_to_download.append(final_format) + else: + logger.warning(f"Could not resolve generic selector '{selector}' using yt-dlp. Skipping.") else: - logger.warning(f"Requested numeric format selector '{selector}' contains no available formats. Skipping.") + # This is a numeric-based selector (e.g., '140' or '299/298' or '140-dashy'). + # Validate it against the available formats. + if not available_formats: + logger.warning(f"Cannot validate numeric selector '{selector}' because available formats list is empty. Assuming it's valid.") + final_formats_to_download.append(selector) + continue + + individual_ids = re.split(r'[/+]', selector) + is_available = any(fid in available_formats for fid in individual_ids) + + if is_available: + final_formats_to_download.append(selector) + else: + logger.warning(f"Requested numeric format selector '{selector}' contains no available formats. Skipping.") if not final_formats_to_download: raise AirflowException("None of the requested formats are available for this video.") @@ -1323,6 +1417,8 @@ def mark_url_as_success(initial_data: dict, downloaded_file_paths: list, token_d logger.info(f"Stored success result for URL '{url}' and removed from progress queue.") + return token_data + @task(trigger_rule='one_failed') def report_failure_and_stop(**context): """ @@ -1331,7 +1427,12 @@ def report_failure_and_stop(**context): """ params = context['params'] ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') + url = params.get('url_to_process') + + # Ensure we have a valid URL string for Redis keys + if not url or url == 'None': + url = f"unknown_url_{context['dag_run'].run_id}" + logger.warning(f"No valid URL found in params. Using generated key: {url}") # Collect error details from XCom error_details = {} @@ -1379,12 +1480,15 @@ def report_failure_and_stop(**context): with client.pipeline() as pipe: pipe.hset(result_queue, url, json.dumps(result_data)) pipe.hset(fail_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) + # Only try to remove from progress queue if we have a real URL + if url != f"unknown_url_{context['dag_run'].run_id}": + pipe.hdel(progress_queue, url) pipe.execute() logger.info(f"Stored failure result for URL '{url}' in '{result_queue}' and '{fail_queue}' and removed from progress queue.") except Exception as e: logger.error(f"Could not report failure to Redis: {e}", exc_info=True) + return None @task(trigger_rule='one_failed') @@ -1395,7 +1499,12 @@ def report_failure_and_continue(**context): """ params = context['params'] ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') + url = params.get('url_to_process') + + # Ensure we have a valid URL string for Redis keys + if not url or url == 'None': + url = f"unknown_url_{context['dag_run'].run_id}" + logger.warning(f"No valid URL found in params. Using generated key: {url}") # Collect error details from XCom error_details = {} @@ -1446,7 +1555,9 @@ def report_failure_and_continue(**context): with client.pipeline() as pipe: pipe.hset(result_queue, url, json.dumps(result_data)) pipe.hset(fail_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) + # Only try to remove from progress queue if we have a real URL + if url != f"unknown_url_{context['dag_run'].run_id}": + pipe.hdel(progress_queue, url) pipe.execute() logger.info(f"Stored failure result for URL '{url}' in '{result_queue}' and '{fail_queue}' and removed from progress queue.") @@ -1463,7 +1574,12 @@ def handle_fatal_error(**context): """ params = context['params'] ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') + url = params.get('url_to_process') + + # Ensure we have a valid URL string for Redis keys + if not url or url == 'None': + url = f"unknown_url_{context['dag_run'].run_id}" + logger.warning(f"No valid URL found in params. Using generated key: {url}") # Collect error details error_details = {} @@ -1509,25 +1625,36 @@ def handle_fatal_error(**context): with client.pipeline() as pipe: pipe.hset(result_queue, url, json.dumps(result_data)) pipe.hset(fail_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) + # Only try to remove from progress queue if we have a real URL + if url != f"unknown_url_{context['dag_run'].run_id}": + pipe.hdel(progress_queue, url) pipe.execute() logger.info(f"Stored fatal error result for URL '{url}' in '{result_queue}' and '{fail_queue}' for later reprocessing.") except Exception as e: logger.error(f"Could not report fatal error to Redis: {e}", exc_info=True) - # Fail the DAG run to prevent automatic continuation of the processing loop - raise AirflowException("Failing DAG due to fatal error. The dispatcher loop will stop.") + # Instead of raising an exception, log a clear message and return a result + # This allows the task to complete successfully while still indicating the error + logger.error("FATAL ERROR: The dispatcher loop will stop due to a non-retryable error.") + return {'status': 'fatal_error', 'url': url} @task(trigger_rule='one_success') -def continue_processing_loop(**context): +def continue_processing_loop(token_data: dict | None = None, **context): """ - After a successful run, triggers a new dispatcher to continue the processing loop, - effectively asking for the next URL to be processed. + After a run, triggers a new dispatcher to continue the processing loop, + passing along the account/proxy to make them sticky if available. """ params = context['params'] dag_run = context['dag_run'] + ti = context['task_instance'] + + # Check if we're coming from a fatal error path + fatal_error_result = ti.xcom_pull(task_ids='handle_fatal_error') + if fatal_error_result and isinstance(fatal_error_result, dict) and fatal_error_result.get('status') == 'fatal_error': + logger.error("Not continuing processing loop due to fatal error in previous task.") + return # Do not continue the loop for manual runs of the worker DAG. # A worker DAG triggered by the dispatcher will have a run_id starting with 'worker_run_'. @@ -1542,18 +1669,29 @@ def continue_processing_loop(**context): return # Create a new unique run_id for the dispatcher. - # Using a timestamp and UUID ensures the ID is unique and does not grow in length over time, - # preventing database errors. new_dispatcher_run_id = f"retriggered_by_worker_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" # Pass all original parameters from the orchestrator through to the new dispatcher run. conf_to_pass = {k: v for k, v in params.items() if v is not None} + conf_to_pass['worker_index'] = params.get('worker_index') - # The new dispatcher will pull its own URL and determine its own queue, so we don't pass these. + if token_data: + # On success path, make the account and proxy "sticky" for the next run. + conf_to_pass['account_id'] = token_data.get('successful_account_id') + conf_to_pass['assigned_proxy_url'] = token_data.get('socks_proxy') + logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop with sticky account/proxy.") + logger.info(f" - Sticky Account: {conf_to_pass.get('account_id')}") + logger.info(f" - Sticky Proxy: {conf_to_pass.get('assigned_proxy_url')}") + else: + # On failure/skip paths, no token_data is passed. Clear sticky params to allow re-selection. + conf_to_pass.pop('account_id', None) + conf_to_pass.pop('assigned_proxy_url', None) + logger.info(f"Worker finished on a non-success path. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop without sticky account/proxy.") + + # The new dispatcher will pull its own URL and determine its own queue. conf_to_pass.pop('url_to_process', None) conf_to_pass.pop('worker_queue', None) - logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop.") trigger_dag( dag_id=dispatcher_dag_id, run_id=new_dispatcher_run_id, @@ -1606,10 +1744,21 @@ def handle_retry_failure_branch(task_id_to_check: str, **context): @task -def ban_and_report_after_retry(retry_data: dict, reason: str, **context): +def ban_and_report_after_retry(**context): """Bans the account used in a failed retry and prepares for failure reporting.""" + ti = context['task_instance'] + reason = "Banned by Airflow worker after failed retry" + + # Manually pull XCom because trigger rules can make XComArgs resolve to None. + retry_data = ti.xcom_pull(task_ids='retry_logic.coalesce_retry_data') + if not retry_data: + # This can happen if the upstream task that generates the data was skipped. + logger.error("Could not retrieve retry data to ban account. This may be due to an unexpected task flow.") + # Instead of failing, return a default dict with enough info to continue + return {'account_id': 'unknown', 'url_to_process': context['params'].get('url_to_process', 'unknown')} + # The account to ban is the one from the retry attempt. - ban_account(retry_data, reason, **context) + _ban_account(retry_data, reason, context) logger.info(f"Account '{retry_data.get('account_id')}' banned after retry failed. Proceeding to report failure.") return retry_data @@ -1624,24 +1773,31 @@ def handle_download_failure_branch(**context): policy = params.get('on_download_failure', 'proceed_loop') ti = context['task_instance'] - # The full task_id for download_and_probe is 'download_processing.download_and_probe' download_error_details = ti.xcom_pull(task_ids='download_processing.download_and_probe', key='download_error_details') + # First, check for specific error codes that override the general policy. if download_error_details: error_code = download_error_details.get('error_code') + + # Unrecoverable video errors always go to the 'skipped' handler. unrecoverable_video_errors = [ "AGE_GATED_SIGN_IN", "MEMBERS_ONLY", "VIDEO_PROCESSING", "COPYRIGHT_REMOVAL", - "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED", - "HTTP_403_FORBIDDEN" + "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED" ] if error_code in unrecoverable_video_errors: logger.warning(f"Unrecoverable video error '{error_code}' during download. Skipping.") return 'handle_unrecoverable_video_error' + # A 403 Forbidden error is not retryable, regardless of policy. + if error_code == 'HTTP_403_FORBIDDEN': + logger.error("Download failed with HTTP 403 Forbidden. This is not retryable. Reporting failure and continuing loop.") + return 'report_failure_and_continue' + + # Now, apply the general policy for other download failures. if policy == 'retry_with_new_token': logger.info("Download failed. Policy is to retry with a new token. Branching to retry logic.") return 'retry_logic_for_download' - + if policy == 'stop_loop': logger.error(f"Download or probe failed with policy '{policy}'. Stopping loop by routing to fatal error handler.") return 'handle_fatal_error' @@ -1667,15 +1823,46 @@ def coalesce_token_data(get_token_result=None, retry_get_token_result=None): raise AirflowException("Could not find a successful token result from any attempt.") -@task +# FIX: Use 'all_done' trigger rule so this task runs even when upstream tasks fail. +# The branch operator will skip other branches, but this task needs to run +# when the branch points to it, regardless of the failed get_token task. +@task(trigger_rule='all_done') def handle_unrecoverable_video_error(**context): """ Handles errors for videos that are unavailable (private, removed, etc.). These are not system failures, so the URL is logged to a 'skipped' queue and the processing loop continues without marking the run as failed. """ - params = context['params'] ti = context['task_instance'] + + # Check if this task was actually selected by the branch operator. + # If it was skipped by the branch, we should not execute the logic. + # We can check if the branch task's result points to us. + dag_run = ti.get_dagrun() + + # Check multiple possible branch tasks that could route here + branch_task_ids = [ + 'initial_attempt.handle_bannable_error_branch', + 'retry_logic.handle_retry_failure_branch', + 'download_processing.handle_download_failure_branch' + ] + + was_selected_by_branch = False + for branch_task_id in branch_task_ids: + try: + branch_result = ti.xcom_pull(task_ids=branch_task_id) + if branch_result == 'handle_unrecoverable_video_error': + was_selected_by_branch = True + logger.info(f"Task was selected by branch '{branch_task_id}'") + break + except Exception: + pass + + if not was_selected_by_branch: + logger.info("Task was not selected by any branch operator. Skipping execution.") + raise AirflowSkipException("Not selected by branch operator") + + params = context['params'] url = params.get('url_to_process', 'unknown') # Collect error details from the failed task @@ -1717,22 +1904,50 @@ def handle_unrecoverable_video_error(**context): logger.info(f"Stored skipped result for URL '{url}' in '{skipped_queue}' and removed from progress queue.") except Exception as e: logger.error(f"Could not report skipped video to Redis: {e}", exc_info=True) + + # Return a marker so downstream tasks know this path was taken + return {'status': 'skipped', 'url': url} -@task +# FIX: Use 'all_done' trigger rule for the same reason as handle_unrecoverable_video_error +@task(trigger_rule='all_done') def report_bannable_and_continue(**context): """ Handles a bannable error by reporting it, but continues the loop as per the 'proceed_loop_under_manual_inspection' policy. """ - params = context['params'] ti = context['task_instance'] + + # Check if this task was actually selected by the branch operator. + dag_run = ti.get_dagrun() + + branch_task_ids = [ + 'initial_attempt.handle_bannable_error_branch', + 'retry_logic.handle_retry_failure_branch' + ] + + was_selected_by_branch = False + for branch_task_id in branch_task_ids: + try: + branch_result = ti.xcom_pull(task_ids=branch_task_id) + if branch_result == 'report_bannable_and_continue': + was_selected_by_branch = True + logger.info(f"Task was selected by branch '{branch_task_id}'") + break + except Exception: + pass + + if not was_selected_by_branch: + logger.info("Task was not selected by any branch operator. Skipping execution.") + raise AirflowSkipException("Not selected by branch operator") + + params = context['params'] url = params.get('url_to_process', 'unknown') # Collect error details error_details = {} - first_token_task_id = 'get_token' - retry_token_task_id = 'retry_get_token' + first_token_task_id = 'initial_attempt.get_token' + retry_token_task_id = 'retry_logic.retry_get_token' first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details') retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details') @@ -1779,6 +1994,11 @@ def report_bannable_and_continue(**context): logger.info(f"Stored bannable error for URL '{url}' in '{result_queue}' and '{fail_queue}'.") except Exception as e: logger.error(f"Could not report bannable error to Redis: {e}", exc_info=True) + + # Return a marker so downstream tasks know this path was taken + return {'status': 'bannable_reported', 'url': url} + + # ============================================================================= @@ -1802,8 +2022,7 @@ with DAG( 'account_pool_size': Param(None, type=["integer", "null"]), 'prepend_client_to_account': Param(True, type="boolean", title="[Worker Param] Prepend Client to Account", description="If True, prepends client and timestamp to account names in prefix mode."), 'machine_id': Param(None, type=["string", "null"]), - 'assigned_proxy_url': Param(None, type=["string", "null"], title="[Worker Param] Assigned Proxy URL", description="A specific proxy URL to use for the request, overriding the server's proxy pool logic."), - 'clients': Param('tv_simply', type="string", description="Comma-separated list of clients for token generation. e.g. mweb,tv,web_camoufox"), + 'assigned_proxy_url': Param(None, type=["string", "null"], title="[Manual/Worker Param] Assigned Proxy URL", description="For manual runs or sticky loops: a specific proxy URL to use, overriding the server's proxy pool logic."), 'timeout': Param(DEFAULT_TIMEOUT, type="integer"), 'output_path_template': Param("%(id)s.f%(format_id)s.%(ext)s", type="string", title="[Worker Param] Output Path Template", description="Output filename template for yt-dlp. It is highly recommended to include `%(format_id)s` to prevent filename collisions when downloading multiple formats."), 'on_auth_failure': Param( @@ -1820,21 +2039,23 @@ with DAG( title="[Worker Param] On Download Failure Policy", description="Policy for handling download or probe failures." ), - 'request_params_json': Param('{}', type="string", title="[Worker Param] Request Params JSON", description="JSON string with request parameters for the token service."), - 'language_code': Param('en-US', type="string", title="[Worker Param] Language Code", description="The language code (e.g., 'en-US', 'de-DE') to use for the YouTube request headers."), 'retry_on_probe_failure': Param(False, type="boolean"), 'skip_probe': Param(False, type="boolean", title="[Worker Param] Skip Probe", description="If True, skips the ffmpeg probe of downloaded files."), 'yt_dlp_cleanup_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Cleanup Mode", description="If True, creates a .empty file and deletes the original media file after successful download and probe."), 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean"), - 'fragment_retries': Param(2, type="integer", title="[Worker Param] Fragment Retries", description="Number of retries for a fragment before giving up. Default is 2 to fail fast on expired tokens."), 'delay_between_formats_s': Param(15, type="integer", title="[Worker Param] Delay Between Formats (s)", description="Delay in seconds between downloading each format when multiple formats are specified. A 22s wait may be effective for batch downloads, while 6-12s may suffice if cookies are refreshed regularly."), 'yt_dlp_test_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Test Mode", description="If True, runs yt-dlp with --test flag (dry run without downloading)."), - 'socket_timeout': Param(15, type="integer", title="[Worker Param] Socket Timeout", description="Timeout in seconds for socket operations."), 'download_format': Param( 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', type="string", title="[Worker Param] Download Format", - description="Custom yt-dlp format string. Common presets: [1] 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' (Default, best quality MP4). [2] '18-dashy/18,140-dashy/140,133-dashy/134-dashy/136-dashy/137-dashy/250-dashy/298-dashy/299-dashy' (Legacy formats). [3] '299-dashy/298-dashy/250-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy' (High-framerate formats)." + description="Custom yt-dlp format string. Common presets: [1] 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' (Default, best quality MP4). [2] '18,140-dashy/140,133-dashy/134-dashy/136-dashy/137-dashy/298-dashy/299-dashy' (Legacy formats). [3] '299-dashy/298-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy' (High-framerate formats)." + ), + 'pass_without_formats_splitting': Param( + False, + type="boolean", + title="[Worker Param] Pass format string without splitting", + description="If True, passes the entire 'download_format' string to the download tool as-is. This is for complex selectors. Not compatible with 'aria-rpc' downloader." ), 'downloader': Param( 'cli', @@ -1846,15 +2067,14 @@ with DAG( 'aria_host': Param('172.17.0.1', type="string", title="Aria2c Host", description="For 'aria-rpc' downloader: Host of the aria2c RPC server."), 'aria_port': Param(6800, type="integer", title="Aria2c Port", description="For 'aria-rpc' downloader: Port of the aria2c RPC server."), 'aria_secret': Param('SQGCQPLVFQIASMPNPOJYLVGJYLMIDIXDXAIXOTX', type="string", title="Aria2c Secret", description="For 'aria-rpc' downloader: Secret token."), - 'yt_dlp_extra_args': Param( - '', - type=["string", "null"], - title="Extra yt-dlp arguments", - ), + # --- Unified JSON Config (passed from orchestrator) --- + 'ytdlp_config_json': Param('{}', type="string", title="[Internal] Unified JSON config from orchestrator."), # --- Manual Run / Internal Parameters --- 'manual_url_to_process': Param('iPwdia3gAnk', type=["string", "null"], title="[Manual Run] URL to Process", description="For manual runs, provide a single YouTube URL, or the special value 'PULL_FROM_QUEUE' to pull one URL from the Redis inbox. This is ignored if triggered by the dispatcher."), 'url_to_process': Param(None, type=["string", "null"], title="[Internal] URL from Dispatcher", description="This parameter is set by the dispatcher DAG and should not be used for manual runs."), 'worker_queue': Param(None, type=["string", "null"], title="[Internal] Worker Queue", description="This parameter is set by the dispatcher DAG and should not be used for manual runs."), + 'worker_index': Param(None, type=["integer", "null"], title="[Internal] Worker Index", description="A unique index for each parallel worker loop, assigned by the orchestrator."), + 'account_id': Param(None, type=["string", "null"], title="[Internal] Assigned Account ID", description="A specific account_id to use, making the account 'sticky' for a worker loop."), } ) as dag: initial_data = get_url_and_assign_account() @@ -1862,10 +2082,11 @@ with DAG( # --- Task Instantiation with TaskGroups --- # Main success/failure handlers (outside groups for clear end points) - fatal_error_task = handle_fatal_error() - report_failure_and_stop_task = report_failure_and_stop() - report_failure_task = report_failure_and_continue() - continue_loop_task = continue_processing_loop() + # These tasks are targets of branch operators that run after failures. + # They need trigger_rule='all_done' to run when the branch points to them. + fatal_error_task = handle_fatal_error.override(trigger_rule='all_done')() + report_failure_and_stop_task = report_failure_and_stop.override(trigger_rule='all_done')() + report_failure_task = report_failure_and_continue.override(trigger_rule='all_done')() unrecoverable_video_error_task = handle_unrecoverable_video_error() report_bannable_and_continue_task = report_bannable_and_continue() @@ -1877,10 +2098,7 @@ with DAG( ) # Tasks for the "stop_loop" policy on initial attempt - ban_and_report_immediately_task = ban_and_report_immediately.override(task_id='ban_and_report_immediately')( - initial_data=initial_data, - reason="Banned by Airflow worker (policy is stop_loop)" - ) + ban_and_report_immediately_task = ban_and_report_immediately.override(task_id='ban_and_report_immediately')() first_token_attempt >> initial_branch_task initial_branch_task >> [fatal_error_task, ban_and_report_immediately_task, unrecoverable_video_error_task, report_bannable_and_continue_task] @@ -1922,10 +2140,7 @@ with DAG( retry_branch_task = handle_retry_failure_branch.override(trigger_rule='one_failed')( task_id_to_check=retry_token_task.operator.task_id ) - ban_after_retry_report_task = ban_and_report_after_retry.override(task_id='ban_and_report_after_retry')( - retry_data=coalesced_retry_data, - reason="Banned by Airflow worker after failed retry" - ) + ban_after_retry_report_task = ban_and_report_after_retry.override(task_id='ban_and_report_after_retry', trigger_rule='all_done')() # Internal dependencies within retry group ban_and_retry_group >> after_ban_account_task @@ -2032,25 +2247,49 @@ with DAG( download_retry_token_task_result=new_token_data ) - # Final success task, fed by coalesced results - final_success_task = mark_url_as_success.override(task_id='final_success_report')( + # Instantiate final success task + final_success_task = mark_url_as_success( initial_data=initial_data, downloaded_file_paths=final_files, token_data=final_token ) - final_success_task >> continue_loop_task + + # Coalesce all paths that lead to the continuation of the loop. + @task(trigger_rule='one_success') + def coalesce_all_continue_paths(success_result=None, unrecoverable_result=None, bannable_result=None, failure_result=None, fatal_error_result=None): + """ + Gathers results from all possible paths that can continue the processing loop. + Only the success path provides data; others provide None. + """ + if fatal_error_result and isinstance(fatal_error_result, dict) and fatal_error_result.get('status') == 'fatal_error': + logger.error("Fatal error detected in coalesce_all_continue_paths. Will not continue processing loop.") + return {'status': 'fatal_error'} + + if success_result: + return success_result + return None + + final_data_for_loop = coalesce_all_continue_paths( + success_result=final_success_task, + unrecoverable_result=unrecoverable_video_error_task, + bannable_result=report_bannable_and_continue_task, + failure_result=report_failure_task, + fatal_error_result=fatal_error_task, + ) + + # Final task to trigger the next DAG run + continue_processing_loop(token_data=final_data_for_loop) + + # Final success task, fed by coalesced results + final_files >> final_success_task + final_token >> final_success_task # --- DAG Dependencies between TaskGroups --- # Initial attempt can lead to retry logic or direct failure initial_branch_task >> [retry_logic_group, fatal_error_task, ban_and_report_immediately_task, unrecoverable_video_error_task, report_bannable_and_continue_task] - + # Ban and report immediately leads to failure reporting ban_and_report_immediately_task >> report_failure_and_stop_task - - # Unrecoverable/bannable errors that don't stop the loop should continue processing - unrecoverable_video_error_task >> continue_loop_task - report_bannable_and_continue_task >> continue_loop_task - report_failure_task >> continue_loop_task # Connect download failure branch to the new retry group download_branch_task >> [retry_logic_for_download_group, report_failure_task, fatal_error_task, unrecoverable_video_error_task] @@ -2058,7 +2297,7 @@ with DAG( # Connect success paths to the coalescing tasks download_task >> final_files retry_download_task >> final_files - + # The token from the initial auth path is one input to the final token coalesce coalesce_token_data(get_token_result=first_token_attempt, retry_get_token_result=retry_token_task) >> final_token # The token from the download retry path is the other input diff --git a/airflow/dags/ytdlp_ops_v02_dispatcher_dl.py b/airflow/dags/ytdlp_ops_v02_dispatcher_dl.py index c1cdd12..78614ef 100644 --- a/airflow/dags/ytdlp_ops_v02_dispatcher_dl.py +++ b/airflow/dags/ytdlp_ops_v02_dispatcher_dl.py @@ -27,34 +27,22 @@ DEFAULT_REDIS_CONN_ID = 'redis_default' @task(queue='queue-dl') def dispatch_job_to_dl_worker(**context): """ - Pulls one job payload from Redis, determines the current worker's dedicated queue, - and triggers the download worker DAG to process the job on that specific queue. + Triggers a v2 download worker for the 'profile-first' model. + The worker itself is responsible for locking a profile and finding a suitable task. + This dispatcher simply starts a worker process. """ ti = context['task_instance'] logger.info(f"Download Dispatcher task '{ti.task_id}' running on queue '{ti.queue}'.") - params = context['params'] - redis_conn_id = params['redis_conn_id'] - queue_name = params['queue_name'] - inbox_queue = f"{queue_name}_inbox" - - logger.info(f"Attempting to pull one job from Redis queue '{inbox_queue}'...") - client = _get_redis_client(redis_conn_id) - job_bytes = client.lpop(inbox_queue) - - if not job_bytes: - logger.info("Redis download inbox queue is empty. No work to dispatch. Skipping task.") - raise AirflowSkipException("Redis download inbox queue is empty. No work to dispatch.") - - job_data_str = job_bytes.decode('utf-8') - logger.info(f"Pulled job from the queue.") # Determine the worker-specific queue for affinity hostname = socket.gethostname() worker_queue = f"queue-dl-{hostname}" - logger.info(f"Running on worker '{hostname}'. Dispatching job to its dedicated queue '{worker_queue}'.") + logger.info(f"Running on worker '{hostname}'. Dispatching a new profile-first worker instance to its dedicated queue '{worker_queue}'.") - conf_to_pass = {**params, 'job_data': job_data_str, 'worker_queue': worker_queue} + # Pass all orchestrator params, but remove job_data as the worker finds its own job. + conf_to_pass = {**params, 'worker_queue': worker_queue} + conf_to_pass.pop('job_data', None) run_id = f"worker_run_dl_{context['dag_run'].run_id}_{context['ts_nodash']}_q_{worker_queue}" @@ -75,10 +63,12 @@ with DAG( tags=['ytdlp', 'worker', 'dispatcher', 'download'], is_paused_upon_creation=True, doc_md=""" - ### YT-DLP Download Job Dispatcher + ### YT-DLP v2 Download Worker Dispatcher (Profile-First) - This DAG dispatches a single download job to a download worker with a pinned queue. - It pulls a JSON payload from the `queue2_dl_inbox` Redis queue and triggers the `ytdlp_ops_v02_worker_per_url_dl` DAG. + This DAG dispatches a single "profile-first" download worker. + It does **not** pull a job from a queue. Instead, it triggers the `ytdlp_ops_v02_worker_per_url_dl` DAG, + which is responsible for locking an available download profile and then finding a matching task + from the `queue_dl_format_tasks` Redis list. """, render_template_as_native_obj=True, params={ diff --git a/airflow/dags/ytdlp_ops_v02_orchestrator_auth.py b/airflow/dags/ytdlp_ops_v02_orchestrator_auth.py index 7039cd2..02eb52e 100644 --- a/airflow/dags/ytdlp_ops_v02_orchestrator_auth.py +++ b/airflow/dags/ytdlp_ops_v02_orchestrator_auth.py @@ -24,6 +24,12 @@ import random import time import json +# --- Add project root to path to allow for yt-ops-client imports --- +import sys +# The yt-ops-client package is installed in editable mode in /app +if '/app' not in sys.path: + sys.path.insert(0, '/app') + # Import utility functions from utils.redis_utils import _get_redis_client @@ -45,6 +51,66 @@ DEFAULT_BUNCH_DELAY_S = 1 DEFAULT_YT_AUTH_SERVICE_IP = Variable.get("YT_AUTH_SERVICE_IP", default_var="172.17.0.1") DEFAULT_YT_AUTH_SERVICE_PORT = Variable.get("YT_AUTH_SERVICE_PORT", default_var=9080) +# Default ytdlp.json content for the unified config parameter +DEFAULT_YTDLP_CONFIG = { + "ytops": { + "force_renew": [], + "session_params": { + # "visitor_rotation_threshold": 250 + } + }, + "ytdlp_params": { + "debug_printtraffic": True, + "write_pages": True, + "verbose": True, + "no_color": True, + "ignoreerrors": True, + "noresizebuffer": True, + "buffersize": "4M", + "concurrent_fragments": 8, + "socket_timeout": 60, + "outtmpl": { + "default": "%(id)s.f%(format_id)s.%(ext)s" + }, + "restrictfilenames": True, + "updatetime": False, + "noplaylist": True, + "match_filter": "!is_live", + "writeinfojson": True, + "skip_download": True, + "allow_playlist_files": False, + "clean_infojson": True, + "getcomments": False, + "writesubtitles": False, + "writethumbnail": False, + "sleep_interval_requests": 0.75, + "parse_metadata": [ + ":(?P)" + ], + "extractor_args": { + "youtube": { + "player_client": ["tv_simply"], + "formats": ["duplicate"], + "jsc_trace": ["true"], + "pot_trace": ["true"], + "skip": ["translated_subs", "hls"] + }, + "youtubepot-bgutilhttp": { + "base_url": ["http://172.17.0.1:4416"] + } + }, + "noprogress": True, + "format_sort": [ + "res", + "ext:mp4:m4a" + ], + "remuxvideo": "mp4", + "nooverwrites": True, + "continuedl": True + } +} + + # --- Helper Functions --- def _check_application_queue(redis_client, queue_base_name: str) -> int: @@ -153,27 +219,21 @@ def orchestrate_workers_ignition_callable(**context): dag_run_id = context['dag_run'].run_id total_triggered = 0 - # --- Generate a consistent timestamped prefix for this orchestrator run --- - # This ensures all workers spawned from this run use the same set of accounts. - final_account_pool_prefix = params['account_pool'] - if params.get('prepend_client_to_account') and params.get('account_pool_size') is not None: - clients_str = params.get('clients', '') - primary_client = clients_str.split(',')[0].strip() if clients_str else 'unknown' - # Use a timestamp from the orchestrator's run for consistency - timestamp = datetime.now().strftime('%Y%m%d%H%M%S') - final_account_pool_prefix = f"{params['account_pool']}_{timestamp}_{primary_client}" - logger.info(f"Generated consistent account prefix for this run: '{final_account_pool_prefix}'") + # --- End of Inspection --- + + logger.info(f"Plan: Triggering {total_workers} total dispatcher runs in {len(bunches)} bunches. Each run will attempt to process one URL.") + + dag_run_id = context['dag_run'].run_id + total_triggered = 0 for i, bunch in enumerate(bunches): logger.info(f"--- Triggering Bunch {i+1}/{len(bunches)} (contains {len(bunch)} dispatcher(s)) ---") - for j, _ in enumerate(bunch): + for j, worker_index in enumerate(bunch): # Create a unique run_id for each dispatcher run run_id = f"dispatched_{dag_run_id}_{total_triggered}" # Pass all orchestrator params to the dispatcher, which will then pass them to the worker. conf_to_pass = {p: params[p] for p in params} - # Override account_pool with the generated prefix - conf_to_pass['account_pool'] = final_account_pool_prefix logger.info(f"Triggering dispatcher {j+1}/{len(bunch)} in bunch {i+1} (run {total_triggered + 1}/{total_workers}) (Run ID: {run_id})") logger.debug(f"Full conf for dispatcher run {run_id}: {conf_to_pass}") @@ -259,36 +319,33 @@ with DAG( 'delay_between_bunches_s': Param(DEFAULT_BUNCH_DELAY_S, type="integer", description="Delay in seconds between starting each bunch."), 'skip_if_queue_empty': Param(False, type="boolean", title="[Ignition Control] Skip if Queue Empty", description="If True, the orchestrator will not start any dispatchers if the application's work queue is empty."), + # --- Unified Worker Configuration --- + 'ytdlp_config_json': Param( + json.dumps(DEFAULT_YTDLP_CONFIG, indent=2), + type="string", + title="[Worker Param] Unified yt-dlp JSON Config", + description="A JSON string containing all parameters for both yt-ops-server and the yt-dlp downloaders. This is the primary way to configure workers.", + **{'ui_widget': 'json', 'multi_line': True} + ), + # --- Worker Passthrough Parameters --- - 'on_bannable_failure': Param( - 'proceed_loop_under_manual_inspection', - type="string", - enum=['stop_loop', 'retry_with_new_account', 'retry_without_ban', 'retry_and_ban_account_only', 'retry_on_connection_error', 'proceed_loop_under_manual_inspection', 'stop_loop_on_auth_proceed_on_download_error'], - title="[Worker Param] On Bannable Failure Policy", - description="Policy for a worker when a bannable error occurs. " - "'stop_loop': Ban the account, mark URL as failed, and stop the worker's loop on any failure (auth or download). " - "'retry_with_new_account': Ban the failed account, retry ONCE with a new account. If retry fails, ban the second account and proxy, then stop." - "'retry_on_connection_error': If a connection error (e.g. SOCKS timeout) occurs, retry with a new account but do NOT ban the first account/proxy. If retry fails, stop the loop without banning." - "'proceed_loop_under_manual_inspection': **BEWARE: MANUAL SUPERVISION REQUIRED.** Marks the URL as failed but continues the processing loop. Use this only when you can manually intervene by pausing the dispatcher DAG or creating a lock file (`/opt/airflow/inputfiles/AIRFLOW.PREVENT_URL_PULL.lockfile`) to prevent a runaway failure loop." - "'stop_loop_on_auth_proceed_on_download_error': **(Default)** Stops the loop on an authentication/token error (like 'stop_loop'), but continues the loop on a download/probe error (like 'proceed...')." - ), - 'request_params_json': Param('{}', type="string", title="[Worker Param] Request Params JSON", description="JSON string with per-request parameters to override server defaults. Can be a full JSON object or comma-separated key=value pairs (e.g., 'session_params.location=DE,ytdlp_params.skip_cache=true')."), - 'language_code': Param('en-US', type="string", title="[Worker Param] Language Code", description="The language code (e.g., 'en-US', 'de-DE') to use for the YouTube request headers."), + # --- V2 Profile Management Parameters --- + 'redis_env': Param("sim_auth", type="string", title="[V2 Profiles] Redis Environment", description="The environment for v2 profile management (e.g., 'sim_auth'). Determines the Redis key prefix."), + 'profile_prefix': Param("auth_user", type="string", title="[V2 Profiles] Profile Prefix", description="The prefix for auth profiles that workers should attempt to lock."), + + # --- Worker Passthrough Parameters --- + 'on_bannable_failure': Param('proceed_loop_under_manual_inspection', type="string", title="DEPRECATED: Worker handles failures internally."), 'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string", description="[Worker Param] Airflow Redis connection ID."), - 'clients': Param( - 'tv_simply', - type="string", - title="[Worker Param] Clients", - description="[Worker Param] Comma-separated list of clients for token generation. Full list: web, web_safari, web_embedded, web_music, web_creator, mweb, web_camoufox, web_safari_camoufox, web_embedded_camoufox, web_music_camoufox, web_creator_camoufox, mweb_camoufox, android, android_music, android_creator, android_vr, ios, ios_music, ios_creator, tv, tv_simply, tv_embedded. See DAG documentation for details." - ), - 'account_pool': Param('ytdlp_account', type="string", description="[Worker Param] Account pool prefix or comma-separated list."), - 'account_pool_size': Param(10, type=["integer", "null"], description="[Worker Param] If using a prefix for 'account_pool', this specifies the number of accounts to generate (e.g., 10 for 'prefix_01' through 'prefix_10'). Required when using a prefix."), - 'prepend_client_to_account': Param(True, type="boolean", title="[Worker Param] Prepend Client to Account", description="If True, prepends client and timestamp to account names in prefix mode. Format: prefix_YYYYMMDDHHMMSS_client_XX."), 'service_ip': Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string", description="[Worker Param] IP of the ytdlp-ops-server. Default is from Airflow variable YT_AUTH_SERVICE_IP or hardcoded."), 'service_port': Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer", description="[Worker Param] Port of the Envoy load balancer. Default is from Airflow variable YT_AUTH_SERVICE_PORT or hardcoded."), 'machine_id': Param("ytdlp-ops-airflow-service", type="string", description="[Worker Param] Identifier for the client machine."), - 'assigned_proxy_url': Param(None, type=["string", "null"], title="[Worker Param] Assigned Proxy URL", description="If provided, forces the token service to use this specific proxy for the request."), - 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean", description="[Worker Param] If True and all accounts in a prefix-based pool are exhausted, create a new one automatically."), + + # --- DEPRECATED PARAMS --- + 'account_pool': Param('ytdlp_account', type="string", description="DEPRECATED: Use profile_prefix instead."), + 'account_pool_size': Param(10, type=["integer", "null"], description="DEPRECATED: Pool size is managed in Redis."), + 'prepend_client_to_account': Param(True, type="boolean", description="DEPRECATED"), + 'assigned_proxy_url': Param(None, type=["string", "null"], description="DEPRECATED: Proxy is determined by the locked profile."), + 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean", description="DEPRECATED"), } ) as dag: diff --git a/airflow/dags/ytdlp_ops_v02_orchestrator_dl.py b/airflow/dags/ytdlp_ops_v02_orchestrator_dl.py index 54c7499..e478bbc 100644 --- a/airflow/dags/ytdlp_ops_v02_orchestrator_dl.py +++ b/airflow/dags/ytdlp_ops_v02_orchestrator_dl.py @@ -24,6 +24,12 @@ import random import time import json +# --- Add project root to path to allow for yt-ops-client imports --- +import sys +# The yt-ops-client package is installed in editable mode in /app +if '/app' not in sys.path: + sys.path.insert(0, '/app') + # Import utility functions from utils.redis_utils import _get_redis_client @@ -242,6 +248,11 @@ with DAG( 'delay_between_workers_s': Param(DEFAULT_WORKER_DELAY_S, type="integer", description="Delay in seconds between starting each dispatcher within a bunch."), 'delay_between_bunches_s': Param(DEFAULT_BUNCH_DELAY_S, type="integer", description="Delay in seconds between starting each bunch."), 'skip_if_queue_empty': Param(False, type="boolean", title="[Ignition Control] Skip if Queue Empty", description="If True, the orchestrator will not start any dispatchers if the application's work queue is empty."), + + # --- V2 Profile Management Parameters --- + 'redis_env': Param("sim_download", type="string", title="[V2 Profiles] Redis Environment", description="The environment for v2 profile management (e.g., 'sim_download'). Determines the Redis key prefix."), + 'profile_prefix': Param("download_user", type="string", title="[V2 Profiles] Profile Prefix", description="The prefix for download profiles that workers should attempt to lock."), + 'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string", description="[Worker Param] Airflow Redis connection ID."), 'clients': Param('mweb,web_camoufox,tv', type="string", title="[Worker Param] Clients", description="Comma-separated list of clients for token generation. e.g. mweb,tv,web_camoufox"), @@ -250,16 +261,17 @@ with DAG( 'yt_dlp_test_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Test Mode", description="If True, runs yt-dlp with --test flag (dry run without downloading)."), 'skip_probe': Param(True, type="boolean", title="[Worker Param] Skip Probe", description="If True, skips the ffmpeg probe of downloaded files."), 'yt_dlp_cleanup_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Cleanup Mode", description="If True, creates a .empty file and deletes the original media file after successful download and probe."), - 'fragment_retries': Param(2, type="integer", title="[Worker Param] Fragment Retries", description="Number of retries for a fragment before giving up."), - 'limit_rate': Param('5M', type=["string", "null"], title="[Worker Param] Limit Rate", description="Download speed limit (e.g., 50K, 4.2M)."), - 'socket_timeout': Param(15, type="integer", title="[Worker Param] Socket Timeout", description="Timeout in seconds for socket operations."), - 'min_sleep_interval': Param(5, type="integer", title="[Worker Param] Min Sleep Interval", description="Minimum time to sleep between downloads (seconds)."), - 'max_sleep_interval': Param(10, type="integer", title="[Worker Param] Max Sleep Interval", description="Maximum time to sleep between downloads (seconds)."), 'download_format': Param( 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', type="string", title="[Worker Param] Download Format", - description="Custom yt-dlp format string. Common presets: [1] 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' (Default, best quality MP4). [2] '18-dashy/18,140-dashy/140,133-dashy/134-dashy/136-dashy/137-dashy/250-dashy/298-dashy/299-dashy' (Legacy formats). [3] '299-dashy/298-dashy/250-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy' (High-framerate formats)." + description="Custom yt-dlp format string. Common presets: [1] 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' (Default, best quality MP4). [2] '18,140-dashy/140,133-dashy/134-dashy/136-dashy/137-dashy/298-dashy/299-dashy' (Legacy formats). [3] '299-dashy/298-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy' (High-framerate formats)." + ), + 'pass_without_formats_splitting': Param( + False, + type="boolean", + title="[Worker Param] Pass format string without splitting", + description="If True, passes the entire 'download_format' string to the download tool as-is. This is for complex selectors. Not compatible with 'aria-rpc' downloader." ), 'downloader': Param( 'cli', @@ -272,7 +284,7 @@ with DAG( 'aria_port': Param(6800, type="integer", title="[Worker Param] Aria2c Port", description="For 'aria-rpc' downloader: Port of the aria2c RPC server. Can be set via Airflow Variable 'YTDLP_ARIA_PORT'."), 'aria_secret': Param('SQGCQPLVFQIASMPNPOJYLVGJYLMIDIXDXAIXOTX', type="string", title="[Worker Param] Aria2c Secret", description="For 'aria-rpc' downloader: Secret token. Can be set via Airflow Variable 'YTDLP_ARIA_SECRET'."), 'yt_dlp_extra_args': Param( - '--no-part --restrict-filenames', + '--verbose --no-resize-buffer --buffer-size 4M --fragment-retries 2 --concurrent-fragments 8 --socket-timeout 15 --sleep-interval 5 --max-sleep-interval 10 --no-part --restrict-filenames', type=["string", "null"], title="[Worker Param] Extra yt-dlp arguments", description="Extra command-line arguments for yt-dlp during download." diff --git a/airflow/dags/ytdlp_ops_v02_worker_per_url_auth.py b/airflow/dags/ytdlp_ops_v02_worker_per_url_auth.py index daa81c0..4b6a9f6 100644 --- a/airflow/dags/ytdlp_ops_v02_worker_per_url_auth.py +++ b/airflow/dags/ytdlp_ops_v02_worker_per_url_auth.py @@ -24,7 +24,6 @@ from airflow.operators.dummy import DummyOperator from airflow.utils.dates import days_ago from airflow.utils.task_group import TaskGroup from airflow.api.common.trigger_dag import trigger_dag -import copy from datetime import datetime, timedelta import concurrent.futures import json @@ -50,63 +49,12 @@ from thrift.transport.TTransport import TTransportException # Configure logging logger = logging.getLogger(__name__) - -# --- Client Stats Helper --- - -def _update_client_stats(redis_client, clients_str: str, status: str, url: str, machine_id: str, dag_run_id: str): - """Updates success/failure statistics for a client type in Redis.""" - if not clients_str: - logger.warning("Cannot update client stats: 'clients' string is empty.") - return - - # Assumption: The service tries clients in the order provided. - # We attribute the result to the first client in the list. - primary_client = clients_str.split(',')[0].strip() - if not primary_client: - logger.warning("Cannot update client stats: could not determine primary client.") - return - - stats_key = "client_stats" - - try: - # Using a pipeline with WATCH for safe concurrent updates. - with redis_client.pipeline() as pipe: - pipe.watch(stats_key) - - current_stats_json = redis_client.hget(stats_key, primary_client) - stats = {} - if current_stats_json: - try: - stats = json.loads(current_stats_json) - except json.JSONDecodeError: - logger.warning(f"Could not parse existing stats for client '{primary_client}'. Resetting stats.") - stats = {} - - stats.setdefault('success_count', 0) - stats.setdefault('failure_count', 0) - - details = { - 'timestamp': time.time(), 'url': url, - 'machine_id': machine_id, 'dag_run_id': dag_run_id, - } - - if status == 'success': - stats['success_count'] += 1 - stats['latest_success'] = details - elif status == 'failure': - stats['failure_count'] += 1 - stats['latest_failure'] = details - - pipe.multi() - pipe.hset(stats_key, primary_client, json.dumps(stats)) - pipe.execute() - - logger.info(f"Successfully updated '{status}' stats for client '{primary_client}'.") - - except redis.exceptions.WatchError: - logger.warning(f"WatchError updating stats for client '{primary_client}'. Another process updated it. Skipping this update.") - except Exception as e: - logger.error(f"Failed to update client stats for '{primary_client}': {e}", exc_info=True) +# ytops_client imports for v2 profile management +try: + from ytops_client.profile_manager_tool import ProfileManager, format_duration, format_timestamp +except ImportError as e: + logger.critical(f"Could not import ytops_client modules: {e}. Ensure yt-ops-client package is installed correctly in Airflow's environment.") + raise # Default settings from Airflow Variables or hardcoded fallbacks @@ -192,60 +140,6 @@ def _extract_video_id(url): return match.group(1) return None -def _get_account_pool(params: dict) -> list: - """ - Gets the list of accounts to use for processing, filtering out banned/resting accounts. - Supports explicit list, prefix-based generation, and single account modes. - """ - account_pool_str = params.get('account_pool', 'default_account') - accounts = [] - is_prefix_mode = False - - if ',' in account_pool_str: - accounts = [acc.strip() for acc in account_pool_str.split(',') if acc.strip()] - else: - prefix = account_pool_str - pool_size_param = params.get('account_pool_size') - if pool_size_param is not None: - is_prefix_mode = True - pool_size = int(pool_size_param) - - # The orchestrator now generates the full prefix if prepend_client_to_account is True. - # The worker just appends the numbers. - accounts = [f"{prefix}_{i:02d}" for i in range(1, pool_size + 1)] - else: - accounts = [prefix] - - if not accounts: - raise AirflowException("Initial account pool is empty.") - - redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID) - try: - redis_client = _get_redis_client(redis_conn_id) - active_accounts = [] - for account in accounts: - status_bytes = redis_client.hget(f"account_status:{account}", "status") - status = status_bytes.decode('utf-8') if status_bytes else "ACTIVE" - if status not in ['BANNED'] and 'RESTING' not in status: - active_accounts.append(account) - - if not active_accounts and accounts: - auto_create = params.get('auto_create_new_accounts_on_exhaustion', False) - if auto_create and is_prefix_mode: - new_account_id = f"{account_pool_str}-auto-{str(uuid.uuid4())[:8]}" - logger.warning(f"Account pool exhausted. Auto-creating new account: '{new_account_id}'") - active_accounts.append(new_account_id) - else: - raise AirflowException("All accounts in the configured pool are currently exhausted.") - accounts = active_accounts - except Exception as e: - logger.error(f"Could not filter accounts from Redis. Using unfiltered pool. Error: {e}", exc_info=True) - - if not accounts: - raise AirflowException("Account pool is empty after filtering.") - - logger.info(f"Final active account pool with {len(accounts)} accounts.") - return accounts @task def list_available_formats(token_data: dict, **context): @@ -306,15 +200,61 @@ def list_available_formats(token_data: dict, **context): # TASK DEFINITIONS (TaskFlow API) # ============================================================================= +def _resolve_formats(info_json_path: str, format_selector: str, logger) -> list[str]: + """Uses yt-dlp to resolve a format selector into a list of specific format IDs.""" + import subprocess + import shlex + + if not format_selector: + return [] + + try: + cmd = [ + 'yt-dlp', '--print', 'format_id', + '-f', format_selector, + '--load-info-json', info_json_path, + ] + + copy_paste_cmd = ' '.join(shlex.quote(arg) for arg in cmd) + logger.info(f"Resolving format selector '{format_selector}' with command: {copy_paste_cmd}") + + process = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + + if process.stderr: + logger.info(f"yt-dlp format resolver STDERR:\n{process.stderr}") + + if process.returncode != 0: + logger.error(f"yt-dlp format resolver failed with exit code {process.returncode}") + return [] + + output_ids = [fid for fid in process.stdout.strip().split('\n') if fid] + final_ids = [] + for fid in output_ids: + final_ids.extend(fid.split('+')) + + logger.info(f"Resolved selector '{format_selector}' to {len(final_ids)} format(s): {final_ids}") + return final_ids + + except Exception as e: + logger.error(f"An error occurred while resolving format selector: {e}", exc_info=True) + return [] + + @task -def get_url_and_assign_account(**context): +def get_url_and_lock_profile(**context): """ - Gets the URL to process from the DAG run configuration and assigns an active account. + Gets the URL to process, then locks an available auth profile from the Redis pool. This is the first task in the pinned-worker DAG. """ params = context['params'] ti = context['task_instance'] + # Log the active policies + auth_policy = params.get('on_bannable_failure', 'not_set') + logger.info(f"--- Worker Policies ---") + logger.info(f" Auth Failure Policy: {auth_policy}") + logger.info(f"-----------------------") + # --- Worker Pinning Verification --- # This is a safeguard against a known Airflow issue where clearing a task # can cause the task_instance_mutation_hook to be skipped, breaking pinning. @@ -384,603 +324,285 @@ def get_url_and_assign_account(**context): except Exception as e: logger.error(f"Could not mark URL as in-progress in Redis: {e}", exc_info=True) - # Account assignment logic is the same as before. - account_id = random.choice(_get_account_pool(params)) - logger.info(f"Selected account '{account_id}' for this run.") + # V2 Profile Locking + redis_conn_id = params['redis_conn_id'] + redis_env = params['redis_env'] + profile_prefix = params['profile_prefix'] + + try: + redis_hook = _get_redis_client(redis_conn_id, return_hook=True) + key_prefix = f"{redis_env}_profile_mgmt_" + pm = ProfileManager(redis_hook=redis_hook, key_prefix=key_prefix) + logger.info(f"Initialized ProfileManager for env '{redis_env}' (Redis key prefix: '{key_prefix}')") + except Exception as e: + raise AirflowException(f"Failed to initialize ProfileManager: {e}") + + owner_id = f"airflow_auth_worker_{context['dag_run'].run_id}" + locked_profile = None + logger.info(f"Attempting to lock a profile with owner '{owner_id}' and prefix '{profile_prefix}'...") + + lock_attempts = 0 + while not locked_profile: + locked_profile = pm.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + if not locked_profile: + logger.info("No auth profiles available to lock. Waiting for 15 seconds...") + time.sleep(15) + lock_attempts += 1 + if lock_attempts > 20: # 5 minutes timeout + raise AirflowException("Timed out waiting to lock an auth profile.") + + logger.info(f"Successfully locked profile: {locked_profile['name']}") return { 'url_to_process': url_to_process, - 'account_id': account_id, - 'accounts_tried': [account_id], + 'locked_profile': locked_profile, } @task def get_token(initial_data: dict, **context): - """Makes a single attempt to get a token by calling the ytops-client get-info tool.""" - import subprocess - import shlex - + """Makes a single attempt to get a token by calling the Thrift service directly.""" ti = context['task_instance'] params = context['params'] - account_id = initial_data['account_id'] + locked_profile = initial_data['locked_profile'] + account_id = locked_profile['name'] + assigned_proxy_url = locked_profile['proxy'] url = initial_data['url_to_process'] info_json_dir = os.path.join(Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles'), 'videos', 'in-progress') host, port = params['service_ip'], int(params['service_port']) machine_id = params.get('machine_id') or socket.gethostname() - clients = params.get('clients') - request_params_json = params.get('request_params_json') - language_code = params.get('language_code') - assigned_proxy_url = params.get('assigned_proxy_url') - if language_code: - try: - params_dict = json.loads(request_params_json) - if not params_dict: - params_dict = copy.deepcopy(DEFAULT_REQUEST_PARAMS) - - logger.info(f"Setting language for request: {language_code}") - if 'session_params' not in params_dict: - params_dict['session_params'] = {} - params_dict['session_params']['lang'] = language_code - request_params_json = json.dumps(params_dict) - except (json.JSONDecodeError, TypeError): - logger.warning("Could not parse request_params_json as JSON. Treating as key=value pairs and appending language code.") - lang_kv = f"session_params.lang={language_code}" - if request_params_json: - request_params_json += f",{lang_kv}" - else: - request_params_json = lang_kv + # The unified JSON config is now the primary source of parameters. + request_params_json = params.get('ytdlp_config_json', '{}') + clients = None # This will be read from the JSON config on the server side. video_id = _extract_video_id(url) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") job_dir_name = f"{timestamp}-{video_id or 'unknown'}" job_dir_path = os.path.join(info_json_dir, job_dir_name) os.makedirs(job_dir_path, exist_ok=True) - info_json_filename = f"info_{video_id or 'unknown'}_{account_id}_{timestamp}.json" - info_json_path = os.path.join(job_dir_path, info_json_filename) + info_json_path = os.path.join(job_dir_path, f"info_{video_id or 'unknown'}_{account_id}_{timestamp}.json") - cmd = [ - 'ytops-client', 'get-info', - '--host', host, - '--port', str(port), - '--profile', account_id, - '--output', info_json_path, - '--print-proxy', - '--verbose', - '--log-return', - ] + # Save the received JSON config to the job directory for the download worker. + ytdlp_config_path = os.path.join(job_dir_path, 'ytdlp.json') + try: + with open(ytdlp_config_path, 'w', encoding='utf-8') as f: + # Pretty-print the JSON for readability + config_data = json.loads(request_params_json) + json.dump(config_data, f, indent=2) + logger.info(f"Saved ytdlp config to {ytdlp_config_path}") + except (IOError, json.JSONDecodeError) as e: + logger.error(f"Failed to save ytdlp.json config: {e}") + # Continue anyway, but download worker may fail. + ytdlp_config_path = None - if clients: - cmd.extend(['--client', clients]) - if machine_id: - cmd.extend(['--machine-id', machine_id]) - if request_params_json and request_params_json != '{}': - cmd.extend(['--request-params-json', request_params_json]) - if assigned_proxy_url: - cmd.extend(['--assigned-proxy-url', assigned_proxy_url]) - - cmd.append(url) - logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' (Clients: {clients}) ---") - copy_paste_cmd = ' '.join(shlex.quote(arg) for arg in cmd) - logger.info(f"Executing command: {copy_paste_cmd}") + client, transport = None, None + try: + timeout = int(params.get('timeout', DEFAULT_TIMEOUT)) + client, transport = _get_thrift_client(host, port, timeout) - process = subprocess.run(cmd, capture_output=True, text=True, timeout=int(params.get('timeout', DEFAULT_TIMEOUT))) + airflow_log_context = AirflowLogContext( + taskId=ti.task_id, + runId=ti.run_id, + tryNumber=ti.try_number + ) - if process.stdout: - logger.info(f"ytops-client STDOUT:\n{process.stdout}") - if process.stderr: - logger.info(f"ytops-client STDERR:\n{process.stderr}") - - if process.returncode != 0: - error_message = "ytops-client failed. See logs for details." - # Try to find a more specific error message from the Thrift client's output - thrift_error_match = re.search(r'A Thrift error occurred: (.*)', process.stderr) - if thrift_error_match: - error_message = thrift_error_match.group(1).strip() - else: # Fallback to old line-by-line parsing - for line in reversed(process.stderr.strip().split('\n')): - if 'ERROR' in line or 'Thrift error' in line or 'Connection to server failed' in line: - error_message = line.strip() - break + logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' (Clients: {clients}, Proxy: {assigned_proxy_url or 'any'}) ---") - # Determine error code for branching logic - error_code = 'GET_INFO_CLIENT_FAIL' - stderr_lower = process.stderr.lower() + token_data = client.getOrRefreshToken( + accountId=account_id, + updateType=TokenUpdateMode.AUTO, + url=url, + clients=clients, + machineId=machine_id, + airflowLogContext=airflow_log_context, + requestParamsJson=request_params_json, + assignedProxyUrl=assigned_proxy_url + ) - # These patterns should match the error codes from PBUserException and others - error_patterns = { - "BOT_DETECTED": ["bot_detected"], - "BOT_DETECTION_SIGN_IN_REQUIRED": ["bot_detection_sign_in_required"], - "TRANSPORT_ERROR": ["connection to server failed"], - "PRIVATE_VIDEO": ["private video"], - "COPYRIGHT_REMOVAL": ["copyright"], - "GEO_RESTRICTED": ["in your country"], - "VIDEO_REMOVED": ["video has been removed"], - "VIDEO_UNAVAILABLE": ["video unavailable"], - "MEMBERS_ONLY": ["members-only"], - "AGE_GATED_SIGN_IN": ["sign in to confirm your age"], - "VIDEO_PROCESSING": ["processing this video"], + # --- Log server-side details for debugging --- + if hasattr(token_data, 'serverVersionInfo') and token_data.serverVersionInfo: + logger.info(f"--- Server Version Info ---\n{token_data.serverVersionInfo}") + + if hasattr(token_data, 'requestSummary') and token_data.requestSummary: + try: + summary_data = json.loads(token_data.requestSummary) + summary_text = summary_data.get('summary', 'Not available.') + prefetch_log = summary_data.get('prefetch_log', 'Not available.') + nodejs_log = summary_data.get('nodejs_log', 'Not available.') + ytdlp_log = summary_data.get('ytdlp_log', 'Not available.') + + logger.info(f"--- Request Summary ---\n{summary_text}") + logger.info(f"--- Prefetch Log ---\n{prefetch_log}") + logger.info(f"--- Node.js Log ---\n{nodejs_log}") + logger.info(f"--- yt-dlp Log ---\n{ytdlp_log}") + except (json.JSONDecodeError, AttributeError): + logger.info(f"--- Raw Request Summary (could not parse JSON) ---\n{token_data.requestSummary}") + + if hasattr(token_data, 'communicationLogPaths') and token_data.communicationLogPaths: + logger.info("--- Communication Log Paths on Server ---") + for log_path in token_data.communicationLogPaths: + logger.info(f" - {log_path}") + # --- End of server-side logging --- + + if not token_data or not token_data.infoJson: + raise AirflowException("Thrift service did not return valid info.json data.") + + # Save info.json to file + with open(info_json_path, 'w', encoding='utf-8') as f: + f.write(token_data.infoJson) + + proxy = token_data.socks + + # Rename file with proxy + final_info_json_path = info_json_path + if proxy: + sanitized_proxy = proxy.replace('://', '---') + new_filename = f"info_{video_id or 'unknown'}_{account_id}_{timestamp}_proxy_{sanitized_proxy}.json" + new_path = os.path.join(job_dir_path, new_filename) + try: + os.rename(info_json_path, new_path) + final_info_json_path = new_path + logger.info(f"Renamed info.json to include proxy: {new_path}") + except OSError as e: + logger.error(f"Failed to rename info.json to include proxy: {e}. Using original path.") + + return { + 'info_json_path': final_info_json_path, + 'job_dir_path': job_dir_path, + 'socks_proxy': proxy, + 'ytdlp_command': None, + 'successful_account_id': account_id, + 'original_url': url, + 'ytdlp_config_path': ytdlp_config_path, + 'ytdlp_config_json': request_params_json, + # Pass locked profile through for unlock/activity tasks + 'locked_profile': locked_profile, } - for code, patterns in error_patterns.items(): - if any(p in stderr_lower for p in patterns): - error_code = code - break # Found a match, stop searching + except (PBServiceException, PBUserException) as e: + error_message = e.message or "Unknown Thrift error" + error_code = getattr(e, 'errorCode', 'THRIFT_ERROR') + # If a "Video unavailable" error mentions rate-limiting, it's a form of bot detection. + if error_code == 'VIDEO_UNAVAILABLE' and 'rate-limited' in error_message.lower(): + logger.warning("Re-classifying rate-limit-related 'VIDEO_UNAVAILABLE' error as 'BOT_DETECTED'.") + error_code = 'BOT_DETECTED' + + unrecoverable_video_errors = [ + "AGE_GATED_SIGN_IN", "MEMBERS_ONLY", "VIDEO_PROCESSING", "COPYRIGHT_REMOVAL", + "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED" + ] + + if error_code in unrecoverable_video_errors: + error_details = { + 'error_message': error_message, + 'error_code': error_code, + 'proxy_url': None + } + ti.xcom_push(key='error_details', value=error_details) + logger.warning(f"Unrecoverable video error '{error_code}' - {error_message}. Marking for skip without failing the task.") + return {'status': 'unrecoverable_video_error'} + else: + logger.error(f"Thrift error getting token: {error_code} - {error_message}") + + error_details = { + 'error_message': error_message, + 'error_code': error_code, + 'proxy_url': None + } + ti.xcom_push(key='error_details', value=error_details) + raise AirflowException(f"ytops-client get-info failed: {error_message}") + except TTransportException as e: + logger.error(f"Thrift transport error: {e}", exc_info=True) error_details = { - 'error_message': error_message, - 'error_code': error_code, + 'error_message': f"Thrift transport error: {e}", + 'error_code': 'TRANSPORT_ERROR', 'proxy_url': None } ti.xcom_push(key='error_details', value=error_details) - raise AirflowException(f"ytops-client get-info failed: {error_message}") - - proxy = None - proxy_match = re.search(r"Proxy used: (.*)", process.stderr) - if proxy_match: - proxy = proxy_match.group(1).strip() - - # Rename the info.json to include the proxy for the download worker - final_info_json_path = info_json_path - if proxy: - # Sanitize for filename: replace '://' which is invalid in paths. Colons are usually fine. - sanitized_proxy = proxy.replace('://', '---') - - new_filename = f"info_{video_id or 'unknown'}_{account_id}_{timestamp}_proxy_{sanitized_proxy}.json" - new_path = os.path.join(job_dir_path, new_filename) - try: - os.rename(info_json_path, new_path) - final_info_json_path = new_path - logger.info(f"Renamed info.json to include proxy: {new_path}") - except OSError as e: - logger.error(f"Failed to rename info.json to include proxy: {e}. Using original path.") - - return { - 'info_json_path': final_info_json_path, - 'job_dir_path': job_dir_path, - 'socks_proxy': proxy, - 'ytdlp_command': None, - 'successful_account_id': account_id, - 'original_url': url, - 'clients': clients, - } - -@task.branch -def handle_bannable_error_branch(task_id_to_check: str, **context): - """ - Inspects a failed task and routes to retry logic if the error is retryable. - Routes to a fatal error handler for non-retryable infrastructure issues. - """ - ti = context['task_instance'] - params = context['params'] - error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details') - if not error_details: - logger.error(f"Task {task_id_to_check} failed without error details. Marking as fatal.") - return 'handle_fatal_error' - - error_message = error_details.get('error_message', '').strip() - error_code = error_details.get('error_code', '').strip() - policy = params.get('on_bannable_failure', 'retry_with_new_account') - - # Unrecoverable video errors that should not be retried or treated as system failures. - unrecoverable_video_errors = [ - "AGE_GATED_SIGN_IN", "MEMBERS_ONLY", "VIDEO_PROCESSING", "COPYRIGHT_REMOVAL", - "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED" - ] - - if error_code in unrecoverable_video_errors: - logger.warning(f"Unrecoverable video error '{error_code}' detected for '{task_id_to_check}'. This is a content issue, not a system failure.") - return 'handle_unrecoverable_video_error' - - # Fatal Thrift connection errors that should stop all processing. - if error_code == 'TRANSPORT_ERROR': - logger.error(f"Fatal Thrift connection error from '{task_id_to_check}'. Stopping processing.") - return 'handle_fatal_error' - - # Service-side connection errors that are potentially retryable. - connection_errors = ['SOCKS5_CONNECTION_FAILED', 'SOCKET_TIMEOUT', 'CAMOUFOX_TIMEOUT'] - if error_code in connection_errors: - logger.info(f"Handling connection error '{error_code}' from '{task_id_to_check}'. Policy: '{policy}'") - if policy == 'stop_loop': - logger.warning(f"Connection error with 'stop_loop' policy. Marking as fatal.") - return 'handle_fatal_error' - else: - logger.info("Retrying with a new account without banning.") - return 'assign_new_account_for_direct_retry' - - # Bannable errors (e.g., bot detection) that can be retried with a new account. - is_bannable = error_code in ["BOT_DETECTED", "BOT_DETECTION_SIGN_IN_REQUIRED"] - logger.info(f"Handling failure from '{task_id_to_check}'. Error code: '{error_code}', Policy: '{policy}'") - if is_bannable: - if policy in ['retry_with_new_account', 'retry_and_ban_account_only']: - return 'ban_account_and_prepare_for_retry' - if policy in ['retry_on_connection_error', 'retry_without_ban']: - return 'assign_new_account_for_direct_retry' - if policy in ['stop_loop', 'stop_loop_on_auth_proceed_on_download_error']: - return 'ban_and_report_immediately' - if policy == 'proceed_loop_under_manual_inspection': - logger.warning(f"Bannable error with 'proceed_loop_under_manual_inspection' policy. Reporting failure and continuing loop. MANUAL INTERVENTION IS LIKELY REQUIRED.") - return 'report_bannable_and_continue' - - # Any other error is considered fatal for this run. - logger.error(f"Unhandled or non-retryable error '{error_code}' from '{task_id_to_check}'. Marking as fatal.") - return 'handle_fatal_error' - -@task_group(group_id='ban_and_retry_logic') -def ban_and_retry_logic(initial_data: dict): - """ - Task group that checks for sliding window failures before banning an account. - If the account meets ban criteria, it's banned. Otherwise, the ban is skipped - but the retry proceeds. - """ - - @task.branch - def check_sliding_window_for_ban(data: dict, **context): - """ - Checks Redis for recent failures. If thresholds are met, proceeds to ban. - Otherwise, proceeds to a dummy task to allow retry without ban. - """ - params = context['params'] - account_id = data['account_id'] - redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID) - - # These thresholds should ideally be Airflow Variables to be configurable - failure_window_seconds = 3600 # 1 hour - failure_threshold_count = 5 - failure_threshold_unique_proxies = 3 - - try: - redis_client = _get_redis_client(redis_conn_id) - failure_key = f"account_failures:{account_id}" - now = time.time() - window_start = now - failure_window_seconds - - # 1. Remove old failures and get recent ones - redis_client.zremrangebyscore(failure_key, '-inf', window_start) - recent_failures = redis_client.zrange(failure_key, 0, -1) - - if len(recent_failures) >= failure_threshold_count: - # Decode from bytes to string for processing - recent_failures_str = [f.decode('utf-8') for f in recent_failures] - # Failure format is "context:job_id:timestamp" - unique_proxies = {f.split(':')[0] for f in recent_failures_str} - - if len(unique_proxies) >= failure_threshold_unique_proxies: - logger.warning( - f"Account {account_id} has failed {len(recent_failures)} times " - f"with {len(unique_proxies)} unique contexts in the last hour. Proceeding to ban." - ) - return 'ban_account_task' - else: - logger.info( - f"Account {account_id} has {len(recent_failures)} failures, but only " - f"from {len(unique_proxies)} unique contexts (threshold is {failure_threshold_unique_proxies}). Skipping ban." - ) - else: - logger.info(f"Account {account_id} has {len(recent_failures)} failures (threshold is {failure_threshold_count}). Skipping ban.") - - except Exception as e: - logger.error(f"Error during sliding window check for account {account_id}: {e}. Skipping ban as a precaution.", exc_info=True) - - return 'skip_ban_task' - - @task(task_id='ban_account_task') - def ban_account_task(data: dict, **context): - """Wrapper task to call the main ban_account function.""" - ban_account(initial_data=data, reason="Banned by Airflow worker after sliding window check", **context) - - @task(task_id='skip_ban_task') - def skip_ban_task(): - """Dummy task to represent the 'skip ban' path.""" - pass - - check_task = check_sliding_window_for_ban(data=initial_data) - ban_task_in_group = ban_account_task(data=initial_data) - skip_task = skip_ban_task() - - check_task >> [ban_task_in_group, skip_task] - - -@task -def ban_account(initial_data: dict, reason: str, **context): - """Bans a single account via the Thrift service.""" - params = context['params'] - account_id = initial_data['account_id'] - client, transport = None, None - try: - host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT)) - client, transport = _get_thrift_client(host, port, timeout) - logger.warning(f"Banning account '{account_id}'. Reason: {reason}") - client.banAccount(accountId=account_id, reason=reason) - except Exception as e: - logger.error(f"Failed to issue ban for account '{account_id}': {e}", exc_info=True) + raise AirflowException(f"Thrift transport error: {e}") finally: if transport and transport.isOpen(): transport.close() -@task -def assign_new_account_for_direct_retry(initial_data: dict, **context): - """Selects a new, unused account for a direct retry (e.g., after connection error).""" - params = context['params'] - accounts_tried = initial_data['accounts_tried'] - account_pool = _get_account_pool(params) - available_for_retry = [acc for acc in account_pool if acc not in accounts_tried] - if not available_for_retry: - raise AirflowException("No other accounts available in the pool for a retry.") - - new_account_id = random.choice(available_for_retry) - accounts_tried.append(new_account_id) - logger.info(f"Selected new account for retry: '{new_account_id}'") - - # Return updated initial_data with new account - return { - 'url_to_process': initial_data['url_to_process'], - 'account_id': new_account_id, - 'accounts_tried': accounts_tried, - } @task -def assign_new_account_after_ban_check(initial_data: dict, **context): - """Selects a new, unused account for the retry attempt after a ban check.""" - params = context['params'] - accounts_tried = initial_data['accounts_tried'] - account_pool = _get_account_pool(params) - available_for_retry = [acc for acc in account_pool if acc not in accounts_tried] - if not available_for_retry: - raise AirflowException("No other accounts available in the pool for a retry.") - - new_account_id = random.choice(available_for_retry) - accounts_tried.append(new_account_id) - logger.info(f"Selected new account for retry: '{new_account_id}'") - - # Return updated initial_data with new account - return { - 'url_to_process': initial_data['url_to_process'], - 'account_id': new_account_id, - 'accounts_tried': accounts_tried, - } - -@task -def ban_and_report_immediately(initial_data: dict, reason: str, **context): - """Bans an account and prepares for failure reporting and continuing the loop.""" - ban_account(initial_data, reason, **context) - logger.info(f"Account '{initial_data.get('account_id')}' banned. Proceeding to report failure.") - # This task is a leaf in its path and is followed by the failure reporting task. - return initial_data # Pass data along if needed by reporting - -@task -def push_auth_success_to_redis(initial_data: dict, token_data: dict, **context): +def generate_and_push_download_tasks(token_data: dict, **context): """ - On successful token acquisition, pushes the complete token data to the - Redis queue for the download worker and records the auth success. + On success, resolves the format selector into individual format IDs and pushes + granular download tasks to the `queue_dl_format_tasks` Redis list. + Also records the successful auth activity for the profile. """ params = context['params'] - url = initial_data['url_to_process'] + url = token_data['original_url'] + info_json_path = token_data['info_json_path'] + locked_profile = token_data['locked_profile'] + + # Resolve format selector from the JSON config + try: + ytdlp_config = json.loads(token_data.get('ytdlp_config_json', '{}')) + download_format_selector = ytdlp_config.get('download_format', 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best') + # This profile prefix is for the *download* worker that will pick up the task + download_profile_prefix = ytdlp_config.get('download_profile_prefix', 'download_user') + except (json.JSONDecodeError, KeyError): + logger.error("Could not parse download_format from ytdlp_config_json. Falling back to default.") + download_format_selector = 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' + download_profile_prefix = 'download_user' + + resolved_formats = _resolve_formats(info_json_path, download_format_selector, logger) + if not resolved_formats: + raise AirflowException(f"Format selector '{download_format_selector}' resolved to no formats for {url}.") + + tasks = [] + for format_id in resolved_formats: + task_payload = { + "info_json_path": info_json_path, + "format_id": format_id, + "profile_prefix": download_profile_prefix, + "original_url": url, + "dag_run_id": context['dag_run'].run_id, + } + tasks.append(json.dumps(task_payload)) - # The download inbox queue is derived from the auth queue name. - dl_inbox_queue = f"{params['queue_name'].replace('_auth', '_dl')}_inbox" + dl_task_queue = "queue_dl_format_tasks" auth_result_queue = f"{params['queue_name']}_result" progress_queue = f"{params['queue_name']}_progress" - - client = _get_redis_client(params['redis_conn_id']) - - payload = { - 'timestamp': time.time(), - 'dag_run_id': context['dag_run'].run_id, - **token_data - } result_data = { 'status': 'success', 'end_time': time.time(), 'url': url, 'dag_run_id': context['dag_run'].run_id, - 'token_data': token_data + 'token_data': {k: v for k, v in token_data.items() if k != 'locked_profile'} # Don't store profile in result } - with client.pipeline() as pipe: - pipe.lpush(dl_inbox_queue, json.dumps(payload)) - pipe.hset(auth_result_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) - pipe.execute() - - logger.info(f"Pushed successful auth data for URL '{url}' to '{dl_inbox_queue}'.") - logger.info(f"Stored success result for auth on URL '{url}' in '{auth_result_queue}'.") - -@task -def handle_unrecoverable_video_error(**context): - """ - Handles errors for videos that are unavailable (private, removed, etc.). - These are not system failures, so the URL is logged to a 'skipped' queue - and the processing loop continues without marking the run as failed. - """ - params = context['params'] - ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') - - # Collect error details from the failed get_token task - error_details = {} - first_token_error = ti.xcom_pull(task_ids='initial_attempt.get_token', key='error_details') - retry_token_error = ti.xcom_pull(task_ids='retry_logic.retry_get_token', key='error_details') - - if retry_token_error: - error_details = retry_token_error - elif first_token_error: - error_details = first_token_error - - error_code = error_details.get('error_code', 'UNKNOWN_VIDEO_ERROR') - error_message = error_details.get('error_message', 'Video is unavailable for an unknown reason.') - - logger.warning(f"Skipping URL '{url}' due to unrecoverable video error: {error_code} - {error_message}") - - result_data = { - 'status': 'skipped', - 'end_time': time.time(), - 'url': url, - 'dag_run_id': context['dag_run'].run_id, - 'reason': error_code, - 'details': error_message, - 'error_details': error_details - } - try: - client = _get_redis_client(params['redis_conn_id']) - - # New queue for skipped videos - skipped_queue = f"{params['queue_name']}_skipped" - progress_queue = f"{params['queue_name']}_progress" - - with client.pipeline() as pipe: - pipe.hset(skipped_queue, url, json.dumps(result_data)) + redis_client = _get_redis_client(params['redis_conn_id']) + with redis_client.pipeline() as pipe: + pipe.rpush(dl_task_queue, *tasks) + pipe.hset(auth_result_queue, url, json.dumps(result_data)) pipe.hdel(progress_queue, url) pipe.execute() - - logger.info(f"Stored skipped result for URL '{url}' in '{skipped_queue}' and removed from progress queue.") - except Exception as e: - logger.error(f"Could not report skipped video to Redis: {e}", exc_info=True) - - -@task(trigger_rule='one_failed') -def report_failure_and_continue(**context): - """ - Handles a failed URL processing attempt by recording a detailed error report to Redis. - This is a common endpoint for various failure paths that should not stop the overall dispatcher loop. - """ - params = context['params'] - ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') - - # Collect error details from XCom - error_details = {} - - # Check for error details from get_token tasks - first_token_task_id = 'initial_attempt.get_token' - retry_token_task_id = 'retry_logic.retry_get_token' - - first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details') - retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details') - - # Use the most recent error details - if retry_token_error: - error_details = retry_token_error - elif first_token_error: - error_details = first_token_error - else: - # Check for other possible error sources - # This is a simplified approach - in a real implementation you might want to - # check more task IDs or use a more sophisticated error collection mechanism - pass - - logger.error(f"A failure occurred while processing URL '{url}'. Reporting to Redis.") - - result_data = { - 'status': 'failed', - 'end_time': time.time(), - 'url': url, - 'dag_run_id': context['dag_run'].run_id, - 'error_details': error_details - } - - try: - client = _get_redis_client(params['redis_conn_id']) - - # Update client-specific stats - try: - machine_id = params.get('machine_id') or socket.gethostname() - _update_client_stats(client, params.get('clients', ''), 'failure', url, machine_id, context['dag_run'].run_id) - except Exception as e: - logger.error(f"Could not update client stats on failure: {e}", exc_info=True) - - result_queue = f"{params['queue_name']}_result" - fail_queue = f"{params['queue_name']}_fail" - progress_queue = f"{params['queue_name']}_progress" - - with client.pipeline() as pipe: - pipe.hset(result_queue, url, json.dumps(result_data)) - pipe.hset(fail_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) - pipe.execute() - - logger.info(f"Stored failure result for URL '{url}' in '{result_queue}' and '{fail_queue}' and removed from progress queue.") + logger.info(f"Pushed {len(tasks)} granular download task(s) for URL '{url}' to '{dl_task_queue}'.") + logger.info(f"Stored success result for auth on URL '{url}' in '{auth_result_queue}'.") except Exception as e: - logger.error(f"Could not report failure to Redis: {e}", exc_info=True) + logger.error(f"Failed to push download tasks to Redis: {e}", exc_info=True) + raise AirflowException("Failed to push tasks to Redis.") + # Return the original token_data (including locked_profile) for the unlock task + return token_data -@task(trigger_rule='one_failed') -def handle_fatal_error(**context): - """ - Handles fatal, non-retryable errors (e.g., infrastructure issues). - This task reports the failure to Redis before failing the DAG run to ensure - failed URLs are queued for later reprocessing, then stops the processing loop. - """ - params = context['params'] - ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') - - # Collect error details - error_details = {} - first_token_task_id = 'initial_attempt.get_token' - retry_token_task_id = 'retry_logic.retry_get_token' - - first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details') - retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details') - - # Use the most recent error details - if retry_token_error: - error_details = retry_token_error - elif first_token_error: - error_details = first_token_error - - logger.error(f"A fatal, non-retryable error occurred for URL '{url}'. See previous task logs for details.") - - # Report failure to Redis so the URL can be reprocessed later - try: - client = _get_redis_client(params['redis_conn_id']) - - # Update client-specific stats - try: - machine_id = params.get('machine_id') or socket.gethostname() - _update_client_stats(client, params.get('clients', ''), 'failure', url, machine_id, context['dag_run'].run_id) - except Exception as e: - logger.error(f"Could not update client stats on fatal error: {e}", exc_info=True) - - result_data = { - 'status': 'failed', - 'end_time': time.time(), - 'url': url, - 'dag_run_id': context['dag_run'].run_id, - 'error': 'fatal_error', - 'error_message': 'Fatal non-retryable error occurred', - 'error_details': error_details - } - result_queue = f"{params['queue_name']}_result" - fail_queue = f"{params['queue_name']}_fail" - - progress_queue = f"{params['queue_name']}_progress" - - with client.pipeline() as pipe: - pipe.hset(result_queue, url, json.dumps(result_data)) - pipe.hset(fail_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) - pipe.execute() - - logger.info(f"Stored fatal error result for URL '{url}' in '{result_queue}' and '{fail_queue}' for later reprocessing.") - except Exception as e: - logger.error(f"Could not report fatal error to Redis: {e}", exc_info=True) - - # Fail the DAG run to prevent automatic continuation of the processing loop - raise AirflowException("Failing DAG due to fatal error. The dispatcher loop will stop.") @task(trigger_rule='one_success') -def continue_processing_loop(**context): +def continue_processing_loop(token_data: dict | None = None, **context): """ - After a successful run, triggers a new dispatcher to continue the processing loop, - effectively asking for the next URL to be processed. + After a run, triggers a new dispatcher to continue the processing loop, + passing along the account/proxy to make them sticky if available. """ params = context['params'] dag_run = context['dag_run'] @@ -998,18 +620,29 @@ def continue_processing_loop(**context): return # Create a new unique run_id for the dispatcher. - # Using a timestamp and UUID ensures the ID is unique and does not grow in length over time, - # preventing database errors. new_dispatcher_run_id = f"retriggered_by_worker_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" # Pass all original parameters from the orchestrator through to the new dispatcher run. conf_to_pass = {k: v for k, v in params.items() if v is not None} + conf_to_pass['worker_index'] = params.get('worker_index') - # The new dispatcher will pull its own URL and determine its own queue, so we don't pass these. + if token_data: + # On success path, make the account and proxy "sticky" for the next run. + conf_to_pass['account_id'] = token_data.get('successful_account_id') + conf_to_pass['assigned_proxy_url'] = token_data.get('socks_proxy') + logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop with sticky account/proxy.") + logger.info(f" - Sticky Account: {conf_to_pass.get('account_id')}") + logger.info(f" - Sticky Proxy: {conf_to_pass.get('assigned_proxy_url')}") + else: + # On failure/skip paths, no token_data is passed. Clear sticky params to allow re-selection. + conf_to_pass.pop('account_id', None) + conf_to_pass.pop('assigned_proxy_url', None) + logger.info(f"Worker finished on a non-success path. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop without sticky account/proxy.") + + # The new dispatcher will pull its own URL and determine its own queue. conf_to_pass.pop('url_to_process', None) conf_to_pass.pop('worker_queue', None) - logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop.") trigger_dag( dag_id=dispatcher_dag_id, run_id=new_dispatcher_run_id, @@ -1018,136 +651,53 @@ def continue_processing_loop(**context): ) -@task.branch(trigger_rule='one_failed') -def handle_retry_failure_branch(task_id_to_check: str, **context): + + + + +# ============================================================================= +# DAG Definition with TaskGroups +# ============================================================================= +@task(trigger_rule='all_done') +def unlock_profile(**context): """ - Inspects a failed retry attempt and decides on the final action. - On retry, most errors are considered fatal for the URL, but not for the system. - """ - ti = context['task_instance'] - params = context['params'] - error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details') - if not error_details: - return 'handle_fatal_error' - - error_message = error_details.get('error_message', '').strip() - error_code = error_details.get('error_code', '').strip() - - # Unrecoverable video errors that should not be retried or treated as system failures. - unrecoverable_video_errors = [ - "AGE_GATED_SIGN_IN", "MEMBERS_ONLY", "VIDEO_PROCESSING", "COPYRIGHT_REMOVAL", - "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED" - ] - - if error_code in unrecoverable_video_errors: - logger.warning(f"Unrecoverable video error '{error_code}' detected on retry for '{task_id_to_check}'.") - return 'handle_unrecoverable_video_error' - - if error_code == 'TRANSPORT_ERROR': - logger.error(f"Fatal Thrift connection error on retry from '{task_id_to_check}'.") - return 'handle_fatal_error' - - is_bannable = error_code in ["BOT_DETECTED", "BOT_DETECTION_SIGN_IN_REQUIRED"] - if is_bannable: - policy = params.get('on_bannable_failure', 'retry_with_new_account') - if policy == 'proceed_loop_under_manual_inspection': - logger.warning(f"Bannable error '{error_code}' on retry with 'proceed_loop_under_manual_inspection' policy. Reporting failure and continuing loop. MANUAL INTERVENTION IS LIKELY REQUIRED.") - return 'report_bannable_and_continue' - - logger.warning(f"Bannable error '{error_code}' on retry. Banning account and reporting failure.") - return 'ban_and_report_after_retry' - - logger.error(f"URL failed on retry with code '{error_code}'. Reporting failure and continuing loop.") - return 'report_failure_and_continue' - - -@task -def ban_and_report_after_retry(retry_data: dict, reason: str, **context): - """Bans the account used in a failed retry and prepares for failure reporting.""" - # The account to ban is the one from the retry attempt. - ban_account(retry_data, reason, **context) - logger.info(f"Account '{retry_data.get('account_id')}' banned after retry failed. Proceeding to report failure.") - return retry_data - - - - -@task(trigger_rule='one_success') -def coalesce_token_data(get_token_result=None, retry_get_token_result=None): - """ - Selects the successful token data from either the first attempt or the retry. - The task that did not run or failed will have a result of None. - """ - if retry_get_token_result: - logger.info("Using token data from retry attempt.") - return retry_get_token_result - if get_token_result: - logger.info("Using token data from initial attempt.") - return get_token_result - # This should not be reached if trigger_rule='one_success' is working correctly. - raise AirflowException("Could not find a successful token result from any attempt.") - - -@task -def report_bannable_and_continue(**context): - """ - Handles a bannable error by reporting it, but continues the loop - as per the 'proceed_loop_under_manual_inspection' policy. + Unlocks the profile and records activity (success or failure). + This task runs regardless of upstream success or failure. """ params = context['params'] + dag_run = context['dag_run'] + + failed_tasks = [ti for ti in dag_run.get_task_instances() if ti.state == 'failed'] + is_success = not failed_tasks + activity_type = 'auth' if is_success else 'auth_error' + ti = context['task_instance'] - url = params.get('url_to_process', 'unknown') + initial_data = ti.xcom_pull(task_ids='get_url_and_lock_profile') - # Collect error details - error_details = {} - first_token_task_id = 'initial_attempt.get_token' - retry_token_task_id = 'retry_logic.retry_get_token' - - first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details') - retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details') - - # Use the most recent error details - if retry_token_error: - error_details = retry_token_error - elif first_token_error: - error_details = first_token_error - - logger.error(f"Bannable error for URL '{url}'. Policy is to continue loop under manual supervision.") - - # Report failure to Redis + locked_profile = initial_data.get('locked_profile') if initial_data else None + + if not locked_profile: + logger.warning("No locked_profile data found. Cannot unlock or record activity.") + return + + profile_name = locked_profile.get('name') + owner_id = f"airflow_auth_worker_{dag_run.run_id}" + try: - client = _get_redis_client(params['redis_conn_id']) + redis_conn_id = params['redis_conn_id'] + redis_env = params['redis_env'] + redis_hook = _get_redis_client(redis_conn_id, return_hook=True) + key_prefix = f"{redis_env}_profile_mgmt_" + pm = ProfileManager(redis_hook=redis_hook, key_prefix=key_prefix) + + logger.info(f"Recording activity '{activity_type}' for profile '{profile_name}'.") + pm.record_activity(profile_name, activity_type) - # Update client-specific stats - try: - machine_id = params.get('machine_id') or socket.gethostname() - _update_client_stats(client, params.get('clients', ''), 'failure', url, machine_id, context['dag_run'].run_id) - except Exception as e: - logger.error(f"Could not update client stats on bannable error: {e}", exc_info=True) + logger.info(f"Unlocking profile '{profile_name}' with owner '{owner_id}'.") + pm.unlock_profile(profile_name, owner=owner_id) - result_data = { - 'status': 'failed', - 'end_time': time.time(), - 'url': url, - 'dag_run_id': context['dag_run'].run_id, - 'error': 'bannable_error_manual_override', - 'error_message': 'Bannable error occurred, but policy is set to continue loop under manual supervision.', - 'error_details': error_details - } - result_queue = f"{params['queue_name']}_result" - fail_queue = f"{params['queue_name']}_fail" - - progress_queue = f"{params['queue_name']}_progress" - - with client.pipeline() as pipe: - pipe.hset(result_queue, url, json.dumps(result_data)) - pipe.hset(fail_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) - pipe.execute() - - logger.info(f"Stored bannable error for URL '{url}' in '{result_queue}' and '{fail_queue}'.") except Exception as e: - logger.error(f"Could not report bannable error to Redis: {e}", exc_info=True) + logger.error(f"Failed to unlock profile or record activity for '{profile_name}': {e}", exc_info=True) # ============================================================================= @@ -1159,141 +709,62 @@ with DAG( schedule=None, start_date=days_ago(1), catchup=False, - tags=['ytdlp', 'worker'], + tags=['ytdlp', 'worker', 'v2'], doc_md=__doc__, render_template_as_native_obj=True, is_paused_upon_creation=True, params={ + # V2 Profile Params + 'redis_env': Param("sim_auth", type="string", title="[V2 Profiles] Redis Environment", description="The environment for v2 profile management (e.g., 'sim_auth'). Determines the Redis key prefix."), + 'profile_prefix': Param("auth_user", type="string", title="[V2 Profiles] Profile Prefix", description="The prefix for auth profiles that workers should attempt to lock."), + 'queue_name': Param(DEFAULT_QUEUE_NAME, type="string"), 'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string"), 'service_ip': Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string"), 'service_port': Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer"), - 'account_pool': Param('default_account', type="string"), - 'account_pool_size': Param(None, type=["integer", "null"]), - 'prepend_client_to_account': Param(True, type="boolean", title="[Worker Param] Prepend Client to Account", description="If True, prepends client and timestamp to account names in prefix mode."), + # DEPRECATED PARAMS (kept for reference, but no longer used) + 'account_pool': Param('default_account', type="string", description="DEPRECATED: Use profile_prefix instead."), + 'account_pool_size': Param(None, type=["integer", "null"], description="DEPRECATED: Pool size is managed in Redis."), + 'prepend_client_to_account': Param(True, type="boolean", description="DEPRECATED"), + 'assigned_proxy_url': Param(None, type=["string", "null"], description="DEPRECATED: Proxy is now determined by the locked profile."), + 'account_id': Param(None, type=["string", "null"], description="DEPRECATED: Profile is locked dynamically."), + 'worker_index': Param(None, type=["integer", "null"], description="DEPRECATED"), + 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean", description="DEPRECATED"), + 'machine_id': Param(None, type=["string", "null"]), - 'assigned_proxy_url': Param(None, type=["string", "null"], title="[Worker Param] Assigned Proxy URL", description="If provided, forces the token service to use this specific proxy for the request."), - 'clients': Param('tv_simply', type="string", description="Comma-separated list of clients for token generation. e.g. mweb,tv,web_camoufox"), + 'clients': Param('tv_simply', type="string", description="DEPRECATED: This is now read from the ytdlp_config_json."), 'timeout': Param(DEFAULT_TIMEOUT, type="integer"), 'on_bannable_failure': Param('stop_loop_on_auth_proceed_on_download_error', type="string", enum=['stop_loop', 'retry_with_new_account', 'retry_without_ban', 'retry_and_ban_account_only', 'retry_on_connection_error', 'proceed_loop_under_manual_inspection', 'stop_loop_on_auth_proceed_on_download_error']), - 'request_params_json': Param(json.dumps(DEFAULT_REQUEST_PARAMS), type="string", title="[Worker Param] Request Params JSON", description="JSON string with request parameters for the token service."), - 'language_code': Param('en-US', type="string", title="[Worker Param] Language Code", description="The language code (e.g., 'en-US', 'de-DE') to use for the YouTube request headers."), - 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean"), + # --- Unified JSON Config (passed from orchestrator) --- + 'ytdlp_config_json': Param('{}', type="string", title="[Internal] Unified JSON config from orchestrator."), # --- Manual Run / Internal Parameters --- 'manual_url_to_process': Param('iPwdia3gAnk', type=["string", "null"], title="[Manual Run] URL to Process", description="For manual runs, provide a single YouTube URL, or the special value 'PULL_FROM_QUEUE' to pull one URL from the Redis inbox. This is ignored if triggered by the dispatcher."), 'url_to_process': Param(None, type=["string", "null"], title="[Internal] URL from Dispatcher", description="This parameter is set by the dispatcher DAG and should not be used for manual runs."), 'worker_queue': Param(None, type=["string", "null"], title="[Internal] Worker Queue", description="This parameter is set by the dispatcher DAG and should not be used for manual runs."), } ) as dag: - initial_data = get_url_and_assign_account() + initial_data = get_url_and_lock_profile() + unlock_profile_task = unlock_profile() # --- Task Instantiation with TaskGroups --- - # Main success/failure handlers (outside groups for clear end points) - fatal_error_task = handle_fatal_error() - report_failure_task = report_failure_and_continue() - continue_loop_task = continue_processing_loop() - unrecoverable_video_error_task = handle_unrecoverable_video_error() - report_bannable_and_continue_task = report_bannable_and_continue() - + # This is simplified. The auth worker does not retry with different accounts anymore, + # as the policy enforcer is responsible for managing profile health. If get_token fails, + # the profile is unlocked with a failure, and the loop continues to the next URL. + # --- Task Group 1: Initial Attempt --- - with TaskGroup("initial_attempt", tooltip="Initial token acquisition attempt") as initial_attempt_group: - first_token_attempt = get_token(initial_data) - initial_branch_task = handle_bannable_error_branch.override(trigger_rule='one_failed')( - task_id_to_check=first_token_attempt.operator.task_id - ) - - # Tasks for the "stop_loop" policy on initial attempt - ban_and_report_immediately_task = ban_and_report_immediately.override(task_id='ban_and_report_immediately')( - initial_data=initial_data, - reason="Banned by Airflow worker (policy is stop_loop)" - ) - - first_token_attempt >> initial_branch_task - initial_branch_task >> [fatal_error_task, ban_and_report_immediately_task, unrecoverable_video_error_task, report_bannable_and_continue_task] - - # --- Task Group 2: Retry Logic --- - with TaskGroup("retry_logic", tooltip="Retry logic with account management") as retry_logic_group: - # Retry path tasks - ban_and_retry_group = ban_and_retry_logic.override(group_id='ban_account_and_prepare_for_retry')( - initial_data=initial_data - ) - # This task is for retries after a ban check - after_ban_account_task = assign_new_account_after_ban_check.override(task_id='assign_new_account_after_ban_check')( - initial_data=initial_data - ) - # This task is for direct retries (e.g., on connection error) - direct_retry_account_task = assign_new_account_for_direct_retry.override(task_id='assign_new_account_for_direct_retry')( - initial_data=initial_data - ) - - @task(trigger_rule='one_success') - def coalesce_retry_data(direct_retry_data=None, after_ban_data=None): - """Coalesces account data from one of the two mutually exclusive retry paths.""" - if direct_retry_data: - return direct_retry_data - if after_ban_data: - return after_ban_data - raise AirflowException("Could not find valid account data for retry.") - - coalesced_retry_data = coalesce_retry_data( - direct_retry_data=direct_retry_account_task, - after_ban_data=after_ban_account_task - ) - - retry_token_task = get_token.override(task_id='retry_get_token')( - initial_data=coalesced_retry_data - ) - - # Retry failure branch and its tasks - retry_branch_task = handle_retry_failure_branch.override(trigger_rule='one_failed')( - task_id_to_check=retry_token_task.operator.task_id - ) - ban_after_retry_report_task = ban_and_report_after_retry.override(task_id='ban_and_report_after_retry')( - retry_data=coalesced_retry_data, - reason="Banned by Airflow worker after failed retry" - ) - - # Internal dependencies within retry group - ban_and_retry_group >> after_ban_account_task - after_ban_account_task >> coalesced_retry_data - direct_retry_account_task >> coalesced_retry_data - coalesced_retry_data >> retry_token_task - retry_token_task >> retry_branch_task - retry_branch_task >> [fatal_error_task, report_failure_task, ban_after_retry_report_task, unrecoverable_video_error_task, report_bannable_and_continue_task] - ban_after_retry_report_task >> report_failure_task - - # --- Task Group 3: Success/Continuation Logic --- - with TaskGroup("success_and_continuation", tooltip="Push to DL queue and continue loop") as success_group: - token_data = coalesce_token_data( - get_token_result=first_token_attempt, - retry_get_token_result=retry_token_task - ) + with TaskGroup("auth_attempt", tooltip="Token acquisition attempt") as auth_attempt_group: + token_data = get_token(initial_data) list_formats_task = list_available_formats(token_data=token_data) - success_task = push_auth_success_to_redis( - initial_data=initial_data, - token_data=token_data - ) + generate_tasks = generate_and_push_download_tasks(token_data=token_data) - first_token_attempt >> token_data - retry_token_task >> token_data - token_data >> list_formats_task >> success_task - success_task >> continue_loop_task + token_data >> list_formats_task >> generate_tasks - # --- DAG Dependencies between TaskGroups --- - # Initial attempt can lead to retry logic or direct failure - initial_branch_task >> [retry_logic_group, fatal_error_task, ban_and_report_immediately_task, unrecoverable_video_error_task, report_bannable_and_continue_task] + # --- Failure Handling --- + # `unlock_profile` is the terminal task, running after all upstream tasks are done. + # It determines success/failure and records activity. - # A successful initial attempt bypasses retry and goes straight to the success group - initial_attempt_group >> success_group - - # Retry logic leads to success/continuation on success or failure reporting on failure - retry_branch_task >> [report_failure_task] # Handled within the group - retry_logic_group >> success_group - - # Ban and report immediately leads to failure reporting - ban_and_report_immediately_task >> report_failure_task - - # Unrecoverable/bannable errors that don't stop the loop should continue processing - unrecoverable_video_error_task >> continue_loop_task - report_bannable_and_continue_task >> continue_loop_task + # --- DAG Dependencies --- + initial_data >> auth_attempt_group + auth_attempt_group >> unlock_profile_task + unlock_profile_task >> continue_processing_loop(token_data=None) # Continue loop regardless of outcome diff --git a/airflow/dags/ytdlp_ops_v02_worker_per_url_dl.py b/airflow/dags/ytdlp_ops_v02_worker_per_url_dl.py index 29d12f5..d878e34 100644 --- a/airflow/dags/ytdlp_ops_v02_worker_per_url_dl.py +++ b/airflow/dags/ytdlp_ops_v02_worker_per_url_dl.py @@ -11,9 +11,14 @@ This is the "Download Worker" part of a separated Auth/Download pattern. It receives a job payload with all necessary token info and handles only the downloading and probing of media files. """ - from __future__ import annotations +# --- Add project root to path to allow for yt-ops-client imports --- +import sys +# The yt-ops-client package is installed in editable mode in /app +if '/app' not in sys.path: + sys.path.insert(0, '/app') + from airflow.decorators import task, task_group from airflow.exceptions import AirflowException, AirflowSkipException from airflow.models import Variable @@ -30,6 +35,9 @@ import json import logging import os import random + +# Configure logging +logger = logging.getLogger(__name__) import re import redis import socket @@ -47,8 +55,12 @@ from thrift.protocol import TBinaryProtocol from thrift.transport import TSocket, TTransport from thrift.transport.TTransport import TTransportException -# Configure logging -logger = logging.getLogger(__name__) +# ytops_client imports for v2 profile management +try: + from ytops_client.profile_manager_tool import ProfileManager, format_duration, format_timestamp +except ImportError as e: + logger.critical(f"Could not import ytops_client modules: {e}. Ensure yt-ops-client package is installed correctly in Airflow's environment.") + raise # --- Client Stats Helper --- @@ -149,79 +161,86 @@ def _extract_video_id(url): # ============================================================================= @task -def get_download_job_from_conf(**context): +def lock_profile_and_find_task(**context): """ - Gets the download job details (which includes token data) from the DAG run conf. - This is the first task in the download worker DAG. + Profile-first worker logic: + 1. Locks an available download profile from the Redis pool. + 2. Scans the granular download task queue for a job matching the profile's prefix. + 3. Returns both the locked profile and the claimed job data. """ params = context['params'] ti = context['task_instance'] - - # --- Worker Pinning Verification --- - # This is a safeguard against a known Airflow issue where clearing a task - # can cause the task_instance_mutation_hook to be skipped, breaking pinning. - # See: https://github.com/apache/airflow/issues/20143 - expected_queue = None - if ti.run_id and '_q_' in ti.run_id: - expected_queue = ti.run_id.split('_q_')[-1] + dag_run = context['dag_run'] - if not expected_queue: - # Fallback to conf if run_id parsing fails for some reason - expected_queue = params.get('worker_queue') - - if expected_queue and ti.queue != expected_queue: - error_msg = ( - f"WORKER PINNING FAILURE: Task is running on queue '{ti.queue}' but was expected on '{expected_queue}'. " - "This usually happens after manually clearing a task, which is not the recommended recovery method for this DAG. " - "To recover a failed URL, let the DAG run fail, use the 'ytdlp_mgmt_queues' DAG to requeue the URL, " - "and use the 'ytdlp_ops_orchestrator' to start a new worker loop if needed." - ) - logger.error(error_msg) - raise AirflowException(error_msg) - elif expected_queue: - logger.info(f"Worker pinning verified. Task is correctly running on queue '{ti.queue}'.") - # --- End Verification --- + redis_conn_id = params['redis_conn_id'] + redis_env = params['redis_env'] + profile_prefix = params['profile_prefix'] - # The job data is passed by the dispatcher DAG via 'job_data'. - job_data = params.get('job_data') - if not job_data: - raise AirflowException("No job_data provided in DAG run configuration.") - - # If job_data is a string, parse it as JSON - if isinstance(job_data, str): - try: - job_data = json.loads(job_data) - except json.JSONDecodeError: - raise AirflowException(f"Could not decode job_data JSON: {job_data}") - - url_to_process = job_data.get('original_url') - if not url_to_process: - raise AirflowException("'original_url' not found in job_data.") - - logger.info(f"Received job for URL '{url_to_process}'.") - - # Mark the URL as in-progress in Redis + # Initialize ProfileManager try: - redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID) - queue_name = params.get('queue_name', DEFAULT_QUEUE_NAME) - progress_queue = f"{queue_name}_progress" - client = _get_redis_client(redis_conn_id) - - progress_data = { - 'status': 'in_progress', - 'start_time': time.time(), - 'dag_run_id': context['dag_run'].run_id, - 'hostname': socket.gethostname(), - } - client.hset(progress_queue, url_to_process, json.dumps(progress_data)) - logger.info(f"Marked URL '{url_to_process}' as in-progress.") + redis_hook = _get_redis_client(redis_conn_id, return_hook=True) + key_prefix = f"{redis_env}_profile_mgmt_" + pm = ProfileManager(redis_hook=redis_hook, key_prefix=key_prefix) + logger.info(f"Initialized ProfileManager for env '{redis_env}' (Redis key prefix: '{key_prefix}')") except Exception as e: - logger.error(f"Could not mark URL as in-progress in Redis: {e}", exc_info=True) + raise AirflowException(f"Failed to initialize ProfileManager: {e}") - return job_data + # Step 1: Lock a profile + owner_id = f"airflow_dl_worker_{dag_run.run_id}" + locked_profile = None + logger.info(f"Attempting to lock a profile with owner '{owner_id}' and prefix '{profile_prefix}'...") + + # This is a blocking loop until a profile is found or the task times out. + while not locked_profile: + locked_profile = pm.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + if not locked_profile: + logger.info("No download profiles available to lock. Waiting for 15 seconds...") + time.sleep(15) + + logger.info(f"Successfully locked profile: {locked_profile['name']}") + + # Step 2: Find a matching task + task_queue = "queue_dl_format_tasks" + job_data = None + logger.info(f"Scanning Redis list '{task_queue}' for a matching task...") + + # This is a simple, non-atomic 'claim' logic suitable for Airflow's concurrency model. + # It's not perfectly race-proof but is a reasonable starting point. + redis_client = pm.redis + max_scan_attempts = 100 # To prevent infinite loops on a busy queue + + for i in range(max_scan_attempts): + task_json = redis_client.lpop(task_queue) + if not task_json: + logger.info("Task queue is empty. Waiting for 10 seconds...") + time.sleep(10) + continue + + try: + task_data = json.loads(task_json) + if task_data.get('profile_prefix') == profile_prefix: + job_data = task_data + logger.info(f"Claimed task for profile prefix '{profile_prefix}': {job_data}") + break + else: + # Not a match, push it back to the end of the queue and try again. + redis_client.rpush(task_queue, task_json) + except (json.JSONDecodeError, TypeError): + logger.error(f"Could not parse task from queue. Discarding item: {task_json}") + + if not job_data: + # If no task is found, unlock the profile and fail gracefully. + pm.unlock_profile(locked_profile['name'], owner=owner_id) + raise AirflowSkipException(f"Could not find a matching task in '{task_queue}' for prefix '{profile_prefix}' after {max_scan_attempts} attempts.") + + # Combine profile and job data to pass to the next task + return { + 'locked_profile': locked_profile, + 'job_data': job_data, + } @task -def list_available_formats(token_data: dict, **context): +def list_available_formats(worker_data: dict, **context): """ Lists available formats for the given video using the info.json. This is for debugging and informational purposes. @@ -229,7 +248,7 @@ def list_available_formats(token_data: dict, **context): import subprocess import shlex - info_json_path = token_data.get('info_json_path') + info_json_path = worker_data['job_data'].get('info_json_path') if not (info_json_path and os.path.exists(info_json_path)): logger.warning(f"Cannot list formats: info.json path is missing or file does not exist ({info_json_path}).") return [] @@ -334,12 +353,55 @@ def _resolve_generic_selector(selector: str, info_json_path: str, logger) -> str return None -@task -def download_and_probe(token_data: dict, available_formats: list[str], **context): +def _check_format_expiry(info_json_path: str, formats_to_check: list[str], logger) -> bool: """ - Uses retrieved token data to download and probe media files. - Supports parallel downloading of specific, comma-separated format IDs. - If probing fails, retries downloading only the failed files. + Checks if any of the specified format URLs have expired using yt-ops-client. + Returns True if any format is expired, False otherwise. + """ + import subprocess + import shlex + + if not formats_to_check: + return False + + logger.info(f"Checking for URL expiry for formats: {formats_to_check}") + + # We can check all formats at once. The tool will report if any of them are expired. + try: + cmd = [ + 'ytops-client', 'check-expiry', + '--load-info-json', info_json_path, + '-f', ','.join(formats_to_check), + ] + + copy_paste_cmd = ' '.join(shlex.quote(arg) for arg in cmd) + logger.info(f"Executing expiry check for all selected formats: {copy_paste_cmd}") + + process = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + + if process.stdout: + logger.info(f"ytops-client check-expiry STDOUT:\n{process.stdout}") + if process.stderr: + logger.info(f"ytops-client check-expiry STDERR:\n{process.stderr}") + + # The tool exits with a non-zero code if a URL is expired. + if process.returncode != 0: + logger.error("Expiry check failed. One or more URLs are likely expired.") + return True # An expiry was found + + except Exception as e: + logger.error(f"An error occurred during expiry check: {e}", exc_info=True) + # To be safe, treat this as a potential expiry to trigger re-authentication. + return True + + logger.info("No expired URLs found for the selected formats.") + return False + + +@task +def download_and_probe(worker_data: dict, **context): + """ + Uses profile and job data to download and probe a single media format. """ try: import subprocess @@ -347,23 +409,11 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context import concurrent.futures params = context['params'] - info_json_path = token_data.get('info_json_path') - original_url = token_data.get('original_url') - - # Extract proxy from filename, with fallback to token_data for backward compatibility - proxy = None - if info_json_path: - filename = os.path.basename(info_json_path) - proxy_match = re.search(r'_proxy_(.+)\.json$', filename) - if proxy_match: - sanitized_proxy = proxy_match.group(1) - # Reverse sanitization from auth worker (replace '---' with '://') - proxy = sanitized_proxy.replace('---', '://') - logger.info(f"Extracted proxy '{proxy}' from filename.") + job_data = worker_data['job_data'] + locked_profile = worker_data['locked_profile'] - if not proxy: - logger.warning("Proxy not found in filename. Falling back to 'socks_proxy' from token_data.") - proxy = token_data.get('socks_proxy') + info_json_path = job_data.get('info_json_path') + proxy = locked_profile.get('proxy') if not (info_json_path and os.path.exists(info_json_path)): raise AirflowException(f"Error: info.json path is missing or file does not exist ({info_json_path}).") @@ -383,14 +433,11 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context except Exception as e: logger.warning(f"Could not process/remove 'js_runtimes' from info.json: {e}", exc_info=True) - download_dir = token_data.get('job_dir_path') - if not download_dir: - # Fallback for older runs or if job_dir_path is missing - download_dir = os.path.dirname(info_json_path) + download_dir = os.path.dirname(info_json_path) - download_format = params.get('download_format') + download_format = job_data.get('format_id') if not download_format: - raise AirflowException("The 'download_format' parameter is missing or empty.") + raise AirflowException("The 'format_id' is missing from the job data.") output_template = params.get('output_path_template', "%(id)s.f%(format_id)s.%(ext)s") full_output_path = os.path.join(download_dir, output_template) @@ -408,17 +455,7 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context # The 'py' tool maps many yt-dlp flags via --extra-ytdlp-args # The 'py' tool maps many yt-dlp flags via --extra-ytdlp-args - py_extra_args = ['--output', output_template, '--no-resize-buffer', '--buffer-size', '4M'] - if params.get('fragment_retries'): - py_extra_args.extend(['--fragment-retries', str(params['fragment_retries'])]) - if params.get('limit_rate'): - py_extra_args.extend(['--limit-rate', params['limit_rate']]) - if params.get('socket_timeout'): - py_extra_args.extend(['--socket-timeout', str(params['socket_timeout'])]) - if params.get('min_sleep_interval'): - py_extra_args.extend(['--sleep-interval', str(params['min_sleep_interval'])]) - if params.get('max_sleep_interval'): - py_extra_args.extend(['--max-sleep-interval', str(params['max_sleep_interval'])]) + py_extra_args = ['--output', output_template] if params.get('yt_dlp_test_mode'): py_extra_args.append('--test') @@ -468,17 +505,7 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context cmd.extend(['--proxy', proxy]) # The 'cli' tool is the old yt-dlp wrapper, so it takes similar arguments. - cli_extra_args = ['--output', full_output_path, '--no-resize-buffer', '--buffer-size', '4M'] - if params.get('fragment_retries'): - cli_extra_args.extend(['--fragment-retries', str(params['fragment_retries'])]) - if params.get('limit_rate'): - cli_extra_args.extend(['--limit-rate', params['limit_rate']]) - if params.get('socket_timeout'): - cli_extra_args.extend(['--socket-timeout', str(params['socket_timeout'])]) - if params.get('min_sleep_interval'): - cli_extra_args.extend(['--sleep-interval', str(params['min_sleep_interval'])]) - if params.get('max_sleep_interval'): - cli_extra_args.extend(['--max-sleep-interval', str(params['max_sleep_interval'])]) + cli_extra_args = ['--output', full_output_path, '--verbose'] if params.get('yt_dlp_test_mode'): cli_extra_args.append('--test') @@ -600,79 +627,19 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context return successful_probes, failed_probes # --- Main Execution Logic --- - with open(info_json_path, 'r', encoding='utf-8') as f: - info = json.load(f) - - # Split the format string by commas to get a list of individual format selectors. - # This enables parallel downloads of different formats or format groups. - # For example, '18,140,299/298' becomes ['18', '140', '299/298'], - # and each item will be downloaded in a separate yt-dlp process. - if download_format and isinstance(download_format, str): - formats_to_download_initial = [selector.strip() for selector in download_format.split(',') if selector.strip()] - else: - # Fallback for safety, though download_format should always be a string. - formats_to_download_initial = [] - - if not formats_to_download_initial: - raise AirflowException("No valid download format selectors were found after parsing.") - - # --- Filter and resolve requested formats --- - final_formats_to_download = [] - if not available_formats: - logger.warning("List of available formats is empty. Cannot validate numeric selectors, but will attempt to resolve generic selectors.") - - for selector in formats_to_download_initial: - # A selector is considered generic if it contains keywords like 'best' or filter brackets '[]'. - is_generic = bool(re.search(r'(best|\[|\])', selector)) - - if is_generic: - resolved_selector = _resolve_generic_selector(selector, info_json_path, logger) - if resolved_selector: - # The resolver returns a list for '+' selectors, or a string for others. - resolved_formats = resolved_selector if isinstance(resolved_selector, list) else [resolved_selector] - - for res_format in resolved_formats: - # Prefer -dashy version if available and the format is a simple numeric ID - if res_format.isdigit() and f"{res_format}-dashy" in available_formats: - final_format = f"{res_format}-dashy" - logger.info(f"Resolved format '{res_format}' from selector '{selector}'. Preferred '-dashy' version: '{final_format}'.") - else: - final_format = res_format - - # Validate the chosen format against available formats - if available_formats: - individual_ids = re.split(r'[/+]', final_format) - is_available = any(fid in available_formats for fid in individual_ids) - - if is_available: - final_formats_to_download.append(final_format) - else: - logger.warning(f"Resolved format '{final_format}' (from '{selector}') contains no available formats. Skipping.") - else: - # Cannot validate, so we trust the resolver's output. - final_formats_to_download.append(final_format) - else: - logger.warning(f"Could not resolve generic selector '{selector}' using yt-dlp. Skipping.") - else: - # This is a numeric-based selector (e.g., '140' or '299/298' or '140-dashy'). - # Validate it against the available formats. - if not available_formats: - logger.warning(f"Cannot validate numeric selector '{selector}' because available formats list is empty. Assuming it's valid.") - final_formats_to_download.append(selector) - continue - - individual_ids = re.split(r'[/+]', selector) - is_available = any(fid in available_formats for fid in individual_ids) - - if is_available: - final_formats_to_download.append(selector) - else: - logger.warning(f"Requested numeric format selector '{selector}' contains no available formats. Skipping.") + final_formats_to_download = download_format if not final_formats_to_download: - raise AirflowException("None of the requested formats are available for this video.") + raise AirflowException("The format_id for this job is empty.") + + # --- Check for expired URLs before attempting download --- + if _check_format_expiry(info_json_path, [final_formats_to_download], logger): + # If URL is expired, we need to fail the task so it can be re-queued for auth. + # We also need to record a failure for the profile. + raise AirflowException("Format URL has expired. The job must be re-authenticated.") # --- Initial Download and Probe --- + # The worker now handles one format at a time. successful_files, failed_files = _download_and_probe_formats(final_formats_to_download) if params.get('yt_dlp_test_mode'): @@ -690,9 +657,10 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context logger.warning(f"Probe failed for {len(failed_files)} file(s). Attempting one re-download for failed files...") delay_between_formats = params.get('delay_between_formats_s', 0) - if delay_between_formats > 0: - logger.info(f"Waiting {delay_between_formats}s before re-download attempt...") - time.sleep(delay_between_formats) + # This delay is no longer needed in the profile-first model. + # if delay_between_formats > 0: + # logger.info(f"Waiting {delay_between_formats}s before re-download attempt...") + # time.sleep(delay_between_formats) format_ids_to_retry = [] # Since each download is now for a specific selector and the output template @@ -744,79 +712,9 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context logger.error(f"Error during cleanup for file {f}: {e}", exc_info=True) # Do not fail the task for a cleanup error, just log it. - # --- Move completed job directory to final destination --- - try: - video_id = _extract_video_id(original_url) - if not video_id: - logger.error(f"Could not extract video_id from URL '{original_url}' for final move. Skipping.") - else: - # --- Rename info.json to a simple format before moving --- - path_to_info_json_for_move = info_json_path # Default to original path - try: - # info_json_path is the full path to the original info.json - if info_json_path and os.path.exists(info_json_path): - new_info_json_name = f"info_{video_id}.json" - new_info_json_path = os.path.join(os.path.dirname(info_json_path), new_info_json_name) - - if info_json_path != new_info_json_path: - logger.info(f"Renaming '{info_json_path}' to '{new_info_json_path}' for final delivery.") - os.rename(info_json_path, new_info_json_path) - path_to_info_json_for_move = new_info_json_path - else: - logger.info("info.json already has the simple name. No rename needed.") - else: - logger.warning("Could not find info.json to rename before moving.") - except Exception as rename_e: - logger.error(f"Failed to rename info.json before move: {rename_e}", exc_info=True) - # --- End of rename logic --- - - source_dir = download_dir # This is the job_dir_path - - # Group downloads into 10-minute batch folders based on completion time. - now = datetime.now() - rounded_minute = (now.minute // 10) * 10 - timestamp_str = now.strftime('%Y%m%dT%H') + f"{rounded_minute:02d}" - - final_dir_base = os.path.join(Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles'), 'videos', 'ready', timestamp_str) - final_dir_path = os.path.join(final_dir_base, video_id) - - os.makedirs(final_dir_base, exist_ok=True) - - logger.info(f"Moving completed job from '{source_dir}' to final destination '{final_dir_path}'") - if os.path.exists(final_dir_path): - logger.warning(f"Destination '{final_dir_path}' already exists. It will be removed and replaced.") - shutil.rmtree(final_dir_path) - - # Create the destination directory and move only the essential files, then clean up the source. - # This ensures no temporary or junk files are carried over. - os.makedirs(final_dir_path) - - # 1. Move the info.json file - if path_to_info_json_for_move and os.path.exists(path_to_info_json_for_move): - shutil.move(path_to_info_json_for_move, final_dir_path) - logger.info(f"Moved '{os.path.basename(path_to_info_json_for_move)}' to destination.") - - # 2. Move the media files (or their .empty placeholders) - files_to_move = [] - if params.get('yt_dlp_cleanup_mode', False): - files_to_move = [f"{f}.empty" for f in final_success_list] - else: - files_to_move = final_success_list - - for f in files_to_move: - if os.path.exists(f): - shutil.move(f, final_dir_path) - logger.info(f"Moved '{os.path.basename(f)}' to destination.") - else: - logger.warning(f"File '{f}' expected but not found for moving.") - - # 3. Clean up the original source directory - logger.info(f"Cleaning up original source directory '{source_dir}'") - shutil.rmtree(source_dir) - logger.info(f"Successfully moved job to '{final_dir_path}' and cleaned up source.") - except Exception as e: - logger.error(f"Failed to move completed job directory: {e}", exc_info=True) - # Do not fail the task for a move error, just log it. + # The logic for moving files to a final destination is now handled by the `ytops-client download py` tool + # when `output_to_airflow_ready_dir` is used. This worker no longer needs to perform the move. + # It just needs to return the list of successfully downloaded files. return final_success_list except Exception as e: @@ -834,7 +732,8 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context "PRIVATE_VIDEO": ['private video'], "VIDEO_REMOVED": ['video has been removed'], "VIDEO_UNAVAILABLE": ['video unavailable'], - "HTTP_403_FORBIDDEN": ['http error 403: forbidden'] + "HTTP_403_FORBIDDEN": ['http error 403: forbidden'], + "URL_EXPIRED": ['urls have expired'] } for code, patterns in unrecoverable_patterns.items(): @@ -846,54 +745,56 @@ def download_and_probe(token_data: dict, available_formats: list[str], **context ti.xcom_push(key='download_error_details', value=error_details) raise e -@task -def mark_url_as_success(job_data: dict, downloaded_file_paths: list, **context): - """Records the successful download result in Redis.""" +@task(trigger_rule='all_done') +def unlock_profile(worker_data: dict, **context): + """ + Unlocks the profile and records activity (success or failure). + This task runs regardless of upstream success or failure. + """ params = context['params'] - url = job_data['original_url'] - result_data = { - 'status': 'success', 'end_time': time.time(), 'url': url, - 'downloaded_file_paths': downloaded_file_paths, **job_data, - 'dag_run_id': context['dag_run'].run_id, - } - client = _get_redis_client(params['redis_conn_id']) - - # Update activity counters - try: - proxy_url = job_data.get('socks_proxy') - account_id = job_data.get('successful_account_id') - now = time.time() - # Use a unique member to prevent collisions, e.g., dag_run_id - member = context['dag_run'].run_id - - if proxy_url: - proxy_key = f"activity:per_proxy:{proxy_url}" - client.zadd(proxy_key, {member: now}) - client.expire(proxy_key, 3600 * 2) # Expire after 2 hours - if account_id: - account_key = f"activity:per_account:{account_id}" - client.zadd(account_key, {member: now}) - client.expire(account_key, 3600 * 2) # Expire after 2 hours - except Exception as e: - logger.error(f"Could not update activity counters: {e}", exc_info=True) + dag_run = context['dag_run'] - # Update client-specific stats - try: - machine_id = params.get('machine_id') or socket.gethostname() - clients_str = job_data.get('clients', params.get('clients', '')) # Prefer clients from job, fallback to params - _update_client_stats(client, clients_str, 'success', url, machine_id, context['dag_run'].run_id) - except Exception as e: - logger.error(f"Could not update client stats on success: {e}", exc_info=True) - - progress_queue = f"{params['queue_name']}_progress" - result_queue = f"{params['queue_name']}_result" + # Check if the DAG run failed + failed_tasks = [ti for ti in dag_run.get_task_instances() if ti.state == 'failed'] + is_success = not failed_tasks + activity_type = 'download' if is_success else 'download_error' - with client.pipeline() as pipe: - pipe.hset(result_queue, url, json.dumps(result_data)) - pipe.hdel(progress_queue, url) - pipe.execute() + # Use XCom pull to get the data from the initial task, which is more robust + # in case of upstream failures where the data is not passed directly. + ti = context['task_instance'] + worker_data_pulled = ti.xcom_pull(task_ids='lock_profile_and_find_task') + + locked_profile = worker_data_pulled.get('locked_profile') if worker_data_pulled else None + + if not locked_profile: + logger.warning("No locked_profile data found from 'lock_profile_and_find_task'. Cannot unlock or record activity.") + return + + profile_name = locked_profile.get('name') + owner_id = f"airflow_dl_worker_{dag_run.run_id}" + + try: + redis_conn_id = params['redis_conn_id'] + redis_env = params['redis_env'] + redis_hook = _get_redis_client(redis_conn_id, return_hook=True) + key_prefix = f"{redis_env}_profile_mgmt_" + pm = ProfileManager(redis_hook=redis_hook, key_prefix=key_prefix) - logger.info(f"Stored success result for URL '{url}' and removed from progress queue.") + logger.info(f"Recording activity '{activity_type}' for profile '{profile_name}'.") + pm.record_activity(profile_name, activity_type) + + logger.info(f"Unlocking profile '{profile_name}' with owner '{owner_id}'.") + # Read cooldown from config if available + cooldown_str = pm.get_config('unlock_cooldown_seconds') + cooldown = int(cooldown_str) if cooldown_str and cooldown_str.isdigit() else None + + pm.unlock_profile(profile_name, owner=owner_id, rest_for_seconds=cooldown) + if cooldown: + logger.info(f"Profile '{profile_name}' was put into COOLDOWN for {cooldown} seconds.") + + except Exception as e: + logger.error(f"Failed to unlock profile or record activity for '{profile_name}': {e}", exc_info=True) + # Do not fail the task, as this is a cleanup step. @task(trigger_rule='one_failed') def report_failure_and_continue(**context): @@ -1121,8 +1022,7 @@ def handle_download_failure_branch(**context): error_code = download_error_details.get('error_code') unrecoverable_video_errors = [ "AGE_GATED_SIGN_IN", "MEMBERS_ONLY", "VIDEO_PROCESSING", "COPYRIGHT_REMOVAL", - "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED", - "HTTP_403_FORBIDDEN" + "GEO_RESTRICTED", "VIDEO_UNAVAILABLE", "PRIVATE_VIDEO", "VIDEO_REMOVED" ] if error_code in unrecoverable_video_errors: logger.warning(f"Unrecoverable video error '{error_code}' during download. Skipping.") @@ -1143,34 +1043,28 @@ with DAG( schedule=None, start_date=days_ago(1), catchup=False, - tags=['ytdlp', 'worker'], + tags=['ytdlp', 'worker', 'v2'], doc_md=__doc__, render_template_as_native_obj=True, is_paused_upon_creation=True, params={ - 'queue_name': Param(DEFAULT_QUEUE_NAME, type="string"), + # --- V2 Profile Management Parameters --- + 'redis_env': Param("sim_download", type="string", title="[V2 Profiles] Redis Environment", description="The environment for v2 profile management (e.g., 'sim_download'). Determines the Redis key prefix."), + 'profile_prefix': Param("download_user", type="string", title="[V2 Profiles] Profile Prefix", description="The prefix for download profiles that workers should attempt to lock."), + 'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string"), 'machine_id': Param(None, type=["string", "null"]), - 'clients': Param('mweb,web_camoufox,tv', type="string", description="Comma-separated list of clients for token generation. e.g. mweb,tv,web_camoufox"), + 'clients': Param('tv_simply', type="string", description="Comma-separated list of clients for token generation. e.g. mweb,tv,web_camoufox"), 'output_path_template': Param("%(id)s.f%(format_id)s.%(ext)s", type="string", title="[Worker Param] Output Path Template", description="Output filename template for yt-dlp. It is highly recommended to include `%(format_id)s` to prevent filename collisions when downloading multiple formats."), 'retry_on_probe_failure': Param(False, type="boolean"), - 'skip_probe': Param(False, type="boolean", title="[Worker Param] Skip Probe", description="If True, skips the ffmpeg probe of downloaded files."), + 'skip_probe': Param(True, type="boolean", title="[Worker Param] Skip Probe", description="If True, skips the ffmpeg probe of downloaded files."), 'yt_dlp_cleanup_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Cleanup Mode", description="If True, creates a .empty file and deletes the original media file after successful download and probe."), - 'delay_between_formats_s': Param(15, type="integer", title="[Worker Param] Delay Between Formats (s)", description="Delay in seconds between downloading each format when multiple formats are specified. A 22s wait may be effective for batch downloads, while 6-12s may suffice if cookies are refreshed regularly."), + 'delay_between_formats_s': Param(0, type="integer", title="[Worker Param] Delay Between Formats (s)", description="No longer used in profile-first model, as each format is a separate task."), 'yt_dlp_test_mode': Param(False, type="boolean", title="[Worker Param] yt-dlp Test Mode", description="If True, runs yt-dlp with --test flag (dry run without downloading)."), - 'fragment_retries': Param(2, type="integer", title="[Worker Param] Fragment Retries", description="Number of retries for a fragment before giving up. Default is 2 to fail fast on expired tokens."), - 'limit_rate': Param('5M', type=["string", "null"], title="[Worker Param] Limit Rate", description="Download speed limit (e.g., 50K, 4.2M)."), - 'socket_timeout': Param(15, type="integer", title="[Worker Param] Socket Timeout", description="Timeout in seconds for socket operations."), - 'min_sleep_interval': Param(5, type="integer", title="[Worker Param] Min Sleep Interval", description="Minimum time to sleep between downloads (seconds)."), - 'max_sleep_interval': Param(10, type="integer", title="[Worker Param] Max Sleep Interval", description="Maximum time to sleep between downloads (seconds)."), - 'download_format': Param( - 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', - type="string", - title="[Worker Param] Download Format", - description="Custom yt-dlp format string. Common presets: [1] 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best' (Default, best quality MP4). [2] '18-dashy/18,140-dashy/140,133-dashy/134-dashy/136-dashy/137-dashy/250-dashy/298-dashy/299-dashy' (Legacy formats). [3] '299-dashy/298-dashy/250-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy' (High-framerate formats)." - ), + 'download_format': Param(None, type=["string", "null"], title="[DEPRECATED] Download Format", description="This is now specified in the granular task generated by the auth worker."), + 'pass_without_formats_splitting': Param(False, type="boolean", title="[DEPRECATED] Pass format string without splitting"), 'downloader': Param( - 'cli', + 'py', type="string", enum=['py', 'aria-rpc', 'cli'], title="Download Tool", @@ -1180,52 +1074,37 @@ with DAG( 'aria_port': Param(6800, type="integer", title="Aria2c Port", description="For 'aria-rpc' downloader: Port of the aria2c RPC server."), 'aria_secret': Param('SQGCQPLVFQIASMPNPOJYLVGJYLMIDIXDXAIXOTX', type="string", title="Aria2c Secret", description="For 'aria-rpc' downloader: Secret token."), 'yt_dlp_extra_args': Param( - '--no-part --restrict-filenames', + '--verbose --no-resize-buffer --buffer-size 4M --fragment-retries 2 --concurrent-fragments 8 --socket-timeout 15 --sleep-interval 5 --max-sleep-interval 10 --no-part --restrict-filenames', type=["string", "null"], title="Extra yt-dlp arguments", description="Extra command-line arguments for yt-dlp during download." ), # --- Manual Run / Internal Parameters --- - 'job_data': Param(None, type=["object", "string", "null"], title="[Internal] Job Data from Dispatcher", description="This parameter is set by the dispatcher DAG and should not be used for manual runs."), - 'worker_queue': Param(None, type=["string", "null"], title="[Internal] Worker Queue", description="This parameter is set by the dispatcher DAG and should not be used for manual runs."), + 'manual_job_input': Param(None, type=["string", "null"], title="[DEPRECATED] Job Input"), + 'job_data': Param(None, type=["object", "string", "null"], title="[Internal] Job Data from Dispatcher", description="This is no longer used. The worker finds its own job."), + 'worker_queue': Param(None, type=["string", "null"], title="[Internal] Worker Queue", description="This parameter is set by the dispatcher DAG."), } ) as dag: - job_data = get_download_job_from_conf() - - # --- Task Instantiation --- + # --- Task Instantiation for Profile-First Model --- - # Main success/failure handlers - fatal_error_task = handle_fatal_error() - report_failure_task = report_failure_and_continue() - continue_loop_task = continue_processing_loop() - unrecoverable_video_error_task = handle_unrecoverable_video_error() + # 1. Start by locking a profile and finding a task. + worker_data = lock_profile_and_find_task() - # --- Download and Processing Group --- + # 2. Define the download processing group. with TaskGroup("download_processing", tooltip="Download and media processing") as download_processing_group: - list_formats_task = list_available_formats(token_data=job_data) - download_task = download_and_probe( - token_data=job_data, - available_formats=list_formats_task, - ) - download_branch_task = handle_download_failure_branch.override(trigger_rule='one_failed')() - success_task = mark_url_as_success( - job_data=job_data, - downloaded_file_paths=download_task, - ) - + list_formats_task = list_available_formats(worker_data=worker_data) + download_task = download_and_probe(worker_data=worker_data) list_formats_task >> download_task - download_task >> download_branch_task - download_branch_task >> [report_failure_task, unrecoverable_video_error_task] - download_task >> success_task - success_task >> continue_loop_task - # If the initial job setup succeeds, proceed to the download group. - # If it fails, trigger the fatal error handler. This prevents fatal_error_task - # from being an "island" task that gets triggered by any other failure in the DAG. - job_data.operator >> download_processing_group - job_data.operator >> fatal_error_task + # 3. Define the final cleanup and loop continuation tasks. + unlock_profile_task = unlock_profile(worker_data=worker_data) + continue_loop_task = continue_processing_loop() - # Any failure or skip path should continue the loop to process the next URL. - report_failure_task >> continue_loop_task - fatal_error_task >> continue_loop_task - unrecoverable_video_error_task >> continue_loop_task + # --- DAG Dependencies --- + # Start -> Download Group -> Unlock -> Continue Loop + worker_data >> download_processing_group + download_processing_group >> unlock_profile_task + + # The loop continues regardless of whether the download succeeded or failed. + # The unlock_profile task (with trigger_rule='all_done') ensures it always runs. + unlock_profile_task >> continue_loop_task diff --git a/airflow/dags/ytdlp_s3_uploader.py b/airflow/dags/ytdlp_s3_uploader.py index ebbc637..8c3fb63 100644 --- a/airflow/dags/ytdlp_s3_uploader.py +++ b/airflow/dags/ytdlp_s3_uploader.py @@ -278,8 +278,9 @@ def run_s3_upload_batch(**context): try: for batch_dir_path in processed_batch_dirs: try: - # Use rsync with an empty source to efficiently delete the contents of the batch directory - # The trailing slash on both source and destination is important. + # Use rsync with an empty source to efficiently delete the contents of the batch directory. + # This is a performant alternative to `shutil.rmtree`, which can be slow with many small files. + # The trailing slash on both source and destination is important for rsync's behavior. rsync_cmd = [ 'rsync', '-a', '--delete', @@ -287,14 +288,21 @@ def run_s3_upload_batch(**context): f'{batch_dir_path}/' ] subprocess.run(rsync_cmd, check=True, capture_output=True, text=True) - + # After the contents are deleted, remove the now-empty directory os.rmdir(batch_dir_path) logger.info(f"Successfully removed {batch_dir_path}") except Exception as cleanup_e: - logger.error(f"Failed to remove directory {batch_dir_path}: {cleanup_e}", exc_info=True) - if isinstance(cleanup_e, subprocess.CalledProcessError): - logger.error(f"rsync STDERR: {cleanup_e.stderr}") + if isinstance(cleanup_e, OSError) and "Directory not empty" in str(cleanup_e): + # This can happen in a race condition where a download worker adds a new video + # to the batch directory after rsync has emptied it but before rmdir runs. + # We log it as a warning; the directory will be re-processed in the next cycle + # because this task rescans all directories on each run. + logger.warning(f"Could not remove directory {batch_dir_path}, it was not empty: {cleanup_e}. It will be re-processed on the next run.") + else: + logger.error(f"Failed to remove directory {batch_dir_path}: {cleanup_e}", exc_info=True) + if isinstance(cleanup_e, subprocess.CalledProcessError): + logger.error(f"rsync STDERR: {cleanup_e.stderr}") finally: # Clean up the temporary empty directory shutil.rmtree(empty_dir_for_rsync) diff --git a/ansible/README.md b/ansible/README.md index 343c56c..8a00cae 100644 --- a/ansible/README.md +++ b/ansible/README.md @@ -37,3 +37,10 @@ These playbooks are used for more specific tasks or are called by the main playb - `playbook-dl.yml`: Older worker deployment logic. Superseded by `playbook-worker.yml`. - `playbook-depricated.dl.yml`: Older worker deployment logic. Superseded by `playbook-worker.yml`. +## Current Goal: Disable Camoufox & Enable Aria2 + +The current objective is to modify the worker deployment (`playbook-worker.yml` and its role `roles/ytdlp-worker/tasks/main.yml`) to: +1. **Disable Camoufox**: Prevent the build, configuration generation, and startup of all `camoufox` services. +2. **Enable Aria2**: Ensure the `aria2-pro` service is built and started correctly on worker nodes. + +The `playbook-worker.yml` has already been updated to build the `aria2-pro` image. The next steps will involve modifying `roles/ytdlp-worker/tasks/main.yml` to remove the Camoufox-related tasks. diff --git a/ansible/playbook-install-local.yml b/ansible/playbook-install-local.yml new file mode 100644 index 0000000..87b4b62 --- /dev/null +++ b/ansible/playbook-install-local.yml @@ -0,0 +1,44 @@ +--- +- name: Install Local Development Packages + hosts: airflow_workers, airflow_master + gather_facts: no + vars_files: + - "{{ inventory_dir }}/group_vars/all/generated_vars.yml" + + tasks: + - name: Ensure python3-pip is installed + ansible.builtin.apt: + name: python3-pip + state: present + update_cache: yes + become: yes + + - name: Upgrade pip to the latest version (for systems without PEP 668) + ansible.builtin.command: python3 -m pip install --upgrade pip + register: pip_upgrade_old_systems + changed_when: "'Requirement already satisfied' not in pip_upgrade_old_systems.stdout" + failed_when: false # This task will fail on newer systems, which is expected. + become: yes + become_user: "{{ ansible_user }}" + + - name: Upgrade pip to the latest version (for systems with PEP 668) + ansible.builtin.command: python3 -m pip install --upgrade pip --break-system-packages + when: pip_upgrade_old_systems.rc != 0 and 'externally-managed-environment' in pip_upgrade_old_systems.stderr + changed_when: "'Requirement already satisfied' not in pip_upgrade_new_systems.stdout" + register: pip_upgrade_new_systems + become: yes + become_user: "{{ ansible_user }}" + + - name: Install or upgrade yt-dlp to the latest nightly version + ansible.builtin.command: python3 -m pip install -U --pre "yt-dlp[default]" --break-system-packages + register: ytdlp_install + changed_when: "'Requirement already satisfied' not in ytdlp_install.stdout" + become: yes + become_user: "{{ ansible_user }}" + + - name: Install requests library + ansible.builtin.command: python3 -m pip install requests==2.31.0 --break-system-packages + register: requests_install + changed_when: "'Requirement already satisfied' not in requests_install.stdout" + become: yes + become_user: "{{ ansible_user }}" diff --git a/ansible/playbook-sync-local.yml b/ansible/playbook-sync-local.yml index 472a4d2..7b15bfe 100644 --- a/ansible/playbook-sync-local.yml +++ b/ansible/playbook-sync-local.yml @@ -1,41 +1,22 @@ --- -- name: Sync Local Development Files to Workers - hosts: airflow_workers +- name: Sync Local Development Files to Workers and Master + hosts: airflow_workers, airflow_master gather_facts: no vars_files: - "{{ inventory_dir }}/group_vars/all/generated_vars.yml" + vars: + sync_dir: "{{ airflow_worker_dir if 'airflow_workers' in group_names else airflow_master_dir }}" pre_tasks: - name: Announce local sync debug: - msg: "Syncing local dev files to {{ inventory_hostname }} at {{ airflow_worker_dir }}" + msg: "Syncing local dev files to {{ inventory_hostname }} at {{ sync_dir }}" tasks: - - name: Ensure python3-pip is installed - ansible.builtin.apt: - name: python3-pip - state: present - update_cache: yes - become: yes - - - name: Check if yt-dlp is installed - ansible.builtin.command: which yt-dlp - register: ytdlp_check - changed_when: false - failed_when: false - become: yes - become_user: "{{ ansible_user }}" - - - name: Install yt-dlp if not found - ansible.builtin.command: python3 -m pip install -U "yt-dlp[default]" --break-system-packages - when: ytdlp_check.rc != 0 - become: yes - become_user: "{{ ansible_user }}" - - - name: Sync thrift_model directory to workers + - name: Sync thrift_model directory ansible.posix.synchronize: src: ../thrift_model/ - dest: "{{ airflow_worker_dir }}/thrift_model/" + dest: "{{ sync_dir }}/thrift_model/" rsync_opts: - "--delete" - "--exclude=.DS_Store" @@ -46,10 +27,10 @@ become: yes become_user: "{{ ansible_user }}" - - name: Sync pangramia package to workers + - name: Sync pangramia package ansible.posix.synchronize: src: ../pangramia/ - dest: "{{ airflow_worker_dir }}/pangramia/" + dest: "{{ sync_dir }}/pangramia/" rsync_opts: - "--delete" - "--exclude=.DS_Store" @@ -60,10 +41,10 @@ become: yes become_user: "{{ ansible_user }}" - - name: Sync ytops_client directory to workers + - name: Sync ytops_client directory ansible.posix.synchronize: src: ../ytops_client/ - dest: "{{ airflow_worker_dir }}/ytops_client/" + dest: "{{ sync_dir }}/ytops_client/" rsync_opts: - "--delete" - "--exclude=.DS_Store" @@ -74,10 +55,10 @@ become: yes become_user: "{{ ansible_user }}" - - name: Sync policies directory to workers + - name: Sync policies directory ansible.posix.synchronize: src: ../policies/ - dest: "{{ airflow_worker_dir }}/policies/" + dest: "{{ sync_dir }}/policies/" rsync_opts: - "--delete" - "--exclude=.DS_Store" @@ -88,22 +69,33 @@ become: yes become_user: "{{ ansible_user }}" - - name: Ensure bin directory exists on workers for client utilities + - name: Sync ytdlp.json + ansible.posix.synchronize: + src: ../ytdlp.json + dest: "{{ sync_dir }}/ytdlp.json" + perms: yes + become: yes + become_user: "{{ ansible_user }}" + + - name: Ensure bin directory exists for client utilities ansible.builtin.file: - path: "{{ airflow_worker_dir }}/bin" + path: "{{ sync_dir }}/bin" state: directory mode: '0755' become: yes become_user: "{{ ansible_user }}" - - name: Sync client utility scripts to workers + - name: Sync client utility scripts ansible.posix.synchronize: src: "../{{ item }}" - dest: "{{ airflow_worker_dir }}/{{ item }}" + dest: "{{ sync_dir }}/{{ item }}" perms: yes loop: - "cli.config" - "package_client.py" + - "setup.py" - "bin/ytops-client" + - "bin/build-yt-dlp-image" + - "VERSION.client" become: yes become_user: "{{ ansible_user }}" diff --git a/ansible/playbook-worker.yml b/ansible/playbook-worker.yml index c9cc841..cb26110 100644 --- a/ansible/playbook-worker.yml +++ b/ansible/playbook-worker.yml @@ -282,6 +282,120 @@ become: yes become_user: "{{ ansible_user }}" + - name: Install base system packages for tools + ansible.builtin.apt: + name: + - unzip + - wget + - xz-utils + state: present + update_cache: yes + become: yes + + - name: Install required Python packages + ansible.builtin.pip: + name: + - python-dotenv + - aria2p + - tabulate + - redis + - PyYAML + - aiothrift + - PySocks + state: present + extra_args: --break-system-packages + become: yes + + - name: Install pinned Python packages + ansible.builtin.pip: + name: + - brotli==1.1.0 + - certifi==2025.10.05 + - curl-cffi==0.13.0 + - mutagen==1.47.0 + - pycryptodomex==3.23.0 + - secretstorage==3.4.0 + - urllib3==2.5.0 + - websockets==15.0.1 + state: present + extra_args: --break-system-packages + become: yes + + - name: Upgrade yt-dlp and bgutil provider + ansible.builtin.shell: | + set -e + python3 -m pip install -U --pre "yt-dlp[default,curl-cffi]" --break-system-packages + python3 -m pip install --no-cache-dir -U bgutil-ytdlp-pot-provider --break-system-packages + args: + warn: false + become: yes + changed_when: true + + - name: Check for FFmpeg + stat: + path: /usr/local/bin/ffmpeg + register: ffmpeg_binary + become: yes + + - name: Install FFmpeg + when: not ffmpeg_binary.stat.exists + become: yes + block: + - name: Create ffmpeg directory + ansible.builtin.file: + path: /opt/ffmpeg + state: directory + mode: '0755' + + - name: Download and unarchive FFmpeg + ansible.builtin.unarchive: + src: "https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz" + dest: /opt/ffmpeg + remote_src: yes + extra_opts: [--strip-components=1] + + - name: Symlink ffmpeg and ffprobe + ansible.builtin.file: + src: "/opt/ffmpeg/bin/{{ item }}" + dest: "/usr/local/bin/{{ item }}" + state: link + force: yes + loop: + - ffmpeg + - ffprobe + + - name: Check for Deno + stat: + path: /usr/local/bin/deno + register: deno_binary + become: yes + + - name: Install Deno + when: not deno_binary.stat.exists + become: yes + block: + - name: Download and unarchive Deno + ansible.builtin.unarchive: + src: https://github.com/denoland/deno/releases/latest/download/deno-x86_64-unknown-linux-gnu.zip + dest: /usr/local/bin/ + remote_src: yes + mode: '0755' + + - name: Check if ytops_client requirements.txt exists + stat: + path: "{{ airflow_worker_dir }}/ytops_client/requirements.txt" + register: ytops_client_reqs + become: yes + become_user: "{{ ansible_user }}" + + - name: Install dependencies from ytops_client/requirements.txt + ansible.builtin.pip: + requirements: "{{ airflow_worker_dir }}/ytops_client/requirements.txt" + state: present + extra_args: --break-system-packages + when: ytops_client_reqs.stat.exists + become: yes + # Include Docker health check - name: Include Docker health check tasks include_tasks: tasks/docker_health_check.yml diff --git a/bin/build-yt-dlp-image b/bin/build-yt-dlp-image new file mode 100755 index 0000000..f3f6858 --- /dev/null +++ b/bin/build-yt-dlp-image @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Script to build and tag the yt-dlp Docker image. + +set -e + +SCRIPT_DIR=$(dirname "$(realpath "$0")") +PROJECT_ROOT=$(realpath "$SCRIPT_DIR/..") +DOCKERFILE_DIR="$PROJECT_ROOT/ytops_client/youtube-dl" +IMAGE_NAME=${1:-"ytops/yt-dlp"} + +# The default version is 'latest'. If a release version file exists, use that for tagging. +VERSION="latest" +VERSION_FILE="$DOCKERFILE_DIR/release-versions/latest.txt" + +if [ -f "$VERSION_FILE" ]; then + VERSION=$(cat "$VERSION_FILE") + echo "Found version: $VERSION from $VERSION_FILE" +fi + +echo "Building Docker image: $IMAGE_NAME:$VERSION" +echo "Dockerfile location: $DOCKERFILE_DIR" + +docker build -t "$IMAGE_NAME:$VERSION" "$DOCKERFILE_DIR" + +if [ "$VERSION" != "latest" ]; then + echo "Also tagging as: $IMAGE_NAME:latest" + docker tag "$IMAGE_NAME:$VERSION" "$IMAGE_NAME:latest" +fi + +echo "Build complete." +echo "Image tags created:" +echo " - $IMAGE_NAME:$VERSION" +if [ "$VERSION" != "latest" ]; then + echo " - $IMAGE_NAME:latest" +fi diff --git a/bin/ytops-client b/bin/ytops-client index 46138aa..513fd14 100755 --- a/bin/ytops-client +++ b/bin/ytops-client @@ -1,10 +1,14 @@ -#!/bin/sh -set -e -# Find the directory where this script is located. -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -# Go up one level to the project root. -PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" -# Set PYTHONPATH to include the project root, so we can import 'ytops_client' -export PYTHONPATH="$PROJECT_ROOT${PYTHONPATH:+:$PYTHONPATH}" -# Execute the Python CLI script as a module to handle relative imports -exec python3 -m ytops_client.cli "$@" +#!/usr/bin/env python3 +import os +import sys + +# Ensure the project root is in the Python path +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, '..')) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from ytops_client.cli import main + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cli.auth.config b/cli.auth.config new file mode 100644 index 0000000..7bb7f64 --- /dev/null +++ b/cli.auth.config @@ -0,0 +1,17 @@ +# This is a yt-dlp configuration file. +# It contains one command-line option per line. + +#--no-progress +--format-sort "res,ext:mp4:m4a" +--recode-video mp4 +--no-playlist +--no-overwrites +--continue +--output "%(extractor)s - %(title)s.%(ext)s" +--no-mtime +--verbose +#--simulate +# Performance options +#--no-resize-buffer +#--buffer-size 4M +#--concurrent-fragments 8 diff --git a/cli.config b/cli.config index 2703daa..7bb7f64 100644 --- a/cli.config +++ b/cli.config @@ -1,41 +1,17 @@ -# yt-dlp configuration for format_download.py +# This is a yt-dlp configuration file. +# It contains one command-line option per line. -# Continue on broken downloads -#--continue - -# Do not simulate ---no-simulate - -# Do not write info.json file (we already have it) ---no-write-info-json - -# Continue on download errors ---ignore-errors - -# Do not download playlist +#--no-progress +--format-sort "res,ext:mp4:m4a" +--recode-video mp4 --no-playlist - -# Retry fragments 10 times ---fragment-retries 10 - -# Use a fixed buffer size to stabilize throughput and avoid traffic shaping ---no-resize-buffer ---buffer-size 4M - -# Socket timeout ---socket-timeout 15 - -# Sleep interval ---min-sleep-interval 5 ---max-sleep-interval 10 - -# Progress ---progress - -# Merge to mp4 by default ---merge-output-format mp4 - -# Don't use "NA" in filenames if metadata is missing ---output-na-placeholder "" - ---no-part +--no-overwrites +--continue +--output "%(extractor)s - %(title)s.%(ext)s" +--no-mtime +--verbose +#--simulate +# Performance options +#--no-resize-buffer +#--buffer-size 4M +#--concurrent-fragments 8 diff --git a/cli.download.config b/cli.download.config new file mode 100644 index 0000000..7bb7f64 --- /dev/null +++ b/cli.download.config @@ -0,0 +1,17 @@ +# This is a yt-dlp configuration file. +# It contains one command-line option per line. + +#--no-progress +--format-sort "res,ext:mp4:m4a" +--recode-video mp4 +--no-playlist +--no-overwrites +--continue +--output "%(extractor)s - %(title)s.%(ext)s" +--no-mtime +--verbose +#--simulate +# Performance options +#--no-resize-buffer +#--buffer-size 4M +#--concurrent-fragments 8 diff --git a/policies/10_direct_docker_auth_simulation.yaml b/policies/10_direct_docker_auth_simulation.yaml new file mode 100644 index 0000000..f0dac53 --- /dev/null +++ b/policies/10_direct_docker_auth_simulation.yaml @@ -0,0 +1,119 @@ +# Policy: Continuous Authentication Simulation via Direct Docker Exec +# +# This policy simulates a continuous stream of info.json fetch requests using +# the 'direct_docker_cli' mode. It calls a yt-dlp command inside a running +# Docker container, passing in a batch file and configuration. +# +# It uses a pool of managed profiles, locking one for each BATCH of requests. +# The host orchestrator prepares files, and docker exec runs yt-dlp. The container +# itself does not need to be Redis-aware. +# +name: direct_docker_auth_simulation + +settings: + mode: fetch_only + orchestration_mode: direct_docker_cli + profile_mode: from_pool_with_lock + urls_file: "inputfiles/urls.sky3.txt" + # The save directory MUST be inside the docker_host_mount_path for the download + # simulation to be able to find the files. + save_info_json_dir: "run/docker_mount/fetched_info_jsons/direct_docker_simulation" + +execution_control: + workers: 1 + # How long a worker should pause if it cannot find an available profile to lock. + worker_polling_interval_seconds: 1 + # No sleep between tasks; throughput is controlled by yt-dlp performance and profile availability. + +info_json_generation_policy: + profile_prefix: "user1" + +direct_docker_cli_policy: + # Which simulation environment's profiles to use for locking. + use_profile_env: "auth" + + # If true, a worker will try to lock a different profile than the one it just used. + avoid_immediate_profile_reuse: true + # How long the worker should wait for a different profile before re-using the same one. + avoid_reuse_max_wait_seconds: 5 + + # NOTE on Rate Limits: With the default yt-dlp settings, the rate limit for guest + # sessions is ~300 videos/hour (~1000 webpage/player requests per hour). + # For accounts, it is ~2000 videos/hour (~4000 webpage/player requests per hour). + # The enforcer policy (e.g., 8_unified_simulation_enforcer.yaml) should be + # configured to respect these limits via rotation and rest periods. + + # If true, extract the visitor_id from yt-dlp logs, save it per-profile, + # and inject it into subsequent requests for that profile. + #track_visitor_id: true + + # --- Docker Execution Settings --- + docker_image_name: "ytops/yt-dlp" # Image to use for `docker run` + docker_network_name: "airflow_proxynet" + # IMPORTANT: This path on the HOST will be mounted into the container at `docker_container_mount_path`. + docker_host_mount_path: "run/docker_mount" + docker_container_mount_path: "/config" # The mount point inside the container + + # Host path for persisting cache data (e.g., cookies, sigfuncs) between runs. + docker_host_cache_path: ".cache/direct_docker_simulation" + # Path inside the container where the cache is mounted. Should match HOME/.cache + docker_container_cache_path: "/config/.cache" + + # If true, create and use a persistent cookie jar per profile inside the cache dir. + # use_cookies: true + + # --- User-Agent Generation --- + # Template for generating User-Agent strings for new profiles. + # The '{major_version}' will be replaced by a version string. + user_agent_template: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/{major_version}.0.0.0 Safari/537.36" + # Range of Chrome major versions to use for the template. A range suitable for TV devices. + user_agent_version_range: [110, 120] + + batch_size: 25 + + # A base config file can be used, with overrides applied from the policy. + # The orchestrator will inject 'proxy', 'batch-file', and 'output' keys into the overrides. + ytdlp_config_file: "cli.auth.config" + ytdlp_config_overrides: + skip-download: true + write-info-json: true + no-write-subs: true + no-color: true + ignore-errors: true + use-extractors: ["youtube"] + + ytdlp_raw_args: + - '--extractor-args "youtube:formats=duplicate;jsc_trace=true;player_client=tv_simply;pot_trace=true;skip=translated_subs,hls"' + - '--extractor-args "youtubepot-bgutilhttp:base_url=http://172.17.0.1:4416"' + - '--sleep-requests 0.75' + # --retry-sleep linear=1::2' + + # --- Live Error Parsing Rules --- + # These regex patterns are checked against yt-dlp's stderr in real-time. + # If a fatal error is detected, immediately ban the profile to stop the container + # and prevent further errors in the same batch. + ban_on_fatal_error_in_batch: true + fatal_error_patterns: + - "Sign in to confirm you’re not a bot" + - "rate-limited by YouTube" + - "This content isn't available, try again later" + - "HTTP Error 502" + + tolerated_error_patterns: + - "HTTP Error 429" + - "The uploader has not made this video available in your country" + - "This video has been removed by the uploader" + - "Private video" + - "This is a private video" + - "Video is private" + - "Video unavailable" + - "account associated with this video has been terminated" + - "members-only content" + - "Sign in to confirm your age" + + # Template for renaming the final info.json. + rename_file_template: "{video_id}-{profile_name}-{proxy}.info.json" + +simulation_parameters: + auth_env: "sim_auth" + download_env: "sim_download" diff --git a/policies/11_direct_docker_download_simulation.yaml b/policies/11_direct_docker_download_simulation.yaml new file mode 100644 index 0000000..811b684 --- /dev/null +++ b/policies/11_direct_docker_download_simulation.yaml @@ -0,0 +1,104 @@ +# Policy: Continuous Download Simulation via Direct Docker Exec +# +# This policy simulates a continuous stream of downloads using the +# 'direct_docker_cli' mode with `mode: download_only`. It finds task files +# (info.jsons) in a directory and invokes a yt-dlp command inside a running +# Docker container to perform the download. +# +name: direct_docker_download_simulation + +settings: + mode: download_only + orchestration_mode: direct_docker_cli + profile_mode: from_pool_with_lock + # This directory should contain info.json files generated by an auth simulation, + # like `10_direct_docker_auth_simulation`. + # It MUST be inside the docker_host_mount_path. + info_json_dir: "run/docker_mount/fetched_info_jsons/direct_docker_simulation" + # Regex to extract the profile name from a task filename. The first capture + # group is used. This is crucial for the task-first locking strategy. + # It looks for a component that starts with 'user' between two hyphens. + profile_extraction_regex: '^.+?-(user[^-]+)-' + +execution_control: + workers: 1 + # How long a worker should pause if it cannot find an available profile or task. + worker_polling_interval_seconds: 1 + +download_policy: + profile_prefix: "user1" + # Default cooldown in seconds if not specified by the enforcer in Redis. + # The value from Redis (set via `unlock_cooldown_seconds` in the enforcer policy) + # will always take precedence. This is a fallback. + # Can be an integer (e.g., 1) or a range (e.g., [1, 3]). + default_unlock_cooldown_seconds: 1 + # If true, check if the download URL in the info.json is expired before + # attempting to download. This is enabled by default. + check_url_expiration: true + # --- Airflow Integration --- + # If true, move downloaded media and info.json to a timestamped, video-id-based + # directory structure that the Airflow DAGs can process. + output_to_airflow_ready_dir: true + airflow_ready_dir_base_path: "downloadfiles/videos/ready" + +simulation_parameters: + download_env: "sim_download" + +direct_docker_cli_policy: + # Which simulation environment's profiles to use for locking. + use_profile_env: "download" + + # If true, a worker will try to lock a different profile than the one it just used. + # This is disabled for downloads, as the cooldown mechanism is sufficient. + avoid_immediate_profile_reuse: false + # How long the worker should wait for a different profile before re-using the same one. + avoid_reuse_max_wait_seconds: 5 + + # NOTE on Rate Limits: With the default yt-dlp settings, the rate limit for guest + # sessions is ~300 videos/hour (~1000 webpage/player requests per hour). + # For accounts, it is ~2000 videos/hour (~4000 webpage/player requests per hour). + # This enforcer policy should be configured to respect these limits via + # rotation and rest periods. + + # --- Docker Execution Settings --- + docker_image_name: "ytops/yt-dlp" + docker_network_name: "airflow_proxynet" + # Host path mounted into the container for task files (info.json, config). + # IMPORTANT: This must be the SAME host path used for the `info_json_dir` above, + # or a parent directory of it, so the container can see the task files. + docker_host_mount_path: "run/docker_mount" + docker_container_mount_path: "/config" + + # Path on the HOST where downloaded files will be saved. + docker_host_download_path: "downloaded_media/direct_docker_simulation" + # Path inside the CONTAINER where `docker_host_download_path` is mounted. + docker_container_download_path: "/downloads" + + # A base config file can be used, with overrides applied from the policy. + # The orchestrator will inject 'proxy', 'load-info-json', and 'output' keys into the overrides. + ytdlp_config_file: "cli.download.config" + ytdlp_config_overrides: + format: "299-dashy/298-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy,140-dashy/140-dashy-0/140" + #format: "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best" + no-resize-buffer: true + buffer-size: "4M" + concurrent-fragments: 8 + + ytdlp_raw_args: [] + + # --- Live Error Parsing Rules --- + # If a fatal error is detected, immediately ban the profile to stop the container. + ban_on_fatal_error_in_batch: true + fatal_error_patterns: + - "HTTP Error 403" + - "HTTP Error 502" + + tolerated_error_patterns: + - "timed out" + - "Timeout" + - "connection reset by peer" + - "Invalid data found when processing input" + - "Error opening input files" + +simulation_parameters: + download_env: "sim_download" diff --git a/policies/1_fetch_only_policies.yaml b/policies/1_fetch_only_policies.yaml deleted file mode 100644 index 50d3ab8..0000000 --- a/policies/1_fetch_only_policies.yaml +++ /dev/null @@ -1,155 +0,0 @@ -# This file contains policies for testing only the info.json generation step. -# No downloads are performed. - ---- -# Policy: Basic fetch-only test for a TV client. -# This policy uses a single, static profile and has a rate limit to avoid being -# too aggressive. It saves the generated info.json files to a directory. -name: tv_downgraded_single_profile - -settings: - mode: fetch_only - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/tv_downgraded" - # Use a single, static profile for all requests. - profile_prefix: "tv_downgraded_user" - profile_mode: per_worker # With 1 worker, this is effectively a single profile. - -execution_control: - run_until: { cycles: 1 } - workers: 1 - sleep_between_tasks: { min_seconds: 5, max_seconds: 10 } - -info_json_generation_policy: - client: tv_downgraded - # Safety rate limit: 450 requests per hour (7.5 req/min) - rate_limits: - per_ip: { max_requests: 450, per_minutes: 60 } - ---- -# Policy: Fetch-only test for an Android client using a cookie file. -# This demonstrates how to pass a cookie file for authenticated requests. -# It uses a single profile and stops if it encounters too many errors. -name: android_sdkless_with_cookies - -settings: - mode: fetch_only - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/android_sdkless" - profile_prefix: "android_user_with_cookies" - profile_mode: per_worker - -execution_control: - run_until: { cycles: 1 } # Run through the URL list once. - workers: 1 - sleep_between_tasks: { min_seconds: 2, max_seconds: 4 } - -info_json_generation_policy: - client: android_sdkless - # Pass per-request parameters. This is how you specify a cookie file. - request_params: - cookies_file_path: "/path/to/your/android_cookies.txt" - -stop_conditions: - # Stop if we get more than 5 errors in any 10-minute window. - on_error_rate: { max_errors: 5, per_minutes: 10 } - ---- -# Policy: TV Fetch with Profile Cooldown (Pipeline Stage 1) -# Fetches info.json files using the 'tv' client. Each profile is limited -# to a certain number of requests before it is put into a cooldown period. -# The output of this policy is intended to be used by a 'download_only' policy. -name: tv_fetch_with_cooldown - -settings: - mode: fetch_only - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - # Save the generated files to this directory for the download task to find. - save_info_json_dir: "live_jsons_tv" - profile_management: - prefix: "tv_user" - initial_pool_size: 10 - auto_expand_pool: true - max_requests_per_profile: 60 - sleep_minutes_on_exhaustion: 60 - -execution_control: - run_until: { cycles: 1 } - workers: 1 - sleep_between_tasks: { min_seconds: 2, max_seconds: 5 } - -info_json_generation_policy: - client: "tv" - request_params: - context_reuse_policy: { enabled: true, max_age_seconds: 86400 } - ---- -# Policy: MWeb with client rotation and rate limits. -# This demonstrates a more complex scenario with multiple clients and strict -# rate limiting, useful for simulating sophisticated user behavior. -name: mweb_client_rotation_and_rate_limits - -settings: - mode: fetch_only - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - # Use the dynamic profile pool management system. - profile_management: - prefix: "mweb_user" - initial_pool_size: 10 - max_requests_per_profile: 100 - sleep_minutes_on_exhaustion: 15 - -execution_control: - run_until: { cycles: 1 } - workers: 10 - sleep_between_tasks: { min_seconds: 2, max_seconds: 5 } - -info_json_generation_policy: - # Enforce strict rate limits for both the entire IP and each individual profile. - rate_limits: - per_ip: { max_requests: 120, per_minutes: 10 } - per_profile: { max_requests: 10, per_minutes: 10 } - - # Rotate between a primary client (mweb) and a refresh client (web_camoufox) - # to keep sessions fresh. - client_rotation_policy: - major_client: "mweb" - major_client_params: - context_reuse_policy: { enabled: true, max_age_seconds: 1800 } - refresh_client: "web_camoufox" - refresh_every: { requests: 20, minutes: 10 } - ---- -# Policy: TV Simply, fetch-only test with per-worker profile rotation. -# Fetches info.json using tv_simply with multiple workers. Each worker gets a -# unique profile that is retired and replaced with a new generation after a -# set number of requests. -name: tv_simply_fetch_rotation - -settings: - mode: fetch_only - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/tv_simply_rotation" - # Use the modern profile management system. - profile_mode: per_worker_with_rotation - profile_management: - prefix: "tv_simply_user" - # Rotate to a new profile generation after 250 requests. - max_requests_per_profile: 250 - -execution_control: - run_until: { cycles: 1 } # Run through the URL list once. - workers: 8 # Run with 8 parallel workers. - sleep_between_tasks: { min_seconds: 2, max_seconds: 5 } - # Optional: Override the assumed time for a fetch task to improve rate estimation. - # The default is 3 seconds for fetch_only mode. - # assumptions: - # fetch_task_duration: 2.5 - -info_json_generation_policy: - client: tv_simply diff --git a/policies/2_download_only_policies.yaml b/policies/2_download_only_policies.yaml deleted file mode 100644 index a5feb54..0000000 --- a/policies/2_download_only_policies.yaml +++ /dev/null @@ -1,58 +0,0 @@ -# This file contains policies for testing only the download step from -# existing info.json files. No new info.json files are generated. - ---- -# Policy: Basic profile-aware download test. -# This policy reads info.json files from a directory, groups them by a profile -# name extracted from the filename, and downloads them using multiple workers. -# Each worker handles one or more profiles sequentially. -name: basic_profile_aware_download - -settings: - mode: download_only - info_json_dir: "prefetched_info_jsons" - # Regex to extract profile names from filenames like '...-VIDEOID-my_profile_name.json'. - profile_extraction_regex: ".*-[a-zA-Z0-9_-]{11}-(.+)\\.json" - -execution_control: - run_until: { cycles: 1 } - # 'auto' sets workers to the number of profiles, capped by auto_workers_max. - workers: auto - auto_workers_max: 8 - # This sleep applies between each file downloaded by a single profile. - sleep_between_tasks: { min_seconds: 1, max_seconds: 2 } - -download_policy: - formats: "18,140,299/298/137/136/135/134/133" - downloader: "aria2c" - downloader_args: "aria2c:-x 4 -k 1M" - extra_args: "--cleanup --output-dir /tmp/downloads" - # This sleep applies between formats of a single video. - sleep_between_formats: { min_seconds: 0, max_seconds: 0 } - ---- -# Policy: Continuous download from a folder (Pipeline Stage 2). -# This policy watches a directory for new info.json files and processes them -# as they appear. It is designed to work as the second stage of a pipeline, -# consuming files generated by a 'fetch_only' policy like 'tv_fetch_with_cooldown'. -name: continuous_watch_download - -settings: - mode: download_only - info_json_dir: "live_info_jsons" - directory_scan_mode: continuous - mark_processed_files: true # Rename files to *.processed to avoid re-downloading. - max_files_per_cycle: 50 # Process up to 50 new files each time it checks. - sleep_if_no_new_files_seconds: 15 - -execution_control: - # Note: For 'continuous' mode, a time-based run_until (e.g., {minutes: 120}) - # is more typical. {cycles: 1} will cause it to scan the directory once - # for new files, process them, and then exit. - run_until: { cycles: 1 } - workers: 4 # Use a few workers to process files in parallel. - sleep_between_tasks: { min_seconds: 0, max_seconds: 0 } - -download_policy: - formats: "18,140" - extra_args: "--cleanup --output-dir /tmp/downloads" diff --git a/policies/3_full_stack_policies.yaml b/policies/3_full_stack_policies.yaml deleted file mode 100644 index 5ac43a0..0000000 --- a/policies/3_full_stack_policies.yaml +++ /dev/null @@ -1,158 +0,0 @@ -# This file contains policies for full-stack tests, which include both -# info.json generation and the subsequent download step. - ---- -# Policy: TV client with profile rotation. -# This test uses multiple parallel workers. Each worker gets its own profile -# that is automatically rotated (e.g., from tv_user_0_0 to tv_user_0_1) after -# a certain number of requests to simulate user churn. -name: tv_simply_profile_rotation - -settings: - mode: full_stack - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/tv_simply_rotation" - # Use the modern profile management system. - profile_mode: per_worker_with_rotation - profile_management: - prefix: "tv_simply" - # Rotate to a new profile generation after 250 requests. - max_requests_per_profile: 250 - -execution_control: - run_until: { cycles: 1 } - workers: 8 # Run with 8 parallel workers. - sleep_between_tasks: { min_seconds: 2, max_seconds: 5 } - # Optional: Override assumptions to improve rate estimation. - # assumptions: - # fetch_task_duration: 10 # Est. seconds to get info.json - # download_task_duration: 20 # Est. seconds to download all formats for one video - -info_json_generation_policy: - client: tv_simply - -download_policy: - formats: "18,140" - extra_args: "--cleanup --output-dir downloads/tv_simply_rotation" - proxy: "socks5://127.0.0.1:1087" - downloader: "aria2c" - downloader_args: "aria2c:-x 8 -k 1M" - sleep_between_formats: { min_seconds: 2, max_seconds: 2 } - -stop_conditions: - on_cumulative_403: { max_errors: 5, per_minutes: 2 } - ---- -# Policy: TV Simply, full-stack test with per-worker profile rotation. -# Generates info.json using tv_simply and immediately attempts to download. -# This combines the fetch and download steps into a single workflow. -name: tv_simply_full_stack_rotation - -settings: - mode: full_stack - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - profile_mode: per_worker_with_rotation - profile_management: - prefix: "tv_simply_worker" - max_requests_per_profile: 240 - -execution_control: - workers: 10 - run_until: { cycles: 1 } - sleep_between_tasks: { min_seconds: 5, max_seconds: 5 } - -info_json_generation_policy: - client: "tv_simply" - request_params: - context_reuse_policy: { enabled: false } - -download_policy: - formats: "18,140" - extra_args: "--output-dir downloads/tv_simply_downloads" - ---- -# Policy: MWeb client with multiple profiles, each with its own cookie file. -# This demonstrates how to run an authenticated test with a pool of accounts. -# The orchestrator will cycle through the cookie files, assigning one to each profile. -name: mweb_multi_profile_with_cookies - -settings: - mode: full_stack - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - # Use the dynamic profile pool management system. - profile_management: - prefix: "mweb_user" - initial_pool_size: 3 # Start with 3 profiles. - auto_expand_pool: true # Create new profiles if the initial 3 are all rate-limited. - max_requests_per_profile: 100 # Let each profile make 100 requests... - sleep_minutes_on_exhaustion: 15 # ...then put it to sleep for 15 minutes. - # Assign a different cookie file to each profile in the pool. - # The tool will cycle through this list. - cookie_files: - - "/path/to/your/mweb_cookies_0.txt" - - "/path/to/your/mweb_cookies_1.txt" - - "/path/to/your/mweb_cookies_2.txt" - -execution_control: - run_until: { cycles: 1 } - workers: 3 # Match workers to the number of initial profiles. - sleep_between_tasks: { min_seconds: 1, max_seconds: 3 } - -info_json_generation_policy: - client: mweb - # This client uses youtubei.js, which generates PO tokens. - -download_policy: - formats: "18,140" - extra_args: "--cleanup --output-dir /tmp/downloads" - ---- -# Policy: TV client with profile rotation and aria2c RPC download. -# This test uses multiple parallel workers. Each worker gets its own profile -# that is automatically rotated. Downloads are submitted to an aria2c daemon -# via its RPC interface. -name: tv_simply_profile_rotation_aria2c_rpc - -settings: - mode: full_stack - urls_file: "urls.txt" - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/tv_simply_rotation_aria" - profile_mode: per_worker_with_rotation - profile_management: - prefix: "tv_simply_aria" - max_requests_per_profile: 250 - -execution_control: - run_until: { cycles: 1 } - workers: 8 - sleep_between_tasks: { min_seconds: 2, max_seconds: 5 } - -info_json_generation_policy: - client: tv_simply - -download_policy: - formats: "18,140" - # Use the aria2c RPC downloader - downloader: "aria2c_rpc" - # RPC server connection details - aria_host: "localhost" - aria_port: 6800 - # aria_secret: "your_secret" # Uncomment and set if needed - # Set to true to wait for each download and get a success/fail result. - # This is the default and recommended for monitoring success/failure. - # Set to false for maximum submission throughput ("fire-and-forget"), - # but you will lose per-download status reporting. - aria_wait: true - # The output directory is on the aria2c host machine - output_dir: "/downloads/tv_simply_rotation_aria" - # Pass custom arguments to aria2c in yt-dlp format for better performance. - # -x: max connections per server, -k: min split size. - downloader_args: "aria2c:[-x 8, -k 1M]" - sleep_between_formats: { min_seconds: 1, max_seconds: 2 } - -stop_conditions: - on_cumulative_403: { max_errors: 5, per_minutes: 2 } diff --git a/policies/4_custom_scenarios.yaml b/policies/4_custom_scenarios.yaml deleted file mode 100644 index 8d648d2..0000000 --- a/policies/4_custom_scenarios.yaml +++ /dev/null @@ -1,126 +0,0 @@ -# This file contains custom policies for specific testing scenarios. - ---- -# Policy: Fetch info.json with visitor ID rotation. -# This policy uses a single worker to fetch info.json files for a list of URLs. -# It simulates user churn by creating a new profile (and thus a new visitor_id and POT) -# every 250 requests. A short sleep is used between requests. -name: fetch_with_visitor_id_rotation - -settings: - mode: fetch_only - urls_file: "urls.txt" # Placeholder, should be overridden with --set - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/visitor_id_rotation" - # Use the modern profile management system to rotate visitor_id. - profile_mode: per_worker_with_rotation - profile_management: - prefix: "visitor_rotator" - # Rotate to a new profile generation after 250 requests. - max_requests_per_profile: 250 - -execution_control: - run_until: { cycles: 1 } # Run through the URL list once. - workers: 1 # Run with a single worker thread. - # A short, fixed sleep between each info.json request. - sleep_between_tasks: { min_seconds: 0.75, max_seconds: 0.75 } - -info_json_generation_policy: - # Use a standard client. The server will handle token generation. - client: tv_simply - ---- -# Policy: Full-stack test with visitor ID rotation and test download. -# This policy uses a single worker to fetch info.json files for a list of URLs, -# and then immediately performs a test download (first 10KB) of specified formats. -# It simulates user churn by creating a new profile (and thus a new visitor_id and POT) -# every 250 requests. A short sleep is used between requests. -name: full_stack_with_visitor_id_rotation - -settings: - mode: full_stack - urls_file: "urls.txt" # Placeholder, should be overridden with --set - info_json_script: "bin/ytops-client get-info" - # Use the modern profile management system to rotate visitor_id. - profile_mode: per_worker_with_rotation - profile_management: - prefix: "visitor_rotator" - # Rotate to a new profile generation after 250 requests. - max_requests_per_profile: 250 - -execution_control: - run_until: { cycles: 1 } # Run through the URL list once. - workers: 1 # Run with a single worker thread. - # A short, fixed sleep between each info.json request. - sleep_between_tasks: { min_seconds: 0.75, max_seconds: 0.75 } - -info_json_generation_policy: - # Use a standard client. The server will handle token generation. - client: tv_simply - -download_policy: - formats: "299-dashy/298-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy" - downloader: "native-py" - extra_args: '--test --cleanup' - output_dir: "downloads/fetch_and_test" - sleep_between_formats: { min_seconds: 6, max_seconds: 6 } - ---- -# Policy: Download-only test from a fetch folder (Batch Mode). -# This policy scans a directory of existing info.json files once, and performs -# a test download (first 10KB) for specific formats. It is designed to run as -# a batch job after a 'fetch_only' policy has completed. -name: download_only_test_from_fetch_folder - -settings: - mode: download_only - # Directory of info.json files to process. - info_json_dir: "fetched_info_jsons/visitor_id_rotation" # Assumes output from 'fetch_with_visitor_id_rotation' - -execution_control: - run_until: { cycles: 1 } # Run through the info.json directory once. - workers: 1 # Run with a single worker thread. - # A longer, randomized sleep between processing each info.json file. - sleep_between_tasks: { min_seconds: 5, max_seconds: 10 } - -download_policy: - # A specific list of video-only DASH formats to test. - formats: "299-dashy/298-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy" - downloader: "native-py" - # Pass extra arguments to perform a "test" download. - extra_args: '--test --cleanup' - output_dir: "downloads/dash_test" - ---- -# Policy: Live download from a watch folder (Continuous Mode). -# This policy continuously watches a directory for new info.json files and -# processes them as they appear. It is designed to work as the second stage -# of a pipeline, consuming files generated by a 'fetch_only' policy. -name: live_download_from_watch_folder - -settings: - mode: download_only - info_json_dir: "live_info_json" # A different directory for the live pipeline - directory_scan_mode: continuous - mark_processed_files: true # Rename files to *.processed to avoid re-downloading. - max_files_per_cycle: 50 # Process up to 50 new files each time it checks. - sleep_if_no_new_files_seconds: 15 - -execution_control: - # For 'continuous' mode, a time-based run_until is typical. - # {cycles: 1} will scan once, process new files, and exit. - # To run for 2 hours, for example, use: run_until: { minutes: 120 } - run_until: { cycles: 1 } - workers: 4 # Use a few workers to process files in parallel. - # sleep_between_tasks controls the pause between processing different info.json files. - # To pause before each download attempt starts, use 'pause_before_download_seconds' - # in the download_policy section below. - sleep_between_tasks: { min_seconds: 0, max_seconds: 0 } - -download_policy: - formats: "299-dashy/298-dashy/137-dashy/136-dashy/135-dashy/134-dashy/133-dashy" - downloader: "native-py" - # Example: Pause for a few seconds before starting each download attempt. - # pause_before_download_seconds: 2 - extra_args: '--test --cleanup' - output_dir: "downloads/live_dash_test" diff --git a/policies/5_ban_test_policies.yaml b/policies/5_ban_test_policies.yaml deleted file mode 100644 index a901fce..0000000 --- a/policies/5_ban_test_policies.yaml +++ /dev/null @@ -1,84 +0,0 @@ -# This file contains policies for testing ban rates and profile survival -# under high request counts. - ---- -# Policy: Single Profile Ban Test (500 Requests) -# This policy uses a single worker and a single, non-rotating profile to make -# 500 consecutive info.json requests. It is designed to test if and when a -# single profile/visitor_id gets banned or rate-limited by YouTube. -# -# It explicitly disables the server's automatic visitor ID rotation to ensure -# the same identity is used for all requests. -# -# The test will stop if it encounters 3 errors within any 1-minute window, -# or a total of 8 errors within any 60-minute window. -name: single_profile_ban_test_500 - -settings: - mode: fetch_only - urls_file: "urls.txt" # Override with --set settings.urls_file=... - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/ban_test_single_profile" - # Use one worker with one profile that does not rotate automatically. - profile_mode: per_worker_with_rotation - profile_management: - prefix: "ban_test_user" - # Set a high request limit to prevent the orchestrator from rotating the profile. - max_requests_per_profile: 1000 - -execution_control: - run_until: { requests: 500 } # Stop after 500 total requests. - workers: 1 - sleep_between_tasks: { min_seconds: 1, max_seconds: 2 } - -info_json_generation_policy: - client: "tv_simply" # A typical client for this kind of test. - # Explicitly disable the server's visitor ID rotation mechanism. - request_params: - session_params: - visitor_rotation_threshold: 0 - -stop_conditions: - # Stop if we get 3 or more errors in any 1-minute window (rapid failure). - on_error_rate: { max_errors: 3, per_minutes: 1 } - # Stop if we get 8 or more 403 errors in any 60-minute window (ban detection). - on_cumulative_403: { max_errors: 8, per_minutes: 60 } - ---- -# Policy: Multi-Profile Survival Test -# This policy uses 5 parallel workers, each with its own unique profile. -# It tests whether using multiple profiles with the server's default automatic -# visitor ID rotation (every 250 requests) can sustain a high request rate -# without getting banned. -# -# The test will run until 1250 total requests have been made (250 per worker), -# which should trigger one rotation for each profile. -name: multi_profile_survival_test - -settings: - mode: fetch_only - urls_file: "urls.txt" # Override with --set settings.urls_file=... - info_json_script: "bin/ytops-client get-info" - save_info_json_dir: "fetched_info_jsons/ban_test_multi_profile" - # Use 5 workers, each getting its own rotating profile. - profile_mode: per_worker_with_rotation - profile_management: - prefix: "survival_test_user" - # Use the default rotation threshold of 250 requests per profile. - max_requests_per_profile: 250 - -execution_control: - run_until: { requests: 1250 } # 5 workers * 250 requests/rotation = 1250 total. - workers: 5 - sleep_between_tasks: { min_seconds: 1, max_seconds: 2 } - -info_json_generation_policy: - client: "tv_simply" - # No request_params are needed here; we want to use the server's default - # visitor ID rotation behavior. - -stop_conditions: - # Stop if we get 3 or more errors in any 1-minute window (rapid failure). - on_error_rate: { max_errors: 3, per_minutes: 1 } - # Stop if we get 8 or more 403 errors in any 60-minute window (ban detection). - on_cumulative_403: { max_errors: 8, per_minutes: 60 } diff --git a/policies/6_profile_setup_policy.yaml b/policies/6_profile_setup_policy.yaml new file mode 100644 index 0000000..56b47f9 --- /dev/null +++ b/policies/6_profile_setup_policy.yaml @@ -0,0 +1,27 @@ +# Configuration for setting up profiles for a simulation or test run. +# This file is used by the `bin/ytops-client setup-profiles` command. +# It contains separate blocks for authentication and download simulations. + +simulation_parameters: + # --- Common Redis settings for all tools --- + # The environment name ('env') is now specified in each setup block below. + env_file: ".env" # Optional: path to a .env file. + +# --- Profile setup for the AUTHENTICATION simulation --- +auth_profile_setup: + env: "sim_auth" + cleanup_before_run: true + pools: + - prefix: "user1" + proxy: "sslocal-rust-1092:1092" + count: 1 + +# --- Profile setup for the DOWNLOAD simulation --- +download_profile_setup: + env: "sim_download" + cleanup_before_run: true + pools: + - prefix: "user1" + proxy: "sslocal-rust-1092:1092" + count: 1 + diff --git a/policies/8_unified_simulation_enforcer.yaml b/policies/8_unified_simulation_enforcer.yaml new file mode 100644 index 0000000..2c1b9c6 --- /dev/null +++ b/policies/8_unified_simulation_enforcer.yaml @@ -0,0 +1,162 @@ +# Policy for the unified simulation enforcer. +# This file is used by `bin/ytops-client policy-enforcer --live` to manage +# both the authentication and download simulation environments from a single process. + +# Policy for the unified simulation enforcer. +# This file is used by `bin/ytops-client policy-enforcer --live` to manage +# both the authentication and download simulation environments from a single process. + +simulation_parameters: + # --- Common Redis settings for all tools --- + # The enforcer will connect to two different Redis environments (key prefixes) + # based on these settings, applying the corresponding policies to each. + env_file: ".env" + auth_env: "sim_auth" + download_env: "sim_download" + + # How often the enforcer should wake up and apply all policies. + interval_seconds: 2 + +# --- Policies for the Authentication Simulation --- +auth_policy_enforcer_config: + # Ban if 2 failures occur within a 1-minute window. + #ban_on_failures: 2 + #ban_on_failures_window_minutes: 1 + + # The standard rest policy is disabled, as rotation is handled by the profile group. + profile_prefix: "user1" + + # New rate limit policy to enforce requests-per-hour limits. + # For guest sessions, the limit is ~300 videos/hour. + rate_limit_requests: 280 + rate_limit_window_minutes: 60 + rate_limit_rest_duration_minutes: 5 + + rest_after_requests: 0 + rest_duration_minutes: 10 + + # NOTE on Rate Limits: With the default yt-dlp settings, the rate limit for guest + # sessions is ~300 videos/hour (~1000 webpage/player requests per hour). + # For accounts, it is ~2000 videos/hour (~4000 webpage/player requests per hour). + # The settings below should be configured to respect these limits. + + # A group of profiles that are managed together. + # The enforcer will ensure that no more than `max_active_profiles` from this + # group are in the ACTIVE state at any time. + profile_groups: + - name: "exclusive_auth_profiles" + prefix: "user1" + # Enforce that only 1 profile from this group can be active at a time. + max_active_profiles: 1 + # After an active profile has been used for this many requests, it will be + # rotated out and put into a RESTING state. + rotate_after_requests: 25 + # How long a profile rests after being rotated out. + rest_duration_minutes_on_rotation: 1 + + # If true, no new profile in this group will be activated while another + # one is in the 'waiting_downloads' state. + defer_activation_if_any_waiting: true + + # --- New settings for download wait feature --- + # When a profile is rotated, wait for its generated downloads to finish + # before it can be used again. + wait_download_finish_per_profile: true + # Safety net: max time to wait for downloads before forcing rotation. + # Should be aligned with info.json URL validity (e.g., 4 hours = 240 mins). + max_wait_for_downloads_minutes: 240 + + # Time-based proxy rules are disabled as they are not needed for this setup. + proxy_work_minutes: 0 + proxy_rest_duration_minutes: 0 + + # Global maximum time a proxy can be active before being rested, regardless of + # other rules. Acts as a safety net. Set to 0 to disable. + max_global_proxy_active_minutes: 0 + rest_duration_on_max_active: 10 + + # Proxy-level ban on failure burst is disabled. + proxy_ban_on_failures: 0 + proxy_ban_window_minutes: 2 + + # Clean up locks held for more than 16 minutes (960s) to prevent stuck workers. + # This should be longer than the docker container timeout (15m). + unlock_stale_locks_after_seconds: 960 + + # No post-task cooldown for auth simulation profiles. When a task is finished, + # the profile is immediately returned to the ACTIVE state. + unlock_cooldown_seconds: 0 + +# Cross-simulation synchronization +cross_simulation_sync: + # Link auth profiles to download profiles (by name) + # Both profiles should exist in their respective environments + profile_links: + - auth: "user1" + download: "user1" + # Which states to synchronize + #sync_states: + # - "RESTING" # Disabling to prevent deadlock when auth profile is waiting for downloads. + # The download profile must remain active to process them. + # - "BANNED" + # Whether to sync rotation (when auth is rotated due to rotate_after_requests) + #sync_rotation: true + # Whether download profile should be banned if auth is banned (even if download hasn't violated its own rules) + #enforce_auth_lead: true + # Ensures the same profile (e.g., user1_0) is active in both simulations. + # This will activate the correct download profile and rest any others in its group. + sync_active_profile: true + # When an auth profile is waiting for downloads, ensure the matching download profile is active + sync_waiting_downloads: true + +# --- Policies for the Download Simulation --- +download_policy_enforcer_config: + # Ban if 1 failure occurs within a 1-minute window. + ban_on_failures: 1 + ban_on_failures_window_minutes: 1 + + # Standard rest policy is disabled in favor of group rotation. + profile_prefix: "user1" + + # New rate limit policy to enforce requests-per-hour limits. + # For guest sessions, the limit is ~300 videos/hour. We set it slightly lower to be safe. + rate_limit_requests: 280 + rate_limit_window_minutes: 60 + rate_limit_rest_duration_minutes: 5 + # + rest_after_requests: 0 + rest_duration_minutes: 20 + + # NOTE on Rate Limits: With the default yt-dlp settings, the rate limit for guest + # sessions is ~300 videos/hour (~1000 webpage/player requests per hour). + # For accounts, it is ~2000 videos/hour (~4000 webpage/player requests per hour). + # The settings below should be configured to respect these limits. + + # A group of profiles that are mutually exclusive. Only one will be active at a time. + profile_groups: + - name: "exclusive_download_profiles" + prefix: "user1" + rotate_after_requests: 25 + rest_duration_minutes_on_rotation: 1 + max_active_profiles: 1 + + # Time-based proxy rules are disabled. + proxy_work_minutes: 50 + proxy_rest_duration_minutes: 10 + + # Global maximum time a proxy can be active before being rested, regardless of + # other rules. Acts as a safety net. Set to 0 to disable. + max_global_proxy_active_minutes: 0 + rest_duration_on_max_active: 10 + + # Proxy-level ban on failure burst is disabled. + proxy_ban_on_failures: 3 + proxy_ban_window_minutes: 1 + + # Clean up download locks held for more than 16 minutes (960s) to allow for long downloads. + # This should be longer than the docker container timeout (15m). + unlock_stale_locks_after_seconds: 960 + + # After a profile is used for a download, unlock it but put it in COOLDOWN + # state for 12-16s. This is enforced by the worker, which reads this config from Redis. + unlock_cooldown_seconds: [2, 3] diff --git a/policies/README.md b/policies/README.md deleted file mode 100644 index c590c79..0000000 --- a/policies/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# Stress Test Policies - -This directory contains example policy files for the `stress_enhanced.py` orchestrator. Each file defines a specific testing strategy, organized by task type. - -## Authentication & Info.json Policies (`fetch_only` mode) - -These policies focus on testing the info.json generation service. - -- `info_json_rate_limit.yaml`: Tests the service with a focus on rate limits and client rotation. -- `auth_scenarios.yaml`: Contains specific scenarios for fetching info.json files, such as using a low-level command template for full control. - -## Download Policies (`download_only` mode) - -These policies focus on testing the download infrastructure using pre-existing info.json files. - -- `download_throughput.yaml`: Tests download/CDN infrastructure, focusing on throughput and error handling. -- `download_scenarios.yaml`: Contains specific scenarios for downloading, such as testing random formats from a directory of info.json files. - -## Full-Stack Policies (`full_stack` mode) - -These policies test the entire workflow from info.json generation through to downloading. - -- `regular_testing_scenarios.yaml`: Contains a collection of common, end-to-end testing scenarios, including: - - `mweb_per_request_profile`: A high-volume test that uses a new profile for every request. - - `mixed_client_profile_pool`: A complex test that alternates clients and reuses profiles from a pool. -- `tv_pipeline_scenarios.yaml`: A two-stage pipeline for fetching with the TV client and then continuously downloading. - -These files can be used as templates for creating custom test scenarios. diff --git a/setup.py b/setup.py index 9c38fe6..8fb27c9 100644 --- a/setup.py +++ b/setup.py @@ -1,42 +1,52 @@ -from setuptools import setup, find_packages import os -import xml.etree.ElementTree as ET +from setuptools import setup, find_packages -def get_version_from_pom(): - """Parse version from pom.xml""" - here = os.path.abspath(os.path.dirname(__file__)) - pom_path = os.path.join(here, 'thrift_model', 'pom.xml') - tree = ET.parse(pom_path) - root = tree.getroot() - - # XML namespaces - ns = {'mvn': 'http://maven.apache.org/POM/4.0.0'} - - version = root.find('mvn:version', ns).text - if version.endswith('-SNAPSHOT'): - version = version.replace('-SNAPSHOT', '.dev0') - return version +try: + with open(os.path.join(os.path.dirname(__file__), 'VERSION.client')) as f: + version = f.read().strip() +except IOError: + version = "0.0.1.dev0" + print(f"Warning: Could not read VERSION.client, falling back to version '{version}'") + +# find_packages() will automatically discover 'ytops_client' and 'yt_ops_services'. +# We manually add the 'pangramia' packages because they are in a separate directory structure. +pangramia_packages = [ + 'pangramia', + 'pangramia.base_service', + 'pangramia.yt', + 'pangramia.yt.common', + 'pangramia.yt.exceptions', + 'pangramia.yt.management', + 'pangramia.yt.tokens_ops', +] setup( - name='yt_ops_services', - version=get_version_from_pom(), - # find_packages() will now discover 'pangramia' via the symlink. - # 'server_fix' is excluded as it's no longer needed. - packages=find_packages(exclude=['tests*', 'server_fix']), - # package_data is not needed for pom.xml as it's only used at build time. - include_package_data=True, - # Add all dependencies from requirements.txt + name='ytops-client-tools', + version=version, + packages=find_packages(exclude=['thrift_model*', 'tests*']) + pangramia_packages, + package_dir={ + # This tells setuptools that the 'pangramia' package lives inside thrift_model/gen_py + 'pangramia': 'thrift_model/gen_py/pangramia', + }, + entry_points={ + 'console_scripts': [ + 'ytops-client=ytops_client.cli:main', + ], + }, install_requires=[ 'thrift>=0.16.0,<=0.20.0', 'python-dotenv>=1.0.0', 'psutil', 'flask', 'waitress', + 'yt_dlp>=2025.3.27', 'yt-dlp-get-pot==0.3.0', 'requests>=2.31.0', 'ffprobe3', 'redis', 'PySocks', + 'tabulate', + 'PyYAML', ], python_requires='>=3.9', ) diff --git a/tools/generate-inventory.py b/tools/generate-inventory.py index 6c05933..8f527e3 100755 --- a/tools/generate-inventory.py +++ b/tools/generate-inventory.py @@ -37,21 +37,10 @@ def generate_inventory(cluster_config, inventory_path): f.write(line + "\n") def generate_host_vars(cluster_config, host_vars_dir): - """Generate host-specific variables""" + """Generate host-specific variables. This function is non-destructive and will only create or overwrite files for hosts defined in the cluster config.""" # Create host_vars directory if it doesn't exist os.makedirs(host_vars_dir, exist_ok=True) - # Clear existing host_vars files to avoid stale configurations - for filename in os.listdir(host_vars_dir): - file_path = os.path.join(host_vars_dir, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print(f'Failed to delete {file_path}. Reason: {e}') - # Get master IP for Redis configuration from the new structure master_ip = list(cluster_config['master'].values())[0]['ip'] @@ -89,20 +78,15 @@ def generate_host_vars(cluster_config, host_vars_dir): for proxy in worker_proxies: f.write(f" - \"{proxy}\"\n") -def generate_group_vars(cluster_config, group_vars_dir): +def generate_group_vars(cluster_config, group_vars_path): """Generate group-level variables""" - # Create group_vars directory if it doesn't exist - os.makedirs(group_vars_dir, exist_ok=True) - - # Create group_vars/all directory if it doesn't exist - all_vars_dir = os.path.join(group_vars_dir, "all") + # Create parent directory if it doesn't exist + all_vars_dir = os.path.dirname(group_vars_path) os.makedirs(all_vars_dir, exist_ok=True) - # Define path for the generated file and remove it if it exists to avoid stale data. - # This is safer than removing the whole directory, which would delete vault.yml. - all_vars_file = os.path.join(all_vars_dir, "generated_vars.yml") - if os.path.exists(all_vars_file): - os.remove(all_vars_file) + # Remove the specific generated file if it exists to avoid stale data. + if os.path.exists(group_vars_path): + os.remove(group_vars_path) global_vars = cluster_config.get('global_vars', {}) external_ips = cluster_config.get('external_access_ips', []) @@ -122,7 +106,7 @@ def generate_group_vars(cluster_config, group_vars_dir): } generated_data.update(global_vars) - with open(all_vars_file, 'w') as f: + with open(group_vars_path, 'w') as f: f.write("---\n") f.write("# This file is auto-generated by tools/generate-inventory.py\n") f.write("# Do not edit – your changes will be overwritten.\n") @@ -130,7 +114,7 @@ def generate_group_vars(cluster_config, group_vars_dir): def main(): if len(sys.argv) != 2: - print("Usage: python3 generate-inventory.py ") + print("Usage: ./tools/generate-inventory.py ") sys.exit(1) config_path = sys.argv[1] @@ -139,12 +123,28 @@ def main(): if not os.path.exists(config_path): print(f"Error: Configuration file {config_path} not found") sys.exit(1) + + # Derive environment name from config filename (e.g., cluster.stress.yml -> stress) + base_name = os.path.basename(config_path) + if base_name == 'cluster.yml': + env_name = '' + elif base_name.startswith('cluster.') and base_name.endswith('.yml'): + env_name = base_name[len('cluster.'):-len('.yml')] + else: + print(f"Warning: Unconventional config file name '{base_name}'. Using base name as environment identifier.") + env_name = os.path.splitext(base_name)[0] + + # Define output paths based on environment + inventory_suffix = f".{env_name}" if env_name else "" + inventory_path = f"ansible/inventory{inventory_suffix}.ini" + + vars_suffix = f".{env_name}" if env_name else "" + group_vars_path = f"ansible/group_vars/all/generated_vars{vars_suffix}.yml" # Load cluster configuration cluster_config = load_cluster_config(config_path) # Generate inventory file - inventory_path = "ansible/inventory.ini" generate_inventory(cluster_config, inventory_path) print(f"Generated {inventory_path}") @@ -154,9 +154,8 @@ def main(): print(f"Generated host variables in {host_vars_dir}") # Generate group variables - group_vars_dir = "ansible/group_vars" - generate_group_vars(cluster_config, group_vars_dir) - print(f"Generated group variables in {group_vars_dir}") + generate_group_vars(cluster_config, group_vars_path) + print(f"Generated group variables in {os.path.dirname(group_vars_path)}") print("Inventory generation complete!") diff --git a/ytops_client/check_expiry_tool.py b/ytops_client/check_expiry_tool.py new file mode 100644 index 0000000..54d1ae8 --- /dev/null +++ b/ytops_client/check_expiry_tool.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +Tool to check format URLs in an info.json for expiration. +""" + +import argparse +import json +import sys +import logging +import time +from datetime import datetime, timezone +from urllib.parse import urlparse, parse_qs + +from .stress_policy import utils as sp_utils + +logger = logging.getLogger('check_expiry_tool') + + +def add_check_expiry_parser(subparsers): + """Add the parser for the 'check-expiry' command.""" + parser = subparsers.add_parser( + 'check-expiry', + description='Check format URLs in an info.json for expiration.', + formatter_class=argparse.RawTextHelpFormatter, + help='Check if format URLs in an info.json are expired.', + epilog=""" +Exit Codes: + 0: All checked URLs are valid. + 1: At least one URL is expired or will expire within the specified time-shift. + 3: No URLs with expiration info were found to check. + 4: Input error (e.g., invalid JSON). +""" + ) + parser.add_argument( + '--load-info-json', + type=argparse.FileType('r', encoding='utf-8'), + default=sys.stdin, + help="Path to the info.json file. Reads from stdin if not provided." + ) + parser.add_argument( + '--time-shift-minutes', + type=int, + default=0, + help='Time shift in minutes. URLs expiring within this time are also reported as expired. Default: 0.' + ) + parser.add_argument( + '--check-all-formats', + action='store_true', + help='Check all available formats. By default, only the first format with an expiry timestamp is checked.' + ) + parser.add_argument('--verbose', action='store_true', help='Enable verbose logging.') + return parser + + +def main_check_expiry(args): + """Main logic for the 'check-expiry' command.""" + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logging.basicConfig(level=logging.DEBUG, format='%(levelname)s: %(message)s', stream=sys.stderr) + else: + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s', stream=sys.stderr) + + try: + info_json_content = args.load_info_json.read() + if not info_json_content.strip(): + logger.error("Input is empty.") + return 4 + + info_data = json.loads(info_json_content) + except json.JSONDecodeError: + logger.error("Invalid JSON provided. Please check the input file.") + return 4 + except Exception as e: + logger.error(f"An unexpected error occurred while reading input: {e}", exc_info=args.verbose) + return 4 + + formats = info_data.get('formats', []) + if not formats: + logger.warning("No formats found in the provided info.json.") + return 3 + + overall_status = 'valid' + checked_any = False + min_time_left = float('inf') + worst_status_format_id = None + + for f in formats: + url = f.get('url') + format_id = f.get('format_id', 'N/A') + if not url: + logger.debug(f"Format {format_id} has no URL, skipping.") + continue + + status, time_left = sp_utils.check_url_expiry(url, args.time_shift_minutes) + + if status == 'no_expiry_info': + logger.debug(f"Format {format_id} has no expiration info in URL, skipping.") + continue + + checked_any = True + + if time_left < min_time_left: + min_time_left = time_left + worst_status_format_id = format_id + + # Determine the "worst" status seen so far. Expired > Valid. + if status == 'expired': + overall_status = 'expired' + + if not args.check_all_formats and overall_status != 'valid': + # If we found a problem and we're not checking all, we can stop. + break + + if not args.check_all_formats: + # If we checked one valid format and we're not checking all, we can stop. + break + + if not checked_any: + logger.warning("No formats with expiration timestamps were found to check.") + return 3 + + if overall_status == 'expired': + expire_datetime = datetime.fromtimestamp(time.time() + min_time_left, timezone.utc) + if min_time_left <= 0: + logger.error(f"URL for format '{worst_status_format_id}' is EXPIRED. It expired at {expire_datetime.strftime('%Y-%m-%d %H:%M:%S %Z')}.") + else: + logger.warning(f"URL for format '{worst_status_format_id}' is considered EXPIRED due to time-shift. It will expire in {min_time_left / 60:.1f} minutes (at {expire_datetime.strftime('%Y-%m-%d %H:%M:%S %Z')}).") + return 1 + else: # valid + expire_datetime = datetime.fromtimestamp(time.time() + min_time_left, timezone.utc) + logger.info(f"OK. The soonest-expiring URL (format '{worst_status_format_id}') is valid for another {min_time_left / 60:.1f} minutes (expires at {expire_datetime.strftime('%Y-%m-%d %H:%M:%S %Z')}).") + return 0 diff --git a/ytops_client/check_log_pattern_tool.py b/ytops_client/check_log_pattern_tool.py new file mode 100644 index 0000000..4e02f18 --- /dev/null +++ b/ytops_client/check_log_pattern_tool.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +""" +CLI tool to check a log line against policy error patterns. +""" +import argparse +import re +import sys +import yaml +import os +from .stress_policy.utils import load_policy + +def add_check_log_pattern_parser(subparsers): + """Adds the parser for the 'check-log-pattern' command.""" + parser = subparsers.add_parser( + 'check-log-pattern', + help='Check a log line against policy error patterns.', + description='Checks a given log line against the fatal and tolerated error patterns defined in a policy file to determine how it would be classified.' + ) + parser.add_argument('--policy', required=True, help='Path to the YAML policy file.') + parser.add_argument('--policy-name', help='Name of the policy to use from a multi-policy file.') + parser.add_argument( + '--policy-section', + default='direct_docker_cli_policy', + help="The top-level key in the policy where error patterns are defined (e.g., 'direct_docker_cli_policy'). Default: direct_docker_cli_policy" + ) + parser.add_argument('log_line', help='The log line to check.') + +def main_check_log_pattern(args): + """Main logic for the 'check-log-pattern' command.""" + policy = load_policy(args.policy, args.policy_name) + if not policy: + return 1 + + policy_section = policy.get(args.policy_section, {}) + if not policy_section: + print(f"Error: Policy section '{args.policy_section}' not found in the policy.", file=sys.stderr) + return 1 + + fatal_patterns = policy_section.get('fatal_error_patterns', []) + tolerated_patterns = policy_section.get('tolerated_error_patterns', []) + + print(f"--- Checking Log Line ---") + print(f"Policy: {args.policy}" + (f" (name: {args.policy_name})" if args.policy_name else "")) + print(f"Policy Section: {args.policy_section}") + print(f"Log Line: '{args.log_line}'") + print("-" * 25) + + # 1. Check for fatal patterns. These take precedence. + for pattern in fatal_patterns: + if re.search(pattern, args.log_line, re.IGNORECASE): + print(f"Result: FATAL") + print(f"Reason: Matched fatal pattern: '{pattern}'") + return 0 + + # 2. Check for tolerated patterns. This is only relevant for lines that look like errors. + # The logic in stress_policy_tool checks for 'ERROR:' before checking tolerated patterns. + if 'ERROR:' in args.log_line: + for pattern in tolerated_patterns: + if re.search(pattern, args.log_line, re.IGNORECASE): + print(f"Result: TOLERATED") + print(f"Reason: Matched tolerated pattern: '{pattern}'") + return 0 + + # 3. If it's an ERROR line and not tolerated, it's a failure. + print(f"Result: FAILURE") + print(f"Reason: Contains 'ERROR:' but did not match any tolerated patterns.") + return 0 + + # 4. If it's not an error line and didn't match fatal, it's neutral. + print(f"Result: NEUTRAL") + print(f"Reason: Does not contain 'ERROR:' and did not match any fatal patterns.") + return 0 diff --git a/ytops_client/cli.py b/ytops_client/cli.py index 3405eff..0b72f96 100644 --- a/ytops_client/cli.py +++ b/ytops_client/cli.py @@ -1,6 +1,25 @@ #!/usr/bin/env python3 import sys import argparse +import os +from datetime import datetime + +# --- Version Info --- +try: + # Get path relative to this file + script_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.abspath(os.path.join(script_dir, '..')) + version_file_path = os.path.join(project_root, 'VERSION.client') + + with open(version_file_path, 'r') as f: + __version__ = f.read().strip() + + mod_time = os.path.getmtime(version_file_path) + __build_date__ = datetime.fromtimestamp(mod_time).strftime('%Y-%m-%d') +except Exception: + __version__ = "unknown" + __build_date__ = "unknown" + # Import the functions that define and execute the logic for each subcommand from .list_formats_tool import add_list_formats_parser, main_list_formats @@ -11,6 +30,19 @@ from .stress_formats_tool import add_stress_formats_parser, main_stress_formats from .cookie_tool import add_cookie_tool_parser, main_cookie_tool from .download_aria_tool import add_download_aria_parser, main_download_aria from .download_native_py_tool import add_download_native_py_parser, main_download_native_py +from .check_expiry_tool import add_check_expiry_parser, main_check_expiry +from .config_tool import add_flags_to_json_parser, main_flags_to_json, add_json_to_flags_parser, main_json_to_flags +from .manage_tool import add_manage_parser, main_manage +from .profile_manager_tool import add_profile_manager_parser, main_profile_manager +from .profile_allocator_tool import add_profile_allocator_parser, main_profile_allocator +from .policy_enforcer_tool import add_policy_enforcer_parser, main_policy_enforcer +from .profile_setup_tool import add_setup_profiles_parser, main_setup_profiles +from .simulation_tool import add_simulation_parser, main_simulation +from .locking_download_emulator_tool import add_locking_download_emulator_parser, main_locking_download_emulator +from .task_generator_tool import add_task_generator_parser, main_task_generator +from .yt_dlp_dummy_tool import add_yt_dlp_dummy_parser, main_yt_dlp_dummy +from .check_log_pattern_tool import add_check_log_pattern_parser, main_check_log_pattern + def main(): """ @@ -36,6 +68,11 @@ def main(): description="YT Ops Client Tools", formatter_class=argparse.RawTextHelpFormatter ) + parser.add_argument( + '--version', + action='version', + version=f'ytops-client version {__version__} (build date: {__build_date__})' + ) subparsers = parser.add_subparsers(dest='command', help='Available sub-commands') # Add subparsers from each tool module @@ -56,6 +93,19 @@ def main(): add_stress_policy_parser(subparsers) add_stress_formats_parser(subparsers) add_cookie_tool_parser(subparsers) + add_check_expiry_parser(subparsers) + add_flags_to_json_parser(subparsers) + add_json_to_flags_parser(subparsers) + add_manage_parser(subparsers) + add_profile_manager_parser(subparsers) + add_profile_allocator_parser(subparsers) + add_policy_enforcer_parser(subparsers) + add_setup_profiles_parser(subparsers) + add_simulation_parser(subparsers) + add_locking_download_emulator_parser(subparsers) + add_task_generator_parser(subparsers) + add_yt_dlp_dummy_parser(subparsers) + add_check_log_pattern_parser(subparsers) args = parser.parse_args() @@ -82,6 +132,32 @@ def main(): return main_stress_formats(args) elif args.command == 'convert-cookies': return main_cookie_tool(args) + elif args.command == 'check-expiry': + return main_check_expiry(args) + elif args.command == 'flags-to-json': + return main_flags_to_json(args) + elif args.command == 'json-to-flags': + return main_json_to_flags(args) + elif args.command == 'manage': + return main_manage(args) + elif args.command == 'profile': + return main_profile_manager(args) + elif args.command == 'profile-allocator': + return main_profile_allocator(args) + elif args.command == 'policy-enforcer': + return main_policy_enforcer(args) + elif args.command == 'setup-profiles': + return main_setup_profiles(args) + elif args.command == 'simulation': + return main_simulation(args) + elif args.command == 'download-emulator': + return main_locking_download_emulator(args) + elif args.command == 'task-generator': + return main_task_generator(args) + elif args.command == 'yt-dlp-dummy': + return main_yt_dlp_dummy(args) + elif args.command == 'check-log-pattern': + return main_check_log_pattern(args) # This path should not be reachable if a command is required or handled above. parser.print_help() diff --git a/ytops_client/config_tool.py b/ytops_client/config_tool.py new file mode 100644 index 0000000..60376f5 --- /dev/null +++ b/ytops_client/config_tool.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +""" +Tool to convert yt-dlp command-line flags to a JSON config using go-ytdlp. +""" + +import argparse +import json +import logging +import os +import shlex +import subprocess +import sys +from pathlib import Path +from typing import Dict, List + +logger = logging.getLogger('config_tool') + + +def get_go_ytdlp_path(user_path: str = None) -> str: + """ + Get the path to the go-ytdlp binary. + Checks in order: + 1. User-provided path + 2. 'go-ytdlp' in PATH + 3. Local binary in ytops_client/go_ytdlp_cli/go-ytdlp + 4. Binary in go-ytdlp/go-ytdlp (the library's built binary) + 5. Binary in /usr/local/bin/go-ytdlp + """ + def is_exe(fpath): + return os.path.isfile(fpath) and os.access(fpath, os.X_OK) + + if user_path: + if is_exe(user_path): + return user_path + # If user provided a path, we return it even if check fails, + # so subprocess can raise the appropriate error for that specific path. + return user_path + + # Check in PATH + import shutil + path_exe = shutil.which('go-ytdlp') + if path_exe: + return path_exe + + # Check local build directory + local_path = Path(__file__).parent / 'go_ytdlp_cli' / 'go-ytdlp' + if is_exe(str(local_path)): + return str(local_path) + + # Check the go-ytdlp library directory + project_root = Path(__file__).parent.parent + library_binary = project_root / 'go-ytdlp' / 'go-ytdlp' + if is_exe(str(library_binary)): + return str(library_binary) + + # Check /usr/local/bin + if is_exe('/usr/local/bin/go-ytdlp'): + return '/usr/local/bin/go-ytdlp' + + # Default to 'go-ytdlp' which will raise FileNotFoundError if not in PATH + return 'go-ytdlp' + +def convert_flags_to_json(flags: List[str], go_ytdlp_path: str = None) -> Dict: + """ + Converts a list of yt-dlp command-line flags to a JSON config dictionary. + + Args: + flags: A list of strings representing the command-line flags. + go_ytdlp_path: Path to the go-ytdlp executable. If None, will try to find it. + + Returns: + A dictionary representing the JSON config. + + Raises: + ValueError: If no flags are provided. + FileNotFoundError: If the go-ytdlp executable is not found. + subprocess.CalledProcessError: If go-ytdlp returns a non-zero exit code. + json.JSONDecodeError: If the output from go-ytdlp is not valid JSON. + """ + if not flags: + raise ValueError("No flags provided to convert.") + + # Get the actual binary path + actual_path = get_go_ytdlp_path(go_ytdlp_path) + + # Use '--' to separate the subcommand flags from the flags to be converted. + # This prevents go-ytdlp from trying to parse the input flags as its own flags. + cmd = [actual_path, 'flags-to-json', '--'] + flags + + logger.debug(f"Executing command: {' '.join(shlex.quote(s) for s in cmd)}") + try: + process = subprocess.run(cmd, capture_output=True, check=True, encoding='utf-8') + + if process.stderr: + logger.info(f"go-ytdlp output on stderr:\n{process.stderr.strip()}") + + return json.loads(process.stdout) + except json.JSONDecodeError: + logger.error("Failed to parse JSON from go-ytdlp stdout.") + logger.error(f"Stdout was: {process.stdout.strip()}") + raise + except FileNotFoundError: + logger.error(f"Executable '{actual_path}' not found.") + logger.error("Please ensure go-ytdlp is installed and in your PATH.") + logger.error("You can run the 'bin/install-goytdlp.sh' script to install it.") + raise + except subprocess.CalledProcessError as e: + logger.error(f"go-ytdlp exited with error code {e.returncode}.") + logger.error(f"Stderr:\n{e.stderr.strip()}") + if "not supported" in e.stderr: + logger.error("NOTE: The installed version of go-ytdlp does not support converting flags to JSON.") + raise + except PermissionError: + logger.error(f"Permission denied executing '{actual_path}'.") + logger.error("Please ensure the file is executable (chmod +x).") + raise + + +def add_flags_to_json_parser(subparsers): + """Add the parser for the 'flags-to-json' command.""" + parser = subparsers.add_parser( + 'flags-to-json', + description='Convert yt-dlp command-line flags to a JSON config using go-ytdlp.', + formatter_class=argparse.RawTextHelpFormatter, + help='Convert yt-dlp flags to a JSON config.', + epilog=""" +Examples: + +# Convert flags from a string +ytops-client flags-to-json --from-string "-f best --no-playlist" + +# Convert flags from a file (like cli.config) +ytops-client flags-to-json --from-file cli.config + +# Convert flags passed directly as arguments +ytops-client flags-to-json -- --retries 5 --fragment-retries 5 + +# Combine sources (direct arguments override file/string) +ytops-client flags-to-json --from-file cli.config -- --retries 20 + +The go-ytdlp executable must be in your PATH. +You can install it by running the 'bin/install-goytdlp.sh' script. +""" + ) + source_group = parser.add_mutually_exclusive_group() + source_group.add_argument('--from-file', type=argparse.FileType('r', encoding='utf-8'), help='Read flags from a file (e.g., a yt-dlp config file).') + source_group.add_argument('--from-string', help='Read flags from a single string.') + + parser.add_argument('flags', nargs=argparse.REMAINDER, help='yt-dlp flags to convert. Use "--" to separate them from this script\'s own flags.') + parser.add_argument('--go-ytdlp-path', default='go-ytdlp', help='Path to the go-ytdlp executable. Defaults to "go-ytdlp" in PATH.') + parser.add_argument('--verbose', action='store_true', help='Enable verbose output for this script.') + return parser + +def main_flags_to_json(args): + """Main logic for the 'flags-to-json' command.""" + if args.verbose: + # Reconfigure root logger for verbose output to stderr + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig(level=logging.DEBUG, format='%(name)s - %(levelname)s - %(message)s', stream=sys.stderr) + else: + # Default to INFO level, also to stderr + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s', stream=sys.stderr) + + flags = [] + if args.from_file: + logger.info(f"Reading flags from file: {args.from_file.name}") + content = args.from_file.read() + # A config file can have comments and one arg per line, or be a single line of args. + # shlex.split is good for single lines, but for multi-line we should split by line and filter. + lines = content.splitlines() + for line in lines: + line = line.strip() + if line and not line.startswith('#'): + # shlex.split can handle quoted arguments within the line + flags.extend(shlex.split(line)) + elif args.from_string: + logger.info("Reading flags from string.") + flags.extend(shlex.split(args.from_string)) + + if args.flags: + # The 'flags' remainder might contain '--' which we should remove if it's the first element. + remainder_flags = args.flags + if remainder_flags and remainder_flags[0] == '--': + remainder_flags = remainder_flags[1:] + + if remainder_flags: + logger.info("Appending flags from command-line arguments.") + flags.extend(remainder_flags) + + if not flags: + logger.error("No flags provided to convert.") + return 1 + + try: + json_output = convert_flags_to_json(flags, args.go_ytdlp_path) + # Print to actual stdout for piping. + print(json.dumps(json_output, indent=2)) + return 0 + except (ValueError, FileNotFoundError, subprocess.CalledProcessError, json.JSONDecodeError, PermissionError): + # Specific error is already logged by the helper function. + return 1 + except Exception as e: + logger.error(f"An unexpected error occurred: {e}", exc_info=args.verbose) + return 1 + + +def convert_json_to_flags(json_input: str, go_ytdlp_path: str = None) -> str: + """ + Converts a JSON config string to yt-dlp command-line flags. + + Args: + json_input: A string containing the JSON config. + go_ytdlp_path: Path to the go-ytdlp executable. If None, will try to find it. + + Returns: + A string of command-line flags. + + Raises: + ValueError: If the json_input is empty. + FileNotFoundError: If the go-ytdlp executable is not found. + subprocess.CalledProcessError: If go-ytdlp returns a non-zero exit code. + """ + if not json_input: + raise ValueError("No JSON input provided to convert.") + + # Get the actual binary path + actual_path = get_go_ytdlp_path(go_ytdlp_path) + + cmd = [actual_path, 'json-to-flags'] + + logger.debug(f"Executing command: {' '.join(shlex.quote(s) for s in cmd)}") + try: + process = subprocess.run(cmd, input=json_input, capture_output=True, check=True, encoding='utf-8') + + if process.stderr: + logger.info(f"go-ytdlp output on stderr:\n{process.stderr.strip()}") + + return process.stdout.strip() + except FileNotFoundError: + logger.error(f"Executable '{actual_path}' not found.") + logger.error("Please ensure go-ytdlp is installed and in your PATH.") + logger.error("You can run the 'bin/install-goytdlp.sh' script to install it.") + raise + except subprocess.CalledProcessError as e: + logger.error(f"go-ytdlp exited with error code {e.returncode}.") + logger.error(f"Stderr:\n{e.stderr.strip()}") + raise + except PermissionError: + logger.error(f"Permission denied executing '{actual_path}'.") + logger.error("Please ensure the file is executable (chmod +x).") + raise + + +def add_json_to_flags_parser(subparsers): + """Add the parser for the 'json-to-flags' command.""" + parser = subparsers.add_parser( + 'json-to-flags', + description='Convert a JSON config to yt-dlp command-line flags using go-ytdlp.', + formatter_class=argparse.RawTextHelpFormatter, + help='Convert a JSON config to yt-dlp flags.', + epilog=""" +Examples: + +# Convert JSON from a string +ytops-client json-to-flags --from-string '{"postprocessor": {"ffmpeg": {"ppa": "SponsorBlock"}}}' + +# Convert JSON from a file +ytops-client json-to-flags --from-file config.json + +The go-ytdlp executable must be in your PATH. +You can install it by running the 'bin/install-goytdlp.sh' script. +""" + ) + source_group = parser.add_mutually_exclusive_group(required=True) + source_group.add_argument('--from-file', type=argparse.FileType('r', encoding='utf-8'), help='Read JSON from a file.') + source_group.add_argument('--from-string', help='Read JSON from a single string.') + + parser.add_argument('--go-ytdlp-path', default='go-ytdlp', help='Path to the go-ytdlp executable. Defaults to "go-ytdlp" in PATH.') + parser.add_argument('--verbose', action='store_true', help='Enable verbose output for this script.') + return parser + + +def main_json_to_flags(args): + """Main logic for the 'json-to-flags' command.""" + if args.verbose: + # Reconfigure root logger for verbose output to stderr + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig(level=logging.DEBUG, format='%(name)s - %(levelname)s - %(message)s', stream=sys.stderr) + else: + # Default to INFO level, also to stderr + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s', stream=sys.stderr) + + json_input = "" + if args.from_file: + logger.info(f"Reading JSON from file: {args.from_file.name}") + json_input = args.from_file.read() + elif args.from_string: + logger.info("Reading JSON from string.") + json_input = args.from_string + + try: + flags_output = convert_json_to_flags(json_input, args.go_ytdlp_path) + # Print to actual stdout for piping. + print(flags_output) + return 0 + except (ValueError, FileNotFoundError, subprocess.CalledProcessError, PermissionError): + # Specific error is already logged by the helper function. + return 1 + except Exception as e: + logger.error(f"An unexpected error occurred: {e}", exc_info=args.verbose) + return 1 diff --git a/ytops_client/download_emulator_tool.py b/ytops_client/download_emulator_tool.py new file mode 100644 index 0000000..e69de29 diff --git a/ytops_client/download_native_py_tool.py b/ytops_client/download_native_py_tool.py index 0622695..3c519cd 100644 --- a/ytops_client/download_native_py_tool.py +++ b/ytops_client/download_native_py_tool.py @@ -5,6 +5,7 @@ Tool to download a specified format using yt-dlp as a Python library. import argparse import contextlib +import copy import io import json import logging @@ -17,6 +18,7 @@ from datetime import datetime try: import yt_dlp + from yt_dlp.utils import match_filter_func except ImportError: print("yt-dlp is not installed. Please install it with: pip install yt-dlp", file=sys.stderr) sys.exit(1) @@ -29,11 +31,15 @@ class YTDLPLogger: self.final_filename = None self.is_403 = False self.is_timeout = False + self.has_errors = False def debug(self, msg): # yt-dlp logs the destination file path at the debug level. if msg.startswith('[download] Destination:'): self.final_filename = msg.split(':', 1)[1].strip() + elif msg.startswith('[Merger] Merging formats into "'): + # This captures the final filename after merging. + self.final_filename = msg.split('"')[1] elif msg.startswith('[download]') and 'has already been downloaded' in msg: match = re.search(r'\[download\]\s+(.*)\s+has already been downloaded', msg) if match: @@ -51,6 +57,7 @@ class YTDLPLogger: self.is_403 = True if "Read timed out" in msg: self.is_timeout = True + self.has_errors = True logger.error(msg) def ytdlp_progress_hook(d, ytdlp_logger): @@ -77,7 +84,7 @@ def add_download_native_py_parser(subparsers): parser.add_argument('--pause', type=int, default=0, help='Seconds to wait before starting the download.') parser.add_argument('--download-continue', action='store_true', help='Enable download continuation (--no-overwrites and --continue flags for yt-dlp).') parser.add_argument('--verbose', action='store_true', help='Enable verbose output for this script and yt-dlp.') - parser.add_argument('--cli-config', help='Path to a yt-dlp configuration file to load.') + parser.add_argument('--config', default=None, help='Path to a yt-dlp JSON configuration file (e.g., ytdlp.json). If not provided, searches for ytdlp.json.') parser.add_argument('--downloader', help='Name of the external downloader backend for yt-dlp to use (e.g., "aria2c", "native").') parser.add_argument('--downloader-args', help='Arguments to pass to the external downloader backend (e.g., "aria2c:-x 8").') parser.add_argument('--extra-ytdlp-args', help='A string of extra command-line arguments to pass to yt-dlp.') @@ -88,11 +95,87 @@ def add_download_native_py_parser(subparsers): parser.add_argument('--fragment-retries', type=int, help='Number of retries for each fragment (default: 10).') parser.add_argument('--socket-timeout', type=int, help='Timeout for socket operations in seconds (default: 20).') parser.add_argument('--add-header', action='append', help='Add a custom HTTP header for the download. Format: "Key: Value". Can be used multiple times.') + parser.add_argument('--concurrent-fragments', type=int, help='Number of fragments to download concurrently for each media.') # Arguments to pass through to yt-dlp parser.add_argument('--download-sections', help='yt-dlp --download-sections argument (e.g., "*0-10240").') parser.add_argument('--test', action='store_true', help='yt-dlp --test argument (download small part).') return parser + +def _download_single_format(format_id, info_data, base_ydl_opts, args): + """ + Download a single format ID from the given info_data. + This function filters info_data to only contain the requested format, + preventing yt-dlp from auto-merging with other streams. + + Returns a tuple: (success: bool, ytdlp_logger: YTDLPLogger) + """ + # Deep copy info_data so we can modify it without affecting other downloads + local_info_data = copy.deepcopy(info_data) + + available_formats = local_info_data.get('formats', []) + + # Find the exact format + target_format = next((f for f in available_formats if f.get('format_id') == format_id), None) + + if not target_format: + logger.error(f"Format '{format_id}' not found in info.json") + ytdlp_logger = YTDLPLogger() + ytdlp_logger.has_errors = True + return False, ytdlp_logger + + # Filter to only this format - this is the key to preventing auto-merge + local_info_data['formats'] = [target_format] + + # Clear any pre-selected format fields that might trigger merging + local_info_data.pop('requested_formats', None) + local_info_data.pop('format', None) + local_info_data.pop('format_id', None) + + logger.info(f"Filtered info_data to only contain format '{format_id}' (removed {len(available_formats) - 1} other formats)") + + # Create a fresh logger for this download + ytdlp_logger = YTDLPLogger() + + # Copy base options and update with this format's specifics + ydl_opts = dict(base_ydl_opts) + ydl_opts['format'] = format_id + ydl_opts['logger'] = ytdlp_logger + ydl_opts['progress_hooks'] = [lambda d, yl=ytdlp_logger: ytdlp_progress_hook(d, yl)] + + try: + download_buffer = None + if args.output_buffer: + download_buffer = io.BytesIO() + ctx_mgr = contextlib.redirect_stdout(download_buffer) + else: + ctx_mgr = contextlib.nullcontext() + + with ctx_mgr, yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.process_ie_result(local_info_data) + + if ytdlp_logger.has_errors: + logger.error(f"Download of format '{format_id}' failed: yt-dlp reported an error during execution.") + return False, ytdlp_logger + + logger.info(f"Download of format '{format_id}' completed successfully.") + + if args.output_buffer and download_buffer: + sys.stdout.buffer.write(download_buffer.getvalue()) + sys.stdout.buffer.flush() + + return True, ytdlp_logger + + except yt_dlp.utils.DownloadError as e: + logger.error(f"yt-dlp DownloadError for format '{format_id}': {e}") + ytdlp_logger.has_errors = True + return False, ytdlp_logger + except Exception as e: + logger.exception(f"Unexpected error downloading format '{format_id}': {e}") + ytdlp_logger.has_errors = True + return False, ytdlp_logger + + def main_download_native_py(args): """Main logic for the 'download-native-py' command.""" # All logging should go to stderr to keep stdout clean for the final filename, or for binary data with --output-buffer. @@ -163,245 +246,422 @@ def main_download_native_py(args): logger.error("Invalid --proxy-rename format. Expected: s/pattern/replacement/") return 1 + # For library usage, ensure proxy URL has a scheme. Default to http if missing. + if proxy_url and '://' not in proxy_url: + original_proxy = proxy_url + proxy_url = 'http://' + proxy_url + logger.info(f"Proxy URL '{original_proxy}' has no scheme. Defaulting to '{proxy_url}'.") + # Build the yt-dlp options dictionary - # Start by parsing options from config file and extra args to establish a baseline. - base_opts_args = [] - if args.cli_config and os.path.exists(args.cli_config): - try: - with open(args.cli_config, 'r', encoding='utf-8') as f: - config_content = f.read() - base_opts_args.extend(shlex.split(config_content)) - logger.info(f"Loaded {len(base_opts_args)} arguments from config file: {args.cli_config}") - except Exception as e: - logger.error(f"Failed to read or parse config file {args.cli_config}: {e}") - return 1 - elif args.cli_config: - logger.warning(f"Config file '{args.cli_config}' not found. Ignoring.") + logger.info("--- Configuring yt-dlp options ---") - if args.extra_ytdlp_args: - extra_args_list = shlex.split(args.extra_ytdlp_args) - logger.info(f"Adding {len(extra_args_list)} extra arguments from --extra-ytdlp-args.") - base_opts_args.extend(extra_args_list) + param_sources = {} + ydl_opts = {} - ydl_opts = { - 'noresizebuffer': True, - 'buffersize': '4M', - } - if base_opts_args: - try: - logger.info(f"Parsing {len(base_opts_args)} arguments from config/extra_args...") - i = 0 - while i < len(base_opts_args): - arg = base_opts_args[i] - if not arg.startswith('--'): - logger.warning(f"Skipping non-option argument in extra args: {arg}") - i += 1 - continue + def _parse_ytdlp_args(args_list, source_name, opts_dict, sources_dict): + """Helper to parse a list of yt-dlp CLI-style args into an options dict.""" + i = 0 + while i < len(args_list): + arg = args_list[i] + if not arg.startswith('--'): + logger.warning(f"Skipping non-option argument from {source_name}: {arg}") + i += 1 + continue - key = arg.lstrip('-').replace('-', '_') - - # Handle flags (no value) - is_flag = i + 1 >= len(base_opts_args) or base_opts_args[i + 1].startswith('--') - - if key == 'resize_buffer': - ydl_opts['noresizebuffer'] = False - logger.debug(f"Parsed flag: noresizebuffer = False") - i += 1 - continue - elif key == 'no_resize_buffer': - ydl_opts['noresizebuffer'] = True - logger.debug(f"Parsed flag: noresizebuffer = True") - i += 1 - continue - - if is_flag: - if key.startswith('no_'): - # Handle --no-foo flags - ydl_opts[key[3:]] = False - else: - ydl_opts[key] = True - logger.debug(f"Parsed flag: {key} = {ydl_opts.get(key[3:] if key.startswith('no_') else key)}") - i += 1 - # Handle options with values + key_cli = arg.lstrip('-') + key_py = key_cli.replace('-', '_') + + is_flag = i + 1 >= len(args_list) or args_list[i + 1].startswith('--') + + if is_flag: + if key_py.startswith('no_'): + real_key = key_py[3:] + # Handle special cases where the Python option name is different + if real_key == 'resize_buffer': real_key = 'noresizebuffer' + opts_dict[real_key] = False + sources_dict[real_key] = source_name else: - value = base_opts_args[i + 1] - # Try to convert values to numbers, which yt-dlp expects. - # This includes parsing byte suffixes like 'K', 'M', 'G'. - if isinstance(value, str): - original_value = value - value_upper = value.upper() - multipliers = {'K': 1024, 'M': 1024**2, 'G': 1024**3, 'T': 1024**4} - - if value_upper and value_upper[-1] in multipliers: - try: - num = float(value[:-1]) - value = int(num * multipliers[value_upper[-1]]) - except (ValueError, TypeError): - value = original_value # fallback - else: - try: - value = int(value) - except (ValueError, TypeError): - try: - value = float(value) - except (ValueError, TypeError): - value = original_value # fallback + # Handle special cases + if key_py == 'resize_buffer': + opts_dict['noresizebuffer'] = False + sources_dict['noresizebuffer'] = source_name + else: + opts_dict[key_py] = True + sources_dict[key_py] = source_name + i += 1 + else: # Has a value + value = args_list[i + 1] + # Special key name conversions + if key_py == 'limit_rate': key_py = 'ratelimit' + elif key_py == 'buffer_size': key_py = 'buffersize' + + # Special value conversion for match_filter + if key_py == 'match_filter': + try: + value = match_filter_func(value) + except Exception as e: + logger.error(f"Failed to compile --match-filter '{value}': {e}") + # Skip this option + i += 2 + continue + else: + # Try to convert values to numbers, which yt-dlp expects + try: + value = int(value) + except (ValueError, TypeError): + try: + value = float(value) + except (ValueError, TypeError): + pass # Keep as string + + opts_dict[key_py] = value + sources_dict[key_py] = source_name + i += 2 - # Special handling for keys that differ from CLI arg, e.g. --limit-rate -> ratelimit - if key == 'limit_rate': - key = 'ratelimit' - elif key == 'buffer_size': - key = 'buffersize' - - ydl_opts[key] = value - logger.debug(f"Parsed option: {key} = {value}") - i += 2 - logger.info("Successfully parsed extra yt-dlp options.") - except Exception as e: - logger.error(f"Failed to parse options from config/extra_args: {e}", exc_info=True) + # 1. Load from JSON config file + config_path = args.config + log_msg = "" + if config_path: + log_msg = f"1. [Source: Config File] Loading from: {config_path}" + else: + if os.path.exists('ytdlp.json'): + config_path = 'ytdlp.json' + log_msg = f"1. [Source: Config File] No --config provided. Found and loading local '{config_path}'." + + if config_path and os.path.exists(config_path): + if log_msg: + logger.info(log_msg) + try: + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + + # All yt-dlp options are expected under the 'ytdlp_params' key. + config_opts = config_data.get('ytdlp_params', {}) + + if config_opts: + logger.info(f"Parameters from config file ('{config_path}'):") + config_str = json.dumps(config_opts, indent=2) + for line in config_str.splitlines(): + logger.info(f" {line}") + + # Special handling for match_filter before updating ydl_opts + if 'match_filter' in config_opts and isinstance(config_opts['match_filter'], str): + logger.info(f" Compiling 'match_filter' string from config file...") + try: + config_opts['match_filter'] = match_filter_func(config_opts['match_filter']) + except Exception as e: + logger.error(f"Failed to compile match_filter from config: {e}") + del config_opts['match_filter'] + + ydl_opts.update(config_opts) + for key in config_opts: + param_sources[key] = "Config File" + except (json.JSONDecodeError, IOError) as e: + logger.error(f"Failed to read or parse JSON config file {config_path}: {e}") return 1 + elif args.config: + logger.warning(f"Config file '{args.config}' not found. Ignoring.") + + # 2. Load from extra command-line args + if args.extra_ytdlp_args: + logger.info(f"2. [Source: CLI Extra Args] Loading extra arguments...") + extra_args_list = shlex.split(args.extra_ytdlp_args) + _parse_ytdlp_args(extra_args_list, "CLI Extra Args", ydl_opts, param_sources) - # Now, layer the script's explicit arguments on top, as they have higher precedence. + # 3. Apply internal defaults if not already set + if 'noresizebuffer' not in ydl_opts: + ydl_opts['noresizebuffer'] = True + param_sources['noresizebuffer'] = "Internal Default" + if 'buffersize' not in ydl_opts: + ydl_opts['buffersize'] = '4M' + param_sources['buffersize'] = "Internal Default" + if 'force_progress' not in ydl_opts: + ydl_opts['force_progress'] = True + param_sources['force_progress'] = "Internal Default" + + # 4. Apply explicit arguments from this script's CLI (highest priority) + logger.info("3. [Source: CLI Explicit Args] Applying explicit overrides:") + os.makedirs(args.output_dir, exist_ok=True) - output_template = os.path.join(args.output_dir, '%(title)s [%(id)s].f%(format_id)s.%(ext)s') - ytdlp_logger = YTDLPLogger() - - # Use update to merge, so explicit args overwrite config/extra args. - ydl_opts.update({ - 'format': args.format, - 'outtmpl': '-' if args.output_buffer else output_template, - 'logger': ytdlp_logger, - 'progress_hooks': [lambda d: ytdlp_progress_hook(d, ytdlp_logger)], - 'verbose': args.verbose, - }) + ydl_opts['verbose'] = args.verbose + param_sources['verbose'] = "CLI Explicit" + # Handle output template ('outtmpl') + if args.output_buffer: + ydl_opts['outtmpl'] = '-' + param_sources['outtmpl'] = "CLI Explicit (Buffer)" + elif 'outtmpl' in ydl_opts: + # Respect outtmpl from config, but prepend the output directory + outtmpl_val = ydl_opts['outtmpl'] + if isinstance(outtmpl_val, dict): + # Prepend dir to each template in the dict + ydl_opts['outtmpl'] = {k: os.path.join(args.output_dir, v) for k, v in outtmpl_val.items()} + else: + # Prepend dir to the single template string + ydl_opts['outtmpl'] = os.path.join(args.output_dir, str(outtmpl_val)) + param_sources['outtmpl'] = "Config File (Path Applied)" + logger.info(f" Using 'outtmpl' from config, with output directory '{args.output_dir}' prepended.") + else: + # Use a default template if not specified in config + output_template = os.path.join(args.output_dir, '%(title)s [%(id)s].f%(format_id)s.%(ext)s') + ydl_opts['outtmpl'] = output_template + param_sources['outtmpl'] = "Internal Default" + logger.info(f" Using default 'outtmpl': {output_template}") + if args.temp_path: ydl_opts['paths'] = {'temp': args.temp_path} - logger.info(f"Using temporary path: {args.temp_path}") + param_sources['paths'] = "CLI Explicit" + logger.info(f" Temporary path: {args.temp_path}") if args.add_header: if 'http_headers' not in ydl_opts: ydl_opts['http_headers'] = {} elif not isinstance(ydl_opts['http_headers'], dict): - logger.warning(f"Overwriting non-dictionary http_headers from config with headers from command line.") + logger.warning(f"Overwriting non-dictionary http_headers with headers from command line.") ydl_opts['http_headers'] = {} - for header in args.add_header: - if ':' not in header: - logger.error(f"Invalid header format in --add-header: '{header}'. Expected 'Key: Value'.") - return 1 - key, value = header.split(':', 1) - ydl_opts['http_headers'][key.strip()] = value.strip() - logger.info(f"Adding/overwriting header: {key.strip()}: {value.strip()}") + if ':' in header: + key, value = header.split(':', 1) + ydl_opts['http_headers'][key.strip()] = value.strip() + else: + logger.error(f"Invalid header format: '{header}'. Expected 'Key: Value'.") + param_sources['http_headers'] = "CLI Explicit (Merged)" if args.download_continue: ydl_opts['continuedl'] = True ydl_opts['nooverwrites'] = True + param_sources['continuedl'] = "CLI Explicit" + param_sources['nooverwrites'] = "CLI Explicit" if proxy_url: ydl_opts['proxy'] = proxy_url + param_sources['proxy'] = "CLI Explicit" if args.downloader: ydl_opts['downloader'] = {args.downloader: None} + param_sources['downloader'] = "CLI Explicit" if args.downloader_args: - # yt-dlp expects a dict for downloader_args - # e.g., {'aria2c': ['-x', '8']} try: downloader_name, args_str = args.downloader_args.split(':', 1) ydl_opts.setdefault('downloader_args', {})[downloader_name] = shlex.split(args_str) + param_sources['downloader_args'] = "CLI Explicit" except ValueError: - logger.error(f"Invalid --downloader-args format. Expected 'downloader:args'. Got: '{args.downloader_args}'") + logger.error(f"Invalid --downloader-args format. Expected 'downloader:args'.") return 1 if args.merge_output_format: ydl_opts['merge_output_format'] = args.merge_output_format - + param_sources['merge_output_format'] = "CLI Explicit" if args.download_sections: ydl_opts['download_sections'] = args.download_sections - + param_sources['download_sections'] = "CLI Explicit" if args.test: ydl_opts['test'] = True - + param_sources['test'] = "CLI Explicit" if args.retries is not None: ydl_opts['retries'] = args.retries + param_sources['retries'] = "CLI Explicit" if args.fragment_retries is not None: ydl_opts['fragment_retries'] = args.fragment_retries + param_sources['fragment_retries'] = "CLI Explicit" if args.socket_timeout is not None: ydl_opts['socket_timeout'] = args.socket_timeout + param_sources['socket_timeout'] = "CLI Explicit" + if args.concurrent_fragments is not None: + ydl_opts['concurrent_fragments'] = args.concurrent_fragments + param_sources['concurrent_fragments'] = "CLI Explicit" + # To prevent timeouts on slow connections, ensure progress reporting is not disabled. + # The CLI wrapper enables this by default, so we match its behavior for robustness. + if ydl_opts.get('noprogress'): + logger.info("Overriding 'noprogress' option. Progress reporting is enabled to prevent network timeouts.") + ydl_opts['noprogress'] = False + param_sources['noprogress'] = "Internal Override" + + # Ensure byte-size options are integers for library use try: - logger.info(f"Starting download for format '{args.format}' using yt-dlp library...") + from yt_dlp.utils import parse_bytes + if 'buffersize' in ydl_opts and isinstance(ydl_opts['buffersize'], str): + ydl_opts['buffersize'] = parse_bytes(ydl_opts['buffersize']) + param_sources['buffersize'] = param_sources.get('buffersize', 'Unknown') + ' (Parsed)' + except (ImportError, Exception) as e: + logger.warning(f"Could not parse 'buffersize' option: {e}") + + # Force skip_download to False, as this script's purpose is to download. + if ydl_opts.get('skip_download'): + logger.info("Overriding 'skip_download: true' from config. This tool is for downloading.") + ydl_opts['skip_download'] = False + param_sources['skip_download'] = "Internal Override" + + # Log final effective options with sources + logger.info("=== Final Effective yt-dlp Options (base) ===") + for k in sorted(ydl_opts.keys()): + v = ydl_opts[k] + src = param_sources.get(k, "Unknown") + if k in ['logger', 'progress_hooks']: continue + logger.info(f" {k}: {v} [Source: {src}]") + + # --- Parse and process the format string --- + requested_format = args.format + available_formats = [str(f['format_id']) for f in info_data.get('formats', []) if 'format_id' in f] + + # Determine what kind of format string we have + # Keywords that yt-dlp treats as special selectors + selector_keywords = ('best', 'worst', 'bestvideo', 'bestaudio') + + # Split by comma to get individual format requests + # Each item could be a simple format ID or a fallback chain (with /) + format_items = [f.strip() for f in requested_format.split(',') if f.strip()] + + logger.info(f"Format string '{requested_format}' parsed into {len(format_items)} item(s): {format_items}") + + # Process each format item + all_success = True + final_filename = None + + for format_item in format_items: + logger.info(f"--- Processing format item: '{format_item}' ---") - download_buffer = None - if args.output_buffer: - # When downloading to buffer, we redirect stdout to capture the binary data. - download_buffer = io.BytesIO() - ctx_mgr = contextlib.redirect_stdout(download_buffer) - else: - # Otherwise, use a null context manager. - ctx_mgr = contextlib.nullcontext() - - with ctx_mgr, yt_dlp.YoutubeDL(ydl_opts) as ydl: - # The download() method is for URLs. For a pre-fetched info dict, - # we must use process_ie_result to bypass the info extraction step. - # It raises DownloadError on failure, which is caught by the outer try...except block. - ydl.process_ie_result(info_data) - # If process_ie_result completes without an exception, the download was successful. - retcode = 0 - - # The success path is now always taken if no exception was raised. - if retcode == 0: - if ytdlp_logger.is_403: - logger.error("Download failed: yt-dlp reported HTTP Error 403: Forbidden. The URL has likely expired.") - return 1 - if ytdlp_logger.is_timeout: - logger.error("Download failed: yt-dlp reported a timeout.") - return 1 - - logger.info("yt-dlp download completed successfully.") + # Check if this specific item is a simple format ID or a complex selector + item_has_complex_syntax = any(c in format_item for c in '/+[]()') or format_item.startswith(selector_keywords) + + if item_has_complex_syntax: + # This is a complex selector like "299/298/137" or "bestvideo+bestaudio" + # We need to handle fallback chains specially - if args.output_buffer: - # Write the captured binary data to the actual stdout. - sys.stdout.buffer.write(download_buffer.getvalue()) - sys.stdout.buffer.flush() - # Print the filename to stderr for the orchestrator. - if ytdlp_logger.final_filename: - print(ytdlp_logger.final_filename, file=sys.stderr) - else: - # Print the filename to stdout as usual. - if ytdlp_logger.final_filename: - print(ytdlp_logger.final_filename, file=sys.stdout) + if '/' in format_item and '+' not in format_item: + # This is a fallback chain like "299/298/137" + # Try each format in order until one succeeds + fallback_formats = [f.strip() for f in format_item.split('/') if f.strip()] + logger.info(f"Detected fallback chain with {len(fallback_formats)} options: {fallback_formats}") - if args.cleanup: - downloaded_filepath = ytdlp_logger.final_filename - if downloaded_filepath and os.path.exists(downloaded_filepath): - try: - logger.info(f"Cleanup: Renaming and truncating '{downloaded_filepath}'") - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - directory, original_filename = os.path.split(downloaded_filepath) - filename_base, filename_ext = os.path.splitext(original_filename) - new_filename = f"{filename_base}_{timestamp}{filename_ext}.empty" - new_filepath = os.path.join(directory, new_filename) - os.rename(downloaded_filepath, new_filepath) - logger.info(f"Renamed to '{new_filepath}'") - with open(new_filepath, 'w') as f: - pass - logger.info(f"Truncated '{new_filepath}' to 0 bytes.") - except Exception as e: - logger.error(f"Cleanup failed: {e}") - return 1 # Treat cleanup failure as a script failure - elif not args.output_buffer: - logger.warning("Cleanup requested, but no downloaded file was found. Skipping cleanup.") - return 0 - else: - logger.error(f"yt-dlp download failed with internal exit code {retcode}.") - return 1 + item_success = False + for fallback_fmt in fallback_formats: + # Check if this fallback is a simple format ID that exists + if fallback_fmt in available_formats: + logger.info(f"Trying fallback format '{fallback_fmt}'...") + success, ytdlp_logger = _download_single_format(fallback_fmt, info_data, ydl_opts, args) + if success: + item_success = True + if ytdlp_logger.final_filename: + final_filename = ytdlp_logger.final_filename + break + else: + logger.warning(f"Fallback format '{fallback_fmt}' failed, trying next...") + else: + # Try to find a matching format with a suffix (e.g., "140" matches "140-0") + prefix_match_re = re.compile(rf'^{re.escape(fallback_fmt)}-\d+$') + first_match = next((af for af in available_formats if prefix_match_re.match(af)), None) + + if first_match: + logger.info(f"Fallback format '{fallback_fmt}' not found exactly. Using match: '{first_match}'...") + success, ytdlp_logger = _download_single_format(first_match, info_data, ydl_opts, args) + if success: + item_success = True + if ytdlp_logger.final_filename: + final_filename = ytdlp_logger.final_filename + break + else: + logger.warning(f"Fallback format '{first_match}' failed, trying next...") + else: + logger.warning(f"Fallback format '{fallback_fmt}' not available, trying next...") + + if not item_success: + logger.error(f"All fallback formats in '{format_item}' failed or were unavailable.") + all_success = False + else: + # This is a merge request or other complex selector + # We can't safely filter for these, so we pass through to yt-dlp + # but warn the user + logger.warning(f"Complex format selector '{format_item}' detected. Cannot prevent auto-merge for this type.") + logger.warning("If you experience merge errors, try specifying simple format IDs separated by commas.") + + # Use the original yt-dlp behavior for complex selectors + ytdlp_logger = YTDLPLogger() + local_ydl_opts = dict(ydl_opts) + local_ydl_opts['format'] = format_item + local_ydl_opts['logger'] = ytdlp_logger + local_ydl_opts['progress_hooks'] = [lambda d, yl=ytdlp_logger: ytdlp_progress_hook(d, yl)] + + try: + download_buffer = None + if args.output_buffer: + download_buffer = io.BytesIO() + ctx_mgr = contextlib.redirect_stdout(download_buffer) + else: + ctx_mgr = contextlib.nullcontext() - except yt_dlp.utils.DownloadError as e: - # This catches download-specific errors from yt-dlp - logger.error(f"yt-dlp DownloadError: {e}") - return 1 - except Exception as e: - logger.exception(f"An unexpected error occurred during yt-dlp execution: {e}") + with ctx_mgr, yt_dlp.YoutubeDL(local_ydl_opts) as ydl: + ydl.process_ie_result(copy.deepcopy(info_data)) + + if ytdlp_logger.has_errors: + logger.error(f"Download of '{format_item}' failed.") + all_success = False + else: + if ytdlp_logger.final_filename: + final_filename = ytdlp_logger.final_filename + if args.output_buffer and download_buffer: + sys.stdout.buffer.write(download_buffer.getvalue()) + sys.stdout.buffer.flush() + + except yt_dlp.utils.DownloadError as e: + logger.error(f"yt-dlp DownloadError for '{format_item}': {e}") + all_success = False + except Exception as e: + logger.exception(f"Unexpected error downloading '{format_item}': {e}") + all_success = False + else: + # This is a simple format ID like "299-dashy" or "140" + # Check if it exists in available formats + if format_item in available_formats: + success, ytdlp_logger = _download_single_format(format_item, info_data, ydl_opts, args) + if success: + if ytdlp_logger.final_filename: + final_filename = ytdlp_logger.final_filename + else: + all_success = False + else: + # Try to find a matching format with a suffix (e.g., "140" matches "140-0") + prefix_match_re = re.compile(rf'^{re.escape(format_item)}-\d+$') + first_match = next((af for af in available_formats if prefix_match_re.match(af)), None) + + if first_match: + logger.info(f"Requested format '{format_item}' not found. Using first available match: '{first_match}'.") + success, ytdlp_logger = _download_single_format(first_match, info_data, ydl_opts, args) + if success: + if ytdlp_logger.final_filename: + final_filename = ytdlp_logger.final_filename + else: + all_success = False + else: + logger.error(f"Requested format '{format_item}' not found in available formats: {available_formats}") + all_success = False + + # Final output + if all_success: + logger.info("All format downloads completed successfully.") + if final_filename: + print(final_filename, file=sys.stdout) + + if args.cleanup and final_filename and os.path.exists(final_filename): + try: + logger.info(f"Cleanup: Renaming and truncating '{final_filename}'") + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + directory, original_filename = os.path.split(final_filename) + filename_base, filename_ext = os.path.splitext(original_filename) + new_filename = f"{filename_base}_{timestamp}{filename_ext}.empty" + new_filepath = os.path.join(directory, new_filename) + os.rename(final_filename, new_filepath) + logger.info(f"Renamed to '{new_filepath}'") + with open(new_filepath, 'w') as f: + pass + logger.info(f"Truncated '{new_filepath}' to 0 bytes.") + except Exception as e: + logger.error(f"Cleanup failed: {e}") + return 1 + + return 0 + else: + logger.error("One or more format downloads failed.") return 1 diff --git a/ytops_client/download_tool.py b/ytops_client/download_tool.py index 7fdd303..3448579 100644 --- a/ytops_client/download_tool.py +++ b/ytops_client/download_tool.py @@ -194,7 +194,12 @@ def main_download(args): cmd.extend(['--proxy', proxy_url]) if args.lang: - cmd.extend(['--extractor-args', f'youtube:lang={args.lang}']) + lang = args.lang + if '-' in lang: + base_lang = lang.split('-')[0] + logger.warning(f"Language code '{lang}' includes a region, which may not be supported. Using base language '{base_lang}' instead.") + lang = base_lang + cmd.extend(['--extractor-args', f'youtube:lang={lang}']) if args.timezone: logger.warning(f"Timezone override ('{args.timezone}') is not supported by yt-dlp and will be ignored.") @@ -205,20 +210,27 @@ def main_download(args): if capture_output and not args.print_traffic: logger.info("Note: --cleanup or --log-file requires capturing output, which may affect progress bar display.") - logger.info(f"Executing yt-dlp command for format '{args.format}'") + logger.info("--- Configuring and Executing yt-dlp ---") + logger.info(f"Executing for format: '{args.format}'") - # Construct a display version of the command for logging - display_cmd_str = ' '.join(f"'{arg}'" if ' ' in arg else arg for arg in cmd) if os.path.exists(args.cli_config): try: with open(args.cli_config, 'r', encoding='utf-8') as f: - config_contents = ' '.join(f.read().split()) + config_contents = f.read().strip() if config_contents: - logger.info(f"cli.config contents: {config_contents}") + logger.info(f"Parameters from config file ('{args.cli_config}'):") + # Indent each line for readability + for line in config_contents.splitlines(): + if line.strip() and not line.strip().startswith('#'): + logger.info(f" {line.strip()}") except IOError as e: logger.warning(f"Could not read config file {args.cli_config}: {e}") + + logger.info("Note: Command-line arguments will override any conflicting settings from the config file.") - logger.info(f"Full command: {display_cmd_str}") + # Construct a display version of the command for logging + display_cmd_str = ' '.join(f"'{arg}'" if ' ' in arg else arg for arg in cmd) + logger.info(f"Final yt-dlp command: {display_cmd_str}") if capture_output: process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8') diff --git a/ytops_client/downloader.py b/ytops_client/downloader.py new file mode 100644 index 0000000..80d57e9 --- /dev/null +++ b/ytops_client/downloader.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Downloader module for yt-ops-client. +""" + +import json +import logging +import subprocess +import sys +from typing import Dict + +logger = logging.getLogger(__name__) + + +def download_with_config(url: str, config: Dict) -> int: + """ + Download a video using yt-dlp with the given configuration. + + Args: + url: The URL to download + config: A dictionary of yt-dlp options + + Returns: + Exit code (0 for success, non-zero for failure) + """ + # Build the command + cmd = ['yt-dlp'] + + # Convert config to command-line arguments + for key, value in config.items(): + if isinstance(value, bool): + if value: + cmd.append(f'--{key}') + else: + cmd.append(f'--no-{key}') + elif isinstance(value, (int, float, str)): + cmd.append(f'--{key}') + cmd.append(str(value)) + elif isinstance(value, dict): + # Handle nested options (like postprocessor) + # For simplicity, convert to JSON string + cmd.append(f'--{key}') + cmd.append(json.dumps(value)) + elif value is None: + # Skip None values + continue + else: + logger.warning(f"Unsupported config value type for key '{key}': {type(value)}") + cmd.append(f'--{key}') + cmd.append(str(value)) + + cmd.append(url) + + # Run the command + logger.info(f"Running command: {' '.join(cmd)}") + try: + result = subprocess.run(cmd, check=True) + return result.returncode + except subprocess.CalledProcessError as e: + logger.error(f"yt-dlp failed with exit code {e.returncode}") + return e.returncode + except FileNotFoundError: + logger.error("yt-dlp not found. Please install yt-dlp first.") + return 1 + except Exception as e: + logger.error(f"Unexpected error: {e}") + return 1 diff --git a/ytops_client/get_info_tool.py b/ytops_client/get_info_tool.py index 15517f2..beec65e 100644 --- a/ytops_client/get_info_tool.py +++ b/ytops_client/get_info_tool.py @@ -31,25 +31,10 @@ from thrift.transport import TTransport from pangramia.yt.common.ttypes import TokenUpdateMode from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException from yt_ops_services.client_utils import get_thrift_client +from ytops_client.stress_policy import utils as sp_utils from ytops_client.request_params_help import REQUEST_PARAMS_HELP_STRING -def get_video_id(url: str) -> str: - """Extracts a YouTube video ID from a URL.""" - # For URLs like https://www.youtube.com/watch?v=VIDEO_ID - match = re.search(r"v=([0-9A-Za-z_-]{11})", url) - if match: - return match.group(1) - # For URLs like https://youtu.be/VIDEO_ID - match = re.search(r"youtu\.be\/([0-9A-Za-z_-]{11})", url) - if match: - return match.group(1) - # For plain video IDs - if re.fullmatch(r'[0-9A-Za-z_-]{11}', url): - return url - return "unknown_video_id" - - def parse_key_value_params(params_str: str) -> Dict[str, Any]: """Parses a comma-separated string of key=value pairs into a nested dict.""" params = {} @@ -123,10 +108,15 @@ the browser-based generation strategy.''') parser.add_argument('--show-ytdlp-log', action='store_true', help='Print the yt-dlp debug log from the server response.') parser.add_argument('--direct', action='store_true', help='Use the direct yt-dlp info.json generation method, bypassing Node.js token generation.') parser.add_argument('--print-info-out', action='store_true', help='Print the final info.json to stdout. By default, output is suppressed unless writing to a file.') - parser.add_argument('--request-params-json', help=REQUEST_PARAMS_HELP_STRING + '\nCan also be a comma-separated string of key=value pairs (e.g., "caching_policy.mode=force_refresh").') - parser.add_argument('--force-renew', help='Comma-separated list of items to force-renew: cookies, visitor_id, po_token, nsig_cache, info_json, all.') + # The new, more powerful argument for passing JSON config. It replaces --request-params-json. + parser.add_argument('--ytdlp-config-json', help=REQUEST_PARAMS_HELP_STRING) + parser.add_argument('--ytdlp-config-json-file', help='Path to a JSON file containing per-request parameters. Overrides other config arguments.') + parser.add_argument('--request-params-json', help='DEPRECATED: Use --ytdlp-config-json. Accepts JSON, a file path with @, or key=value pairs.') + parser.add_argument('--force-renew', help='Comma-separated list of items to force-renew: cookies, visitor_id, po_token, nsig_cache, info_json, all. Overrides settings in --ytdlp-config-json.') parser.add_argument('--lang', help='Language code for the request (e.g., "fr", "ja"). Affects metadata language.') parser.add_argument('--timezone', help='Timezone for the request (e.g., "UTC", "America/New_York"). Note: experimental, may not be fully supported.') + parser.add_argument('--prevent-cookie-rotation', action='store_true', help='Prevent the server from saving updated cookies for this profile on this request.') + parser.add_argument('--prevent-visitor-rotation', action='store_true', help='Prevent the server from changing the visitor_id for this profile on this request, if one already exists.') return parser def main_get_info(args): @@ -138,7 +128,7 @@ def main_get_info(args): logging.getLogger().setLevel(logging.DEBUG) if args.log_file_auto: - video_id = get_video_id(args.url) + video_id = sp_utils.get_video_id(args.url) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_filename = f"{video_id}-{timestamp}.log" @@ -178,48 +168,118 @@ def main_get_info(args): machine_id = socket.gethostname() logger.info(f"No machine ID provided, using hostname: {machine_id}") + # --- JSON Config Handling --- request_params = {} - if args.request_params_json: - try: - request_params = json.loads(args.request_params_json) - except json.JSONDecodeError: - logger.info("Could not parse --request-params-json as JSON, trying as key-value string.") - request_params = parse_key_value_params(args.request_params_json) + json_config_source = None + + # Config source priority: + # 1. --ytdlp-config-json-file (explicit file path) + # 2. --ytdlp-config-json (JSON string or @file path) + # 3. --request-params-json (deprecated) + # 4. ytdlp.json in current directory (automatic) + + if args.ytdlp_config_json_file: + logger.info(f"Loading config from file specified by --ytdlp-config-json-file: {args.ytdlp_config_json_file}") + json_config_source = f"@{args.ytdlp_config_json_file}" + elif args.ytdlp_config_json: + json_config_source = args.ytdlp_config_json + elif args.request_params_json: + logger.warning("The --request-params-json argument is deprecated and will be removed in a future version. Please use --ytdlp-config-json or --ytdlp-config-json-file instead.") + json_config_source = args.request_params_json + else: + # Fallback to auto-discovery in order of precedence: local files, then user-level files. + home_dir = os.path.expanduser('~') + config_paths_to_check = [ + ('ytdlp.json', "current directory"), + (os.path.join(home_dir, '.config', 'yt-dlp', 'ytdlp.json'), "user config directory"), + ] + + # Find the first existing config file and use it + for path, location in config_paths_to_check: + if os.path.exists(path): + logger.info(f"No config argument provided. Found and loading '{path}' from {location}.") + json_config_source = f"@{path}" + break + else: # nobreak + logger.info("No config argument or default config file found. Proceeding with CLI flags and server defaults.") + if json_config_source: + if json_config_source.startswith('@'): + config_file = json_config_source[1:] + try: + with open(config_file, 'r', encoding='utf-8') as f: + request_params = json.load(f) + logger.info(f"Loaded request parameters from file: {config_file}") + except FileNotFoundError: + logger.error(f"Config file not found: {config_file}") + return 1 + except json.JSONDecodeError as e: + logger.error(f"Error parsing JSON from {config_file}: {e}") + return 1 + else: + # Try parsing as JSON first, then as key-value pairs for backward compatibility. + try: + request_params = json.loads(json_config_source) + logger.info("Loaded request parameters from command-line JSON string.") + except json.JSONDecodeError: + logger.info("Could not parse config as JSON, trying as key=value string for backward compatibility.") + request_params = parse_key_value_params(json_config_source) + + # --- Override JSON with command-line flags for convenience --- + # Server-specific controls are placed under the 'ytops' key. + ytops_params = request_params.setdefault('ytops', {}) + if args.force_renew: items_to_renew = [item.strip() for item in args.force_renew.split(',')] - request_params['force_renew'] = items_to_renew - logger.info(f"Requesting force renew for: {items_to_renew}") + ytops_params['force_renew'] = items_to_renew + logger.info(f"Overriding force_renew with CLI value: {items_to_renew}") - if args.lang: - session_params = request_params.setdefault('session_params', {}) - session_params['lang'] = args.lang - logger.info(f"Requesting language: {args.lang}") + # Session parameters are also server-specific. + session_params = ytops_params.setdefault('session_params', {}) + if args.prevent_cookie_rotation: + session_params['prevent_cookie_rotation'] = True + logger.info("Requesting to prevent cookie rotation for this request.") + if args.prevent_visitor_rotation: + session_params['prevent_visitor_rotation'] = True + logger.info("Requesting to prevent visitor ID rotation for this request.") + + # yt-dlp parameters are at the top level. + ytdlp_params = request_params.setdefault('ytdlp_params', {}) + + # Language and timezone are yt-dlp extractor arguments. + if args.lang or args.timezone: + extractor_args = ytdlp_params.setdefault('extractor_args', {}) + youtube_args = extractor_args.setdefault('youtube', {}) + if args.lang: + # yt-dlp expects lang to be a list + youtube_args['lang'] = [args.lang] + logger.info(f"Overriding lang with CLI value: {args.lang}") + if args.timezone: + # yt-dlp expects timeZone to be a list + youtube_args['timeZone'] = [args.timezone] + logger.info(f"Overriding timezone with CLI value: {args.timezone}") - if args.timezone: - session_params = request_params.setdefault('session_params', {}) - session_params['timeZone'] = args.timezone - logger.info(f"Requesting timezone: {args.timezone}") - if args.verbose: - # Add verbose flag for yt-dlp on the server - ytdlp_params = request_params.setdefault('ytdlp_params', {}) ytdlp_params['verbose'] = True - logger.info("Verbose mode enabled, requesting verbose yt-dlp logs from server.") + logger.info("Overriding verbose to True due to CLI flag.") + + # --client is a special override for a nested extractor arg + if args.client: + extractor_args = ytdlp_params.setdefault('extractor_args', {}) + youtube_args = extractor_args.setdefault('youtube', {}) + # yt-dlp expects player_client to be a list + youtube_args['player_client'] = [args.client] + logger.info(f"Overriding player_client with CLI value: {args.client}") - thrift_args = { - 'accountId': args.profile, - 'updateType': TokenUpdateMode.AUTO, - 'url': args.url, - 'clients': args.client, - 'machineId': machine_id, - 'airflowLogContext': None, - 'requestParamsJson': json.dumps(request_params) if request_params else None, - 'assignedProxyUrl': args.assigned_proxy_url - } + # Determine the assigned proxy, with the CLI flag overriding any value from the JSON config. + assigned_proxy = args.assigned_proxy_url or ytops_params.get('assigned_proxy_url') - # Handle proxy renaming - assigned_proxy = args.assigned_proxy_url + # If a proxy is assigned, ensure it's also set in the ytdlp_params for consistency. + if assigned_proxy: + ytdlp_params['proxy'] = assigned_proxy + logger.info(f"Setting ytdlp_params.proxy to assigned proxy: {assigned_proxy}") + + # Handle proxy renaming if requested if assigned_proxy and args.proxy_rename: rename_rule = args.proxy_rename.strip("'\"") if rename_rule.startswith('s/') and rename_rule.count('/') >= 2: @@ -239,7 +299,17 @@ def main_get_info(args): else: logger.error("Invalid --proxy-rename format. Expected: s/pattern/replacement/") return 1 - thrift_args['assignedProxyUrl'] = assigned_proxy + + thrift_args = { + 'accountId': args.profile, + 'updateType': TokenUpdateMode.AUTO, + 'url': args.url, + 'clients': args.client, # Kept for backward compatibility on server, though player_client in JSON is preferred. + 'machineId': machine_id, + 'airflowLogContext': None, + 'requestParamsJson': json.dumps(request_params) if request_params else None, + 'assignedProxyUrl': assigned_proxy + } if args.client: logger.info(f"Requesting to use specific client: {args.client}") @@ -343,6 +413,15 @@ def main_get_info(args): info_data = json.loads(info_json_str) if hasattr(token_data, 'socks') and token_data.socks: info_data['_proxy_url'] = token_data.socks + + # Add yt-ops metadata to the info.json for self-description + if isinstance(info_data, dict): + info_data['_ytops_metadata'] = { + 'profile_name': args.profile, + 'proxy_url': token_data.socks if hasattr(token_data, 'socks') and token_data.socks else None, + 'generation_timestamp_utc': datetime.utcnow().isoformat() + 'Z' + } + if isinstance(info_data, dict) and 'error' in info_data: error_code = info_data.get('errorCode', 'N/A') error_message = info_data.get('message', info_data.get('error', 'Unknown error')) @@ -387,7 +466,7 @@ def main_get_info(args): # Determine output file path if auto-naming is used output_file = args.output if args.output_auto or args.output_auto_url_only: - video_id = get_video_id(args.url) + video_id = sp_utils.get_video_id(args.url) suffix = args.output_auto_suffix or "" if args.output_auto: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') diff --git a/ytops_client/go_ytdlp_cli/go-ytdlp b/ytops_client/go_ytdlp_cli/go-ytdlp new file mode 100755 index 0000000..c1b1267 Binary files /dev/null and b/ytops_client/go_ytdlp_cli/go-ytdlp differ diff --git a/ytops_client/go_ytdlp_cli/go.mod b/ytops_client/go_ytdlp_cli/go.mod new file mode 100644 index 0000000..656e2ff --- /dev/null +++ b/ytops_client/go_ytdlp_cli/go.mod @@ -0,0 +1,22 @@ +module github.com/yourproject/ytops_client/go_ytdlp_cli + +go 1.23.0 + +toolchain go1.24.4 + +require ( + github.com/lrstanley/go-ytdlp v0.0.0-00010101000000-000000000000 + github.com/spf13/cobra v1.8.0 +) + +replace github.com/lrstanley/go-ytdlp => ../../go-ytdlp + +require ( + github.com/ProtonMail/go-crypto v1.3.0 // indirect + github.com/cloudflare/circl v1.6.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/ulikunitz/xz v0.5.13 // indirect + golang.org/x/crypto v0.41.0 // indirect + golang.org/x/sys v0.35.0 // indirect +) diff --git a/ytops_client/go_ytdlp_cli/go.sum b/ytops_client/go_ytdlp_cli/go.sum new file mode 100644 index 0000000..9a2693b --- /dev/null +++ b/ytops_client/go_ytdlp_cli/go.sum @@ -0,0 +1,27 @@ +github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= +github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/ulikunitz/xz v0.5.13 h1:ar98gWrjf4H1ev05fYP/o29PDZw9DrI3niHtnEqyuXA= +github.com/ulikunitz/xz v0.5.13/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ytops_client/go_ytdlp_cli/main.go b/ytops_client/go_ytdlp_cli/main.go new file mode 100644 index 0000000..09a434d --- /dev/null +++ b/ytops_client/go_ytdlp_cli/main.go @@ -0,0 +1,64 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "os" + + "github.com/lrstanley/go-ytdlp" + "github.com/spf13/cobra" +) + +func main() { + cli := &cobra.Command{ + Use: "go-ytdlp", + Short: "A simple CLI wrapper for go-ytdlp.", + SilenceUsage: true, + SilenceErrors: true, + } + + cli.AddCommand(&cobra.Command{ + Use: "flags-to-json [flags...]", + Short: "Converts yt-dlp flags to a JSON config.", + Long: "Converts yt-dlp flags to a JSON config. Note that this does not validate the flags.", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) (err error) { + // The go-ytdlp library documentation mentions FlagsToJSON and SetFlagConfig, + // but these methods are missing from the generated code in the current version. + // Therefore, we cannot implement this command yet. + return fmt.Errorf("flags-to-json is not supported by the underlying go-ytdlp library") + }, + }) + + cli.AddCommand(&cobra.Command{ + Use: "json-to-flags", + Short: "Converts a JSON config to yt-dlp flags.", + Long: "Converts a JSON config to yt-dlp flags. Note that this does not validate the flags. Reads from stdin.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) (err error) { + var in []byte + in, err = io.ReadAll(cmd.InOrStdin()) + if err != nil { + return err + } + + // Manually unmarshal into FlagConfig since JSONToFlags helper is missing + var cfg ytdlp.FlagConfig + if err := json.Unmarshal(in, &cfg); err != nil { + return fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + flags := cfg.ToFlags() + for _, flag := range flags { + fmt.Fprintln(cmd.OutOrStdout(), flag) + } + return nil + }, + }) + + if err := cli.Execute(); err != nil { + fmt.Fprintln(os.Stderr, "Error:", err) + os.Exit(1) + } +} diff --git a/ytops_client/list_formats_tool.py b/ytops_client/list_formats_tool.py index 8502f2b..7293017 100644 --- a/ytops_client/list_formats_tool.py +++ b/ytops_client/list_formats_tool.py @@ -9,24 +9,13 @@ import re from urllib.parse import urlparse, parse_qs from datetime import datetime, timezone +from .stress_policy import utils as sp_utils + try: import yt_dlp except ImportError: yt_dlp = None -def format_size(b): - """Format size in bytes to human-readable string.""" - if b is None: - return 'N/A' - if b < 1024: - return f"{b}B" - elif b < 1024**2: - return f"{b/1024:.2f}KiB" - elif b < 1024**3: - return f"{b/1024**2:.2f}MiB" - else: - return f"{b/1024**3:.2f}GiB" - def list_formats(info_json, requested_formats_str=None, file=sys.stdout): """Prints a table of available formats from info.json data.""" formats = info_json.get('formats', []) @@ -197,7 +186,7 @@ def list_formats(info_json, requested_formats_str=None, file=sys.stdout): str(fps) if fps else '', str(vcodec)[:18], str(acodec)[:18], - format_size(filesize), + sp_utils.format_size(filesize), f"{tbr:.0f}k" if tbr else 'N/A', partial_url, expire_date, diff --git a/ytops_client/locking_download_emulator_tool.py b/ytops_client/locking_download_emulator_tool.py new file mode 100644 index 0000000..907defe --- /dev/null +++ b/ytops_client/locking_download_emulator_tool.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +Standalone worker tool for the distributed download simulation. +This tool is responsible for the "lock-execute-unlock" workflow for a single +download task based on an info.json file. It's designed to be called by an +orchestrator like `stress_policy_tool.py`. +""" + +import argparse +import json +import logging +import os +import sys +import time +from copy import deepcopy + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +# Temporarily add project root to path to allow importing from sibling packages +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(script_dir, '..')) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from ytops_client.profile_manager_tool import ProfileManager +from ytops_client.stress_policy import utils as sp_utils +from ytops_client.stress_policy.state_manager import StateManager +from ytops_client.stress_policy.utils import load_policy, apply_overrides +from ytops_client.stress_policy.workers import _run_download_logic +from ytops_client.stress_policy_tool import shutdown_event + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def add_locking_download_emulator_parser(subparsers): + """Adds the parser for the 'download-emulator' command.""" + parser = subparsers.add_parser( + 'download-emulator', + help='(Internal) Standalone download worker.', + description='Internal tool to run a single download task with profile locking. Not intended for direct user invocation.' + ) + # Since this is an internal tool, we expect one subcommand. + download_subparsers = parser.add_subparsers(dest='download_emulator_command', help='Action to perform', required=True) + + run_parser = download_subparsers.add_parser( + 'lock-and-run', + help='Lock a profile, run a download, and unlock it.', + formatter_class=argparse.RawTextHelpFormatter + ) + run_parser.add_argument('--policy-file', required=True, help='Path to the YAML policy file.') + run_parser.add_argument('--info-json-path', required=True, help='Path to the info.json file to process.') + run_parser.add_argument('--set', action='append', default=[], help="Override a policy setting using 'key.subkey=value'.") + + # Redis connection arguments, to be passed from the orchestrator + redis_group = run_parser.add_argument_group('Redis Connection') + redis_group.add_argument('--env-file', help='Path to a .env file.') + redis_group.add_argument('--redis-host', help='Redis host.') + redis_group.add_argument('--redis-port', type=int, help='Redis port.') + redis_group.add_argument('--redis-password', help='Redis password.') + redis_group.add_argument('--env', help="Environment name for Redis key prefix.") + redis_group.add_argument('--key-prefix', help='Explicit key prefix for Redis.') + + run_parser.add_argument('--verbose', action='store_true', help='Enable verbose logging.') + + +def main_locking_download_emulator(args): + """Main logic for the 'download-emulator' tool.""" + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # --- Load Policy --- + policy = load_policy(args.policy_file) + if not policy: + return 1 + policy = apply_overrides(policy, args.set) + + # --- Load .env if specified --- + if load_dotenv and args.env_file: + if load_dotenv(args.env_file): + logger.info(f"Loaded environment variables from {args.env_file}") + else: + logger.error(f"Specified --env-file was not found: {args.env_file}") + return 1 + + # --- Setup ProfileManager --- + redis_host = args.redis_host or os.getenv('REDIS_HOST', 'localhost') + redis_port = args.redis_port or int(os.getenv('REDIS_PORT', 6379)) + redis_password = args.redis_password or os.getenv('REDIS_PASSWORD') + + if args.key_prefix: + key_prefix = args.key_prefix + elif args.env: + key_prefix = f"{args.env}_profile_mgmt_" + else: + logger.error("Must provide --env or --key-prefix for Redis connection.") + return 1 + + manager = ProfileManager( + redis_host=redis_host, + redis_port=redis_port, + redis_password=redis_password, + key_prefix=key_prefix + ) + + download_policy = policy.get('download_policy', {}) + profile_prefix = download_policy.get('profile_prefix') + if not profile_prefix: + logger.error("Policy file must specify 'download_policy.profile_prefix'.") + return 1 + + # --- Main Lock-Execute-Unlock Logic --- + owner_id = f"dl-emulator-{os.getpid()}" + locked_profile = None + lock_attempts = 0 + + try: + # --- 1. Lock a profile (with wait & backoff) --- + while not shutdown_event.is_set(): + locked_profile = manager.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + if locked_profile: + logger.info(f"Locked profile '{locked_profile['name']}' with proxy '{locked_profile['proxy']}'.") + break + + # Simplified wait logic from stress_policy_tool + backoff_seconds = [3, 5, 9, 20, 50] + sleep_duration = backoff_seconds[min(lock_attempts, len(backoff_seconds) - 1)] + logger.info(f"No download profiles available. Waiting {sleep_duration}s... (attempt {lock_attempts + 1})") + time.sleep(sleep_duration) + lock_attempts += 1 + + if not locked_profile: + logger.warning("Could not lock a profile; shutting down.") + return 1 + + # --- 2. Read info.json --- + try: + with open(args.info_json_path, 'r', encoding='utf-8') as f: + info_json_content = f.read() + except (IOError, FileNotFoundError) as e: + logger.error(f"Could not read info.json file '{args.info_json_path}': {e}") + return 1 + + # --- 3. Execute download logic --- + # The locked profile's proxy MUST be used for the download. + local_policy = deepcopy(policy) + local_policy.setdefault('download_policy', {})['proxy'] = locked_profile['proxy'] + + # The StateManager is used by _run_download_logic for rate limiting and cooldowns, + # but for this standalone worker, we don't need its persistence features. + # We disable log writing to prevent creating state files. + dummy_state_manager = StateManager(policy_name="locking_emulator_run", disable_log_writing=True) + + results = _run_download_logic( + source=args.info_json_path, + info_json_content=info_json_content, + policy=local_policy, + state_manager=dummy_state_manager, + args=args, # Pass orchestrator args through + profile_name=locked_profile['name'], + profile_manager_instance=manager + ) + + # --- 4. Record overall task activity --- + # Note: Download-specific activity ('download'/'download_error') is recorded + # inside _run_download_logic -> process_info_json_cycle. + download_success = all(r.get('success') for r in results) if results else False + activity_type = 'success' if download_success else 'failure' + manager.record_activity(locked_profile['name'], activity_type) + + logger.info(f"Finished processing '{sp_utils.get_display_name(args.info_json_path)}' with profile '{locked_profile['name']}'. Overall success: {download_success}") + + return 0 if download_success else 1 + + except Exception as e: + logger.error(f"An unexpected error occurred: {e}", exc_info=True) + return 1 + finally: + # --- 5. Unlock the profile --- + if locked_profile: + cooldown_str = manager.get_config('unlock_cooldown_seconds') + cooldown_seconds = int(cooldown_str) if cooldown_str and cooldown_str.isdigit() else None + + if cooldown_seconds and cooldown_seconds > 0: + logger.info(f"Unlocking profile '{locked_profile['name']}' with a {cooldown_seconds}s cooldown.") + manager.unlock_profile(locked_profile['name'], owner=owner_id, rest_for_seconds=cooldown_seconds) + else: + logger.info(f"Unlocking profile '{locked_profile['name']}'.") + manager.unlock_profile(locked_profile['name'], owner=owner_id) + +if __name__ == '__main__': + # This is a simplified parser setup for direct execution, + # the real one is defined in `add_..._parser` for the main CLI. + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest='command') + add_locking_download_emulator_parser(subparsers) + args = parser.parse_args() + + if hasattr(args, 'download_emulator_command'): + sys.exit(main_locking_download_emulator(args)) + else: + parser.print_help() diff --git a/ytops_client/manage_tool.py b/ytops_client/manage_tool.py new file mode 100644 index 0000000..f275a3a --- /dev/null +++ b/ytops_client/manage_tool.py @@ -0,0 +1,891 @@ +#!/usr/bin/env python3 +""" +Tool for managing the ytdlp-ops-server via its Thrift management interface. +""" + +import argparse +import json +import logging +import os +import re +import sys +import time +from datetime import datetime + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +try: + from tabulate import tabulate +except ImportError: + print("'tabulate' library not found. Please install it with: pip install tabulate", file=sys.stderr) + sys.exit(1) + +try: + import redis +except ImportError: + print("'redis' library not found. Please install it with: pip install redis", file=sys.stderr) + sys.exit(1) + + +# Add project's thrift gen_py path +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(script_dir, '..')) +sys.path.insert(0, os.path.join(project_root, 'thrift_model', 'gen_py')) + +try: + from yt_ops_services.client_utils import get_thrift_client, format_timestamp + from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException + from .profile_manager_tool import ProfileManager, format_duration + from .stress_policy import utils as sp_utils +except ImportError: + print("Could not import Thrift modules. Ensure this script is run in an environment where 'yt_ops_services' is installed.", file=sys.stderr) + sys.exit(1) + +logger = logging.getLogger('manage_tool') + +# --- Helper Functions (adapted from regression.py) --- + +def _get_redis_client(redis_host, redis_port, redis_password): + """Gets a Redis client.""" + if not redis_host: + return None + try: + client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, decode_responses=True) + client.ping() + logger.info(f"Successfully connected to Redis at {client.connection_pool.connection_kwargs.get('host')}:{client.connection_pool.connection_kwargs.get('port')}") + return client + except redis.exceptions.ConnectionError as e: + logger.error(f"Failed to connect to Redis: {e}") + return None + except Exception as e: + logger.error(f"An unexpected error occurred while connecting to Redis: {e}") + return None + +def _list_proxy_statuses(client, server_identity=None, auth_manager=None, download_manager=None): + """Lists proxy statuses by connecting directly to the Thrift service.""" + print("\n--- Proxy Statuses ---") + try: + statuses = client.getProxyStatus(server_identity) + if not statuses: + print("No proxy statuses found.") + return + + # Enrich with Redis data from simulation environments + all_proxy_urls = [s.proxyUrl for s in statuses] + auth_proxy_states, download_proxy_states = {}, {} + auth_work_minutes, download_work_minutes = 0, 0 + + if auth_manager: + auth_proxy_states = auth_manager.get_proxy_states(all_proxy_urls) + work_minutes_str = auth_manager.get_config('proxy_work_minutes') + if work_minutes_str and work_minutes_str.isdigit(): + auth_work_minutes = int(work_minutes_str) + + if download_manager: + download_proxy_states = download_manager.get_proxy_states(all_proxy_urls) + work_minutes_str = download_manager.get_config('proxy_work_minutes') + if work_minutes_str and work_minutes_str.isdigit(): + download_work_minutes = int(work_minutes_str) + + status_list = [] + # This is forward-compatible: it checks for new attributes before using them. + has_extended_info = hasattr(statuses[0], 'recentAccounts') or hasattr(statuses[0], 'recentMachines') + + headers = ["Server", "Proxy URL", "Status", "Success", "Failures", "Last Success", "Last Failure"] + if auth_manager: headers.append("Auth State") + if download_manager: headers.append("Download State") + if has_extended_info: + headers.extend(["Recent Accounts", "Recent Machines"]) + + for s in statuses: + status_item = { + "Server": s.serverIdentity, + "Proxy URL": s.proxyUrl, + "Status": s.status, + "Success": s.successCount, + "Failures": s.failureCount, + "Last Success": format_timestamp(s.lastSuccessTimestamp), + "Last Failure": format_timestamp(s.lastFailureTimestamp), + } + + now = time.time() + if auth_manager: + state_data = auth_proxy_states.get(s.proxyUrl, {}) + state = state_data.get('state', 'N/A') + rest_until = state_data.get('rest_until', 0) + work_start = state_data.get('work_start_timestamp', 0) + + state_str = state + if state == 'RESTING' and rest_until > now: + state_str += f"\n(ends in {format_duration(rest_until - now)})" + elif state == 'ACTIVE' and work_start > 0 and auth_work_minutes > 0: + work_end_time = work_start + (auth_work_minutes * 60) + if work_end_time > now: + state_str += f"\n(ends in {format_duration(work_end_time - now)})" + status_item["Auth State"] = state_str + + if download_manager: + state_data = download_proxy_states.get(s.proxyUrl, {}) + state = state_data.get('state', 'N/A') + rest_until = state_data.get('rest_until', 0) + work_start = state_data.get('work_start_timestamp', 0) + + state_str = state + if state == 'RESTING' and rest_until > now: + state_str += f"\n(ends in {format_duration(rest_until - now)})" + elif state == 'ACTIVE' and work_start > 0 and download_work_minutes > 0: + work_end_time = work_start + (download_work_minutes * 60) + if work_end_time > now: + state_str += f"\n(ends in {format_duration(work_end_time - now)})" + status_item["Download State"] = state_str + + if has_extended_info: + recent_accounts = getattr(s, 'recentAccounts', []) + recent_machines = getattr(s, 'recentMachines', []) + status_item["Recent Accounts"] = "\n".join(recent_accounts) if recent_accounts else "N/A" + status_item["Recent Machines"] = "\n".join(recent_machines) if recent_machines else "N/A" + status_list.append(status_item) + + # Stabilize column widths to reduce jitter in --watch mode + if status_list: + headers = list(status_list[0].keys()) + table_data = [[item.get(h) for h in headers] for item in status_list] + + # Calculate max width for each column based on its content, accounting for multi-line cells. + columns = list(zip(*([headers] + table_data))) + maxwidths = [] + for col in columns: + max_w = 0 + for cell in col: + cell_w = max((len(line) for line in str(cell).split('\n')), default=0) if cell else 0 + if cell_w > max_w: + max_w = cell_w + maxwidths.append(max_w) + + # Enforce a minimum width for columns that can have variable content + MIN_WIDTH = 25 + if "Recent Accounts" in headers: + idx = headers.index("Recent Accounts") + maxwidths[idx] = max(MIN_WIDTH, maxwidths[idx]) + if "Recent Machines" in headers: + idx = headers.index("Recent Machines") + maxwidths[idx] = max(MIN_WIDTH, maxwidths[idx]) + + print(tabulate(table_data, headers=headers, tablefmt='grid', maxcolwidths=maxwidths)) + else: + print(tabulate(status_list, headers='keys', tablefmt='grid')) + + if not has_extended_info: + print("\nNOTE: To see Recent Accounts/Machines, the server's `getProxyStatus` method must be updated to return these fields.") + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to get proxy statuses: {e.message}") + except Exception as e: + logger.error(f"An unexpected error occurred while getting proxy statuses: {e}", exc_info=True) + +def _list_account_statuses(client, redis_client, account_id=None): + """Lists account statuses from Thrift, optionally enriched with live Redis data.""" + print(f"\n--- Account Statuses ---") + try: + statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None) + if not statuses: + print("No account statuses found.") + return + + status_list = [] + for s in statuses: + status_str = s.status + if redis_client and 'RESTING' in status_str: + try: + status_key = f"account_status:{s.accountId}" + expiry_ts_bytes = redis_client.hget(status_key, "resting_until") + if expiry_ts_bytes: + expiry_ts = float(expiry_ts_bytes) + now = datetime.now().timestamp() + if now >= expiry_ts: + status_str = "ACTIVE (was RESTING)" + else: + remaining_seconds = int(expiry_ts - now) + if remaining_seconds > 3600: + status_str = f"RESTING (active in {remaining_seconds // 3600}h {remaining_seconds % 3600 // 60}m)" + elif remaining_seconds > 60: + status_str = f"RESTING (active in {remaining_seconds // 60}m {remaining_seconds % 60}s)" + else: + status_str = f"RESTING (active in {remaining_seconds}s)" + except Exception as e: + logger.warning(f"Could not parse resting time for {s.accountId} from Redis: {e}. Using server status.") + + last_success = float(s.lastSuccessTimestamp) if s.lastSuccessTimestamp else 0 + last_failure = float(s.lastFailureTimestamp) if s.lastFailureTimestamp else 0 + last_activity = max(last_success, last_failure) + + status_list.append({ + "Account ID": s.accountId, "Status": status_str, "Success": s.successCount, + "Failures": s.failureCount, "Last Success": format_timestamp(s.lastSuccessTimestamp), + "Last Failure": format_timestamp(s.lastFailureTimestamp), "Last Proxy": s.lastUsedProxy or "N/A", + "_last_activity": last_activity, + }) + + status_list.sort(key=lambda item: item.get('_last_activity', 0), reverse=True) + for item in status_list: + if '_last_activity' in item: + del item['_last_activity'] + + # Stabilize column widths to reduce jitter in --watch mode + if status_list: + headers = list(status_list[0].keys()) + table_data = [[item.get(h) for h in headers] for item in status_list] + + columns = list(zip(*([headers] + table_data))) + maxwidths = [max((len(str(x)) for x in col), default=0) if col else 0 for col in columns] + + # Enforce a minimum width for the Status column to prevent it from changing size + STATUS_MIN_WIDTH = 30 + if "Status" in headers: + idx = headers.index("Status") + maxwidths[idx] = max(STATUS_MIN_WIDTH, maxwidths[idx]) + + print(tabulate(table_data, headers=headers, tablefmt='grid', maxcolwidths=maxwidths)) + else: + print(tabulate(status_list, headers='keys', tablefmt='grid')) + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to get account statuses: {e.message}") + except Exception as e: + logger.error(f"An unexpected error occurred while getting account statuses: {e}", exc_info=True) + +def _list_client_statuses(redis_client): + """Lists client statistics from Redis.""" + if not redis_client: + return + print("\n--- Client Statuses (from Redis) ---") + try: + stats_key = "client_stats" + all_stats_raw = redis_client.hgetall(stats_key) + if not all_stats_raw: + print("No client stats found in Redis.") + return + + status_list = [] + for client, stats_json in all_stats_raw.items(): + try: + stats = json.loads(stats_json) + def format_latest(data): + if not data: return "N/A" + ts = format_timestamp(data.get('timestamp')) + url = data.get('url', 'N/A') + video_id = sp_utils.get_video_id(url) + if video_id == "unknown_video_id": + video_id = "N/A" + return f"{ts} ({video_id})" + + status_list.append({ + "Client": client, "Success": stats.get('success_count', 0), + "Failures": stats.get('failure_count', 0), + "Last Success": format_latest(stats.get('latest_success')), + "Last Failure": format_latest(stats.get('latest_failure')), + }) + except (json.JSONDecodeError, AttributeError): + status_list.append({"Client": client, "Success": "ERROR", "Failures": "ERROR", "Last Success": "Parse Error", "Last Failure": "Parse Error"}) + + status_list.sort(key=lambda item: item.get('Client', '')) + print(tabulate(status_list, headers='keys', tablefmt='grid')) + except Exception as e: + logger.error(f"An unexpected error occurred while getting client statuses: {e}", exc_info=True) + + +def _list_activity_counters(redis_client): + """Lists current activity rates for proxies and accounts from Redis.""" + if not redis_client: + print("\n--- Activity Counters ---\nRedis is not configured. Cannot show activity counters.\n---------------------------\n") + return + + print("\n--- Activity Counters ---") + now = time.time() + + def process_keys(pattern, entity_name): + try: + keys = redis_client.scan_iter(pattern) + except redis.exceptions.RedisError as e: + logger.error(f"Redis error scanning for keys with pattern '{pattern}': {e}") + return + + status_list = [] + for key in keys: + entity_id = key.split(':', 2)[-1] + + try: + + count_1m = redis_client.zcount(key, now - 60, now) + count_5m = redis_client.zcount(key, now - 300, now) + count_1h = redis_client.zcount(key, now - 3600, now) + + if count_1h == 0: # Don't show entities with no recent activity + continue + + status_list.append({ + entity_name: entity_id, + "Activity (Last 1m)": count_1m, + "Activity (Last 5m)": count_5m, + "Activity (Last 1h)": count_1h, + }) + except redis.exceptions.RedisError as e: + logger.error(f"Redis error processing key '{key}': {e}") + + status_list.sort(key=lambda item: item.get(entity_name, '')) + + print(f"\n--- {entity_name} Activity Counters ---") + if not status_list: + print(f"No recent activity found for {entity_name.lower()}s.") + else: + print(f"\n{tabulate(status_list, headers='keys', tablefmt='grid')}") + print("-----------------------------------\n") + + try: + process_keys("activity:per_proxy:*", "Proxy URL") + process_keys("activity:per_account:*", "Account ID") + except Exception as e: + logger.error(f"An unexpected error occurred while getting activity counters: {e}", exc_info=True) + print(f"\nERROR: An unexpected error occurred: {e}\n") + + +def get_system_status(args): + """Connects to services and prints status tables.""" + logger.info("--- Getting System Status ---") + client, transport = None, None + redis_client = _get_redis_client(args.redis_host, args.redis_port, args.redis_password) + + auth_manager, download_manager = None, None + + def _create_manager(env_name): + if not env_name or not args.redis_host: return None + if args.key_prefix: + key_prefix = args.key_prefix + else: + key_prefix = f"{env_name}_profile_mgmt_" + return ProfileManager(args.redis_host, args.redis_port, args.redis_password, key_prefix) + + # Precedence: --auth-env > --env + auth_env_to_use = args.auth_env or args.env + if auth_env_to_use: + auth_manager = _create_manager(auth_env_to_use) + + # Precedence: --download-env > --env + download_env_to_use = args.download_env or args.env + if download_env_to_use: + # If it's the same env, reuse the manager instance + if download_env_to_use == auth_env_to_use and auth_manager: + download_manager = auth_manager + else: + download_manager = _create_manager(download_env_to_use) + + try: + client, transport = get_thrift_client(args.host, args.port) + _list_proxy_statuses(client, args.server_identity, auth_manager=auth_manager, download_manager=download_manager) + _list_account_statuses(client, redis_client) + _list_client_statuses(redis_client) + except Exception as e: + logger.error(f"Could not get system status: {e}") + finally: + if transport and transport.isOpen(): + transport.close() + +def main_activity_counters(args): + """Main logic for the 'activity-counters' command.""" + redis_client = _get_redis_client(args.redis_host, args.redis_port, args.redis_password) + _list_activity_counters(redis_client) + return 0 + +def main_status(args): + """Main logic for the 'status' command.""" + if not args.watch: + get_system_status(args) + return 0 + + try: + while True: + os.system('cls' if os.name == 'nt' else 'clear') + print(f"--- System Status (auto-refreshing every {args.watch} seconds, press Ctrl+C to exit) ---") + print(f"--- Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---") + get_system_status(args) + + for i in range(args.watch, 0, -1): + sys.stdout.write(f"\rRefreshing in {i} seconds... ") + sys.stdout.flush() + time.sleep(1) + # Clear the countdown line before the next refresh + sys.stdout.write("\r" + " " * 30 + "\r") + sys.stdout.flush() + + except KeyboardInterrupt: + print("\nStopping status watch.") + return 0 + +def main_ban_proxy(args): + """Main logic for the 'ban-proxy' command.""" + client, transport = None, None + try: + client, transport = get_thrift_client(args.host, args.port) + success = client.banProxy(args.proxy_url, args.server_identity) + if success: + print(f"Successfully banned proxy '{args.proxy_url}' for server '{args.server_identity}'.") + return 0 + else: + logger.error("Server reported failure to ban proxy.") + return 1 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to ban proxy: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + +def main_unban_proxy(args): + """Main logic for the 'unban-proxy' command.""" + client, transport = None, None + try: + client, transport = get_thrift_client(args.host, args.port) + success = client.unbanProxy(args.proxy_url, args.server_identity) + if success: + print(f"Successfully unbanned proxy '{args.proxy_url}' for server '{args.server_identity}'.") + return 0 + else: + logger.error("Server reported failure to unban proxy.") + return 1 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to unban proxy: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + +def main_reset_proxies(args): + """Main logic for the 'reset-proxies' command.""" + client, transport = None, None + try: + client, transport = get_thrift_client(args.host, args.port) + success = client.resetAllProxyStatuses(args.server_identity) + if success: + print(f"Successfully reset all proxy statuses for server '{args.server_identity}'.") + return 0 + else: + logger.error("Server reported failure to reset proxies.") + return 1 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to reset proxies: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def main_ban_all_proxies(args): + """Main logic for the 'ban-all-proxies' command.""" + client, transport = None, None + try: + client, transport = get_thrift_client(args.host, args.port) + if args.server_identity: + print(f"Banning all proxies for server '{args.server_identity}'...") + client.banAllProxies(args.server_identity) + print(f"Successfully sent request to ban all proxies for '{args.server_identity}'.") + else: + print("Banning all proxies for ALL servers...") + all_statuses = client.getProxyStatus(None) + if not all_statuses: + print("No proxies found for any server. Nothing to ban.") + return 0 + + all_server_identities = sorted(list(set(s.serverIdentity for s in all_statuses))) + print(f"Found {len(all_server_identities)} server identities. Sending ban request for each...") + + success_count, fail_count = 0, 0 + for identity in all_server_identities: + try: + client.banAllProxies(identity) + print(f" - Sent ban_all for '{identity}'.") + success_count += 1 + except Exception as e: + logger.error(f" - Failed to ban all proxies for '{identity}': {e}") + fail_count += 1 + + print(f"\nSuccessfully sent ban_all requests for {success_count} server identities.") + if fail_count > 0: + print(f"Failed to send ban_all requests for {fail_count} server identities. See logs for details.") + return 0 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to ban all proxies: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def main_unban_all_proxies(args): + """Main logic for the 'unban-all-proxies' command.""" + client, transport = None, None + try: + client, transport = get_thrift_client(args.host, args.port) + if args.server_identity: + print(f"Unbanning all proxies for server '{args.server_identity}'...") + client.resetAllProxyStatuses(args.server_identity) + print(f"Successfully sent request to unban all proxies for '{args.server_identity}'.") + else: + print("Unbanning all proxies for ALL servers...") + all_statuses = client.getProxyStatus(None) + if not all_statuses: + print("No proxies found for any server. Nothing to unban.") + return 0 + + all_server_identities = sorted(list(set(s.serverIdentity for s in all_statuses))) + print(f"Found {len(all_server_identities)} server identities. Sending unban request for each...") + + success_count, fail_count = 0, 0 + for identity in all_server_identities: + try: + client.resetAllProxyStatuses(identity) + print(f" - Sent unban_all for '{identity}'.") + success_count += 1 + except Exception as e: + logger.error(f" - Failed to unban all proxies for '{identity}': {e}") + fail_count += 1 + + print(f"\nSuccessfully sent unban_all requests for {success_count} server identities.") + if fail_count > 0: + print(f"Failed to send unban_all requests for {fail_count} server identities. See logs for details.") + + return 0 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to unban all proxies: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def main_ban_account(args): + """Main logic for the 'ban-account' command.""" + client, transport = None, None + try: + client, transport = get_thrift_client(args.host, args.port) + reason = f"Manual ban from yt-ops-client by {os.getlogin() if hasattr(os, 'getlogin') else 'unknown_user'}" + client.banAccount(accountId=args.account_id, reason=reason) + print(f"Successfully sent request to ban account '{args.account_id}'.") + return 0 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to ban account: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def main_unban_account(args): + """Main logic for the 'unban-account' command.""" + client, transport = None, None + redis_client = _get_redis_client(args.redis_host, args.redis_port, args.redis_password) + if not redis_client: + logger.error("Redis connection is required to correctly unban an account (to reset success_count_at_activation).") + return 1 + try: + client, transport = get_thrift_client(args.host, args.port) + reason = f"Manual un-ban from yt-ops-client by {os.getlogin() if hasattr(os, 'getlogin') else 'unknown_user'}" + + statuses = client.getAccountStatus(accountId=args.account_id, accountPrefix=None) + if not statuses: + logger.error(f"Account '{args.account_id}' not found.") + return 1 + current_success_count = statuses[0].successCount or 0 + + client.unbanAccount(accountId=args.account_id, reason=reason) + print(f"Successfully sent request to unban account '{args.account_id}'.") + + redis_client.hset(f"account_status:{args.account_id}", "success_count_at_activation", current_success_count) + print(f"Set 'success_count_at_activation' for '{args.account_id}' to {current_success_count}.") + + return 0 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to unban account: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def main_unban_all_accounts(args): + """Main logic for the 'unban-all-accounts' command.""" + client, transport = None, None + redis_client = _get_redis_client(args.redis_host, args.redis_port, args.redis_password) + if not redis_client: + logger.error("Redis connection is required to correctly unban accounts.") + return 1 + + try: + client, transport = get_thrift_client(args.host, args.port) + account_prefix = args.account_id # can be prefix + + print(f"Unbanning all accounts (prefix: '{account_prefix or 'ALL'}')...") + all_statuses = client.getAccountStatus(accountId=None, accountPrefix=account_prefix) + if not all_statuses: + print(f"No accounts found with prefix '{account_prefix or 'ALL'}' to unban.") + return 0 + + accounts_to_unban = [s.accountId for s in all_statuses] + account_map = {s.accountId: s for s in all_statuses} + print(f"Found {len(accounts_to_unban)} accounts. Sending unban request for each...") + + unban_count, fail_count = 0, 0 + reason = f"Manual unban_all from yt-ops-client by {os.getlogin() if hasattr(os, 'getlogin') else 'unknown_user'}" + + for acc_id in accounts_to_unban: + try: + client.unbanAccount(accountId=acc_id, reason=reason) + current_success_count = account_map[acc_id].successCount or 0 + redis_client.hset(f"account_status:{acc_id}", "success_count_at_activation", current_success_count) + unban_count += 1 + except Exception as e: + logger.error(f" - Failed to unban account '{acc_id}': {e}") + fail_count += 1 + + print(f"\nSuccessfully sent unban requests for {unban_count} accounts.") + if fail_count > 0: + print(f"Failed to send unban requests for {fail_count} accounts. See logs for details.") + + return 0 + except (PBServiceException, PBUserException) as e: + logger.error(f"Failed to unban all accounts: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def main_delete(args): + """Main logic for 'delete' commands.""" + if not args.yes: + print("This is a destructive action. Use --yes to confirm.", file=sys.stderr) + return 1 + + client, transport = None, None + try: + # For Redis-only actions, we don't need a Thrift client. + if args.delete_entity not in ['client-stats', 'activity-counters']: + client, transport = get_thrift_client(args.host, args.port) + + if args.delete_entity == 'proxies': + if args.proxy_url and args.server_identity: + print(f"Deleting proxy '{args.proxy_url}' for server '{args.server_identity}'...") + result = client.deleteProxyFromRedis(args.proxy_url, args.server_identity) + if result: + print(f"Successfully deleted proxy '{args.proxy_url}' for server '{args.server_identity}' from Redis.") + else: + print(f"Failed to delete proxy '{args.proxy_url}'. It may not have existed.") + else: + print(f"Deleting all proxies from Redis (server filter: {args.server_identity or 'ALL'})...") + result = client.deleteAllProxiesFromRedis(args.server_identity) + print(f"Successfully deleted {result} proxy key(s) from Redis.") + + elif args.delete_entity == 'accounts': + if args.account_id: + if args.prefix: + print(f"Deleting accounts with prefix '{args.account_id}' from Redis...") + result = client.deleteAllAccountsFromRedis(args.account_id) + print(f"Successfully deleted {result} account(s) with prefix '{args.account_id}' from Redis.") + else: + print(f"Deleting account '{args.account_id}' from Redis...") + result = client.deleteAccountFromRedis(args.account_id) + if result: + print(f"Successfully deleted account '{args.account_id}' from Redis.") + else: + print(f"Failed to delete account '{args.account_id}'. It may not have existed.") + else: # Delete all + print("Deleting ALL accounts from Redis...") + result = client.deleteAllAccountsFromRedis(None) + print(f"Successfully deleted {result} account(s) from Redis.") + + elif args.delete_entity == 'client-stats': + redis_client = _get_redis_client(args.redis_host, args.redis_port, args.redis_password) + if not redis_client: + logger.error("Redis connection is required to delete client stats.") + return 1 + print("Deleting all client stats from Redis...") + result = redis_client.delete("client_stats") + if result > 0: + print("Successfully deleted 'client_stats' key from Redis.") + else: + print("Key 'client_stats' not found in Redis. Nothing to delete.") + + elif args.delete_entity == 'activity-counters': + redis_client = _get_redis_client(args.redis_host, args.redis_port, args.redis_password) + if not redis_client: + logger.error("Redis connection is required to delete activity counters.") + return 1 + print("Deleting all activity counters from Redis...") + + deleted_count = 0 + for pattern in ["activity:per_proxy:*", "activity:per_account:*"]: + keys_to_delete_chunk = [] + for key in redis_client.scan_iter(pattern): + keys_to_delete_chunk.append(key) + if len(keys_to_delete_chunk) >= 500: + deleted_count += redis_client.delete(*keys_to_delete_chunk) + keys_to_delete_chunk = [] + if keys_to_delete_chunk: + deleted_count += redis_client.delete(*keys_to_delete_chunk) + + if deleted_count > 0: + print(f"Successfully deleted {deleted_count} activity counter keys from Redis.") + else: + print("No activity counter keys found to delete.") + + return 0 + + except (PBServiceException, PBUserException) as e: + logger.error(f"Thrift error performing delete action: {e.message}") + return 1 + finally: + if transport and transport.isOpen(): + transport.close() + + +def add_manage_parser(subparsers): + """Add the parser for the 'manage' command.""" + parser = subparsers.add_parser( + 'manage', + description='Manage the ytdlp-ops-server.', + formatter_class=argparse.RawTextHelpFormatter, + help='Manage the ytdlp-ops-server.' + ) + + # Common arguments for all manage subcommands + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--env-file', help='Path to a .env file to load environment variables from.') + common_parser.add_argument('--host', default=None, help="Thrift management server host. Defaults to MASTER_HOST_IP env var or 127.0.0.1.") + common_parser.add_argument('--port', type=int, default=9090, help='Thrift management server port.') + common_parser.add_argument('--redis-host', default=None, help='Redis host for client stats. Defaults to REDIS_HOST env var.') + common_parser.add_argument('--redis-port', type=int, default=None, help='Redis port. Defaults to REDIS_PORT env var or 6379.') + common_parser.add_argument('--redis-password', default=None, help='Redis password. Defaults to REDIS_PASSWORD env var.') + common_parser.add_argument('--verbose', action='store_true', help='Enable verbose output.') + common_parser.add_argument('--env', default=None, help="Default environment name for Redis key prefix. Used if --auth-env or --download-env are not specified.") + common_parser.add_argument('--auth-env', help="Environment for the Auth simulation to enrich status from.") + common_parser.add_argument('--download-env', help="Environment for the Download simulation to enrich status from.") + common_parser.add_argument('--key-prefix', default=None, help='Explicit key prefix for Redis. Overrides --env and any defaults.') + + manage_subparsers = parser.add_subparsers(dest='manage_command', help='Available management commands', required=True) + + # --- Status and Listing Commands --- + status_parser = manage_subparsers.add_parser('status', help='View system status.', parents=[common_parser]) + status_parser.add_argument('--server-identity', help='Filter status for a specific server identity.') + status_parser.add_argument('--watch', type=int, nargs='?', const=5, help='Periodically refresh status every N seconds. Default: 5.') + + activity_parser = manage_subparsers.add_parser('activity-counters', help='View current activity rates for proxies and accounts.', parents=[common_parser]) + + # --- Proxy Management Commands --- + ban_proxy_parser = manage_subparsers.add_parser('ban-proxy', help='Ban a proxy.', parents=[common_parser]) + ban_proxy_parser.add_argument('proxy_url', help='The full proxy URL to ban (e.g., "socks5://host:port").') + ban_proxy_parser.add_argument('--server-identity', required=True, help='The server identity for which to ban the proxy.') + + unban_proxy_parser = manage_subparsers.add_parser('unban-proxy', help='Unban a proxy.', parents=[common_parser]) + unban_proxy_parser.add_argument('proxy_url', help='The full proxy URL to unban.') + unban_proxy_parser.add_argument('--server-identity', required=True, help='The server identity for which to unban the proxy.') + + ban_all_proxies_parser = manage_subparsers.add_parser('ban-all-proxies', help='Ban all proxies for one or all servers.', parents=[common_parser]) + ban_all_proxies_parser.add_argument('--server-identity', help='Optional server identity to ban all proxies for. If omitted, bans for all servers.') + + unban_all_proxies_parser = manage_subparsers.add_parser('unban-all-proxies', help='Unban all proxies for one or all servers.', parents=[common_parser]) + unban_all_proxies_parser.add_argument('--server-identity', help='Optional server identity to unban all proxies for. If omitted, unbans for all servers.') + + # --- Account Management Commands --- + ban_account_parser = manage_subparsers.add_parser('ban-account', help='Ban an account.', parents=[common_parser]) + ban_account_parser.add_argument('account_id', help='The account ID to ban.') + + unban_account_parser = manage_subparsers.add_parser('unban-account', help='Unban an account.', parents=[common_parser]) + unban_account_parser.add_argument('account_id', help='The account ID to unban.') + + unban_all_accounts_parser = manage_subparsers.add_parser('unban-all-accounts', help='Unban all accounts, optionally filtered by a prefix.', parents=[common_parser]) + unban_all_accounts_parser.add_argument('account_id', nargs='?', help='Optional account prefix to filter which accounts to unban.') + + # --- Destructive Delete Commands --- + delete_parser = manage_subparsers.add_parser('delete', help='(Destructive) Delete entities from Redis.') + delete_subparsers = delete_parser.add_subparsers(dest='delete_entity', help='Entity to delete', required=True) + + # Create a parent for the confirmation flag, so it can be used on sub-subcommands + confirm_parser = argparse.ArgumentParser(add_help=False) + confirm_parser.add_argument('--yes', action='store_true', help='Confirm the destructive action.') + + delete_proxies_parser = delete_subparsers.add_parser('proxies', help='Delete one or all proxies from Redis.', parents=[common_parser, confirm_parser]) + delete_proxies_parser.add_argument('--proxy-url', help='The proxy URL to delete.') + delete_proxies_parser.add_argument('--server-identity', help='The server identity of the proxy to delete. Required if --proxy-url is given. If omitted, deletes all proxies for all servers.') + + delete_accounts_parser = delete_subparsers.add_parser('accounts', help='Delete one or all accounts from Redis.', parents=[common_parser, confirm_parser]) + delete_accounts_parser.add_argument('--account-id', help='The account ID to delete. If --prefix is used, this is treated as a prefix.') + delete_accounts_parser.add_argument('--prefix', action='store_true', help='Treat --account-id as a prefix to delete multiple accounts.') + + delete_client_stats_parser = delete_subparsers.add_parser('client-stats', help='Delete all client stats from Redis.', parents=[common_parser, confirm_parser]) + + delete_activity_counters_parser = delete_subparsers.add_parser('activity-counters', help='Delete all activity counter stats from Redis.', description='Deletes all activity counter stats (keys matching "activity:*") from Redis. This does NOT delete account or proxy statuses.', parents=[common_parser, confirm_parser]) + + return parser + +def main_manage(args): + """Main dispatcher for 'manage' command.""" + if load_dotenv: + # load_dotenv() with no args will search for a .env file. + # If args.env_file is provided, it will use that specific file. + was_loaded = load_dotenv(args.env_file) + if was_loaded: + logger.info(f"Loaded environment variables from {args.env_file or '.env file'}") + elif args.env_file: + # If a specific file was requested but not found, it's an error. + logger.error(f"The specified --env-file was not found: {args.env_file}") + return 1 + elif args.env_file: + logger.warning("'python-dotenv' is not installed, but --env-file was provided. Please install it with: pip install python-dotenv") + else: + logger.info("'python-dotenv' not installed. Relying on shell environment variables and command-line arguments.") + + + # Set defaults for args that were not provided, now that .env is loaded. + if args.host is None: + args.host = os.getenv('MASTER_HOST_IP', '127.0.0.1') + if args.redis_host is None: + # Default REDIS_HOST to the management host if not explicitly set + args.redis_host = os.getenv('REDIS_HOST', args.host) + if args.redis_port is None: + args.redis_port = int(os.getenv('REDIS_PORT', 6379)) + if args.redis_password is None: + args.redis_password = os.getenv('REDIS_PASSWORD') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Log the effective connection parameters being used. + logger.info(f"Using management host: {args.host}:{args.port}") + if args.redis_host: + redis_password_status = "provided" if args.redis_password else "not provided" + logger.info(f"Using Redis host: {args.redis_host}:{args.redis_port} (password: {redis_password_status})") + else: + logger.warning("Redis host not configured (via --redis-host or REDIS_HOST env var). Redis-dependent features will be unavailable.") + + if args.manage_command == 'status': + return main_status(args) + elif args.manage_command == 'activity-counters': + return main_activity_counters(args) + elif args.manage_command == 'ban-proxy': + return main_ban_proxy(args) + elif args.manage_command == 'unban-proxy': + return main_unban_proxy(args) + elif args.manage_command == 'ban-all-proxies': + return main_ban_all_proxies(args) + elif args.manage_command == 'unban-all-proxies': + return main_unban_all_proxies(args) + elif args.manage_command == 'ban-account': + return main_ban_account(args) + elif args.manage_command == 'unban-account': + return main_unban_account(args) + elif args.manage_command == 'unban-all-accounts': + return main_unban_all_accounts(args) + elif args.manage_command == 'delete': + return main_delete(args) + + return 1 # Should not be reached diff --git a/ytops_client/policy_enforcer_tool.py b/ytops_client/policy_enforcer_tool.py new file mode 100644 index 0000000..4b517ee --- /dev/null +++ b/ytops_client/policy_enforcer_tool.py @@ -0,0 +1,1297 @@ +#!/usr/bin/env python3 +""" +CLI tool to enforce policies on profiles. +""" +import argparse +import json +import logging +import os +import signal +import sys +import time + +try: + import yaml +except ImportError: + print("PyYAML is not installed. Please install it with: pip install PyYAML", file=sys.stderr) + yaml = None + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +from .profile_manager_tool import ProfileManager + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Graceful shutdown handler +shutdown_event = False +def handle_shutdown(sig, frame): + global shutdown_event + logger.info("Shutdown signal received. Stopping policy enforcer...") + shutdown_event = True + +class PolicyEnforcer: + def __init__(self, manager, dry_run=False): + self.manager = manager + self.dry_run = dry_run + self.actions_taken_this_cycle = 0 + + PROXY_REST_REASON = "Proxy resting" + + def apply_policies(self, args): + self.actions_taken_this_cycle = 0 + logger.debug(f"Applying policies... (Dry run: {self.dry_run})") + + # --- Phase 1: Policies that don't depend on a consistent profile state snapshot --- + # Manage proxy states and clean up stale locks before we fetch profile states. + self.enforce_proxy_group_rotation(getattr(args, 'proxy_groups', [])) + self.enforce_proxy_work_rest_cycle(args) + self.enforce_max_proxy_active_time(args) + self.enforce_proxy_policies(args) + if args.unlock_stale_locks_after_seconds and args.unlock_stale_locks_after_seconds > 0: + self.enforce_stale_lock_cleanup(args.unlock_stale_locks_after_seconds) + + # --- Phase 2: Policies that require a consistent, shared view of profile states --- + # Fetch all profile states ONCE to create a consistent snapshot for this cycle. + all_profiles_list = self.manager.list_profiles() + all_profiles_map = {p['name']: p for p in all_profiles_list} + + # Apply profile group policies (rotation, max_active). This will modify the local `all_profiles_map`. + self.enforce_profile_group_policies(getattr(args, 'profile_groups', []), all_profiles_map) + + # Un-rest profiles. This also reads from and modifies the local `all_profiles_map`. + self.enforce_unrest_policy(getattr(args, 'profile_groups', []), all_profiles_map) + + # --- Phase 3: Apply policies to individual active profiles --- + # Use the now-updated snapshot to determine which profiles are active. + active_profiles = [p for p in all_profiles_map.values() if p['state'] == self.manager.STATE_ACTIVE] + + # Filter out profiles that are managed by a profile group, as their state is handled separately. + profile_groups = getattr(args, 'profile_groups', []) + if profile_groups: + grouped_profiles = set() + for group in profile_groups: + if 'profiles' in group: + for p_name in group['profiles']: + grouped_profiles.add(p_name) + elif 'prefix' in group: + prefix = group['prefix'] + for p in all_profiles_list: + if p['name'].startswith(prefix): + grouped_profiles.add(p['name']) + + original_count = len(active_profiles) + active_profiles = [p for p in active_profiles if p['name'] not in grouped_profiles] + if len(active_profiles) != original_count: + logger.debug(f"Filtered out {original_count - len(active_profiles)} profile(s) managed by profile groups.") + + for profile in active_profiles: + # Check for failure burst first, as it's more severe. + # If it's banned, no need to check other rules for it. + if self.enforce_failure_burst_policy(profile, args.ban_on_failures, args.ban_on_failures_window_minutes): + continue + if self.enforce_rate_limit_policy(profile, getattr(args, 'rate_limit_requests', 0), getattr(args, 'rate_limit_window_minutes', 0), getattr(args, 'rate_limit_rest_duration_minutes', 0)): + continue + self.enforce_failure_rate_policy(profile, args.max_failure_rate, args.min_requests_for_rate) + self.enforce_rest_policy(profile, args.rest_after_requests, args.rest_duration_minutes) + + return self.actions_taken_this_cycle > 0 + + def enforce_failure_burst_policy(self, profile, max_failures, window_minutes): + if not max_failures or not window_minutes or max_failures <= 0 or window_minutes <= 0: + return False + + window_seconds = window_minutes * 60 + # Count only fatal error types (auth, download) for the ban policy. + # Tolerated errors are excluded from this check. + error_count = ( + self.manager.get_activity_rate(profile['name'], 'failure', window_seconds) + + self.manager.get_activity_rate(profile['name'], 'download_error', window_seconds) + ) + + if error_count >= max_failures: + reason = f"Error burst detected: {error_count} errors in the last {window_minutes} minute(s) (threshold: {max_failures})" + logger.warning(f"Banning profile '{profile['name']}' due to error burst: {reason}") + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_BANNED, reason) + self.actions_taken_this_cycle += 1 + return True # Indicates profile was banned + return False + + def enforce_rate_limit_policy(self, profile, max_requests, window_minutes, rest_duration_minutes): + if not max_requests or not window_minutes or max_requests <= 0 or window_minutes <= 0: + return False + + window_seconds = window_minutes * 60 + # Count all successful activities (auth, download) for rate limiting. + # We don't count failures, as they often don't hit the target server in the same way. + activity_count = ( + self.manager.get_activity_rate(profile['name'], 'success', window_seconds) + + self.manager.get_activity_rate(profile['name'], 'download', window_seconds) + ) + + if activity_count >= max_requests: + reason = f"Rate limit hit: {activity_count} requests in last {window_minutes} minute(s) (limit: {max_requests})" + logger.info(f"Resting profile '{profile['name']}' for {rest_duration_minutes}m: {reason}") + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, reason) + rest_until = time.time() + rest_duration_minutes * 60 + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until)) + self.actions_taken_this_cycle += 1 + return True # Indicates profile was rested + return False + + def enforce_unrest_policy(self, profile_groups, all_profiles_map): + all_profiles_list = list(all_profiles_map.values()) + resting_profiles = [p for p in all_profiles_list if p['state'] == self.manager.STATE_RESTING] + cooldown_profiles = [p for p in all_profiles_list if p['state'] == self.manager.STATE_COOLDOWN] + profiles_to_check = resting_profiles + cooldown_profiles + now = time.time() + + if not profiles_to_check: + return + + # Sort profiles to check by their rest_until timestamp, then by name. + # This creates a deterministic FIFO queue for activation. + profiles_to_check.sort(key=lambda p: (p.get('rest_until', 0), p.get('name', ''))) + + # --- Group-aware unrest logic --- + profile_to_group_map = {} + group_to_profiles_map = {} + if profile_groups: + for group in profile_groups: + group_name = group.get('name') + if not group_name: continue + + profiles_in_group = [] + if 'profiles' in group: + profiles_in_group = sorted(group['profiles']) + elif 'prefix' in group: + prefix = group['prefix'] + profiles_in_group = sorted([p['name'] for p in all_profiles_list if p['name'].startswith(prefix)]) + + group_to_profiles_map[group_name] = profiles_in_group + for p_name in profiles_in_group: + profile_to_group_map[p_name] = group_name + + # This will store the live count of active profiles for each group, + # preventing race conditions within a single enforcer run. + live_active_counts = {} + if profile_groups: + for group_name, profiles_in_group in group_to_profiles_map.items(): + count = 0 + for p_name in profiles_in_group: + profile_state = all_profiles_map.get(p_name, {}).get('state') + if profile_state in [self.manager.STATE_ACTIVE, self.manager.STATE_LOCKED, self.manager.STATE_COOLDOWN]: + count += 1 + live_active_counts[group_name] = count + # --- End group logic setup --- + + unique_proxies = sorted(list(set(p['proxy'] for p in profiles_to_check if p.get('proxy')))) + proxy_states = self.manager.get_proxy_states(unique_proxies) + + for profile in profiles_to_check: + # --- New logic for waiting_downloads --- + if profile.get('rest_reason') == 'waiting_downloads': + profile_name = profile['name'] + group_name = profile_to_group_map.get(profile_name) + group_policy = next((g for g in profile_groups if g.get('name') == group_name), {}) + + max_wait_minutes = group_policy.get('max_wait_for_downloads_minutes', 240) + wait_started_at = profile.get('wait_started_at', 0) + + downloads_pending = self.manager.get_pending_downloads(profile_name) + is_timed_out = (time.time() - wait_started_at) > (max_wait_minutes * 60) if wait_started_at > 0 else False + + if downloads_pending <= 0 or is_timed_out: + if is_timed_out: + logger.warning(f"Profile '{profile_name}' download wait timed out after {max_wait_minutes}m. Forcing rotation.") + else: + logger.info(f"All pending downloads for profile '{profile_name}' are complete. Proceeding with rotation.") + + self.actions_taken_this_cycle += 1 + + # Transition to a normal post-rotation rest period. + new_reason = "Rotation complete (downloads finished)" + rest_duration_minutes = group_policy.get('rest_duration_minutes_on_rotation', 0) + rest_until_ts = time.time() + (rest_duration_minutes * 60) + + if not self.dry_run: + self.manager.update_profile_field(profile['name'], 'rest_reason', new_reason) + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until_ts)) + self.manager.clear_pending_downloads(profile_name) + # Reset counters to ensure a fresh start after the wait period. + self.manager.reset_profile_counters(profile_name) + + # Update local map so it can be activated in the same cycle if rest_duration is 0 + all_profiles_map[profile_name]['rest_reason'] = new_reason + all_profiles_map[profile_name]['rest_until'] = rest_until_ts + all_profiles_map[profile_name]['success_count'] = 0 + all_profiles_map[profile_name]['failure_count'] = 0 + all_profiles_map[profile_name]['tolerated_error_count'] = 0 + all_profiles_map[profile_name]['download_count'] = 0 + all_profiles_map[profile_name]['download_error_count'] = 0 + + # Let the rest of the unrest logic handle the activation now that rest_until is set. + profile['rest_until'] = rest_until_ts # Update profile in loop + else: + logger.debug(f"Profile '{profile_name}' is still waiting for {downloads_pending} download(s) to complete.") + continue # Skip to next profile, do not attempt to activate. + # --- End new logic --- + + rest_until = profile.get('rest_until', 0) + if now >= rest_until: + profile_name = profile['name'] + group_name = profile_to_group_map.get(profile_name) + + # --- Group-aware unrest check --- + if group_name: + group_policy = next((g for g in profile_groups if g.get('name') == group_name), None) + if not group_policy: + continue # Should not happen if maps are built correctly + + # --- New check: Defer activation if another profile in the group is waiting for downloads --- + defer_activation = group_policy.get('defer_activation_if_any_waiting', False) + if defer_activation: + profiles_in_group = group_to_profiles_map.get(group_name, []) + # Find if any profile in the group is currently in the waiting state. + # This check is crucial to ensure strict sequential processing. + waiting_profile = next( + (p for p_name, p in all_profiles_map.items() + if p_name in profiles_in_group and p.get('rest_reason') == 'waiting_downloads'), + None + ) + + if waiting_profile and waiting_profile['name'] != profile_name: + logger.debug(f"Profile '{profile_name}' rest ended, but profile '{waiting_profile['name']}' in group '{group_name}' is waiting for downloads. Deferring activation.") + continue # Do not activate, another is waiting. + # --- End new check --- + + max_active = group_policy.get('max_active_profiles', 1) + + # Check if the group is already at its capacity for active profiles. + # We use the live counter which is updated during this enforcer cycle. + + # Special handling for COOLDOWN profiles: they should be allowed to become ACTIVE + # even if the group is at capacity, because they are already counted as "active". + # We check if the group would be over capacity *without* this profile. + is_cooldown_profile = profile['state'] == self.manager.STATE_COOLDOWN + effective_active_count = live_active_counts.get(group_name, 0) + + # If we are considering a COOLDOWN profile, it's already in the count. + # The check should be if activating it would exceed the limit, assuming + # no *other* profile is active. + capacity_check_count = effective_active_count + if is_cooldown_profile: + capacity_check_count -= 1 + + if capacity_check_count >= max_active: + logger.debug(f"Profile '{profile_name}' rest ended, but group '{group_name}' is at capacity ({effective_active_count}/{max_active}). Deferring activation.") + + # If a profile's COOLDOWN ends but it can't be activated (because another + # profile is active), move it to RESTING so it's clear it's waiting for capacity. + if is_cooldown_profile: + reason = "Waiting for group capacity" + logger.info(f"Profile '{profile_name}' cooldown ended but group is full. Moving to RESTING to wait for a slot.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + # We need to update state but keep rest_until in the past. + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, reason) + self.manager.update_profile_field(profile['name'], 'rest_until', '0') + + # Update local map + all_profiles_map[profile_name]['state'] = self.manager.STATE_RESTING + all_profiles_map[profile_name]['rest_until'] = 0 + all_profiles_map[profile_name]['rest_reason'] = reason + + continue # Do not activate, group is full. + # --- End group check --- + + # Before activating, ensure the profile's proxy is not resting. + proxy_url = profile.get('proxy') + if proxy_url: + proxy_state_data = proxy_states.get(proxy_url, {}) + if proxy_state_data.get('state') == self.manager.STATE_RESTING: + logger.debug(f"Profile '{profile['name']}' rest period ended, but its proxy '{proxy_url}' is still resting. Deferring activation.") + + # Update reason for clarity in the UI when a profile is blocked by its proxy. + new_reason = "Waiting for proxy" + if profile.get('rest_reason') != new_reason: + logger.info(f"Updating profile '{profile['name']}' reason to '{new_reason}'.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_field(profile['name'], 'rest_reason', new_reason) + # Update local map for consistency within this enforcer cycle. + all_profiles_map[profile_name]['rest_reason'] = new_reason + + continue # Do not activate this profile yet. + + # Update group counter BEFORE making any changes, so subsequent checks in this cycle use the updated count + if group_name and profile['state'] == self.manager.STATE_RESTING: + # For RESTING profiles, they're becoming active, so increment the count + live_active_counts[group_name] = live_active_counts.get(group_name, 0) + 1 + # COOLDOWN profiles are already counted, no change needed + + logger.info(f"Activating profile '{profile['name']}' (rest period completed).") + self.actions_taken_this_cycle += 1 + + # Determine if this was a true rest or just a cooldown that was waiting for a slot. + is_waiting_after_cooldown = profile.get('rest_reason') == "Waiting for group capacity" + + if not self.dry_run: + # When un-resting from a long rest, reset counters to give it a fresh start. + # Do not reset for COOLDOWN or a profile that was waiting after cooldown. + self.manager.update_profile_state(profile['name'], self.manager.STATE_ACTIVE, "Rest period completed") + if profile['state'] == self.manager.STATE_RESTING and not is_waiting_after_cooldown: + self.manager.reset_profile_counters(profile['name']) + + # Update the shared map to reflect the change immediately for this cycle. + all_profiles_map[profile_name]['state'] = self.manager.STATE_ACTIVE + if profile['state'] == self.manager.STATE_RESTING and not is_waiting_after_cooldown: + all_profiles_map[profile_name]['success_count'] = 0 + all_profiles_map[profile_name]['failure_count'] = 0 + all_profiles_map[profile_name]['tolerated_error_count'] = 0 + all_profiles_map[profile_name]['download_count'] = 0 + all_profiles_map[profile_name]['download_error_count'] = 0 + + def enforce_failure_rate_policy(self, profile, max_failure_rate, min_requests): + if max_failure_rate <= 0: + return + + success = profile.get('global_success_count', 0) + failure = profile.get('global_failure_count', 0) + total = success + failure + + if total < min_requests: + return + + current_failure_rate = failure / total if total > 0 else 0 + + if current_failure_rate >= max_failure_rate: + reason = f"Global failure rate {current_failure_rate:.2f} >= threshold {max_failure_rate} ({int(failure)}/{int(total)} failures)" + logger.warning(f"Banning profile '{profile['name']}' due to high failure rate: {reason}") + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_BANNED, reason) + self.actions_taken_this_cycle += 1 + + def enforce_rest_policy(self, profile, rest_after_requests, rest_duration_minutes): + if not rest_after_requests or rest_after_requests <= 0 or not rest_duration_minutes or rest_duration_minutes <= 0: + return + + total_requests = ( + int(profile.get('success_count', 0)) + + int(profile.get('failure_count', 0)) + + int(profile.get('tolerated_error_count', 0)) + + int(profile.get('download_count', 0)) + + int(profile.get('download_error_count', 0)) + ) + + if total_requests >= rest_after_requests: + reason = f"Request count {total_requests} >= threshold {rest_after_requests}" + logger.info(f"Resting profile '{profile['name']}' for {rest_duration_minutes}m: {reason}") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, reason) + rest_until = time.time() + rest_duration_minutes * 60 + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until)) + + def enforce_stale_lock_cleanup(self, max_lock_seconds): + """Finds and unlocks profiles with stale locks.""" + if self.dry_run: + logger.info(f"[Dry Run] Would check for and clean up locks older than {max_lock_seconds} seconds.") + return + + cleaned_count = self.manager.cleanup_stale_locks(max_lock_seconds) + if cleaned_count > 0: + self.actions_taken_this_cycle += cleaned_count + + def enforce_profile_group_policies(self, profile_groups, all_profiles_map): + """ + Manages profiles within defined groups. This includes: + 1. Rotating out profiles that have met their request limit. + 2. Healing the group by ensuring no more than `max_active_profiles` are active. + 3. Initializing the group by activating a profile if none are active. + + This method operates on and modifies the `all_profiles_map` passed to it. + """ + if not profile_groups: + return + + all_profiles_list = list(all_profiles_map.values()) + + for group in profile_groups: + group_name = group.get('name') + if not group_name: + logger.warning("Found a profile group without a 'name'. Skipping.") + continue + + profiles_in_group = set() + if 'profiles' in group: + profiles_in_group = set(group['profiles']) + elif 'prefix' in group: + prefix = group['prefix'] + profiles_in_group = {p['name'] for p in all_profiles_list if p['name'].startswith(prefix)} + + if not profiles_in_group: + logger.warning(f"Profile group '{group_name}' has no matching profiles. Skipping.") + continue + + # --- Persist group policy to Redis for observability --- + rotate_after_requests = group.get('rotate_after_requests') + max_active_profiles = group.get('max_active_profiles') + if not self.dry_run: + # This is a non-critical update, so we don't need to check for existence. + # We just update it on every cycle to ensure it's fresh. + self.manager.set_profile_group_state(group_name, { + 'rotate_after_requests': rotate_after_requests, + 'max_active_profiles': max_active_profiles, + 'prefix': group.get('prefix') # Store prefix for observability + }) + + # --- 1. Handle Rotation for Active Profiles --- + rotate_after_requests = group.get('rotate_after_requests') + if rotate_after_requests and rotate_after_requests > 0: + # Consider ACTIVE, LOCKED, and COOLDOWN profiles for rotation eligibility. + eligible_for_rotation_check = [ + p for p in all_profiles_list + if p['name'] in profiles_in_group and p['state'] in [self.manager.STATE_ACTIVE, self.manager.STATE_LOCKED, self.manager.STATE_COOLDOWN] + ] + + for profile in eligible_for_rotation_check: + total_requests = ( + int(profile.get('success_count', 0)) + + int(profile.get('failure_count', 0)) + + int(profile.get('tolerated_error_count', 0)) + + int(profile.get('download_count', 0)) + + int(profile.get('download_error_count', 0)) + ) + if total_requests >= rotate_after_requests: + # If a profile is LOCKED, we can't rotate it yet. + # Instead, we update its reason to show that a rotation is pending. + if profile['state'] == self.manager.STATE_LOCKED: + pending_reason = f"Pending Rotation (requests: {total_requests}/{rotate_after_requests})" + # Only update if the reason is not already set, to avoid spamming Redis. + if profile.get('reason') != pending_reason: + logger.info(f"Profile '{profile['name']}' in group '{group_name}' is due for rotation but is LOCKED. Marking as pending.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_field(profile['name'], 'reason', pending_reason) + else: + logger.debug(f"Profile '{profile['name']}' in group '{group_name}' is due for rotation but is currently LOCKED. Already marked as pending.") + continue + + # If the profile is ACTIVE or in COOLDOWN, we can rotate it immediately. + reason = f"Rotated after {total_requests} requests (limit: {rotate_after_requests})" + logger.info(f"Rotating profile '{profile['name']}' in group '{group_name}': {reason}") + self.actions_taken_this_cycle += 1 + + wait_for_downloads = group.get('wait_download_finish_per_profile', False) + + new_reason = reason + rest_until_ts = 0 + + if wait_for_downloads: + new_reason = "waiting_downloads" + logger.info(f"Profile '{profile['name']}' will wait for pending downloads to complete.") + else: + rest_duration_minutes = group.get('rest_duration_minutes_on_rotation') + if rest_duration_minutes and rest_duration_minutes > 0: + rest_until_ts = time.time() + rest_duration_minutes * 60 + + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, new_reason) + + if wait_for_downloads: + self.manager.update_profile_field(profile['name'], 'wait_started_at', str(time.time())) + # Set rest_until to 0 to indicate it's not a time-based rest + self.manager.update_profile_field(profile['name'], 'rest_until', '0') + elif rest_until_ts > 0: + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until_ts)) + + # Reset all session counters for the next active cycle + self.manager.reset_profile_counters(profile['name']) + + # Update our local map so subsequent policies in this cycle see the change immediately. + all_profiles_map[profile['name']]['state'] = self.manager.STATE_RESTING + all_profiles_map[profile['name']]['rest_reason'] = new_reason + if wait_for_downloads: + all_profiles_map[profile['name']]['wait_started_at'] = time.time() + all_profiles_map[profile['name']]['rest_until'] = 0 + else: + all_profiles_map[profile['name']]['rest_until'] = rest_until_ts + all_profiles_map[profile['name']]['success_count'] = 0 + all_profiles_map[profile['name']]['failure_count'] = 0 + all_profiles_map[profile['name']]['tolerated_error_count'] = 0 + all_profiles_map[profile['name']]['download_count'] = 0 + all_profiles_map[profile['name']]['download_error_count'] = 0 + + # --- 2. Self-Healing: Enforce max_active_profiles --- + max_active = group.get('max_active_profiles', 1) + + # Get the current list of active/locked profiles from our potentially modified local map + # A profile is considered "active" for group limits if it is ACTIVE, LOCKED, or in COOLDOWN. + current_active_or_locked_profiles = [ + p for name, p in all_profiles_map.items() + if name in profiles_in_group and p['state'] in [self.manager.STATE_ACTIVE, self.manager.STATE_LOCKED, self.manager.STATE_COOLDOWN] + ] + + num_active_or_locked = len(current_active_or_locked_profiles) + if num_active_or_locked > max_active: + logger.warning(f"Healing group '{group_name}': Found {num_active_or_locked} active/locked profiles, but max is {max_active}. Resting excess ACTIVE profiles.") + + # We can only rest profiles that are in the ACTIVE state, not LOCKED. + profiles_that_can_be_rested = [p for p in current_active_or_locked_profiles if p['state'] == self.manager.STATE_ACTIVE] + + # Sort to determine which profiles to rest. We prefer to rest profiles + # that have been used more. As a tie-breaker (especially for profiles + # with 0 requests), we rest the one that has been active the longest + # (oldest last_used timestamp). + profiles_that_can_be_rested.sort(key=lambda p: p.get('last_used', 0)) # Oldest first + profiles_that_can_be_rested.sort(key=lambda p: ( + p.get('success_count', 0) + p.get('failure_count', 0) + + p.get('tolerated_error_count', 0) + + p.get('download_count', 0) + p.get('download_error_count', 0) + ), reverse=True) # Most requests first + + num_to_rest = num_active_or_locked - max_active + profiles_to_rest = profiles_that_can_be_rested[:num_to_rest] + for profile in profiles_to_rest: + req_count = ( + profile.get('success_count', 0) + profile.get('failure_count', 0) + + profile.get('tolerated_error_count', 0) + + profile.get('download_count', 0) + + profile.get('download_error_count', 0) + ) + logger.warning(f"Healing group '{group_name}': Resting profile '{profile['name']}' (request count: {req_count}).") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, "Group max_active healing") + # Give it a rest time of 0, so it's immediately eligible for activation + # by the unrest logic if the group has capacity. + rest_until_ts = 0 + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until_ts)) + + # --- 3. Initialization: Activate profiles if below capacity --- + # This is a fallback for initialization or if all profiles were rested/banned. + # The primary activation mechanism is in `enforce_unrest_policy`. + elif num_active_or_locked < max_active: + # Check if there are any non-active, non-banned, non-locked profiles to activate. + eligible_profiles = [p for name, p in all_profiles_map.items() if name in profiles_in_group and p['state'] not in [self.manager.STATE_ACTIVE, self.manager.STATE_BANNED, self.manager.STATE_LOCKED]] + if eligible_profiles: + # This is a simple initialization case. We don't activate here because + # `enforce_unrest_policy` will handle it more intelligently based on rest times. + # This block ensures that on the very first run, a group doesn't sit empty. + if num_active_or_locked == 0: + logger.debug(f"Group '{group_name}' has no active profiles. `enforce_unrest_policy` will attempt to activate one.") + + def enforce_proxy_group_rotation(self, proxy_groups): + """Manages mutually exclusive work cycles for proxies within defined groups.""" + if not proxy_groups: + return + + group_names = [g['name'] for g in proxy_groups if g.get('name')] + if not group_names: + return + + group_states = self.manager.get_proxy_group_states(group_names) + now = time.time() + + for group in proxy_groups: + group_name = group.get('name') + if not group_name: + logger.warning("Found a proxy group without a 'name'. Skipping.") + continue + + proxies_in_group = group.get('proxies', []) + if not proxies_in_group: + logger.warning(f"Proxy group '{group_name}' has no proxies defined. Skipping.") + continue + + work_minutes = group.get('work_minutes_per_proxy') + if not work_minutes or work_minutes <= 0: + logger.warning(f"Proxy group '{group_name}' is missing 'work_minutes_per_proxy'. Skipping.") + continue + + if not self.dry_run: + for proxy_url in proxies_in_group: + self.manager.set_proxy_group_membership(proxy_url, group_name, work_minutes) + + work_duration_seconds = work_minutes * 60 + state = group_states.get(group_name, {}) + + if not state: + # First run for this group, initialize it + logger.info(f"Initializing new proxy group '{group_name}'. Activating first proxy '{proxies_in_group[0]}'.") + self.actions_taken_this_cycle += 1 + active_proxy_index = 0 + next_rotation_ts = now + work_duration_seconds + + if not self.dry_run: + # Activate the first, rest the others + self.manager.set_proxy_state(proxies_in_group[0], self.manager.STATE_ACTIVE) + for i, proxy_url in enumerate(proxies_in_group): + if i != active_proxy_index: + # Rest indefinitely; group logic will activate it when its turn comes. + self.manager.set_proxy_state(proxy_url, self.manager.STATE_RESTING, rest_duration_minutes=99999) + + self.manager.set_proxy_group_state(group_name, active_proxy_index, next_rotation_ts) + + elif now >= state.get('next_rotation_timestamp', 0): + # Time to rotate + current_active_index = state.get('active_proxy_index', 0) + next_active_index = (current_active_index + 1) % len(proxies_in_group) + + old_active_proxy = proxies_in_group[current_active_index] + new_active_proxy = proxies_in_group[next_active_index] + + logger.info(f"Rotating proxy group '{group_name}': Deactivating '{old_active_proxy}', Activating '{new_active_proxy}'.") + self.actions_taken_this_cycle += 1 + + next_rotation_ts = now + work_duration_seconds + + if not self.dry_run: + # Rest the old proxy + self.manager.set_proxy_state(old_active_proxy, self.manager.STATE_RESTING, rest_duration_minutes=99999) + # Activate the new one + self.manager.set_proxy_state(new_active_proxy, self.manager.STATE_ACTIVE) + # Update group state + self.manager.set_proxy_group_state(group_name, next_active_index, next_rotation_ts) + + def enforce_proxy_work_rest_cycle(self, args): + """Enforces a work/rest cycle on proxies based on time.""" + work_minutes = args.proxy_work_minutes + rest_minutes = args.proxy_rest_duration_minutes + + if not work_minutes or work_minutes <= 0 or not rest_minutes or rest_minutes <= 0: + return + + # Get a flat list of all proxies managed by groups, so we can ignore them. + proxy_groups = getattr(args, 'proxy_groups', []) + grouped_proxies = set() + if proxy_groups: + for group in proxy_groups: + for proxy_url in group.get('proxies', []): + grouped_proxies.add(proxy_url) + + all_profiles = self.manager.list_profiles() + if not all_profiles: + return + + unique_proxies = sorted(list(set(p['proxy'] for p in all_profiles if p.get('proxy')))) + + # Filter out proxies that are managed by the group rotation logic + proxies_to_manage = [p for p in unique_proxies if p not in grouped_proxies] + if not proxies_to_manage: + logger.debug("All unique proxies are managed by proxy groups. Skipping individual work/rest cycle enforcement.") + return + + proxy_states = self.manager.get_proxy_states(proxies_to_manage) + now = time.time() + + for proxy_url, state_data in proxy_states.items(): + state = state_data.get('state', self.manager.STATE_ACTIVE) + + # Un-rest logic + if state == self.manager.STATE_RESTING: + rest_until = state_data.get('rest_until', 0) + if now >= rest_until: + logger.info(f"Activating proxy '{proxy_url}' (rest period complete).") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.set_proxy_state(proxy_url, self.manager.STATE_ACTIVE) + + # Also activate any profiles that were resting due to this proxy + profiles_for_proxy = [p for p in all_profiles if p.get('proxy') == proxy_url] + for profile in profiles_for_proxy: + if profile['state'] == self.manager.STATE_RESTING and profile.get('rest_reason') == self.PROXY_REST_REASON: + logger.info(f"Activating profile '{profile['name']}' as its proxy '{proxy_url}' is now active.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_ACTIVE, "Proxy activated") + else: + # Proxy is still resting. Ensure any of its profiles that are ACTIVE are moved to RESTING. + # This catches profiles that were unlocked while their proxy was resting. + rest_until_ts = state_data.get('rest_until', 0) + profiles_for_proxy = [p for p in all_profiles if p.get('proxy') == proxy_url] + for profile in profiles_for_proxy: + if profile['state'] == self.manager.STATE_ACTIVE: + logger.info(f"Resting profile '{profile['name']}' as its proxy '{proxy_url}' is resting.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, self.PROXY_REST_REASON) + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until_ts)) + + # Rest logic + elif state == self.manager.STATE_ACTIVE: + work_start = state_data.get('work_start_timestamp', 0) + if work_start == 0: # Proxy was just created, start its work cycle + if not self.dry_run: + self.manager.set_proxy_state(proxy_url, self.manager.STATE_ACTIVE) + continue + + work_duration_seconds = work_minutes * 60 + active_duration = now - work_start + logger.debug(f"Proxy '{proxy_url}' has been active for {active_duration:.0f}s (limit: {work_duration_seconds}s).") + if active_duration >= work_duration_seconds: + logger.info(f"Resting proxy '{proxy_url}' for {rest_minutes}m (work period of {work_minutes}m complete).") + self.actions_taken_this_cycle += 1 + + rest_until_ts = time.time() + rest_minutes * 60 + if not self.dry_run: + self.manager.set_proxy_state(proxy_url, self.manager.STATE_RESTING, rest_minutes) + + # Also rest any active profiles using this proxy + profiles_for_proxy = [p for p in all_profiles if p.get('proxy') == proxy_url] + for profile in profiles_for_proxy: + if profile['state'] == self.manager.STATE_ACTIVE: + logger.info(f"Resting profile '{profile['name']}' as its proxy '{proxy_url}' is resting.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, self.PROXY_REST_REASON) + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until_ts)) + + def enforce_max_proxy_active_time(self, args): + """ + Enforces a global maximum active time for any proxy, regardless of group membership. + This acts as a safety net to prevent a proxy from being stuck in an ACTIVE state. + """ + max_active_minutes = args.max_global_proxy_active_minutes + rest_minutes = args.rest_duration_on_max_active + + if not max_active_minutes or max_active_minutes <= 0: + return + + all_profiles = self.manager.list_profiles() + if not all_profiles: + return + + unique_proxies = sorted(list(set(p['proxy'] for p in all_profiles if p.get('proxy')))) + if not unique_proxies: + return + + proxy_states = self.manager.get_proxy_states(unique_proxies) + now = time.time() + + for proxy_url, state_data in proxy_states.items(): + if state_data.get('state') == self.manager.STATE_ACTIVE: + work_start = state_data.get('work_start_timestamp', 0) + if work_start == 0: + continue # Just activated, timestamp not set yet. + + active_duration_seconds = now - work_start + max_active_seconds = max_active_minutes * 60 + + if active_duration_seconds >= max_active_seconds: + reason = f"Exceeded max active time of {max_active_minutes}m" + logger.warning(f"Resting proxy '{proxy_url}' for {rest_minutes}m: {reason}") + self.actions_taken_this_cycle += 1 + + rest_until_ts = now + rest_minutes * 60 + if not self.dry_run: + self.manager.set_proxy_state(proxy_url, self.manager.STATE_RESTING, rest_minutes) + + # Also rest any active profiles using this proxy + profiles_for_proxy = [p for p in all_profiles if p.get('proxy') == proxy_url] + for profile in profiles_for_proxy: + if profile['state'] == self.manager.STATE_ACTIVE: + logger.info(f"Resting profile '{profile['name']}' as its proxy '{proxy_url}' is resting due to max active time.") + self.actions_taken_this_cycle += 1 + if not self.dry_run: + self.manager.update_profile_state(profile['name'], self.manager.STATE_RESTING, self.PROXY_REST_REASON) + self.manager.update_profile_field(profile['name'], 'rest_until', str(rest_until_ts)) + + def enforce_proxy_policies(self, args): + proxy_ban_enabled = args.proxy_ban_on_failures and args.proxy_ban_on_failures > 0 + proxy_rate_limit_enabled = getattr(args, 'proxy_rate_limit_requests', 0) > 0 + if not proxy_ban_enabled and not proxy_rate_limit_enabled: + return + + all_profiles = self.manager.list_profiles() + if not all_profiles: + return + + unique_proxies = sorted(list(set(p['proxy'] for p in all_profiles if p.get('proxy')))) + + if not unique_proxies: + return + + logger.debug(f"Checking proxy policies for {len(unique_proxies)} unique proxies...") + + for proxy_url in unique_proxies: + profiles_for_proxy = [p for p in all_profiles if p.get('proxy') == proxy_url] + if self.enforce_proxy_failure_burst_policy( + proxy_url, + profiles_for_proxy, + args.proxy_ban_on_failures, + args.proxy_ban_window_minutes + ): + continue # Banned, no need for other checks + + self.enforce_proxy_rate_limit_policy( + proxy_url, + profiles_for_proxy, + getattr(args, 'proxy_rate_limit_requests', 0), + getattr(args, 'proxy_rate_limit_window_minutes', 0), + getattr(args, 'proxy_rate_limit_rest_duration_minutes', 0) + ) + + def enforce_proxy_failure_burst_policy(self, proxy_url, profiles_for_proxy, max_failures, window_minutes): + if not max_failures or not window_minutes or max_failures <= 0 or window_minutes <= 0: + return False + + window_seconds = window_minutes * 60 + failure_count = self.manager.get_proxy_activity_rate(proxy_url, 'failure', window_seconds) + + if failure_count >= max_failures: + reason = f"Proxy failure burst: {failure_count} failures in last {window_minutes}m (threshold: {max_failures})" + logger.warning(f"Banning {len(profiles_for_proxy)} profile(s) on proxy '{proxy_url}' due to failure burst: {reason}") + self.actions_taken_this_cycle += 1 + + if not self.dry_run: + for profile in profiles_for_proxy: + # Don't re-ban already banned profiles + if profile['state'] != self.manager.STATE_BANNED: + self.manager.update_profile_state(profile['name'], self.manager.STATE_BANNED, reason) + return True # Indicates action was taken + return False + + def enforce_proxy_rate_limit_policy(self, proxy_url, profiles_for_proxy, max_requests, window_minutes, rest_duration_minutes): + if not max_requests or not window_minutes or max_requests <= 0 or window_minutes <= 0: + return False + + window_seconds = window_minutes * 60 + # Count all successful activities for the proxy + activity_count = ( + self.manager.get_proxy_activity_rate(proxy_url, 'success', window_seconds) + + self.manager.get_proxy_activity_rate(proxy_url, 'download', window_seconds) + ) + + if activity_count >= max_requests: + reason = f"Proxy rate limit hit: {activity_count} requests in last {window_minutes}m (limit: {max_requests})" + logger.info(f"Resting proxy '{proxy_url}' for {rest_duration_minutes}m: {reason}") + self.actions_taken_this_cycle += 1 + + if not self.dry_run: + self.manager.set_proxy_state(proxy_url, self.manager.STATE_RESTING, rest_duration_minutes) + return True # Indicates action was taken + return False + +def add_policy_enforcer_parser(subparsers): + """Adds the parser for the 'policy-enforcer' command.""" + parser = subparsers.add_parser( + 'policy-enforcer', + description='Apply policies to profiles (ban, rest, etc.).', + formatter_class=argparse.RawTextHelpFormatter, + help='Apply policies to profiles (ban, rest, etc.).' + ) + + parser.add_argument('--policy', '--policy-file', dest='policy_file', help='Path to a YAML policy file to load default settings from.') + parser.add_argument('--env-file', help='Path to a .env file to load environment variables from.') + parser.add_argument('--redis-host', default=None, help='Redis host. Defaults to REDIS_HOST or MASTER_HOST_IP env var, or localhost.') + parser.add_argument('--redis-port', type=int, default=None, help='Redis port. Defaults to REDIS_PORT env var, or 6379.') + parser.add_argument('--redis-password', default=None, help='Redis password. Defaults to REDIS_PASSWORD env var.') + parser.add_argument('--env', default=None, help="Default environment name for Redis key prefix. Used if --auth-env or --download-env are not specified. Overrides policy file setting.") + parser.add_argument('--auth-env', help="Override the environment for the Auth simulation.") + parser.add_argument('--download-env', help="Override the environment for the Download simulation.") + parser.add_argument('--legacy', action='store_true', help="Use legacy key prefix ('profile_mgmt_') without environment.") + parser.add_argument('--key-prefix', default=None, help='Explicit key prefix for Redis. Overrides --env, --legacy and any defaults.') + parser.add_argument('--verbose', action='store_true', help='Enable verbose logging') + parser.add_argument('--dry-run', action='store_true', help='Show what would be done without making changes.') + + # Policy arguments + policy_group = parser.add_argument_group('Policy Rules') + policy_group.add_argument('--max-failure-rate', type=float, default=None, + help='Ban a profile if its failure rate exceeds this value (0.0 to 1.0). Default: 0.5') + policy_group.add_argument('--min-requests-for-rate', type=int, default=None, + help='Minimum total requests before failure rate is calculated. Default: 20') + policy_group.add_argument('--ban-on-failures', type=int, default=None, + help='Ban a profile if it has this many failures within the time window (0 to disable). Default: 0') + policy_group.add_argument('--ban-on-failures-window-minutes', type=int, default=None, + help='The time window in minutes for the failure burst check. Default: 5') + policy_group.add_argument('--rest-after-requests', type=int, default=None, + help='Move a profile to RESTING after this many total requests (0 to disable). Default: 0') + policy_group.add_argument('--rest-duration-minutes', type=int, default=None, + help='How long a profile should rest. Default: 15') + policy_group.add_argument('--rate-limit-requests', type=int, default=None, + help='Rest a profile if it exceeds this many requests in the time window (0 to disable).') + policy_group.add_argument('--rate-limit-window-minutes', type=int, default=None, + help='The time window in minutes for the rate limit check.') + policy_group.add_argument('--rate-limit-rest-duration-minutes', type=int, default=None, + help='How long a profile should rest after hitting the rate limit.') + policy_group.add_argument('--unlock-stale-locks-after-seconds', type=int, default=None, + help='Unlock profiles that have been in a LOCKED state for more than this many seconds (0 to disable). Default: 120') + + proxy_policy_group = parser.add_argument_group('Proxy Policy Rules') + proxy_policy_group.add_argument('--proxy-work-minutes', type=int, default=None, + help='Work duration for a proxy before it rests (0 to disable). Default: 0') + proxy_policy_group.add_argument('--proxy-rest-duration-minutes', type=int, default=None, + help='Rest duration for a proxy after its work period. Default: 0') + proxy_policy_group.add_argument('--proxy-ban-on-failures', type=int, default=None, + help='Ban a proxy (and all its profiles) if it has this many failures within the time window (0 to disable). Default: 0') + proxy_policy_group.add_argument('--proxy-ban-window-minutes', type=int, default=None, + help='The time window in minutes for the proxy failure burst check. Default: 10') + proxy_policy_group.add_argument('--proxy-rate-limit-requests', type=int, default=None, + help='Rest a proxy if it exceeds this many requests in the time window (0 to disable).') + proxy_policy_group.add_argument('--proxy-rate-limit-window-minutes', type=int, default=None, + help='The time window in minutes for the proxy rate limit check.') + proxy_policy_group.add_argument('--proxy-rate-limit-rest-duration-minutes', type=int, default=None, + help='How long a proxy should rest after hitting the rate limit.') + proxy_policy_group.add_argument('--max-global-proxy-active-minutes', type=int, default=None, + help='Global maximum time a proxy can be active before being rested (0 to disable). Acts as a safety net. Default: 0') + proxy_policy_group.add_argument('--rest-duration-on-max-active', type=int, default=None, + help='How long a proxy should rest after hitting the global max active time. Default: 10') + + # Execution control + exec_group = parser.add_argument_group('Execution Control') + exec_group.add_argument('--live', action='store_true', help='Run continuously, applying policies periodically.') + exec_group.add_argument('--interval-seconds', type=int, default=None, + help='When in --live mode, how often to apply policies. Default: 60') + exec_group.add_argument('--auth-only', action='store_true', help='Run enforcer for the auth simulation only.') + exec_group.add_argument('--download-only', action='store_true', help='Run enforcer for the download simulation only.') + + return parser + +def sync_cross_simulation(auth_manager, download_manager, sync_config, dry_run=False): + """Synchronize profile states between auth and download simulations.""" + if not sync_config: + return + + profile_links = sync_config.get('profile_links', []) + sync_states = sync_config.get('sync_states', []) + sync_rotation = sync_config.get('sync_rotation', False) + enforce_auth_lead = sync_config.get('enforce_auth_lead', False) + + if not profile_links: + return + + # --- Get all profiles once for efficiency --- + all_auth_profiles = {p['name']: p for p in auth_manager.list_profiles()} + all_download_profiles = {p['name']: p for p in download_manager.list_profiles()} + + # --- State and Rotation Sync (handles prefixes correctly) --- + for link in profile_links: + auth_prefix = link.get('auth') + download_prefix = link.get('download') + if not auth_prefix or not download_prefix: + continue + + auth_profiles_in_group = [p for name, p in all_auth_profiles.items() if name.startswith(auth_prefix)] + + for auth_profile in auth_profiles_in_group: + # Assume 1-to-1 name mapping (e.g., auth 'user1_0' maps to download 'user1_0') + download_profile_name = auth_profile['name'] + download_profile = all_download_profiles.get(download_profile_name) + + if not download_profile: + logger.debug(f"Auth profile '{auth_profile['name']}' has no corresponding download profile.") + continue + + auth_state = auth_profile.get('state') + download_state = download_profile.get('state') + + # Sync states from auth to download + if enforce_auth_lead and auth_state in sync_states and download_state != auth_state: + auth_reason = auth_profile.get('reason', '') + # If auth profile is waiting for downloads, we must NOT sync the RESTING state to the download profile, + # as that would prevent it from processing the very downloads we are waiting for. + if auth_state == auth_manager.STATE_RESTING and auth_reason == 'waiting_downloads': + logger.debug(f"Auth profile '{auth_profile['name']}' is waiting for downloads. Skipping state sync to download profile to prevent deadlock.") + else: + logger.info(f"Syncing download profile '{download_profile_name}' to state '{auth_state}' (auth lead)") + if not dry_run: + reason_to_sync = auth_reason or 'Synced from auth' + download_manager.update_profile_state(download_profile_name, auth_state, f"Synced from auth: {reason_to_sync}") + if auth_state == auth_manager.STATE_RESTING: + auth_rest_until = auth_profile.get('rest_until') + if auth_rest_until: + download_manager.update_profile_field(download_profile_name, 'rest_until', str(auth_rest_until)) + + # Handle rotation sync + if sync_rotation: + auth_reason = auth_profile.get('rest_reason', '') + + # If auth profile is waiting for downloads, we must NOT sync the RESTING state to the download profile, + # as that would prevent it from processing the very downloads we are waiting for. + if auth_reason == 'waiting_downloads': + logger.debug(f"Auth profile '{auth_profile['name']}' is waiting for downloads. Skipping rotation sync to download profile to prevent deadlock.") + elif auth_state == auth_manager.STATE_RESTING and 'rotate' in auth_reason.lower(): + if download_state != download_manager.STATE_RESTING: + logger.info(f"Rotating download profile '{download_profile_name}' due to auth rotation") + if not dry_run: + download_manager.update_profile_state(download_profile_name, download_manager.STATE_RESTING, f"Rotated due to auth rotation: {auth_reason}") + auth_rest_until = auth_profile.get('rest_until') + if auth_rest_until: + download_manager.update_profile_field(download_profile_name, 'rest_until', str(auth_rest_until)) + + # --- Active Profile Sync --- + sync_active = sync_config.get('sync_active_profile', False) + sync_waiting_downloads = sync_config.get('sync_waiting_downloads', False) + + if not (sync_active or sync_waiting_downloads): + return + + logger.debug("Syncing active profiles from Auth to Download simulation...") + + # Get profiles that should be active in the download simulation + target_active_download_profiles = set() + + # 1. Add profiles that are active in auth simulation (if sync_active is enabled) + if sync_active: + active_auth_profiles = [p for p in all_auth_profiles.values() if p['state'] in [auth_manager.STATE_ACTIVE, auth_manager.STATE_LOCKED]] + for auth_profile in active_auth_profiles: + target_active_download_profiles.add(auth_profile['name']) + + # 2. Add profiles that are waiting for downloads to complete (if sync_waiting_downloads is enabled) + if sync_waiting_downloads: + waiting_auth_profiles = [p for p in all_auth_profiles.values() + if p['state'] == auth_manager.STATE_RESTING + and p.get('rest_reason') == 'waiting_downloads'] + for auth_profile in waiting_auth_profiles: + target_active_download_profiles.add(auth_profile['name']) + logger.debug(f"Auth profile '{auth_profile['name']}' is waiting for downloads. Ensuring matching download profile is active.") + + if not target_active_download_profiles: + logger.debug("No auth profiles found that need active download profiles.") + return + + # Get download profile group info from Redis + dl_group_state_keys = [k for k in download_manager.redis.scan_iter(f"{download_manager.key_prefix}profile_group_state:*")] + dl_group_names = [k.split(':')[-1] for k in dl_group_state_keys] + dl_group_states = download_manager.get_profile_group_states(dl_group_names) + + dl_profile_to_group = {} + for name, state in dl_group_states.items(): + prefix = state.get('prefix') + if prefix: + for p_name in all_download_profiles: + if p_name.startswith(prefix): + dl_profile_to_group[p_name] = {'name': name, 'max_active': state.get('max_active_profiles', 1)} + + # Activate download profiles that should be active but aren't + for target_profile_name in target_active_download_profiles: + download_profile = all_download_profiles.get(target_profile_name) + if not download_profile: + logger.warning(f"Auth profile '{target_profile_name}' needs an active download profile, but no corresponding download profile found.") + continue + + if download_profile['state'] not in [download_manager.STATE_ACTIVE, download_manager.STATE_LOCKED]: + logger.info(f"Syncing active state: Activating download profile '{target_profile_name}' to match auth requirements.") + if not dry_run: + download_manager.update_profile_state(target_profile_name, download_manager.STATE_ACTIVE, "Synced from auth requirements") + download_manager.reset_profile_counters(target_profile_name) + + # Deactivate any download profiles that are active but shouldn't be + for dl_profile_name, dl_profile in all_download_profiles.items(): + if dl_profile['state'] == download_manager.STATE_ACTIVE and dl_profile_name not in target_active_download_profiles: + group_info = dl_profile_to_group.get(dl_profile_name) + if group_info: + logger.info(f"Syncing active state: Resting download profile '{dl_profile_name}' as it is no longer the active profile in its group.") + if not dry_run: + download_manager.update_profile_state(dl_profile_name, download_manager.STATE_RESTING, "Synced rotation from auth") + download_manager.update_profile_field(dl_profile_name, 'rest_until', '0') + +def main_policy_enforcer(args): + """Main dispatcher for 'policy-enforcer' command.""" + policy = {} + if args.policy_file: + if not yaml: + logger.error("Cannot load policy file because PyYAML is not installed.") + return 1 + try: + with open(args.policy_file, 'r') as f: + policy = yaml.safe_load(f) or {} + except (IOError, yaml.YAMLError) as e: + logger.error(f"Failed to load or parse policy file {args.policy_file}: {e}") + return 1 + + class Config: + def __init__(self, cli_args, policy_defaults, code_defaults): + for key, code_default in code_defaults.items(): + cli_val = getattr(cli_args, key, None) + policy_val = policy_defaults.get(key) + if cli_val is not None: + setattr(self, key, cli_val) + elif policy_val is not None: + setattr(self, key, policy_val) + else: + setattr(self, key, code_default) + + code_defaults = { + 'max_failure_rate': 0.0, 'min_requests_for_rate': 20, 'ban_on_failures': 0, + 'ban_on_failures_window_minutes': 5, 'rest_after_requests': 0, + 'rest_duration_minutes': 15, + 'rate_limit_requests': 0, 'rate_limit_window_minutes': 60, 'rate_limit_rest_duration_minutes': 5, + 'proxy_work_minutes': 0, + 'proxy_rest_duration_minutes': 0, 'proxy_ban_on_failures': 0, + 'proxy_ban_window_minutes': 10, + 'proxy_rate_limit_requests': 0, 'proxy_rate_limit_window_minutes': 60, 'proxy_rate_limit_rest_duration_minutes': 10, + 'unlock_stale_locks_after_seconds': 120, + 'unlock_cooldown_seconds': 0, + 'max_global_proxy_active_minutes': 0, 'rest_duration_on_max_active': 10, + 'interval_seconds': 60, 'proxy_groups': [], 'profile_groups': [] + } + + sim_params = policy.get('simulation_parameters', {}) + env_file_from_policy = sim_params.get('env_file') + + if load_dotenv: + env_file = args.env_file or env_file_from_policy + if not env_file and args.env and '.env' in args.env and os.path.exists(args.env): + print(f"WARNING: --env should be an environment name, not a file path. Treating '{args.env}' as --env-file.", file=sys.stderr) + env_file = args.env + if env_file and load_dotenv(env_file): + print(f"Loaded environment variables from {env_file}", file=sys.stderr) + elif args.env_file and not os.path.exists(args.env_file): + print(f"ERROR: The specified --env-file was not found: {args.env_file}", file=sys.stderr) + return 1 + + redis_host = args.redis_host or os.getenv('REDIS_HOST', os.getenv('MASTER_HOST_IP', 'localhost')) + redis_port = args.redis_port if args.redis_port is not None else int(os.getenv('REDIS_PORT', 6379)) + redis_password = args.redis_password or os.getenv('REDIS_PASSWORD') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + signal.signal(signal.SIGINT, handle_shutdown) + signal.signal(signal.SIGTERM, handle_shutdown) + + enforcer_setups = [] + + def setup_enforcer(sim_type, env_cli_arg, policy_config_key, env_policy_key): + policy_config = policy.get(policy_config_key) + # Fallback for single-enforcer policy files + if policy_config is None and sim_type == 'Auth': + policy_config = policy.get('policy_enforcer_config', {}) + + if policy_config is None: + logger.debug(f"No config block found for {sim_type} simulation ('{policy_config_key}'). Skipping.") + return None + + logger.info(f"Setting up enforcer for {sim_type} simulation...") + config = Config(args, policy_config, code_defaults) + + # Determine the effective environment name with correct precedence: + # 1. Specific CLI arg (e.g., --auth-env) + # 2. General CLI arg (--env) + # 3. Specific policy setting (e.g., simulation_parameters.auth_env) + # 4. General policy setting (simulation_parameters.env) + # 5. Hardcoded default ('dev') + policy_env = sim_params.get(env_policy_key) + default_policy_env = sim_params.get('env') + effective_env = env_cli_arg or args.env or policy_env or default_policy_env or 'dev' + + logger.info(f"Using environment '{effective_env}' for {sim_type}.") + + if args.key_prefix: key_prefix = args.key_prefix + elif args.legacy: key_prefix = 'profile_mgmt_' + else: key_prefix = f"{effective_env}_profile_mgmt_" + + manager = ProfileManager(redis_host, redis_port, redis_password, key_prefix) + enforcer = PolicyEnforcer(manager, dry_run=args.dry_run) + + # Write any relevant config to Redis for workers to use + cooldown = getattr(config, 'unlock_cooldown_seconds', None) + if cooldown is not None and not args.dry_run: + # If it's a list or int, convert to JSON string to store in Redis + manager.set_config('unlock_cooldown_seconds', json.dumps(cooldown)) + + proxy_work_minutes = getattr(config, 'proxy_work_minutes', None) + if proxy_work_minutes is not None and not args.dry_run: + manager.set_config('proxy_work_minutes', proxy_work_minutes) + + proxy_rest_duration_minutes = getattr(config, 'proxy_rest_duration_minutes', None) + if proxy_rest_duration_minutes is not None and not args.dry_run: + manager.set_config('proxy_rest_duration_minutes', proxy_rest_duration_minutes) + + return {'name': sim_type, 'enforcer': enforcer, 'config': config} + + if not args.download_only: + auth_setup = setup_enforcer('Auth', args.auth_env, 'auth_policy_enforcer_config', 'auth_env') + if auth_setup: enforcer_setups.append(auth_setup) + + if not args.auth_only: + download_setup = setup_enforcer('Download', args.download_env, 'download_policy_enforcer_config', 'download_env') + if download_setup: enforcer_setups.append(download_setup) + + if not enforcer_setups: + logger.error("No policies to enforce. Check policy file and --auth-only/--download-only flags.") + return 1 + + # Determine interval. Precedence: CLI -> simulation_parameters -> per-setup config -> code default. + # The CLI arg is already handled by the Config objects, so we just need to check sim_params. + sim_params_interval = sim_params.get('interval_seconds') + if args.interval_seconds is None and sim_params_interval is not None: + interval = sim_params_interval + else: + interval = min(s['config'].interval_seconds for s in enforcer_setups) + + # Get cross-simulation sync configuration + cross_sync_config = policy.get('cross_simulation_sync', {}) + + if not args.live: + for setup in enforcer_setups: + logger.info(f"--- Applying policies for {setup['name']} Simulation ---") + setup['enforcer'].apply_policies(setup['config']) + + # Apply cross-simulation sync after all policies have been applied + if cross_sync_config and len(enforcer_setups) == 2: + # We need to identify which setup is auth and which is download + # Based on their names + auth_setup = next((s for s in enforcer_setups if s['name'] == 'Auth'), None) + download_setup = next((s for s in enforcer_setups if s['name'] == 'Download'), None) + if auth_setup and download_setup: + sync_cross_simulation( + auth_setup['enforcer'].manager, + download_setup['enforcer'].manager, + cross_sync_config, + dry_run=args.dry_run + ) + return 0 + + logger.info(f"Running in live mode. Applying policies every {interval} seconds. Press Ctrl+C to stop.") + if not args.verbose: + print("Each '.' represents a check cycle with no actions taken.", file=sys.stderr) + + while not shutdown_event: + had_action_in_cycle = False + for setup in enforcer_setups: + logger.debug(f"--- Applying policies for {setup['name']} Simulation ({setup['enforcer'].manager.key_prefix}) ---") + if setup['enforcer'].apply_policies(setup['config']): + had_action_in_cycle = True + + # Apply cross-simulation sync after all policies have been applied in this cycle + if cross_sync_config and len(enforcer_setups) == 2: + auth_setup = next((s for s in enforcer_setups if s['name'] == 'Auth'), None) + download_setup = next((s for s in enforcer_setups if s['name'] == 'Download'), None) + if auth_setup and download_setup: + sync_cross_simulation( + auth_setup['enforcer'].manager, + download_setup['enforcer'].manager, + cross_sync_config, + dry_run=args.dry_run + ) + # Note: sync_cross_simulation may take actions, but we don't track them for the dot indicator + # This is fine for now + + if had_action_in_cycle: + if not args.verbose: + # Print a newline to separate the action logs from subsequent dots + print(file=sys.stderr) + else: + if not args.verbose: + print(".", end="", file=sys.stderr) + sys.stderr.flush() + + sleep_end_time = time.time() + interval + while time.time() < sleep_end_time and not shutdown_event: + time.sleep(1) + + logger.info("Policy enforcer stopped.") + return 0 diff --git a/ytops_client/profile_allocator_tool.py b/ytops_client/profile_allocator_tool.py new file mode 100644 index 0000000..331129c --- /dev/null +++ b/ytops_client/profile_allocator_tool.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +CLI tool for acquiring and releasing profile locks. +""" + +import argparse +import json +import logging +import os +import sys +import time + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +from .profile_manager_tool import ProfileManager + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def add_profile_allocator_parser(subparsers): + """Adds the parser for the 'profile-allocator' command.""" + parser = subparsers.add_parser( + 'profile-allocator', + description='Acquire and release profile locks.', + formatter_class=argparse.RawTextHelpFormatter, + help='Acquire and release profile locks.' + ) + + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--env-file', help='Path to a .env file to load environment variables from.') + common_parser.add_argument('--redis-host', default=None, help='Redis host. Defaults to MASTER_HOST_IP or REDIS_HOST env var, or localhost.') + common_parser.add_argument('--redis-port', type=int, default=None, help='Redis port. Defaults to REDIS_PORT env var, or 6379.') + common_parser.add_argument('--redis-password', default=None, help='Redis password. Defaults to REDIS_PASSWORD env var.') + common_parser.add_argument('--env', default='dev', help="Environment name for Redis key prefix (e.g., 'stg', 'prod'). Defaults to 'dev'.") + common_parser.add_argument('--legacy', action='store_true', help="Use legacy key prefix ('profile_mgmt_') without environment.") + common_parser.add_argument('--key-prefix', default=None, help='Explicit key prefix for Redis. Overrides --env, --legacy and any defaults.') + common_parser.add_argument('--verbose', action='store_true', help='Enable verbose logging') + + allocator_subparsers = parser.add_subparsers(dest='allocator_command', help='Command to execute', required=True) + + # Lock command + lock_parser = allocator_subparsers.add_parser('lock', help='Find and lock an available profile', parents=[common_parser]) + lock_parser.add_argument('--owner', required=True, help='Identifier for the process locking the profile') + lock_parser.add_argument('--profile-prefix', help='Only lock profiles with this name prefix') + lock_parser.add_argument('--wait', action='store_true', help='Wait indefinitely for a profile to become available, with exponential backoff.') + + # Unlock command + unlock_parser = allocator_subparsers.add_parser('unlock', help='Unlock a profile', parents=[common_parser]) + unlock_parser.add_argument('name', help='Profile name to unlock') + unlock_parser.add_argument('--owner', help='Identifier of the owner. If provided, unlock will only succeed if owner matches.') + + # Cleanup command + cleanup_parser = allocator_subparsers.add_parser('cleanup-locks', help='Clean up stale locks', parents=[common_parser]) + cleanup_parser.add_argument('--max-age-seconds', type=int, default=3600, + help='Maximum lock age in seconds before it is considered stale (default: 3600)') + + return parser + +def main_profile_allocator(args): + """Main dispatcher for 'profile-allocator' command.""" + if load_dotenv: + env_file = args.env_file + if not env_file and args.env and '.env' in args.env and os.path.exists(args.env): + logger.warning(f"Warning: --env should be an environment name (e.g., 'dev'), not a file path. Treating '{args.env}' as --env-file. The environment name will default to 'dev'.") + env_file = args.env + args.env = 'dev' + + was_loaded = load_dotenv(env_file) + if was_loaded: + logger.info(f"Loaded environment variables from {env_file or '.env file'}") + elif args.env_file: + logger.error(f"The specified --env-file was not found: {args.env_file}") + return 1 + + if args.redis_host is None: + args.redis_host = os.getenv('MASTER_HOST_IP', os.getenv('REDIS_HOST', 'localhost')) + if args.redis_port is None: + args.redis_port = int(os.getenv('REDIS_PORT', 6379)) + if args.redis_password is None: + args.redis_password = os.getenv('REDIS_PASSWORD') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + if args.key_prefix: + key_prefix = args.key_prefix + elif args.legacy: + key_prefix = 'profile_mgmt_' + else: + key_prefix = f"{args.env}_profile_mgmt_" + + manager = ProfileManager( + redis_host=args.redis_host, + redis_port=args.redis_port, + redis_password=args.redis_password, + key_prefix=key_prefix + ) + + if args.allocator_command == 'lock': + if not args.wait: + profile = manager.lock_profile(args.owner, profile_prefix=args.profile_prefix) + if profile: + print(json.dumps(profile, indent=2, default=str)) + return 0 + else: + print("No available profile could be locked.", file=sys.stderr) + return 1 + + # With --wait, loop with backoff + lock_attempts = 0 + backoff_seconds = [3, 5, 9, 20, 50, 120, 300] + while True: + profile = manager.lock_profile(args.owner, profile_prefix=args.profile_prefix) + if profile: + print(json.dumps(profile, indent=2, default=str)) + return 0 + + sleep_duration = backoff_seconds[min(lock_attempts, len(backoff_seconds) - 1)] + logger.info(f"No available profile. Retrying in {sleep_duration}s... (attempt {lock_attempts + 1})") + + try: + time.sleep(sleep_duration) + except KeyboardInterrupt: + logger.warning("Wait for lock interrupted by user.") + # Use print for stderr as well, since logger might be configured differently by callers + print("\nWait for lock interrupted by user.", file=sys.stderr) + return 130 # Standard exit code for Ctrl+C + + lock_attempts += 1 + + elif args.allocator_command == 'unlock': + success = manager.unlock_profile(args.name, args.owner) + return 0 if success else 1 + + elif args.allocator_command == 'cleanup-locks': + cleaned_count = manager.cleanup_stale_locks(args.max_age_seconds) + print(f"Cleaned {cleaned_count} stale lock(s).") + return 0 + + return 1 # Should not be reached diff --git a/ytops_client/profile_manager_tool.py b/ytops_client/profile_manager_tool.py new file mode 100644 index 0000000..e465977 --- /dev/null +++ b/ytops_client/profile_manager_tool.py @@ -0,0 +1,1989 @@ +#!/usr/bin/env python3 +""" +Profile Management CLI Tool (v2) for yt-ops-client. +""" + +import argparse +import base64 +import json +import io +import logging +import os +import random +import signal +import sys +import threading +import time +from datetime import datetime +from typing import Dict, List, Optional, Any +import collections + +import redis + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +try: + from tabulate import tabulate +except ImportError: + print("'tabulate' library not found. Please install it with: pip install tabulate", file=sys.stderr) + tabulate = None + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Graceful shutdown handler for live mode +shutdown_event = threading.Event() +def handle_shutdown(sig, frame): + """Sets the shutdown_event on SIGINT or SIGTERM.""" + if not shutdown_event.is_set(): + # Use print to stderr to avoid messing with the live display + print("\nShutdown signal received. Stopping live view...", file=sys.stderr) + shutdown_event.set() + + +class ProfileManager: + """Manages profiles in Redis with configurable prefix.""" + + # Profile states + STATE_ACTIVE = "ACTIVE" + STATE_PAUSED = "PAUSED" + STATE_RESTING = "RESTING" + STATE_BANNED = "BANNED" + STATE_LOCKED = "LOCKED" + STATE_COOLDOWN = "COOLDOWN" + + VALID_STATES = [STATE_ACTIVE, STATE_PAUSED, STATE_RESTING, STATE_BANNED, STATE_LOCKED, STATE_COOLDOWN] + + def __init__(self, redis_host='localhost', redis_port=6379, + redis_password=None, key_prefix='profile_mgmt_'): + """Initialize Redis connection and key prefix.""" + self.key_prefix = key_prefix + logger.info(f"Attempting to connect to Redis at {redis_host}:{redis_port}...") + try: + self.redis = redis.Redis( + host=redis_host, + port=redis_port, + password=redis_password, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5 + ) + self.redis.ping() + logger.info(f"Successfully connected to Redis.") + logger.info(f"Using key prefix: {key_prefix}") + except redis.exceptions.ConnectionError as e: + logger.error(f"Failed to connect to Redis at {redis_host}:{redis_port}: {e}") + sys.exit(1) + + def _profile_key(self, profile_name: str) -> str: + """Get Redis key for a profile.""" + return f"{self.key_prefix}profile:{profile_name}" + + def _state_key(self, state: str) -> str: + """Get Redis key for a state index.""" + return f"{self.key_prefix}state:{state}" + + def _activity_key(self, profile_name: str, activity_type: str) -> str: + """Get Redis key for activity timeline.""" + return f"{self.key_prefix}activity:{profile_name}:{activity_type}" + + def _proxy_state_key(self, proxy_url: str) -> str: + """Get Redis key for proxy state hash.""" + encoded_proxy = base64.urlsafe_b64encode(proxy_url.encode()).decode() + return f"{self.key_prefix}proxy_state:{encoded_proxy}" + + def _proxy_group_state_key(self, group_name: str) -> str: + """Get Redis key for proxy group state hash.""" + return f"{self.key_prefix}proxy_group_state:{group_name}" + + def _profile_group_state_key(self, group_name: str) -> str: + """Get Redis key for profile group state hash.""" + return f"{self.key_prefix}profile_group_state:{group_name}" + + def _proxy_activity_key(self, proxy_url: str, activity_type: str) -> str: + """Get Redis key for proxy activity.""" + # Use base64 to handle special chars in URL + encoded_proxy = base64.urlsafe_b64encode(proxy_url.encode()).decode() + return f"{self.key_prefix}activity:proxy:{encoded_proxy}:{activity_type}" + + def _config_key(self) -> str: + """Get Redis key for shared configuration.""" + return f"{self.key_prefix}config" + + def _pending_downloads_key(self, profile_name: str) -> str: + """Get Redis key for a profile's pending downloads counter.""" + return f"{self.key_prefix}downloads_pending:{profile_name}" + + def increment_pending_downloads(self, profile_name: str, count: int = 1) -> Optional[int]: + """Atomically increments the pending downloads counter for a profile.""" + if count <= 0: + return None + key = self._pending_downloads_key(profile_name) + new_value = self.redis.incrby(key, count) + # Set a TTL on the key to prevent it from living forever if something goes wrong. + # 5 hours is a safe buffer for the 4-hour info.json validity. + self.redis.expire(key, 5 * 3600) + logger.info(f"Incremented pending downloads for '{profile_name}' by {count}. New count: {new_value}") + return new_value + + def decrement_pending_downloads(self, profile_name: str) -> Optional[int]: + """Atomically decrements the pending downloads counter for a profile.""" + key = self._pending_downloads_key(profile_name) + + # Only decrement if the key exists. This prevents stray calls from creating negative counters. + if not self.redis.exists(key): + logger.warning(f"Attempted to decrement pending downloads for '{profile_name}', but no counter exists. No action taken.") + return None + + new_value = self.redis.decr(key) + + logger.info(f"Decremented pending downloads for '{profile_name}'. New count: {new_value}") + if new_value <= 0: + # Clean up the key once it reaches zero. + self.redis.delete(key) + logger.info(f"Pending downloads for '{profile_name}' reached zero. Cleared counter key.") + + return new_value + + def get_pending_downloads(self, profile_name: str) -> int: + """Retrieves the current pending downloads count for a profile.""" + key = self._pending_downloads_key(profile_name) + value = self.redis.get(key) + return int(value) if value else 0 + + def clear_pending_downloads(self, profile_name: str) -> bool: + """Deletes the pending downloads counter key for a profile.""" + key = self._pending_downloads_key(profile_name) + deleted_count = self.redis.delete(key) + if deleted_count > 0: + logger.info(f"Cleared pending downloads counter for '{profile_name}'.") + return deleted_count > 0 + + def set_config(self, key: str, value: Any) -> bool: + """Sets a configuration value in Redis.""" + self.redis.hset(self._config_key(), key, str(value)) + logger.info(f"Set config '{key}' to '{value}'") + return True + + def get_config(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + """Gets a configuration value from Redis.""" + value = self.redis.hget(self._config_key(), key) + if value is None: + return default + return value + + def _locks_key(self) -> str: + """Get Redis key for locks hash.""" + return f"{self.key_prefix}locks" + + def _failed_lock_attempts_key(self) -> str: + """Get Redis key for the failed lock attempts counter.""" + return f"{self.key_prefix}stats:failed_lock_attempts" + + def create_profile(self, name: str, proxy: str, initial_state: str = STATE_ACTIVE) -> bool: + """Create a new profile.""" + if initial_state not in self.VALID_STATES: + logger.error(f"Invalid initial state: {initial_state}") + return False + + profile_key = self._profile_key(name) + + # Check if profile already exists + if self.redis.exists(profile_key): + logger.error(f"Profile '{name}' already exists") + return False + + now = time.time() + profile_data = { + 'name': name, + 'proxy': proxy, + 'state': initial_state, + 'created_at': str(now), + 'last_used': str(now), + 'success_count': '0', + 'failure_count': '0', + 'tolerated_error_count': '0', + 'download_count': '0', + 'download_error_count': '0', + 'global_success_count': '0', + 'global_failure_count': '0', + 'global_tolerated_error_count': '0', + 'global_download_count': '0', + 'global_download_error_count': '0', + 'lock_timestamp': '0', + 'lock_owner': '', + 'rest_until': '0', + 'last_rest_timestamp': '0', + 'wait_started_at': '0', + 'ban_reason': '', + 'rest_reason': '', + 'reason': '', + 'notes': '' + } + + # Use pipeline for atomic operations + pipe = self.redis.pipeline() + pipe.hset(profile_key, mapping=profile_data) + # Add to state index + pipe.zadd(self._state_key(initial_state), {name: now}) + result = pipe.execute() + + if result[0] > 0: + logger.info(f"Created profile '{name}' with proxy '{proxy}' (state: {initial_state})") + return True + else: + logger.error(f"Failed to create profile '{name}'") + return False + + def get_profile(self, name: str) -> Optional[Dict[str, Any]]: + """Get profile details.""" + profile_key = self._profile_key(name) + data = self.redis.hgetall(profile_key) + + if not data: + return None + + # Convert numeric fields + numeric_fields = ['created_at', 'last_used', 'success_count', 'failure_count', + 'tolerated_error_count', 'download_count', 'download_error_count', + 'global_success_count', 'global_failure_count', + 'global_tolerated_error_count', 'global_download_count', + 'global_download_error_count', + 'lock_timestamp', 'rest_until', 'last_rest_timestamp', 'wait_started_at'] + for field in numeric_fields: + if field in data: + try: + data[field] = float(data[field]) + except (ValueError, TypeError): + data[field] = 0.0 + + return data + + def list_profiles(self, state_filter: Optional[str] = None, + proxy_filter: Optional[str] = None) -> List[Dict[str, Any]]: + """List profiles with optional filtering.""" + profiles = [] + + if state_filter: + # Get profiles from specific state index + state_key = self._state_key(state_filter) + profile_names = self.redis.zrange(state_key, 0, -1) + else: + # Get all profiles by scanning keys + pattern = self._profile_key('*') + keys = [] + cursor = 0 + while True: + cursor, found_keys = self.redis.scan(cursor=cursor, match=pattern, count=100) + keys.extend(found_keys) + if cursor == 0: + break + profile_names = [k.split(':')[-1] for k in keys] + + if not profile_names: + return [] + + # Use a pipeline to fetch all profile data at once for efficiency + pipe = self.redis.pipeline() + for name in profile_names: + pipe.hgetall(self._profile_key(name)) + all_profile_data = pipe.execute() + + # Also fetch pending download counts for all profiles + pipe = self.redis.pipeline() + for name in profile_names: + pipe.get(self._pending_downloads_key(name)) + all_pending_downloads = pipe.execute() + + numeric_fields = ['created_at', 'last_used', 'success_count', 'failure_count', + 'tolerated_error_count', 'download_count', 'download_error_count', + 'global_success_count', 'global_failure_count', + 'global_tolerated_error_count', 'global_download_count', + 'global_download_error_count', + 'lock_timestamp', 'rest_until', 'last_rest_timestamp', 'wait_started_at'] + + for i, data in enumerate(all_profile_data): + if not data: + continue + + # Add pending downloads count to the profile data + pending_downloads = all_pending_downloads[i] + data['pending_downloads'] = int(pending_downloads) if pending_downloads else 0 + + # Convert numeric fields + for field in numeric_fields: + if field in data: + try: + data[field] = float(data[field]) + except (ValueError, TypeError): + data[field] = 0.0 + + if proxy_filter and proxy_filter not in data.get('proxy', ''): + continue + + profiles.append(data) + + # Sort by creation time (newest first) + profiles.sort(key=lambda x: x.get('created_at', 0), reverse=True) + return profiles + + def update_profile_state(self, name: str, new_state: str, + reason: str = '') -> bool: + """Update profile state.""" + if new_state not in self.VALID_STATES: + logger.error(f"Invalid state: {new_state}") + return False + + profile = self.get_profile(name) + if not profile: + logger.error(f"Profile '{name}' not found") + return False + + old_state = profile['state'] + if old_state == new_state: + logger.info(f"Profile '{name}' already in state {new_state}") + return True + + now = time.time() + profile_key = self._profile_key(name) + + pipe = self.redis.pipeline() + + # Update profile hash + updates = {'state': new_state, 'last_used': str(now)} + + if new_state == self.STATE_BANNED and reason: + updates['ban_reason'] = reason + elif new_state == self.STATE_RESTING: + # Set rest_until to 1 hour from now by default + rest_until = now + 3600 + updates['rest_until'] = str(rest_until) + if reason: + updates['rest_reason'] = reason + + # Handle transitions into ACTIVE state + if new_state == self.STATE_ACTIVE: + # Clear any resting/banned state fields + updates['rest_until'] = '0' + updates['rest_reason'] = '' + updates['reason'] = '' + updates['ban_reason'] = '' # Clear ban reason on manual activation + if old_state in [self.STATE_RESTING, self.STATE_COOLDOWN]: + updates['last_rest_timestamp'] = str(now) + + # When activating a profile, ensure its proxy is also active. + proxy_url = profile.get('proxy') + if proxy_url: + logger.info(f"Activating associated proxy '{proxy_url}' for profile '{name}'.") + pipe.hset(self._proxy_state_key(proxy_url), mapping={ + 'state': self.STATE_ACTIVE, + 'rest_until': '0', + 'work_start_timestamp': str(now) + }) + + # If moving to any state that is not LOCKED, ensure any stale lock data is cleared. + # This makes manual state changes (like 'activate' or 'unban') more robust. + if new_state != self.STATE_LOCKED: + updates['lock_owner'] = '' + updates['lock_timestamp'] = '0' + pipe.hdel(self._locks_key(), name) + if old_state == self.STATE_LOCKED: + logger.info(f"Profile '{name}' was in LOCKED state. Clearing global lock.") + + pipe.hset(profile_key, mapping=updates) + + # Remove from old state index, add to new state index + if old_state in self.VALID_STATES: + pipe.zrem(self._state_key(old_state), name) + pipe.zadd(self._state_key(new_state), {name: now}) + + result = pipe.execute() + + logger.info(f"Updated profile '{name}' from {old_state} to {new_state}") + if reason: + logger.info(f"Reason: {reason}") + + return True + + def update_profile_field(self, name: str, field: str, value: str) -> bool: + """Update a specific field in profile.""" + profile = self.get_profile(name) + if not profile: + logger.error(f"Profile '{name}' not found") + return False + + profile_key = self._profile_key(name) + self.redis.hset(profile_key, field, value) + logger.info(f"Updated profile '{name}' field '{field}' to '{value}'") + return True + + def delete_profile(self, name: str) -> bool: + """Delete a profile and all associated data.""" + profile = self.get_profile(name) + if not profile: + logger.error(f"Profile '{name}' not found") + return False + + state = profile['state'] + + pipe = self.redis.pipeline() + + # Delete profile hash + profile_key = self._profile_key(name) + pipe.delete(profile_key) + + # Remove from state index + if state in self.VALID_STATES: + pipe.zrem(self._state_key(state), name) + + # Delete activity keys + for activity_type in ['success', 'failure', 'tolerated_error', 'download', 'download_error']: + activity_key = self._activity_key(name, activity_type) + pipe.delete(activity_key) + + # Remove from locks if present + locks_key = self._locks_key() + pipe.hdel(locks_key, name) + + result = pipe.execute() + + logger.info(f"Deleted profile '{name}' and all associated data") + return True + + def delete_all_data(self) -> int: + """Deletes all keys associated with the current manager's key_prefix.""" + logger.warning(f"Deleting all keys with prefix: {self.key_prefix}") + + keys_to_delete = [] + for key in self.redis.scan_iter(f"{self.key_prefix}*"): + keys_to_delete.append(key) + + if not keys_to_delete: + logger.info("No keys found to delete.") + return 0 + + total_deleted = 0 + chunk_size = 500 + for i in range(0, len(keys_to_delete), chunk_size): + chunk = keys_to_delete[i:i + chunk_size] + total_deleted += self.redis.delete(*chunk) + + logger.info(f"Deleted {total_deleted} key(s).") + return total_deleted + + def record_activity(self, name: str, activity_type: str, + timestamp: Optional[float] = None) -> bool: + """Record activity (success/failure) for a profile.""" + if activity_type not in ['success', 'failure', 'tolerated_error', 'download', 'download_error']: + logger.error(f"Invalid activity type: {activity_type}") + return False + + profile = self.get_profile(name) + if not profile: + logger.error(f"Profile '{name}' not found") + return False + + ts = timestamp or time.time() + activity_key = self._activity_key(name, activity_type) + + # Add to sorted set + self.redis.zadd(activity_key, {str(ts): ts}) + + # Update counters in profile + profile_key = self._profile_key(name) + counter_field = f"{activity_type}_count" + self.redis.hincrby(profile_key, counter_field, 1) + global_counter_field = f"global_{activity_type}_count" + self.redis.hincrby(profile_key, global_counter_field, 1) + + # Update last_used + self.redis.hset(profile_key, 'last_used', str(ts)) + + # Keep only last 1000 activities to prevent unbounded growth + self.redis.zremrangebyrank(activity_key, 0, -1001) + + # Also record activity for the proxy + proxy_url = profile.get('proxy') + if proxy_url: + proxy_activity_key = self._proxy_activity_key(proxy_url, activity_type) + pipe = self.redis.pipeline() + pipe.zadd(proxy_activity_key, {str(ts): ts}) + # Keep last 5000 activities per proxy (higher limit) + pipe.zremrangebyrank(proxy_activity_key, 0, -5001) + pipe.execute() + logger.debug(f"Recorded {activity_type} for proxy '{proxy_url}'") + + logger.debug(f"Recorded {activity_type} for profile '{name}' at {ts}") + return True + + def get_activity_rate(self, name: str, activity_type: str, + window_seconds: int) -> int: + """Get activity count within time window.""" + if activity_type not in ['success', 'failure', 'tolerated_error', 'download', 'download_error']: + return 0 + + activity_key = self._activity_key(name, activity_type) + now = time.time() + start = now - window_seconds + + count = self.redis.zcount(activity_key, start, now) + return count + + def get_proxy_activity_rate(self, proxy_url: str, activity_type: str, + window_seconds: int) -> int: + """Get proxy activity count within time window.""" + if activity_type not in ['success', 'failure', 'tolerated_error', 'download', 'download_error']: + return 0 + + activity_key = self._proxy_activity_key(proxy_url, activity_type) + now = time.time() + start = now - window_seconds + + count = self.redis.zcount(activity_key, start, now) + return count + + def reset_profile_counters(self, name: str) -> bool: + """Resets the session counters for a single profile (does not affect global counters).""" + profile = self.get_profile(name) + if not profile: + logger.error(f"Profile '{name}' not found") + return False + + profile_key = self._profile_key(name) + counters_to_reset = { + 'success_count': '0', + 'failure_count': '0', + 'tolerated_error_count': '0', + 'download_count': '0', + 'download_error_count': '0', + } + self.redis.hset(profile_key, mapping=counters_to_reset) + logger.info(f"Reset session counters for profile '{name}'.") + return True + + def get_failed_lock_attempts(self) -> int: + """Get the total count of failed lock attempts from Redis.""" + count = self.redis.get(self._failed_lock_attempts_key()) + return int(count) if count else 0 + + def get_global_stats(self) -> Dict[str, int]: + """Get aggregated global stats across all profiles.""" + profiles = self.list_profiles() + total_success = sum(int(p.get('global_success_count', 0)) for p in profiles) + total_failure = sum(int(p.get('global_failure_count', 0)) for p in profiles) + total_tolerated_error = sum(int(p.get('global_tolerated_error_count', 0)) for p in profiles) + total_downloads = sum(int(p.get('global_download_count', 0)) for p in profiles) + total_download_errors = sum(int(p.get('global_download_error_count', 0)) for p in profiles) + return { + 'total_success': total_success, + 'total_failure': total_failure, + 'total_tolerated_error': total_tolerated_error, + 'total_downloads': total_downloads, + 'total_download_errors': total_download_errors, + } + + def get_per_proxy_stats(self) -> Dict[str, Dict[str, Any]]: + """Get aggregated stats per proxy.""" + profiles = self.list_profiles() + proxy_stats = collections.defaultdict(lambda: { + 'success': 0, 'failure': 0, 'tolerated_error': 0, 'downloads': 0, 'download_errors': 0, 'profiles': 0 + }) + for p in profiles: + proxy = p.get('proxy') + if proxy: + proxy_stats[proxy]['success'] += int(p.get('global_success_count', 0)) + proxy_stats[proxy]['failure'] += int(p.get('global_failure_count', 0)) + proxy_stats[proxy]['tolerated_error'] += int(p.get('global_tolerated_error_count', 0)) + proxy_stats[proxy]['downloads'] += int(p.get('global_download_count', 0)) + proxy_stats[proxy]['download_errors'] += int(p.get('global_download_error_count', 0)) + proxy_stats[proxy]['profiles'] += 1 + return dict(proxy_stats) + + def reset_global_counters(self) -> int: + """Resets global, non-profile-specific counters.""" + logger.info("Resetting global counters...") + keys_to_delete = [self._failed_lock_attempts_key()] + + deleted_count = 0 + if keys_to_delete: + deleted_count = self.redis.delete(*keys_to_delete) + + logger.info(f"Deleted {deleted_count} global counter key(s).") + return deleted_count + + def set_proxy_state(self, proxy_url: str, state: str, rest_duration_minutes: Optional[int] = None) -> bool: + """Set the state of a proxy and propagates it to associated profiles.""" + if state not in [self.STATE_ACTIVE, self.STATE_RESTING]: + logger.error(f"Invalid proxy state: {state}. Only ACTIVE and RESTING are supported for proxies.") + return False + + proxy_key = self._proxy_state_key(proxy_url) + now = time.time() + updates = {'state': state} + + rest_until = 0 + if state == self.STATE_RESTING: + if not rest_duration_minutes or rest_duration_minutes <= 0: + logger.error("rest_duration_minutes is required when setting proxy state to RESTING.") + return False + rest_until = now + rest_duration_minutes * 60 + updates['rest_until'] = str(rest_until) + updates['work_start_timestamp'] = '0' # Clear work start time + else: # ACTIVE + updates['rest_until'] = '0' + updates['work_start_timestamp'] = str(now) + + self.redis.hset(proxy_key, mapping=updates) + logger.info(f"Set proxy '{proxy_url}' state to {state}.") + + # Now, update associated profiles + profiles_on_proxy = self.list_profiles(proxy_filter=proxy_url) + if not profiles_on_proxy: + return True + + if state == self.STATE_RESTING: + logger.info(f"Propagating RESTING state to profiles on proxy '{proxy_url}'.") + for profile in profiles_on_proxy: + if profile['state'] == self.STATE_ACTIVE: + self.update_profile_state(profile['name'], self.STATE_RESTING, "Proxy resting") + self.update_profile_field(profile['name'], 'rest_until', str(rest_until)) + elif state == self.STATE_ACTIVE: + logger.info(f"Propagating ACTIVE state to profiles on proxy '{proxy_url}'.") + for profile in profiles_on_proxy: + if profile['state'] == self.STATE_RESTING and profile.get('rest_reason') == "Proxy resting": + self.update_profile_state(profile['name'], self.STATE_ACTIVE, "Proxy activated") + + return True + + def get_proxy_states(self, proxy_urls: List[str]) -> Dict[str, Dict[str, Any]]: + """Get states for multiple proxies.""" + if not proxy_urls: + return {} + + pipe = self.redis.pipeline() + for proxy_url in proxy_urls: + pipe.hgetall(self._proxy_state_key(proxy_url)) + + results = pipe.execute() + + states = {} + for i, data in enumerate(results): + proxy_url = proxy_urls[i] + if data: + # Convert numeric fields + for field in ['rest_until', 'work_start_timestamp']: + if field in data: + try: + data[field] = float(data[field]) + except (ValueError, TypeError): + data[field] = 0.0 + states[proxy_url] = data + else: + # Default to ACTIVE if no state is found + states[proxy_url] = {'state': self.STATE_ACTIVE, 'rest_until': 0.0, 'work_start_timestamp': 0.0} + + return states + + def set_proxy_group_membership(self, proxy_url: str, group_name: str, work_minutes: int) -> bool: + """Records a proxy's membership in a rotation group by updating its state hash.""" + proxy_key = self._proxy_state_key(proxy_url) + updates = { + 'group_name': group_name, + 'group_work_minutes': str(work_minutes) + } + self.redis.hset(proxy_key, mapping=updates) + logger.debug(f"Set proxy '{proxy_url}' group membership to '{group_name}'.") + return True + + def set_proxy_group_state(self, group_name: str, active_proxy_index: int, next_rotation_timestamp: float) -> bool: + """Set the state of a proxy group.""" + group_key = self._proxy_group_state_key(group_name) + updates = { + 'active_proxy_index': str(active_proxy_index), + 'next_rotation_timestamp': str(next_rotation_timestamp) + } + self.redis.hset(group_key, mapping=updates) + logger.info(f"Set proxy group '{group_name}' state: active_index={active_proxy_index}, next_rotation at {format_timestamp(next_rotation_timestamp)}.") + return True + + def get_proxy_group_states(self, group_names: List[str]) -> Dict[str, Dict[str, Any]]: + """Get states for multiple proxy groups.""" + if not group_names: + return {} + + pipe = self.redis.pipeline() + for name in group_names: + pipe.hgetall(self._proxy_group_state_key(name)) + + results = pipe.execute() + + states = {} + for i, data in enumerate(results): + group_name = group_names[i] + if data: + # Convert numeric fields + for field in ['active_proxy_index', 'next_rotation_timestamp']: + if field in data: + try: + data[field] = float(data[field]) + except (ValueError, TypeError): + data[field] = 0.0 + if 'active_proxy_index' in data: + data['active_proxy_index'] = int(data['active_proxy_index']) + states[group_name] = data + else: + # Default to empty dict if no state is found + states[group_name] = {} + + return states + + def set_profile_group_state(self, group_name: str, state_data: Dict[str, Any]) -> bool: + """Set or update the state of a profile group.""" + group_key = self._profile_group_state_key(group_name) + # Ensure all values are strings for redis hset, and filter out None values + updates = {k: str(v) for k, v in state_data.items() if v is not None} + if not updates: + return True # Nothing to do + self.redis.hset(group_key, mapping=updates) + logger.debug(f"Set profile group '{group_name}' state: {updates}.") + return True + + def get_profile_group_states(self, group_names: List[str]) -> Dict[str, Dict[str, Any]]: + """Get states for multiple profile groups.""" + if not group_names: + return {} + + pipe = self.redis.pipeline() + for name in group_names: + pipe.hgetall(self._profile_group_state_key(name)) + + results = pipe.execute() + + states = {} + for i, data in enumerate(results): + group_name = group_names[i] + if data: + numeric_fields = { + 'active_profile_index': int, + 'rotate_after_requests': int, + 'max_active_profiles': int, + } + for field, type_converter in numeric_fields.items(): + if field in data: + try: + data[field] = type_converter(data[field]) + except (ValueError, TypeError): + data[field] = 0 + states[group_name] = data + else: + states[group_name] = {} + + return states + + def lock_profile(self, owner: str, profile_prefix: Optional[str] = None, specific_profile_name: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Find and lock an available ACTIVE profile. + If `specific_profile_name` is provided, it will attempt to lock only that profile. + Otherwise, it scans for available profiles, optionally filtered by `profile_prefix`. + """ + profiles_to_check = [] + if specific_profile_name: + # If a specific profile is requested, we only check that one. + profiles_to_check = [specific_profile_name] + else: + # Original logic: find all active profiles, optionally filtered by prefix. + active_profiles = self.redis.zrange(self._state_key(self.STATE_ACTIVE), 0, -1) + if not active_profiles: + logger.warning("No active profiles available to lock.") + self.redis.incr(self._failed_lock_attempts_key()) + return None + + if profile_prefix: + profiles_to_check = [p for p in active_profiles if p.startswith(profile_prefix)] + if not profiles_to_check: + logger.warning(f"No active profiles with prefix '{profile_prefix}' available to lock.") + self.redis.incr(self._failed_lock_attempts_key()) + return None + else: + profiles_to_check = active_profiles + + # --- Filter by active proxy and prepare for locking --- + full_profiles = [self.get_profile(p) for p in profiles_to_check] + # Filter out any None profiles from a race condition with deletion, and ensure state is ACTIVE. + # This is especially important when locking a specific profile. + full_profiles = [p for p in full_profiles if p and p.get('proxy') and p.get('state') == self.STATE_ACTIVE] + + if not full_profiles: + if specific_profile_name: + logger.warning(f"Profile '{specific_profile_name}' is not eligible for locking (e.g., not ACTIVE or missing).") + else: + logger.warning("No active profiles available to lock after filtering.") + self.redis.incr(self._failed_lock_attempts_key()) + return None + + unique_proxies = sorted(list(set(p['proxy'] for p in full_profiles))) + proxy_states = self.get_proxy_states(unique_proxies) + + eligible_profiles = [ + p['name'] for p in full_profiles + if proxy_states.get(p['proxy'], {}).get('state', self.STATE_ACTIVE) == self.STATE_ACTIVE + ] + + if not eligible_profiles: + logger.warning("No active profiles with an active proxy available to lock.") + self.redis.incr(self._failed_lock_attempts_key()) + return None + + # Make selection deterministic (use Redis's sorted set order) instead of random + # random.shuffle(active_profiles) + + locks_key = self._locks_key() + + for name in eligible_profiles: + # Try to acquire lock atomically + if self.redis.hsetnx(locks_key, name, owner): + # Lock acquired. Now, re-check state to avoid race condition with enforcer. + profile_key = self._profile_key(name) + current_state = self.redis.hget(profile_key, 'state') + + if current_state != self.STATE_ACTIVE: + # Another process (enforcer) changed the state. Release lock and try next. + self.redis.hdel(locks_key, name) + logger.warning(f"Aborted lock for '{name}'; state changed from ACTIVE to '{current_state}' during lock acquisition.") + continue + + # State is still ACTIVE, proceed with locking. + now = time.time() + + pipe = self.redis.pipeline() + # Update profile state and lock info + pipe.hset(profile_key, mapping={ + 'state': self.STATE_LOCKED, + 'lock_owner': owner, + 'lock_timestamp': str(now), + 'last_used': str(now) + }) + # Move from ACTIVE to LOCKED state index + pipe.zrem(self._state_key(self.STATE_ACTIVE), name) + pipe.zadd(self._state_key(self.STATE_LOCKED), {name: now}) + pipe.execute() + + logger.info(f"Locked profile '{name}' for owner '{owner}'") + return self.get_profile(name) + + logger.warning("Could not lock any active profile (all may have been locked by other workers).") + self.redis.incr(self._failed_lock_attempts_key()) + return None + + def unlock_profile(self, name: str, owner: Optional[str] = None, rest_for_seconds: Optional[int] = None) -> bool: + """Unlock a profile. If owner provided, it must match. Can optionally put profile into COOLDOWN state.""" + profile = self.get_profile(name) + if not profile: + logger.error(f"Profile '{name}' not found.") + return False + + if profile['state'] != self.STATE_LOCKED: + logger.warning(f"Profile '{name}' is not in LOCKED state (current: {profile['state']}).") + # Forcibly remove from locks hash if it's inconsistent + self.redis.hdel(self._locks_key(), name) + return False + + if owner and profile['lock_owner'] != owner: + logger.error(f"Owner mismatch: cannot unlock profile '{name}'. Locked by '{profile['lock_owner']}', attempted by '{owner}'.") + return False + + now = time.time() + profile_key = self._profile_key(name) + + pipe = self.redis.pipeline() + + updates = { + 'lock_owner': '', + 'lock_timestamp': '0', + 'last_used': str(now) + } + + if rest_for_seconds and rest_for_seconds > 0: + new_state = self.STATE_COOLDOWN + rest_until = now + rest_for_seconds + updates['rest_until'] = str(rest_until) + updates['rest_reason'] = 'Post-task cooldown' + logger_msg = f"Unlocked profile '{name}' into COOLDOWN for {rest_for_seconds}s." + else: + new_state = self.STATE_ACTIVE + # Clear any rest-related fields when moving to ACTIVE + updates['rest_until'] = '0' + updates['rest_reason'] = '' + updates['reason'] = '' + logger_msg = f"Unlocked profile '{name}'" + + updates['state'] = new_state + pipe.hset(profile_key, mapping=updates) + + # Move from LOCKED to the new state index + pipe.zrem(self._state_key(self.STATE_LOCKED), name) + pipe.zadd(self._state_key(new_state), {name: now}) + + # Remove from global locks hash + pipe.hdel(self._locks_key(), name) + + pipe.execute() + + logger.info(logger_msg) + return True + + def cleanup_stale_locks(self, max_lock_time_seconds: int) -> int: + """Find and unlock profiles with stale locks.""" + locks_key = self._locks_key() + all_locks = self.redis.hgetall(locks_key) + if not all_locks: + logger.debug("No active locks found to clean up.") + return 0 + + now = time.time() + cleaned_count = 0 + + for name, owner in all_locks.items(): + profile = self.get_profile(name) + if not profile: + # Lock exists but profile doesn't. Clean up the lock. + self.redis.hdel(locks_key, name) + logger.warning(f"Removed stale lock for non-existent profile '{name}'") + cleaned_count += 1 + continue + + lock_timestamp = profile.get('lock_timestamp', 0) + if lock_timestamp > 0 and (now - lock_timestamp) > max_lock_time_seconds: + logger.warning(f"Found stale lock for profile '{name}' (locked by '{owner}' for {now - lock_timestamp:.0f}s). Unlocking...") + if self.unlock_profile(name): + cleaned_count += 1 + + if cleaned_count > 0: + logger.info(f"Cleaned up {cleaned_count} stale lock(s).") + else: + logger.debug("No stale locks found to clean up.") + return cleaned_count + +def format_timestamp(ts: float) -> str: + """Format timestamp for display.""" + if not ts or ts == 0: + return "Never" + return datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') + +def format_duration(seconds: float) -> str: + """Format duration for display.""" + if seconds < 60: + return f"{seconds:.0f}s" + elif seconds < 3600: + return f"{seconds/60:.1f}m" + elif seconds < 86400: + return f"{seconds/3600:.1f}h" + else: + return f"{seconds/86400:.1f}d" + + +def add_profile_manager_parser(subparsers): + """Adds the parser for the 'profile' command.""" + parser = subparsers.add_parser( + 'profile', + description='Manage profiles (v2).', + formatter_class=argparse.RawTextHelpFormatter, + help='Manage profiles (v2).' + ) + + # Common arguments for all profile manager subcommands + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--env-file', help='Path to a .env file to load environment variables from.') + common_parser.add_argument('--redis-host', default=None, help='Redis host. Defaults to REDIS_HOST or MASTER_HOST_IP env var, or localhost.') + common_parser.add_argument('--redis-port', type=int, default=None, help='Redis port. Defaults to REDIS_PORT env var, or 6379.') + common_parser.add_argument('--redis-password', default=None, help='Redis password. Defaults to REDIS_PASSWORD env var.') + common_parser.add_argument('--env', default='dev', help="Environment name for Redis key prefix (e.g., 'stg', 'prod'). Used by all non-list commands, and by 'list' for single-view mode. Defaults to 'dev'.") + common_parser.add_argument('--legacy', action='store_true', help="Use legacy key prefix ('profile_mgmt_') without environment.") + common_parser.add_argument('--key-prefix', default=None, help='Explicit key prefix for Redis. Overrides --env, --legacy and any defaults.') + common_parser.add_argument('--verbose', action='store_true', help='Enable verbose logging') + + subparsers = parser.add_subparsers(dest='profile_command', help='Command to execute', required=True) + + # Create command + create_parser = subparsers.add_parser('create', help='Create a new profile', parents=[common_parser]) + create_parser.add_argument('name', help='Profile name') + create_parser.add_argument('proxy', help='Proxy URL (e.g., sslocal-rust-1090:1090)') + create_parser.add_argument('--state', default='ACTIVE', + choices=['ACTIVE', 'PAUSED', 'RESTING', 'BANNED', 'COOLDOWN'], + help='Initial state (default: ACTIVE)') + + # List command + list_parser = subparsers.add_parser('list', help='List profiles', parents=[common_parser]) + list_parser.add_argument('--auth-env', help='Environment name for the Auth simulation monitor. Use with --download-env for a merged view.') + list_parser.add_argument('--download-env', help='Environment name for the Download simulation monitor. Use with --auth-env for a merged view.') + list_parser.add_argument('--separate-views', action='store_true', help='In dual-monitor mode, show two separate reports instead of a single merged view.') + list_parser.add_argument('--rest-after-requests', type=int, help='(For display) Show countdown to rest based on this request limit.') + list_parser.add_argument('--state', help='Filter by state') + list_parser.add_argument('--proxy', help='Filter by proxy (substring match)') + list_parser.add_argument('--show-proxy-activity', action='store_true', help='Show a detailed activity summary table for proxies. If --proxy is specified, shows details for that proxy only. Otherwise, shows a summary for all proxies.') + list_parser.add_argument('--format', choices=['table', 'json', 'csv'], default='table', + help='Output format (default: table)') + list_parser.add_argument('--live', action='store_true', help='Run continuously with a non-blinking live-updating display.') + list_parser.add_argument('--no-blink', action='store_true', help='Use ANSI escape codes for smoother screen updates in --live mode (experimental).') + list_parser.add_argument('--interval-seconds', type=int, default=5, help='When in --live mode, how often to refresh in seconds. Default: 5.') + list_parser.add_argument('--hide-active-state', action='store_true', help="Display 'ACTIVE' state as blank for cleaner UI.") + + # Get command + get_parser = subparsers.add_parser('get', help='Get profile details', parents=[common_parser]) + get_parser.add_argument('name', help='Profile name') + + # Set proxy state command + set_proxy_state_parser = subparsers.add_parser('set-proxy-state', help='Set the state of a proxy and propagate to its profiles.', parents=[common_parser]) + set_proxy_state_parser.add_argument('proxy_url', help='Proxy URL') + set_proxy_state_parser.add_argument('state', choices=['ACTIVE', 'RESTING'], help='New state for the proxy') + set_proxy_state_parser.add_argument('--duration-minutes', type=int, help='Duration for the RESTING state') + + # Update state command + update_state_parser = subparsers.add_parser('update-state', help='Update profile state', parents=[common_parser]) + update_state_parser.add_argument('name', help='Profile name') + update_state_parser.add_argument('state', choices=['ACTIVE', 'PAUSED', 'RESTING', 'BANNED', 'LOCKED', 'COOLDOWN'], + help='New state') + update_state_parser.add_argument('--reason', help='Reason for state change (especially for BAN)') + + # Update field command + update_field_parser = subparsers.add_parser('update-field', help='Update a profile field', parents=[common_parser]) + update_field_parser.add_argument('name', help='Profile name') + update_field_parser.add_argument('field', help='Field name to update') + update_field_parser.add_argument('value', help='New value') + + # Pause command (convenience) + pause_parser = subparsers.add_parser('pause', help='Pause a profile (sets state to PAUSED).', parents=[common_parser]) + pause_parser.add_argument('name', help='Profile name') + + # Activate command (convenience) + activate_parser = subparsers.add_parser('activate', help='Activate a profile (sets state to ACTIVE). Useful for resuming a PAUSED profile or fixing a stale LOCKED one.', parents=[common_parser]) + activate_parser.add_argument('name', help='Profile name') + + # Ban command (convenience) + ban_parser = subparsers.add_parser('ban', help='Ban a profile (sets state to BANNED).', parents=[common_parser]) + ban_parser.add_argument('name', help='Profile name') + ban_parser.add_argument('--reason', required=True, help='Reason for ban') + + # Unban command (convenience) + unban_parser = subparsers.add_parser('unban', help='Unban a profile (sets state to ACTIVE and resets session counters).', parents=[common_parser]) + unban_parser.add_argument('name', help='Profile name') + + # Delete command + delete_parser = subparsers.add_parser('delete', help='Delete a profile', parents=[common_parser]) + delete_parser.add_argument('name', help='Profile name') + delete_parser.add_argument('--confirm', action='store_true', + help='Confirm deletion (required)') + + # Delete all command + delete_all_parser = subparsers.add_parser('delete-all', help='(Destructive) Delete all profiles and data under the current key prefix.', parents=[common_parser]) + delete_all_parser.add_argument('--confirm', action='store_true', help='Confirm this highly destructive action (required)') + + # Reset global counters command + reset_global_parser = subparsers.add_parser('reset-global-counters', help='Reset global counters (e.g., failed_lock_attempts).', parents=[common_parser]) + + # Reset counters command + reset_counters_parser = subparsers.add_parser( + 'reset-counters', + help='Reset session counters for profiles or proxies.', + description="Resets session-specific counters (success, failure, etc.) for one or more profiles.\n\nWARNING: This only resets Redis counters. It does not affect any data stored on disk\n(e.g., downloaded files, logs) associated with the profile or proxy.", + formatter_class=argparse.RawTextHelpFormatter, + parents=[common_parser] + ) + reset_group = reset_counters_parser.add_mutually_exclusive_group(required=True) + reset_group.add_argument('--profile-name', help='The name of the single profile to reset.') + reset_group.add_argument('--proxy-url', help='Reset all profiles associated with this proxy.') + reset_group.add_argument('--all-profiles', action='store_true', help='Reset all profiles in the environment.') + + # Record activity command (for testing) + record_parser = subparsers.add_parser('record-activity', help='(Testing) Record a synthetic activity event for a profile.', parents=[common_parser]) + record_parser.add_argument('name', help='Profile name') + record_parser.add_argument('type', choices=['success', 'failure', 'tolerated_error', 'download', 'download_error'], help='Activity type') + record_parser.add_argument('--timestamp', type=float, help='Timestamp (default: now)') + + # Get rate command + rate_parser = subparsers.add_parser('get-rate', help='Get activity rate for a profile', parents=[common_parser]) + rate_parser.add_argument('name', help='Profile name') + rate_parser.add_argument('type', choices=['success', 'failure', 'tolerated_error', 'download', 'download_error'], help='Activity type') + rate_parser.add_argument('--window', type=int, default=3600, + help='Time window in seconds (default: 3600)') + return parser + +def _build_profile_groups_config(manager, profiles): + """Builds a configuration structure for profile groups by reading state from Redis.""" + group_state_keys = [k for k in manager.redis.scan_iter(f"{manager.key_prefix}profile_group_state:*")] + if not group_state_keys: + return [] + + group_names = [k.split(':')[-1] for k in group_state_keys] + group_states = manager.get_profile_group_states(group_names) + + config = [] + for name, state in group_states.items(): + profiles_in_group = [] + prefix = state.get('prefix') + if prefix: + profiles_in_group = [p['name'] for p in profiles if p['name'].startswith(prefix)] + + config.append({ + 'name': name, + 'profiles_in_group': profiles_in_group, + **state + }) + return config + + +def _render_all_proxies_activity_summary(manager, simulation_type, file=sys.stdout): + """Renders a summary of activity rates for all proxies.""" + if not manager: + return + + print(f"\n--- All Proxies Activity Summary ({simulation_type}) ---", file=file) + + all_profiles = manager.list_profiles() + if not all_profiles: + print("No profiles found to determine proxy list.", file=file) + return + + unique_proxies = sorted(list(set(p['proxy'] for p in all_profiles if p.get('proxy')))) + if not unique_proxies: + print("No proxies are currently associated with any profiles.", file=file) + return + + proxy_states = manager.get_proxy_states(unique_proxies) + + is_auth_sim = 'Auth' in simulation_type + # Sum up all relevant activity types for the rate columns + activity_types_to_sum = ['success', 'failure', 'tolerated_error'] if is_auth_sim else ['download', 'download_error', 'tolerated_error'] + + proxy_work_minutes_str = manager.get_config('proxy_work_minutes') + proxy_work_minutes = 0 + if proxy_work_minutes_str and proxy_work_minutes_str.isdigit(): + proxy_work_minutes = int(proxy_work_minutes_str) + + proxy_rest_minutes_str = manager.get_config('proxy_rest_duration_minutes') + proxy_rest_minutes = 0 + if proxy_rest_minutes_str and proxy_rest_minutes_str.isdigit(): + proxy_rest_minutes = int(proxy_rest_minutes_str) + + table_data = [] + headers = ['Proxy URL', 'State', 'Policy', 'State Ends In', 'Reqs (1m)', 'Reqs (5m)', 'Reqs (1h)'] + + for proxy_url in unique_proxies: + state_data = proxy_states.get(proxy_url, {}) + state = state_data.get('state', 'N/A') + rest_until = state_data.get('rest_until', 0) + work_start = state_data.get('work_start_timestamp', 0) + + state_str = state + countdown_str = "N/A" + now = time.time() + + policy_str = "N/A" + group_name = state_data.get('group_name') + work_minutes_for_countdown = 0 + + if group_name: + group_work_minutes = state_data.get('group_work_minutes', 0) + try: + group_work_minutes = int(group_work_minutes) + work_minutes_for_countdown = group_work_minutes + except (ValueError, TypeError): + group_work_minutes = 0 + policy_str = f"Group: {group_name}\n({group_work_minutes}m/proxy)" + elif proxy_work_minutes > 0: + policy_str = f"Work: {proxy_work_minutes}m\nRest: {proxy_rest_minutes}m" + work_minutes_for_countdown = proxy_work_minutes + + if state == 'RESTING' and rest_until > now: + countdown_str = format_duration(rest_until - now) + elif state == 'ACTIVE' and work_start > 0 and work_minutes_for_countdown > 0: + work_end_time = work_start + (work_minutes_for_countdown * 60) + if work_end_time > now: + countdown_str = format_duration(work_end_time - now) + else: + countdown_str = "Now" + + rate_1m = sum(manager.get_proxy_activity_rate(proxy_url, act_type, 60) for act_type in activity_types_to_sum) + rate_5m = sum(manager.get_proxy_activity_rate(proxy_url, act_type, 300) for act_type in activity_types_to_sum) + rate_1h = sum(manager.get_proxy_activity_rate(proxy_url, act_type, 3600) for act_type in activity_types_to_sum) + + row = [ + proxy_url, + state_str, + policy_str, + countdown_str, + rate_1m, + rate_5m, + rate_1h, + ] + table_data.append(row) + + if table_data: + print(tabulate(table_data, headers=headers, tablefmt='grid'), file=file) + + +def _render_proxy_activity_summary(manager, proxy_url, simulation_type, file=sys.stdout): + """Renders a detailed activity summary for a single proxy.""" + if not manager or not proxy_url: + return + + print(f"\n--- Activity Summary for Proxy: {proxy_url} ({simulation_type}) ---", file=file) + + proxy_work_minutes_str = manager.get_config('proxy_work_minutes') + proxy_work_minutes = 0 + if proxy_work_minutes_str and proxy_work_minutes_str.isdigit(): + proxy_work_minutes = int(proxy_work_minutes_str) + + proxy_rest_minutes_str = manager.get_config('proxy_rest_duration_minutes') + proxy_rest_minutes = 0 + if proxy_rest_minutes_str and proxy_rest_minutes_str.isdigit(): + proxy_rest_minutes = int(proxy_rest_minutes_str) + + proxy_state_data = manager.get_proxy_states([proxy_url]).get(proxy_url, {}) + state = proxy_state_data.get('state', 'N/A') + rest_until = proxy_state_data.get('rest_until', 0) + work_start = proxy_state_data.get('work_start_timestamp', 0) + + policy_str = "N/A" + group_name = proxy_state_data.get('group_name') + work_minutes_for_countdown = 0 + + if group_name: + group_work_minutes = proxy_state_data.get('group_work_minutes', 0) + try: + group_work_minutes = int(group_work_minutes) + work_minutes_for_countdown = group_work_minutes + except (ValueError, TypeError): + group_work_minutes = 0 + policy_str = f"Group: {group_name} ({group_work_minutes}m/proxy)" + elif proxy_work_minutes > 0: + policy_str = f"Work: {proxy_work_minutes}m, Rest: {proxy_rest_minutes}m" + work_minutes_for_countdown = proxy_work_minutes + + state_str = state + now = time.time() + if state == 'RESTING' and rest_until > now: + state_str += f" (ends in {format_duration(rest_until - now)})" + + active_duration_str = "N/A" + time_until_rest_str = "N/A" + if state == 'ACTIVE' and work_start > 0: + active_duration_str = format_duration(now - work_start) + if work_minutes_for_countdown > 0: + work_end_time = work_start + (work_minutes_for_countdown * 60) + if work_end_time > now: + time_until_rest_str = format_duration(work_end_time - now) + else: + time_until_rest_str = "Now" + + summary_data = [ + ("State", state_str), + ("Policy", policy_str), + ("Active Since", format_timestamp(work_start)), + ("Active Duration", active_duration_str), + ("Time Until Rest", time_until_rest_str), + ] + print(tabulate(summary_data, tablefmt='grid'), file=file) + + windows = { + "Last 1 Min": 60, + "Last 5 Min": 300, + "Last 1 Hour": 3600, + "Last 24 Hours": 86400, + } + + is_auth_sim = 'Auth' in simulation_type + activity_types = ['success', 'failure', 'tolerated_error'] if is_auth_sim else ['download', 'download_error', 'tolerated_error'] + + table_data = [] + headers = ['Window'] + [act_type.replace('_', ' ').title() for act_type in activity_types] + + for name, seconds in windows.items(): + row = [name] + for act_type in activity_types: + count = manager.get_proxy_activity_rate(proxy_url, act_type, seconds) + row.append(count) + table_data.append(row) + + if table_data: + print(tabulate(table_data, headers=headers, tablefmt='grid'), file=file) + + +def _render_profile_group_summary_table(manager, all_profiles, profile_groups_config, file=sys.stdout): + """Renders a summary table for profile groups.""" + if not profile_groups_config: + return + + print("\nProfile Group Status:", file=file) + table_data = [] + all_profiles_map = {p['name']: p for p in all_profiles} + + for group in profile_groups_config: + group_name = group.get('name', 'N/A') + profiles_in_group = group.get('profiles_in_group', []) + + active_profiles = [ + p_name for p_name in profiles_in_group + if all_profiles_map.get(p_name, {}).get('state') in [manager.STATE_ACTIVE, manager.STATE_LOCKED] + ] + + active_profiles_str = ', '.join(active_profiles) or "None" + + max_active = group.get('max_active_profiles', 1) + policy_str = f"{len(active_profiles)}/{max_active} Active" + + rotate_after = group.get('rotate_after_requests') + rotation_rule_str = f"After {rotate_after} reqs" if rotate_after else "N/A" + + reqs_left_str = "N/A" + if rotate_after and rotate_after > 0 and active_profiles: + # Show countdown for the first active profile + active_profile_name = active_profiles[0] + p = all_profiles_map.get(active_profile_name) + if p: + total_reqs = ( + p.get('success_count', 0) + p.get('failure_count', 0) + + p.get('tolerated_error_count', 0) + + p.get('download_count', 0) + p.get('download_error_count', 0) + ) + remaining_reqs = rotate_after - total_reqs + reqs_left_str = str(max(0, int(remaining_reqs))) + + table_data.append([ + group_name, + active_profiles_str, + policy_str, + rotation_rule_str, + reqs_left_str + ]) + + headers = ['Group Name', 'Active Profile(s)', 'Policy', 'Rotation Rule', 'Requests Left ↓'] + print(tabulate(table_data, headers=headers, tablefmt='grid'), file=file) + + +def _render_profile_details_table(manager, args, simulation_type, profile_groups_config, file=sys.stdout): + """Renders the detailed profile list table for a given manager.""" + if not manager: + print("Manager not configured.", file=file) + return + + profiles = manager.list_profiles(args.state, args.proxy) + if not profiles: + print("No profiles found matching the criteria.", file=file) + return + + table_data = [] + is_auth_sim = 'Auth' in simulation_type + + for p in profiles: + rest_until_str = 'N/A' + last_rest_ts = p.get('last_rest_timestamp', 0) + last_rest_str = format_timestamp(last_rest_ts) + + state_str = p.get('state', 'UNKNOWN') + + if state_str in ['RESTING', 'COOLDOWN']: + rest_until = p.get('rest_until', 0) + if rest_until > 0: + remaining = rest_until - time.time() + if remaining > 0: + rest_until_str = f"in {format_duration(remaining)}" + else: + rest_until_str = "Ending now" + + if last_rest_ts == 0: + last_rest_str = "NOW" + + pending_dl_count = p.get('pending_downloads', 0) + if p.get('rest_reason') == 'waiting_downloads' and pending_dl_count > 0: + state_str += f"\n({pending_dl_count} DLs)" + + countdown_str = 'N/A' + # Find the group this profile belongs to and get its rotation policy + profile_group = next((g for g in profile_groups_config if p['name'] in g.get('profiles_in_group', [])), None) + + rotate_after = 0 + if profile_group: + rotate_after = profile_group.get('rotate_after_requests') + elif args.rest_after_requests and args.rest_after_requests > 0: + rotate_after = args.rest_after_requests + + if rotate_after > 0 and state_str != manager.STATE_COOLDOWN: + total_reqs = ( + p.get('success_count', 0) + p.get('failure_count', 0) + + p.get('tolerated_error_count', 0) + + p.get('download_count', 0) + p.get('download_error_count', 0) + ) + remaining_reqs = rotate_after - total_reqs + countdown_str = str(max(0, int(remaining_reqs))) + + if args.hide_active_state and state_str == 'ACTIVE': + state_str = '' + + row = [ + p.get('name', 'MISSING_NAME'), + p.get('proxy', 'MISSING_PROXY'), + state_str, + format_timestamp(p.get('last_used', 0)), + ] + + if is_auth_sim: + row.extend([ + p.get('success_count', 0), + p.get('failure_count', 0), + p.get('tolerated_error_count', 0), + p.get('global_success_count', 0), + p.get('global_failure_count', 0), + ]) + else: # is_download_sim or unknown + row.extend([ + p.get('download_count', 0), + p.get('download_error_count', 0), + p.get('tolerated_error_count', 0), + p.get('global_download_count', 0), + p.get('global_download_error_count', 0), + ]) + + # Display generic 'reason' field as a fallback for 'rest_reason' + reason_str = p.get('rest_reason') or p.get('reason') or '' + row.extend([ + countdown_str, + rest_until_str, + reason_str, + p.get('ban_reason') or '' + ]) + table_data.append(row) + + headers = ['Name', 'Proxy', 'State', 'Last Used'] + + if is_auth_sim: + headers.extend(['AuthOK', 'AuthFail', 'Skip.Err', 'Tot.AuthOK', 'Tot.AuthFail']) + else: # is_download_sim or unknown + headers.extend(['DataOK', 'DownFail', 'Skip.Err', 'Tot.DataOK', 'Tot.DownFail']) + + headers.extend(['ReqCD ↓', 'RestCD ↓', 'R.Reason', 'B.Reason']) + + # Using `maxcolwidths` to control column width for backward compatibility + # with older versions of the `tabulate` library. This prevents content + # from making columns excessively wide, but does not guarantee a fixed width. + maxwidths = None + if table_data or headers: # Check headers in case table_data is empty + # Transpose table to get columns, including headers. + # This handles empty table_data correctly. + columns = list(zip(*([headers] + table_data))) + # Calculate max width for each column based on its content. + maxwidths = [max(len(str(x)) for x in col) if col else 0 for col in columns] + + # Enforce a minimum width for the reason columns to keep table width stable. + DEFAULT_REASON_WIDTH = 25 + try: + r_reason_idx = headers.index('R.Reason') + b_reason_idx = headers.index('B.Reason') + maxwidths[r_reason_idx] = max(DEFAULT_REASON_WIDTH, maxwidths[r_reason_idx]) + maxwidths[b_reason_idx] = max(DEFAULT_REASON_WIDTH, maxwidths[b_reason_idx]) + except (ValueError, IndexError): + # This should not happen if headers are constructed as expected. + pass + + print(tabulate(table_data, headers=headers, tablefmt='grid', maxcolwidths=maxwidths), file=file) + + +def _render_simulation_view(title, manager, args, file=sys.stdout): + """Helper function to render the list of profiles for a single simulation environment.""" + if not manager: + print(f"\n--- {title} (Environment Not Configured) ---", file=file) + return 0 + + if not tabulate: + print("'tabulate' library is required for table format. Please install it.", file=sys.stderr) + return 1 + + print(f"\n--- {title} ---", file=file) + profiles = manager.list_profiles(args.state, args.proxy) + + if args.format == 'json': + print(json.dumps(profiles, indent=2, default=str), file=file) + return 0 + elif args.format == 'csv': + if profiles: + headers = profiles[0].keys() + print(','.join(headers), file=file) + for p in profiles: + print(','.join(str(p.get(h, '')) for h in headers), file=file) + return 0 + + # --- Table Format with Summaries --- + + if args.show_proxy_activity: + if args.proxy: + _render_proxy_activity_summary(manager, args.proxy, title, file=file) + else: + _render_all_proxies_activity_summary(manager, title, file=file) + + profile_groups_config = _build_profile_groups_config(manager, profiles) + _render_profile_group_summary_table(manager, profiles, profile_groups_config, file=file) + + failed_lock_attempts = manager.get_failed_lock_attempts() + global_stats = manager.get_global_stats() + per_proxy_stats = manager.get_per_proxy_stats() + + unique_proxies = sorted(per_proxy_stats.keys()) + proxy_states = manager.get_proxy_states(unique_proxies) + + # Build global summary + total_reqs = global_stats['total_success'] + global_stats['total_failure'] + success_rate = (global_stats['total_success'] / total_reqs * 100) if total_reqs > 0 else 100 + global_summary_str = ( + f"Total Requests: {total_reqs} | " + f"Success: {global_stats['total_success']} | " + f"Failure: {global_stats['total_failure']} | " + f"Tolerated Error: {global_stats['total_tolerated_error']} | " + f"Downloads: {global_stats['total_downloads']} | " + f"Download Errors: {global_stats.get('total_download_errors', 0)} | " + f"Success Rate: {success_rate:.2f}% | " + f"Failed Lock Attempts: {failed_lock_attempts}" + ) + print("Global Stats:", global_summary_str, file=file) + + # Build per-proxy summary + if per_proxy_stats: + print("\nPer-Proxy Stats:", file=file) + proxy_table_data = [] + for proxy_url in unique_proxies: + stats = per_proxy_stats[proxy_url] + state_info = proxy_states.get(proxy_url, {}) + state = state_info.get('state', 'ACTIVE') + + cooldown_str = 'N/A' + if state == 'RESTING': + rest_until = state_info.get('rest_until', 0) + if rest_until > time.time(): + cooldown_str = f"in {format_duration(rest_until - time.time())}" + else: + cooldown_str = "Ending now" + + proxy_total_auth = stats['success'] + stats['failure'] + proxy_total_downloads = stats['downloads'] + stats['download_errors'] + proxy_total_reqs = proxy_total_auth + proxy_total_downloads + proxy_success_rate = (stats['success'] / proxy_total_auth * 100) if proxy_total_auth > 0 else 100 + + proxy_table_data.append([ + proxy_url, + state, + cooldown_str, + stats['profiles'], + proxy_total_reqs, + stats['success'], + stats['failure'], + stats['tolerated_error'], + stats['downloads'], + stats['download_errors'], + f"{proxy_success_rate:.1f}%" + ]) + proxy_headers = ['Proxy', 'State', 'Cooldown', 'Profiles', 'Total Reqs', 'AuthOK', 'AuthFail', 'Skip.Err', 'DataOK', 'DownFail', 'OK %'] + print(tabulate(proxy_table_data, headers=proxy_headers, tablefmt='grid'), file=file) + + print("\nProfile Details:", file=file) + _render_profile_details_table(manager, args, title, profile_groups_config, file=file) + return 0 + + +def _render_merged_view(auth_manager, download_manager, args, file=sys.stdout): + """Renders a merged, unified view for both auth and download simulations.""" + # --- 1. Fetch ALL data first to prevent delays during rendering --- + auth_stats = auth_manager.get_global_stats() + auth_failed_locks = auth_manager.get_failed_lock_attempts() + dl_stats = download_manager.get_global_stats() + dl_failed_locks = download_manager.get_failed_lock_attempts() + + auth_proxy_stats = auth_manager.get_per_proxy_stats() + dl_proxy_stats = download_manager.get_per_proxy_stats() + all_proxies = sorted(list(set(auth_proxy_stats.keys()) | set(dl_proxy_stats.keys()))) + + auth_proxy_states, dl_proxy_states = {}, {} + if all_proxies: + auth_proxy_states = auth_manager.get_proxy_states(all_proxies) + dl_proxy_states = download_manager.get_proxy_states(all_proxies) + + auth_profiles = auth_manager.list_profiles(args.state, args.proxy) + auth_groups_config = _build_profile_groups_config(auth_manager, auth_profiles) + + dl_profiles = download_manager.list_profiles(args.state, args.proxy) + dl_groups_config = _build_profile_groups_config(download_manager, dl_profiles) + + # --- 2. Prepare all display data using fetched information --- + total_reqs = auth_stats['total_success'] + auth_stats['total_failure'] + success_rate = (auth_stats['total_success'] / total_reqs * 100) if total_reqs > 0 else 100 + + total_dls = dl_stats['total_downloads'] + dl_stats['total_download_errors'] + dl_success_rate = (dl_stats['total_downloads'] / total_dls * 100) if total_dls > 0 else 100 + + global_summary_str = ( + f"Auth: {total_reqs} reqs ({auth_stats['total_success']} OK, {auth_stats['total_failure']} Fail, {auth_stats['total_tolerated_error']} Tol.Err) | " + f"OK Rate: {success_rate:.2f}% | " + f"Failed Locks: {auth_failed_locks} || " + f"Download: {total_dls} attempts ({dl_stats['total_downloads']} OK, {dl_stats['total_download_errors']} Fail) | " + f"OK Rate: {dl_success_rate:.2f}% | " + f"Failed Locks: {dl_failed_locks}" + ) + + proxy_table_data = [] + if all_proxies: + for proxy in all_proxies: + astats = auth_proxy_stats.get(proxy, {}) + dstats = dl_proxy_stats.get(proxy, {}) + astate = auth_proxy_states.get(proxy, {}) + dstate = dl_proxy_states.get(proxy, {}) + + state_str = f"{astate.get('state', 'N/A')} / {dstate.get('state', 'N/A')}" + + proxy_table_data.append([ + proxy, + state_str, + astats.get('profiles', 0), + dstats.get('profiles', 0), + astats.get('success', 0), + astats.get('failure', 0), + astats.get('tolerated_error', 0), + dstats.get('downloads', 0), + dstats.get('download_errors', 0), + dstats.get('tolerated_error', 0) + ]) + + # --- 3. Render everything to the buffer at once --- + print("--- Global Simulation Stats ---", file=file) + print(global_summary_str, file=file) + + if args.show_proxy_activity: + if args.proxy: + _render_proxy_activity_summary(auth_manager, args.proxy, "Auth", file=file) + _render_proxy_activity_summary(download_manager, args.proxy, "Download", file=file) + else: + # In merged view, it makes sense to show both summaries if requested. + _render_all_proxies_activity_summary(auth_manager, "Auth", file=file) + _render_all_proxies_activity_summary(download_manager, "Download", file=file) + + if all_proxies: + print("\n--- Per-Proxy Stats (Merged) ---", file=file) + proxy_headers = ['Proxy', 'State (A/D)', 'Profiles (A)', 'Profiles (D)', 'AuthOK', 'AuthFail', 'Skip.Err(A)', 'DataOK', 'DownFail', 'Skip.Err(D)'] + print(tabulate(proxy_table_data, headers=proxy_headers, tablefmt='grid'), file=file) + + print(f"\n--- Auth Simulation Profile Details ({args.auth_env}) ---", file=file) + _render_profile_group_summary_table(auth_manager, auth_profiles, auth_groups_config, file=file) + _render_profile_details_table(auth_manager, args, "Auth", auth_groups_config, file=file) + + print(f"\n--- Download Simulation Profile Details ({args.download_env}) ---", file=file) + _render_profile_group_summary_table(download_manager, dl_profiles, dl_groups_config, file=file) + _render_profile_details_table(download_manager, args, "Download", dl_groups_config, file=file) + + return 0 + + +def _print_profile_list(manager, args, title="Profile Status"): + """Helper function to print the list of profiles in the desired format.""" + return _render_simulation_view(title, manager, args, file=sys.stdout) + + +def main_profile_manager(args): + """Main dispatcher for 'profile' command.""" + if load_dotenv: + env_file = args.env_file + if not env_file and args.env and '.env' in args.env and os.path.exists(args.env): + print(f"WARNING: --env should be an environment name (e.g., 'dev'), not a file path. Treating '{args.env}' as --env-file. The environment name will default to 'dev'.", file=sys.stderr) + env_file = args.env + args.env = 'dev' + + was_loaded = load_dotenv(env_file) + if was_loaded: + print(f"Loaded environment variables from {env_file or '.env file'}", file=sys.stderr) + elif args.env_file: + print(f"ERROR: The specified --env-file was not found: {args.env_file}", file=sys.stderr) + return 1 + + if args.redis_host is None: + args.redis_host = os.getenv('REDIS_HOST', os.getenv('MASTER_HOST_IP', 'localhost')) + if args.redis_port is None: + args.redis_port = int(os.getenv('REDIS_PORT', 6379)) + if args.redis_password is None: + args.redis_password = os.getenv('REDIS_PASSWORD') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + if args.key_prefix: + key_prefix = args.key_prefix + elif args.legacy: + key_prefix = 'profile_mgmt_' + else: + key_prefix = f"{args.env}_profile_mgmt_" + + manager = ProfileManager( + redis_host=args.redis_host, + redis_port=args.redis_port, + redis_password=args.redis_password, + key_prefix=key_prefix + ) + + if args.profile_command == 'create': + success = manager.create_profile(args.name, args.proxy, args.state) + return 0 if success else 1 + + elif args.profile_command == 'list': + is_dual_mode = args.auth_env and args.download_env + + def _create_manager(env_name, is_for_dual_mode): + if not env_name: + return None + + # For dual mode, we ignore --legacy and --key-prefix from CLI, and derive from env name. + # This is opinionated but makes dual-mode behavior predictable. + if is_for_dual_mode: + key_prefix = f"{env_name}_profile_mgmt_" + else: + # Single mode respects all CLI flags + if args.key_prefix: + key_prefix = args.key_prefix + elif args.legacy: + key_prefix = 'profile_mgmt_' + else: + key_prefix = f"{env_name}_profile_mgmt_" + + return ProfileManager( + redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, key_prefix=key_prefix + ) + + if not args.live: + if is_dual_mode and not args.separate_views: + auth_manager = _create_manager(args.auth_env, is_for_dual_mode=True) + download_manager = _create_manager(args.download_env, is_for_dual_mode=True) + return _render_merged_view(auth_manager, download_manager, args) + elif is_dual_mode and args.separate_views: + auth_manager = _create_manager(args.auth_env, is_for_dual_mode=True) + download_manager = _create_manager(args.download_env, is_for_dual_mode=True) + _render_simulation_view(f"Auth Simulation ({args.auth_env})", auth_manager, args) + _render_simulation_view(f"Download Simulation ({args.download_env})", download_manager, args) + return 0 + else: + # Single view mode + single_env = args.auth_env or args.download_env or args.env + manager = _create_manager(single_env, is_for_dual_mode=False) + # Determine the title for correct table headers + title = f"Profile Status ({single_env})" + if args.auth_env: + title = f"Auth Simulation ({args.auth_env})" + elif args.download_env: + title = f"Download Simulation ({args.download_env})" + return _print_profile_list(manager, args, title=title) + + # --- Live Mode --- + pm_logger = logging.getLogger(__name__) + original_log_level = pm_logger.level + try: + if args.no_blink: + sys.stdout.write('\033[?25l') # Hide cursor + sys.stdout.flush() + + # Register signal handlers for graceful shutdown in live mode + signal.signal(signal.SIGINT, handle_shutdown) + signal.signal(signal.SIGTERM, handle_shutdown) + + while not shutdown_event.is_set(): + pm_logger.setLevel(logging.WARNING) # Suppress connection logs for cleaner UI + start_time = time.time() + + output_buffer = io.StringIO() + print(f"--- Profile Status (auto-refreshing every {args.interval_seconds}s, Ctrl+C to exit) | Last updated: {datetime.now().strftime('%H:%M:%S')} ---", file=output_buffer) + + if is_dual_mode and not args.separate_views: + auth_manager = _create_manager(args.auth_env, is_for_dual_mode=True) + download_manager = _create_manager(args.download_env, is_for_dual_mode=True) + _render_merged_view(auth_manager, download_manager, args, file=output_buffer) + elif is_dual_mode and args.separate_views: + auth_manager = _create_manager(args.auth_env, is_for_dual_mode=True) + download_manager = _create_manager(args.download_env, is_for_dual_mode=True) + _render_simulation_view(f"Auth Simulation ({args.auth_env})", auth_manager, args, file=output_buffer) + _render_simulation_view(f"Download Simulation ({args.download_env})", download_manager, args, file=output_buffer) + else: + # Single view mode + single_env = args.auth_env or args.download_env or args.env + manager = _create_manager(single_env, is_for_dual_mode=False) + # Determine the title for correct table headers + title = f"Profile Status ({single_env})" + if args.auth_env: + title = f"Auth Simulation ({args.auth_env})" + elif args.download_env: + title = f"Download Simulation ({args.download_env})" + _render_simulation_view(title, manager, args, file=output_buffer) + + pm_logger.setLevel(original_log_level) # Restore log level + fetch_and_render_duration = time.time() - start_time + + if args.no_blink: + sys.stdout.write('\033[2J\033[H') # Clear screen, move to top + else: + os.system('cls' if os.name == 'nt' else 'clear') + + sys.stdout.write(output_buffer.getvalue()) + sys.stdout.flush() + + # --- Adaptive Countdown --- + remaining_sleep = args.interval_seconds - fetch_and_render_duration + + if remaining_sleep > 0: + end_time = time.time() + remaining_sleep + while time.time() < end_time and not shutdown_event.is_set(): + time_left = end_time - time.time() + sys.stdout.write(f"\rRefreshing in {int(time_left)}s... (fetch took {fetch_and_render_duration:.2f}s) ") + sys.stdout.flush() + time.sleep(min(1, time_left if time_left > 0 else 1)) + elif not shutdown_event.is_set(): + sys.stdout.write(f"\rRefreshing now... (fetch took {fetch_and_render_duration:.2f}s, behind by {-remaining_sleep:.2f}s) ") + sys.stdout.flush() + time.sleep(0.5) # Brief pause to make message readable + + sys.stdout.write("\r" + " " * 80 + "\r") # Clear line + sys.stdout.flush() + + except KeyboardInterrupt: + # This can be triggered by Ctrl+C during a time.sleep(). + # The signal handler will have already set the shutdown_event and printed a message. + # This block is a fallback. + if not shutdown_event.is_set(): + print("\nKeyboardInterrupt received. Stopping live view...", file=sys.stderr) + shutdown_event.set() # Ensure event is set if handler didn't run + return 0 + finally: + pm_logger.setLevel(original_log_level) + if args.live and args.no_blink: + sys.stdout.write('\033[?25h') # Restore cursor + sys.stdout.flush() + + elif args.profile_command == 'set-proxy-state': + success = manager.set_proxy_state(args.proxy_url, args.state, args.duration_minutes) + return 0 if success else 1 + + elif args.profile_command == 'get': + profile = manager.get_profile(args.name) + if not profile: + print(f"Profile '{args.name}' not found") + return 1 + + print(f"Profile: {profile['name']}") + print(f"Proxy: {profile['proxy']}") + print(f"State: {profile['state']}") + print(f"Created: {format_timestamp(profile['created_at'])}") + print(f"Last Used: {format_timestamp(profile['last_used'])}") + print(f"Success Count: {profile['success_count']}") + print(f"Failure Count: {profile['failure_count']}") + + if profile.get('rest_until', 0) > 0: + remaining = profile['rest_until'] - time.time() + if remaining > 0: + print(f"Resting for: {format_duration(remaining)} more") + else: + print(f"Rest period ended: {format_timestamp(profile['rest_until'])}") + + if profile.get('ban_reason'): + print(f"Ban Reason: {profile['ban_reason']}") + + if profile.get('lock_timestamp', 0) > 0: + print(f"Locked since: {format_timestamp(profile['lock_timestamp'])}") + print(f"Lock Owner: {profile['lock_owner']}") + + if profile.get('notes'): + print(f"Notes: {profile['notes']}") + return 0 + + elif args.profile_command == 'update-state': + success = manager.update_profile_state(args.name, args.state, args.reason or '') + return 0 if success else 1 + + elif args.profile_command == 'update-field': + success = manager.update_profile_field(args.name, args.field, args.value) + return 0 if success else 1 + + elif args.profile_command == 'pause': + success = manager.update_profile_state(args.name, manager.STATE_PAUSED, 'Manual pause') + return 0 if success else 1 + + elif args.profile_command == 'activate': + success = manager.update_profile_state(args.name, manager.STATE_ACTIVE, 'Manual activation') + return 0 if success else 1 + + elif args.profile_command == 'ban': + success = manager.update_profile_state(args.name, manager.STATE_BANNED, args.reason) + return 0 if success else 1 + + elif args.profile_command == 'unban': + # First activate, then reset session counters. The ban reason is cleared by update_profile_state. + success = manager.update_profile_state(args.name, manager.STATE_ACTIVE, 'Manual unban') + if success: + manager.reset_profile_counters(args.name) + return 0 if success else 1 + + elif args.profile_command == 'delete': + if not args.confirm: + print("Error: --confirm flag is required for deletion", file=sys.stderr) + return 1 + success = manager.delete_profile(args.name) + return 0 if success else 1 + + elif args.profile_command == 'delete-all': + if not args.confirm: + print("Error: --confirm flag is required for this destructive action.", file=sys.stderr) + return 1 + deleted_count = manager.delete_all_data() + print(f"Deleted {deleted_count} key(s) with prefix '{manager.key_prefix}'.") + return 0 + + elif args.profile_command == 'reset-global-counters': + manager.reset_global_counters() + return 0 + + elif args.profile_command == 'reset-counters': + profiles_to_reset = [] + if args.profile_name: + profile = manager.get_profile(args.profile_name) + if profile: + profiles_to_reset.append(profile) + elif args.proxy_url: + profiles_to_reset = manager.list_profiles(proxy_filter=args.proxy_url) + elif args.all_profiles: + profiles_to_reset = manager.list_profiles() + + if not profiles_to_reset: + print("No profiles found to reset.", file=sys.stderr) + return 1 + + print(f"Found {len(profiles_to_reset)} profile(s) to reset. This action is not reversible.") + confirm = input("Continue? (y/N): ") + if confirm.lower() != 'y': + print("Aborted.") + return 1 + + success_count = 0 + for profile in profiles_to_reset: + if manager.reset_profile_counters(profile['name']): + success_count += 1 + + print(f"Successfully reset session counters for {success_count} profile(s).") + return 0 + + elif args.profile_command == 'record-activity': + success = manager.record_activity(args.name, args.type, args.timestamp) + return 0 if success else 1 + + elif args.profile_command == 'get-rate': + rate = manager.get_activity_rate(args.name, args.type, args.window) + print(f"{args.type.capitalize()} rate for '{args.name}' over {args.window}s: {rate}") + return 0 + + return 1 # Should not be reached diff --git a/ytops_client/profile_setup_tool.py b/ytops_client/profile_setup_tool.py new file mode 100644 index 0000000..be3cf3f --- /dev/null +++ b/ytops_client/profile_setup_tool.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" +CLI tool to set up profiles from a YAML policy file. +""" + +import argparse +import json +import logging +import os +import subprocess +import sys +from typing import List + +try: + import yaml +except ImportError: + print("PyYAML is not installed. Please install it with: pip install PyYAML", file=sys.stderr) + yaml = None + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +logger = logging.getLogger(__name__) + +def run_command(cmd: List[str], capture: bool = False): + """Runs a command and raises an exception on failure.""" + logger.debug(f"Running command: {' '.join(cmd)}") + # check=True will raise CalledProcessError on non-zero exit codes + result = subprocess.run( + cmd, + capture_output=capture, + text=True, + check=True + ) + return result + +def add_setup_profiles_parser(subparsers): + """Adds the parser for the 'setup-profiles' command.""" + parser = subparsers.add_parser( + 'setup-profiles', + description="Set up profiles for a simulation or test run based on a policy file.", + formatter_class=argparse.RawTextHelpFormatter, + help="Set up profiles from a policy file." + ) + parser.add_argument('--policy', '--policy-file', dest='policy_file', required=True, help="Path to the YAML profile setup policy file.") + parser.add_argument('--env', help="Override the environment name from the policy file. For multi-setup files, this will override the 'env' for ALL setups being run.") + parser.add_argument('--env-file', help="Override the env_file setting in the policy.") + parser.add_argument('--auth-only', action='store_true', help='In a multi-setup policy file, run only the auth_profile_setup.') + parser.add_argument('--download-only', action='store_true', help='In a multi-setup policy file, run only the download_profile_setup.') + parser.add_argument('--redis-host', default=None, help='Redis host. Overrides policy and .env file.') + parser.add_argument('--redis-port', type=int, default=None, help='Redis port. Overrides policy and .env file.') + parser.add_argument('--redis-password', default=None, help='Redis password. Overrides policy and .env file.') + parser.add_argument('--preserve-profiles', action='store_true', help="Do not clean up existing profiles; create only what is missing.") + parser.add_argument('--cleanup-prefix', action='append', help="Prefix of profiles to delete before setup. Can be specified multiple times. Overrides policy-based cleanup.") + parser.add_argument('--cleanup-all', action='store_true', help="(Destructive) Delete ALL data for the environment (profiles, proxies, counters) before setup. Overrides all other cleanup options.") + parser.add_argument('--reset-global-counters', action='store_true', help="Reset global counters like 'failed_lock_attempts'.") + parser.add_argument('--verbose', action='store_true', help="Enable verbose logging.") + return parser + +def _run_setup_for_env(profile_setup: dict, common_args: list, args: argparse.Namespace) -> int: + """Runs the profile setup logic for a given configuration block.""" + if args.cleanup_all: + logger.info("--- (DESTRUCTIVE) Cleaning up all data for the environment via --cleanup-all ---") + try: + cleanup_cmd = ['bin/ytops-client', 'profile', 'delete-all', '--confirm'] + common_args + run_command(cleanup_cmd) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to clean up all data for the environment: {e}") + return 1 # Stop if cleanup fails + + # Disable other cleanup logic + profile_setup['cleanup_before_run'] = False + args.cleanup_prefix = None + + # If --cleanup-prefix is provided, it takes precedence over policy settings + if args.cleanup_prefix: + logger.info("--- Cleaning up profiles based on --cleanup-prefix ---") + for prefix in args.cleanup_prefix: + try: + list_cmd = ['bin/ytops-client', 'profile', 'list', '--format', 'json'] + common_args + result = run_command(list_cmd, capture=True) + profiles_to_delete = [p for p in json.loads(result.stdout) if p['name'].startswith(prefix)] + + if not profiles_to_delete: + logger.info(f"No profiles with prefix '{prefix}' found to delete.") + continue + + logger.info(f"Found {len(profiles_to_delete)} profiles with prefix '{prefix}' to delete.") + for p in profiles_to_delete: + delete_cmd = ['bin/ytops-client', 'profile', 'delete', p['name'], '--confirm'] + common_args + run_command(delete_cmd) + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + logger.warning(f"Could not list or parse existing profiles with prefix '{prefix}' for cleanup. Error: {e}") + + # Disable policy-based cleanup as we've handled it via CLI + profile_setup['cleanup_before_run'] = False + + if args.preserve_profiles: + if profile_setup.get('cleanup_before_run'): + logger.info("--preserve-profiles is set, overriding 'cleanup_before_run: true' from policy.") + profile_setup['cleanup_before_run'] = False + + if args.reset_global_counters: + logger.info("--- Resetting global counters ---") + try: + reset_cmd = ['bin/ytops-client', 'profile', 'reset-global-counters'] + common_args + run_command(reset_cmd) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to reset global counters: {e}") + + if profile_setup.get('cleanup_before_run'): + logger.info("--- Cleaning up old profiles ---") + for pool in profile_setup.get('pools', []): + prefix = pool.get('prefix') + if prefix: + try: + list_cmd = ['bin/ytops-client', 'profile', 'list', '--format', 'json'] + common_args + result = run_command(list_cmd, capture=True) + profiles = [p for p in json.loads(result.stdout) if p['name'].startswith(prefix)] + + if not profiles: + logger.info(f"No profiles with prefix '{prefix}' found to delete.") + continue + + logger.info(f"Found {len(profiles)} profiles with prefix '{prefix}' to delete.") + for p in profiles: + delete_cmd = ['bin/ytops-client', 'profile', 'delete', p['name'], '--confirm'] + common_args + run_command(delete_cmd) + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + logger.warning(f"Could not list existing profiles with prefix '{prefix}' for cleanup. Assuming none exist. Error: {e}") + + existing_profiles = set() + if not profile_setup.get('cleanup_before_run'): + logger.info("--- Checking for existing profiles ---") + try: + list_cmd = ['bin/ytops-client', 'profile', 'list', '--format', 'json'] + common_args + result = run_command(list_cmd, capture=True) + profiles_data = json.loads(result.stdout) + existing_profiles = {p['name'] for p in profiles_data} + logger.info(f"Found {len(existing_profiles)} existing profiles.") + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + logger.error(f"Failed to list existing profiles. Will attempt to create all profiles. Error: {e}") + + logger.info("--- Creating new profiles (if needed) ---") + for pool in profile_setup.get('pools', []): + prefix = pool.get('prefix') + proxy = pool.get('proxy') + count = pool.get('count', 0) + start_in_rest_minutes = pool.get('start_in_rest_for_minutes') + + profiles_in_pool_created = 0 + # Get a list of all profiles that should exist for this pool + profile_names_in_pool = [f"{prefix}_{i}" for i in range(count)] + + for profile_name in profile_names_in_pool: + if profile_name in existing_profiles: + logger.debug(f"Profile '{profile_name}' already exists, preserving.") + continue + + try: + create_cmd = ['bin/ytops-client', 'profile', 'create', profile_name, proxy] + common_args + run_command(create_cmd) + profiles_in_pool_created += 1 + except subprocess.CalledProcessError as e: + logger.error(f"Failed to create profile '{profile_name}': {e}") + + if profiles_in_pool_created > 0: + logger.info(f"Created {profiles_in_pool_created} new profile(s) for pool '{prefix}'.") + elif count > 0: + logger.info(f"No new profiles needed for pool '{prefix}'. All {count} profile(s) already exist.") + + # If requested, put the proxy for this pool into a RESTING state. + # This is done even for existing profiles when --preserve-profiles is used. + if start_in_rest_minutes and proxy: + logger.info(f"Setting proxy '{proxy}' for pool '{prefix}' to start in RESTING state for {start_in_rest_minutes} minutes.") + try: + # The 'set-proxy-state' command takes '--duration-minutes' + set_state_cmd = ['bin/ytops-client', 'profile', 'set-proxy-state', proxy, 'RESTING', + '--duration-minutes', str(start_in_rest_minutes)] + common_args + run_command(set_state_cmd) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to set initial REST state for proxy '{proxy}': {e}") + + return 0 + + +def main_setup_profiles(args): + """Main logic for the 'setup-profiles' command.""" + if not yaml: return 1 + + if args.verbose: + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger.setLevel(logging.DEBUG) + else: + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger.setLevel(logging.INFO) + + try: + with open(args.policy_file, 'r', encoding='utf-8') as f: + policy = yaml.safe_load(f) or {} + except (IOError, yaml.YAMLError) as e: + logger.error(f"Failed to load or parse policy file {args.policy_file}: {e}") + return 1 + + sim_params = policy.get('simulation_parameters', {}) + setups_to_run = [] + + if not args.download_only and 'auth_profile_setup' in policy: + setups_to_run.append(('Auth', policy['auth_profile_setup'])) + + if not args.auth_only and 'download_profile_setup' in policy: + setups_to_run.append(('Download', policy['download_profile_setup'])) + + # Backward compatibility for old single-block format + if not setups_to_run and 'profile_setup' in policy: + legacy_config = policy['profile_setup'] + # Synthesize the env from the global section for the legacy block + legacy_config['env'] = args.env or sim_params.get('env') + setups_to_run.append(('Legacy', legacy_config)) + + if not setups_to_run: + logger.error("No 'auth_profile_setup', 'download_profile_setup', or legacy 'profile_setup' block found in policy file.") + return 1 + + env_file = args.env_file or sim_params.get('env_file') + if load_dotenv: + if load_dotenv(env_file): + print(f"Loaded environment variables from {env_file or '.env file'}", file=sys.stderr) + elif args.env_file and not os.path.exists(args.env_file): + print(f"Error: The specified env_file was not found: {args.env_file}", file=sys.stderr) + return 1 + + base_common_args = [] + if env_file: base_common_args.extend(['--env-file', env_file]) + + redis_host = args.redis_host or os.getenv('REDIS_HOST') or os.getenv('MASTER_HOST_IP') or sim_params.get('redis_host') + if redis_host: base_common_args.extend(['--redis-host', redis_host]) + + redis_port = args.redis_port + if redis_port is None: + redis_port_env = os.getenv('REDIS_PORT') + redis_port = int(redis_port_env) if redis_port_env and redis_port_env.isdigit() else sim_params.get('redis_port') + if redis_port: base_common_args.extend(['--redis-port', str(redis_port)]) + + redis_password = args.redis_password or os.getenv('REDIS_PASSWORD') or sim_params.get('redis_password') + if redis_password: base_common_args.extend(['--redis-password', redis_password]) + + if args.verbose: base_common_args.append('--verbose') + + for setup_name, setup_config in setups_to_run: + logger.info(f"--- Running setup for {setup_name} simulation ---") + + effective_env = args.env or setup_config.get('env') + if not effective_env: + logger.error(f"Could not determine environment for '{setup_name}' setup. Please specify 'env' in the policy block or via --env.") + return 1 + + env_common_args = base_common_args + ['--env', effective_env] + + if _run_setup_for_env(setup_config, env_common_args, args) != 0: + return 1 + + logger.info("\n--- All profile setups complete. ---") + logger.info("You can now run the policy enforcer to manage the profiles:") + logger.info("e.g., bin/ytops-client policy-enforcer --policy-file policies/8_unified_simulation_enforcer.yaml --live") + return 0 diff --git a/ytops_client/request_params_help.py b/ytops_client/request_params_help.py index f9b1af7..7c01fcc 100644 --- a/ytops_client/request_params_help.py +++ b/ytops_client/request_params_help.py @@ -1,50 +1,55 @@ # Using a separate file for this long help message to keep the main script clean. -# It's imported by client tools that use the --request-params-json argument. +# It's imported by client tools that use the --ytdlp-config-json argument. -REQUEST_PARAMS_HELP_STRING = """JSON string with per-request parameters to override server defaults. -Example of a full configuration JSON showing default values (use single quotes to wrap it): -'{ - "_comment": "This JSON object allows overriding server-side defaults for a single request.", - "cookies_file_path": "/path/to/your/cookies.txt", +REQUEST_PARAMS_HELP_STRING = """JSON string or path to a JSON file (prefixed with '@') containing per-request parameters. +This allows overriding server-side defaults and passing a full yt-dlp options dictionary. +If this argument is not provided, the tool will automatically look for and load 'ytdlp.json' +in the current directory if it exists. - "context_reuse_policy": { - "enabled": true, - "max_age_seconds": 86400, - "reuse_visitor_id": true, - "reuse_cookies": true - }, - "_comment_context_reuse_policy": "Controls how the server reuses session context (cookies, visitor ID) from the account's previous successful request.", - "_comment_reuse_visitor_id": "If true, reuses the visitor ID from the last session to maintain a consistent identity to YouTube. This is automatically disabled for TV clients to avoid bot detection.", +The JSON structure is unified: +1. 'ytops': Parameters that control the behavior of the yt-ops-server itself (for 'get-info'). +2. 'ytdlp_params': A dictionary of options passed directly to yt-dlp. This is used by both + 'get-info' (server-side) and 'download py' (client-side). - "ytdlp_params": { - "use_curl_prefetch": false, - "skip_cache": false, - "visitor_id_override_enabled": true, - "webpo_bind_to_visitor_id": true, - "extractor_args": { - "youtubepot-bgutilhttp": { - "base_url": "http://172.17.0.1:4416" - }, - "youtube": { - "pot_trace": "true", - "formats": "duplicate", - "player_js_version": "actual" - } +Example 'ytdlp.json' for getting info.json and for downloads: +{ + // --- YTOPS: Server-Side Controls (for 'get-info') --- + "ytops": { + "assigned_proxy_url": "socks5://your.proxy.com:1080", // Optional: Assign a specific proxy + "force_renew": ["cookies", "visitor_id"], + "session_params": { + "visitor_rotation_threshold": 0, + "prevent_cookie_rotation": false, + "prevent_visitor_rotation": false } }, - "_comment_ytdlp_params": "Parameters passed directly to the yt-dlp wrapper for info.json generation.", - "_comment_webpo_bind_to_visitor_id": "If true (default), binds the PO Token cache to the visitor ID. Set to false for TV clients if caching issues occur, as this is not recommended for them.", - "_comment_visitor_id_override_enabled": "If true (default), the server validates the visitor ID from the token generator and creates a new one if it is invalid. Set to false to force using the provided visitor ID without validation, which is useful for debugging.", - "_comment_extractor_args": "Directly override yt-dlp extractor arguments. To use BGUtils in script mode, replace 'youtubepot-bgutilhttp' with 'youtubepot-bgutilscript'. The script path is '/opt/bgutil-ytdlp-pot-provider-server/build/generate_once.js'. To disable any explicit provider (like '--bgutils-mode none' on the server), remove both 'youtubepot-bgutilhttp' and 'youtubepot-bgutilscript' keys.", - "session_params": { - "lang": "en-US", - "timeZone": "UTC", - "location": "US", - "deviceCategory": "MOBILE", - "user_agent": "Mozilla/5.0 (iPad; CPU OS 16_7_10 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.6 Mobile/15E148 Safari/604.1,gzip(gfe)", - "visitor_rotation_threshold": 250 - }, - "_comment_session_params": "Parameters for the token generation session. `visitor_rotation_threshold` overrides the server's default request limit before a profile's visitor ID is rotated. Set to 0 to disable rotation.", - "_comment_lang_and_tz": "`lang` sets the 'hl' parameter for YouTube's API, affecting metadata language. `timeZone` is intended to set the timezone for requests, but is not fully supported by yt-dlp yet." -}'""" + // --- YTDLP: Parameters for yt-dlp (for 'get-info' and 'download py') --- + "ytdlp_params": { + "verbose": true, + "socket_timeout": 60, + "http_headers": { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36" + }, + "extractor_args": { + "youtube": { + "player_client": ["tv_simply"], + "skip": ["translated_subs", "hls"], + "pot_trace": ["true"], + "jsc_trace": ["true"], + "formats": ["duplicate"], + "lang": ["en-US"], + "timeZone": ["UTC"] + }, + "youtubepot-bgutilhttp": { + "base_url": ["http://172.17.0.1:4416"] + } + }, + "format_sort": ["res", "ext:mp4:m4a"], + "remuxvideo": "mp4" + } +} + +You can also pass a minimal JSON string directly on the command line: +'{"ytops": {"force_renew": ["all"]}, "ytdlp_params": {"verbose": true}}' +""" diff --git a/ytops_client/requirements.txt b/ytops_client/requirements.txt new file mode 100644 index 0000000..fdeedc4 --- /dev/null +++ b/ytops_client/requirements.txt @@ -0,0 +1,42 @@ +# Client-side dependencies for yt-ops-client tools. +# This file is separate from the root requirements.txt to avoid +# installing server-side dependencies on client-only machines. + +# For Thrift communication with the yt-ops-server +aiothrift + +# For the 'download aria-rpc' tool +aria2p + +# For reading .env files for configuration +python-dotenv==1.0.1 + +# For SOCKS proxy support in client tools +PySocks + +# For YAML policy files used by stress-policy, simulation, etc. +PyYAML + +# For connecting to Redis for profile management +redis + +# Dependency for aria2p +requests==2.32.5 + +# For 'manage' and 'profile' tools to display tables +tabulate + +# For yt-dlp integration in 'download py', 'list-formats', etc. +yt-dlp + +# --- Pinned yt-dlp dependencies --- +# These are pinned to match versions known to work with the server. +# This helps ensure consistent behavior. +brotli==1.1.0 +certifi==2025.10.05 +curl-cffi==0.13.0 +mutagen==1.47.0 +pycryptodomex==3.23.0 +secretstorage==3.4.0 +urllib3==2.5.0 +websockets==15.0.1 diff --git a/ytops_client/simulation_tool.py b/ytops_client/simulation_tool.py new file mode 100644 index 0000000..dd7901b --- /dev/null +++ b/ytops_client/simulation_tool.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +CLI tool to orchestrate multi-stage profile simulations. +""" +import argparse +import logging +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from types import SimpleNamespace + +try: + import yaml +except ImportError: + print("PyYAML is not installed. Please install it with: pip install PyYAML", file=sys.stderr) + yaml = None + +# Import the main functions from the tools we are wrapping +from .profile_setup_tool import main_setup_profiles +from .stress_policy_tool import main_stress_policy + +logger = logging.getLogger(__name__) + +# Define default policy paths relative to the project root +PROJECT_ROOT = Path(__file__).resolve().parent.parent +POLICY_DIR = PROJECT_ROOT / 'policies' +POLICY_FILE_SETUP = str(POLICY_DIR / '6_simulation_policy.yaml') +POLICY_FILE_AUTH = str(POLICY_DIR / '7_continuous_auth.yaml') +POLICY_FILE_DOWNLOAD = str(POLICY_DIR / '8_continuous_download.yaml') + +def add_simulation_parser(subparsers): + """Adds the parser for the 'simulation' command.""" + parser = subparsers.add_parser( + 'simulation', + description="Run multi-stage profile simulations (setup, auth, download). This provides a unified interface for the simulation workflow.", + formatter_class=argparse.RawTextHelpFormatter, + help="Run multi-stage profile simulations." + ) + + # Common arguments for all simulation subcommands + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--env-file', default=None, help="Path to a .env file to load. Overrides setting from policy file.") + common_parser.add_argument('--redis-host', default=None, help='Redis host. Overrides policy and .env file.') + common_parser.add_argument('--redis-port', type=int, default=None, help='Redis port. Overrides policy and .env file.') + common_parser.add_argument('--redis-password', default=None, help='Redis password. Overrides policy and .env file.') + common_parser.add_argument('--env', default='sim', help="Environment name for Redis key prefix. Default: 'sim'.") + common_parser.add_argument('--expire-time-shift-minutes', type=int, default=None, help="Consider URLs expiring in N minutes as expired. Overrides policy.") + common_parser.add_argument('--verbose', action='store_true', help="Enable verbose logging.") + + sim_subparsers = parser.add_subparsers(dest='simulation_command', help='Simulation stage to run', required=True) + + # --- Setup --- + setup_parser = sim_subparsers.add_parser('setup', help='Set up profiles for a simulation.', parents=[common_parser]) + setup_parser.add_argument('--policy-file', dest='policy', default=POLICY_FILE_SETUP, help=f'Path to the setup policy YAML file. Default: {POLICY_FILE_SETUP}') + setup_parser.add_argument('--preserve-profiles', action='store_true', help="Do not clean up existing profiles.") + setup_parser.add_argument('--reset-global-counters', action='store_true', help="Reset global counters like 'failed_lock_attempts'.") + + # --- Auth --- + auth_parser = sim_subparsers.add_parser('auth', help='Run the authentication (get-info) part of the simulation.', parents=[common_parser]) + auth_parser.add_argument('--policy-file', dest='policy', default=POLICY_FILE_AUTH, help=f'Path to the auth simulation policy file. Default: {POLICY_FILE_AUTH}') + auth_parser.add_argument('--set', action='append', default=[], help="Override a policy setting using 'key.subkey=value' format.") + + # --- Download --- + download_parser = sim_subparsers.add_parser('download', help='Run the download part of the simulation.', parents=[common_parser]) + download_parser.add_argument('--policy-file', dest='policy', default=POLICY_FILE_DOWNLOAD, help=f'Path to the download simulation policy file. Default: {POLICY_FILE_DOWNLOAD}') + download_parser.add_argument('--set', action='append', default=[], help="Override a policy setting using 'key.subkey=value' format.") + +def main_simulation(args): + """Main dispatcher for 'simulation' command.""" + # --- Load policy to get simulation parameters --- + policy = {} + # The 'policy' attribute is guaranteed to exist by the arg parser for all subcommands + if not yaml: + logger.error("Cannot load policy file because PyYAML is not installed.") + return 1 + try: + with open(args.policy, 'r') as f: + # We only need the first document if it's a multi-policy file + policy = yaml.safe_load(f) or {} + except (IOError, yaml.YAMLError) as e: + logger.error(f"Failed to load or parse policy file {args.policy}: {e}") + return 1 + + sim_params = policy.get('simulation_parameters', {}) + effective_env_file = args.env_file or sim_params.get('env_file') + + if args.simulation_command == 'setup': + # Create an args object that main_setup_profiles expects + setup_args = SimpleNamespace( + policy_file=args.policy, + env_file=effective_env_file, + preserve_profiles=args.preserve_profiles, + reset_global_counters=args.reset_global_counters, + verbose=args.verbose, + redis_host=args.redis_host, + redis_port=args.redis_port, + redis_password=args.redis_password + ) + return main_setup_profiles(setup_args) + + elif args.simulation_command == 'auth': + # This command runs the stress tool in auth (fetch_only) mode. + # It is expected that the policy-enforcer is run as a separate process. + stress_args = SimpleNamespace( + policy=args.policy, policy_name=None, list_policies=False, show_overrides=False, + set=args.set, profile_prefix=None, start_from_url_index=None, auto_merge_fragments=None, + remove_fragments_after_merge=None, fragments_dir=None, remote_dir=None, cleanup=None, + verbose=args.verbose, dry_run=False, disable_log_writing=False, + # Redis connection args + env_file=effective_env_file, redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, env=args.env, key_prefix=None, + expire_time_shift_minutes=args.expire_time_shift_minutes + ) + + logger.info("\n--- Starting Auth Simulation (stress-policy) ---") + return main_stress_policy(stress_args) + + elif args.simulation_command == 'download': + # This is simpler, just runs the stress tool in download mode. + stress_args = SimpleNamespace( + policy=args.policy, policy_name=None, list_policies=False, show_overrides=False, + set=args.set, profile_prefix=None, start_from_url_index=None, auto_merge_fragments=None, + remove_fragments_after_merge=None, fragments_dir=None, remote_dir=None, cleanup=None, + verbose=args.verbose, dry_run=False, disable_log_writing=False, + # Redis connection args + env_file=effective_env_file, redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, env=args.env, key_prefix=None, + expire_time_shift_minutes=args.expire_time_shift_minutes + ) + logger.info("\n--- Starting Download Simulation (stress-policy) ---") + return main_stress_policy(stress_args) + + return 1 # Should not be reached diff --git a/ytops_client/stress_formats_tool.py b/ytops_client/stress_formats_tool.py index f45dbee..a6a8603 100644 --- a/ytops_client/stress_formats_tool.py +++ b/ytops_client/stress_formats_tool.py @@ -21,52 +21,32 @@ from datetime import datetime, timezone from pathlib import Path from urllib.parse import urlparse, parse_qs +from .stress_policy import utils as sp_utils + # Configure logging logger = logging.getLogger('stress_formats_tool') -def get_video_id(url: str) -> str: - """Extracts a YouTube video ID from a URL.""" - # For URLs like https://www.youtube.com/watch?v=VIDEO_ID - match = re.search(r"v=([0-9A-Za-z_-]{11})", url) - if match: - return match.group(1) - # For URLs like https://youtu.be/VIDEO_ID - match = re.search(r"youtu\.be\/([0-9A-Za-z_-]{11})", url) - if match: - return match.group(1) - # For plain video IDs - if re.fullmatch(r'[0-9A-Za-z_-]{11}', url): - return url - return "unknown_video_id" - - -def get_display_name(path_or_url): - """Returns a clean name for logging, either a filename or a video ID.""" - if isinstance(path_or_url, Path): - return path_or_url.name - - path_str = str(path_or_url) - video_id = get_video_id(path_str) - if video_id != "unknown_video_id": - return video_id - - # Fallback for file paths as strings or weird URLs - return Path(path_str).name - - -def format_size(b): - """Format size in bytes to human-readable string.""" - if b is None: - return 'N/A' - if b < 1024: - return f"{b}B" - elif b < 1024**2: - return f"{b/1024:.2f}KiB" - elif b < 1024**3: - return f"{b/1024**2:.2f}MiB" - else: - return f"{b/1024**3:.2f}GiB" +def run_command(cmd, input_data=None): + """Runs a command, captures its output, and returns status.""" + logger.debug(f"Running command: {' '.join(cmd)}") + try: + process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE if input_data else None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding='utf-8' + ) + stdout, stderr = process.communicate(input=input_data) + return process.returncode, stdout, stderr + except FileNotFoundError: + logger.error(f"Command not found: {cmd[0]}. Make sure it's in your PATH.") + return -1, "", f"Command not found: {cmd[0]}" + except Exception as e: + logger.error(f"An error occurred while running command: {' '.join(cmd)}. Error: {e}") + return -1, "", str(e) class StatsTracker: @@ -151,7 +131,7 @@ class StatsTracker: # Download volume stats total_bytes = sum(e.get('downloaded_bytes', 0) for e in download_events if e['success']) if total_bytes > 0: - logger.info(f"Total data downloaded: {format_size(total_bytes)}") + logger.info(f"Total data downloaded: {sp_utils.format_size(total_bytes)}") if duration > 1: bytes_per_second = total_bytes / duration gb_per_hour = (bytes_per_second * 3600) / (1024**3) @@ -299,7 +279,7 @@ def run_download_worker(info_json_path, info_json_content, format_to_download, a """ # 1. Attempt download download_cmd = [ - sys.executable, '-m', 'ytops_client.cli', 'download', + sys.executable, '-m', 'ytops_client.cli', 'download', 'py', '-f', format_to_download ] if args.format_download_args: @@ -311,7 +291,7 @@ def run_download_worker(info_json_path, info_json_content, format_to_download, a # multiple items, assume they are already split by shell download_cmd.extend(args.format_download_args) - display_name = get_display_name(info_json_path) + display_name = sp_utils.get_display_name(info_json_path) logger.info(f"[{display_name} @ {format_to_download}] Kicking off download process...") retcode, stdout, stderr = run_command(download_cmd, input_data=info_json_content) @@ -385,7 +365,7 @@ def process_info_json_cycle(path, content, args, stats): """ results = [] should_stop_file = False - display_name = get_display_name(path) + display_name = sp_utils.get_display_name(path) # Determine formats to test based on the info.json content try: @@ -629,7 +609,7 @@ def main_stress_formats(args): try: # shlex.split handles quoted arguments in the template - video_id = get_video_id(url) + video_id = sp_utils.get_video_id(url) gen_cmd = [] template_args = shlex.split(gen_cmd_template) @@ -721,7 +701,7 @@ def main_stress_formats(args): for future in done: identifier = future_to_identifier[future] - identifier_name = get_display_name(identifier) + identifier_name = sp_utils.get_display_name(identifier) try: results = future.result() # Check if any result from this file triggers a global stop diff --git a/ytops_client/stress_policy/__init__.py b/ytops_client/stress_policy/__init__.py new file mode 100644 index 0000000..1c622ff --- /dev/null +++ b/ytops_client/stress_policy/__init__.py @@ -0,0 +1 @@ +# This package contains modules for the stress policy tool diff --git a/ytops_client/stress_policy/arg_parser.py b/ytops_client/stress_policy/arg_parser.py new file mode 100644 index 0000000..fbd7159 --- /dev/null +++ b/ytops_client/stress_policy/arg_parser.py @@ -0,0 +1,216 @@ +import argparse + +def add_stress_policy_parser(subparsers): + """Add the parser for the 'stress-policy' command.""" + parser = subparsers.add_parser( + 'stress-policy', + description="The primary, policy-driven stress-testing orchestrator.\nIt runs complex, multi-stage stress tests based on a YAML policy file.\nUse '--list-policies' to see available pre-configured scenarios.\n\nModes supported:\n- full_stack: Generate info.json and then download from it.\n- fetch_only: Only generate info.json files.\n- download_only: Only download from existing info.json files.", + formatter_class=argparse.RawTextHelpFormatter, + help='Run advanced, policy-driven stress tests (recommended).', + epilog=""" +Examples: + +1. Fetch info.jsons for a TV client with a single profile and a rate limit: + ytops-client stress-policy --policy policies/1_fetch_only_policies.yaml \\ + --policy-name tv_downgraded_single_profile \\ + --set settings.urls_file=my_urls.txt \\ + --set execution_control.run_until.minutes=30 + # This runs a 'fetch_only' test using the 'tv_downgraded' client. It uses a single, + # static profile for all requests and enforces a safety limit of 450 requests per hour. + +2. Fetch info.jsons for an Android client using cookies for authentication: + ytops-client stress-policy --policy policies/1_fetch_only_policies.yaml \\ + --policy-name android_sdkless_with_cookies \\ + --set settings.urls_file=my_urls.txt \\ + --set info_json_generation_policy.request_params.cookies_file_path=/path/to/my_cookies.txt + # This demonstrates an authenticated 'fetch_only' test. It passes the path to a + # Netscape cookie file, which the server will use for the requests. + +3. Download from a folder of info.jsons, grouped by profile, with auto-workers: + ytops-client stress-policy --policy policies/2_download_only_policies.yaml \\ + --policy-name basic_profile_aware_download \\ + --set settings.info_json_dir=/path/to/my/infojsons + # This runs a 'download_only' test. It scans a directory, extracts profile names from + # the filenames (e.g., 'tv_user_1' from '...-VIDEOID-tv_user_1.json'), and groups + # them. 'workers=auto' sets the number of workers to the number of unique profiles found. + +4. Full-stack test with multiple workers and profile rotation: + ytops-client stress-policy --policy policies/3_full_stack_policies.yaml \\ + --policy-name tv_simply_profile_rotation \\ + --set settings.urls_file=my_urls.txt \\ + --set execution_control.workers=4 \\ + --set settings.profile_management.max_requests_per_profile=500 + # This runs a 'full_stack' test with 4 parallel workers. Each worker gets a unique + # profile (e.g., tv_simply_user_0_0, tv_simply_user_1_0, etc.). After a profile is + # used 500 times, it is retired, and a new "generation" is created (e.g., tv_simply_user_0_1). + +5. Full-stack authenticated test with a pool of profiles and corresponding cookie files: + ytops-client stress-policy --policy policies/3_full_stack_policies.yaml \\ + --policy-name mweb_multi_profile_with_cookies \\ + --set settings.urls_file=my_urls.txt \\ + --set settings.profile_management.cookie_files='["/path/c1.txt","/path/c2.txt"]' + # This runs a 'full_stack' test using a pool of profiles (e.g., mweb_user_0, mweb_user_1). + # It uses the 'cookie_files' list to assign a specific cookie file to each profile in the + # pool, enabling multi-account authenticated testing. Note the JSON/YAML list format for the override. + +6. Full-stack test submitting downloads to an aria2c RPC server: + ytops-client stress-policy --policy policies/3_full_stack_policies.yaml \\ + --policy-name tv_simply_profile_rotation_aria2c_rpc \\ + --set settings.urls_file=my_urls.txt \\ + --set download_policy.aria_host=192.168.1.100 \\ + --set download_policy.aria_port=6801 + # This runs a test where downloads are not performed by the worker itself, but are + # sent to a remote aria2c daemon. The policy specifies 'downloader: aria2c_rpc' + # and provides connection details. This is useful for offloading download traffic. + +-------------------------------------------------------------------------------- +Overridable Policy Parameters via --set: + + Key Description + -------------------------------------- ------------------------------------------------ + [settings] + settings.mode Test mode: 'full_stack', 'fetch_only', or 'download_only'. + settings.urls_file Path to file with URLs/video IDs. + settings.info_json_dir Path to directory with existing info.json files. + settings.profile_extraction_regex For 'download_only' stats, a regex to extract profile names from info.json filenames. The first capture group is used as the profile name. E.g., '.*-(.*?).json'. + settings.info_json_dir_sample_percent Randomly sample this %% of files from the directory (for 'once' scan mode). + settings.directory_scan_mode For 'download_only': 'once' (default) or 'continuous' to watch for new files. + settings.mark_processed_files For 'continuous' scan mode: if true, rename processed files to '*..processed' to avoid reprocessing. + settings.max_files_per_cycle For 'continuous' scan mode: max new files to process per cycle. + settings.sleep_if_no_new_files_seconds For 'continuous' scan mode: seconds to sleep if no new files are found (default: 10). + settings.profile_prefix (Legacy) Prefix for profile names (e.g., 'test_user'). + settings.profile_pool (Legacy) Size of the profile pool. + settings.profile_mode Profile strategy. 'per_request' (legacy), 'per_worker' (legacy), or 'per_worker_with_rotation' (requires profile_management). + settings.info_json_script Command to run the info.json generation script (e.g., 'bin/ytops-client get-info'). + settings.save_info_json_dir If set, save all successfully generated info.json files to this directory. + + [settings.profile_management] (New, preferred method for profile control) + profile_management.prefix Prefix for profile names (e.g., 'dyn_user'). + profile_management.suffix Suffix for profile names. Set to 'auto' for a timestamp, or provide a string. + profile_management.initial_pool_size The number of profiles to start with. + profile_management.auto_expand_pool If true, create new profiles when the initial pool is exhausted (all sleeping). + profile_management.max_requests_per_profile Max requests a profile can make before it must 'sleep'. + profile_management.sleep_minutes_on_exhaustion How many minutes a profile 'sleeps' after hitting its request limit. + profile_management.cookie_files A list of paths to cookie files. Used to assign a unique cookie file to each profile in a pool. + + [execution_control] + execution_control.workers Number of parallel worker threads. Set to "auto" to calculate from target_rate. + execution_control.target_rate.requests Target requests for 'auto' workers calculation. + execution_control.target_rate.per_minutes Period in minutes for target_rate. + execution_control.run_until.minutes Stop test after N minutes. Will continuously cycle through sources. + execution_control.run_until.cycles Stop test after N cycles. A cycle is one full pass through all sources. + execution_control.run_until.requests Stop test after N total info.json requests (cumulative across runs). + execution_control.sleep_between_tasks.{min,max}_seconds Min/max sleep time between tasks, per worker. + + [info_json_generation_policy] + info_json_generation_policy.client Client to use (e.g., 'mweb', 'tv_camoufox'). + info_json_generation_policy.auth_host Host for the auth/Thrift service. + info_json_generation_policy.auth_port Port for the auth/Thrift service. + info_json_generation_policy.assigned_proxy_url A specific proxy to use for a request, overriding the server's proxy pool. + info_json_generation_policy.proxy_rename Regex substitution for the assigned proxy URL (e.g., 's/old/new/'). + info_json_generation_policy.command_template A full command template for the info.json script. Overrides other keys. + info_json_generation_policy.rate_limits.per_ip.max_requests Max requests for the given time period from one IP. + info_json_generation_policy.rate_limits.per_ip.per_minutes Time period in minutes for the per_ip rate limit. + info_json_generation_policy.rate_limits.per_profile.max_requests Max requests for a single profile in a time period. + info_json_generation_policy.rate_limits.per_profile.per_minutes Time period in minutes for the per_profile rate limit. + info_json_generation_policy.client_rotation_policy.major_client The primary client to use for most requests. + info_json_generation_policy.client_rotation_policy.refresh_client The client to use periodically to refresh context. + info_json_generation_policy.client_rotation_policy.refresh_every.requests Trigger refresh client after N requests for a profile. + + [download_policy] + download_policy.formats Formats to download (e.g., '18,140', 'random:50%%'). + download_policy.downloader Orchestrator script to use: 'native-py' (default, Python lib), 'native-cli' (legacy CLI wrapper), or 'aria2c_rpc'. + download_policy.external_downloader For 'native-py' or default, the backend yt-dlp should use (e.g., 'aria2c', 'native'). + download_policy.downloader_args Arguments for the external_downloader. For yt-dlp, e.g., 'aria2c:-x 8'. + download_policy.merge_output_format Container to merge to (e.g., 'mkv'). Defaults to 'mp4' via cli.config. + download_policy.temp_path For 'native-py', path to a directory for temporary files (e.g., a RAM disk like /dev/shm). + download_policy.output_to_buffer For 'native-py', download to an in-memory buffer and pipe to stdout instead of saving to a file (true/false). Best for single-file formats. + download_policy.proxy Proxy for direct downloads (e.g., "socks5://127.0.0.1:1080"). + download_policy.proxy_rename Regex substitution for the proxy URL (e.g., 's/old/new/'). + download_policy.pause_before_download_seconds Pause for N seconds before starting each download attempt. + download_policy.continue_downloads Enable download continuation (true/false). + download_policy.cleanup After success: for native downloaders, rename and truncate file to 0 bytes; for 'aria2c_rpc', remove file(s) from filesystem. + download_policy.extra_args A string of extra arguments for the download script (e.g., "--limit-rate 5M"). + download_policy.sleep_per_proxy_seconds Cooldown in seconds between downloads on the same proxy. + download_policy.rate_limits.per_proxy.max_requests Max downloads for a single proxy in a time period. + download_policy.rate_limits.per_proxy.per_minutes Time period in minutes for the per_proxy download rate limit. + # For downloader: 'aria2c_rpc' + download_policy.aria_host Hostname of the aria2c RPC server. + download_policy.aria_port Port of the aria2c RPC server. + download_policy.aria_secret Secret token for the aria2c RPC server. + download_policy.aria_wait Wait for aria2c downloads to complete (true/false). + download_policy.purge_on_complete On success, purge ALL completed/failed downloads from aria2c history. Use as a workaround for older aria2c versions where targeted removal fails. + download_policy.output_dir Output directory for downloads. + download_policy.aria_remote_dir The absolute download path on the remote aria2c host. + download_policy.aria_fragments_dir The local path to find fragments for merging (if different from output_dir). + download_policy.auto_merge_fragments For fragmented downloads, automatically merge parts after download (true/false). Requires aria_wait=true. + download_policy.remove_fragments_after_merge For fragmented downloads, delete fragment files after a successful merge (true/false). Requires auto_merge_fragments=true. + + [stop_conditions] + stop_conditions.on_failure Stop on any download failure (true/false). + stop_conditions.on_http_403 Stop on any HTTP 403 error (true/false). + stop_conditions.on_error_rate.max_errors Stop test if more than N errors (of any type) occur within the time period. + stop_conditions.on_error_rate.per_minutes Time period in minutes for the error rate calculation. + stop_conditions.fatal_error_patterns A list of regex patterns. Errors matching these are always considered fatal and count towards 'on_error_rate', even if they also match a tolerated pattern. + stop_conditions.tolerated_error_patterns A list of regex patterns. Fetch errors matching these will be ignored by 'on_error_rate'. + stop_conditions.on_cumulative_403.max_errors Stop test if more than N HTTP 403 errors occur within the time period. + stop_conditions.on_cumulative_403.per_minutes Time period in minutes for the cumulative 403 calculation. + stop_conditions.on_quality_degradation.trigger_if_missing_formats A format ID or comma-separated list of IDs. Triggers if any are missing. + stop_conditions.on_quality_degradation.max_triggers Stop test if quality degradation is detected N times. + stop_conditions.on_quality_degradation.per_minutes Time period in minutes for the quality degradation calculation. +-------------------------------------------------------------------------------- +""" + ) + parser.add_argument('--policy', help='Path to the YAML policy file. Required unless --list-policies is used.') + parser.add_argument('--policy-name', help='Name of the policy to run from a multi-policy file (if it contains "---" separators).') + parser.add_argument('--list-policies', action='store_true', help='List all available policies from the default policies directory and exit.') + parser.add_argument('--show-overrides', action='store_true', help='Load the specified policy and print all its defined values as a single-line of --set arguments, then exit.') + parser.add_argument('--set', action='append', default=[], help="Override a policy setting using 'key.subkey=value' format.\n(e.g., --set execution_control.workers=5)") + parser.add_argument('--profile-prefix', help="Shortcut to override the profile prefix for profile locking mode. Affects both auth and download stages.") + parser.add_argument('--start-from-url-index', type=int, help='Start processing from this line number (1-based) in the urls_file. Overrides saved state.') + parser.add_argument('--expire-time-shift-minutes', type=int, help="Consider URLs expiring in N minutes as expired. Overrides policy.") + + # Add a group for aria2c-specific overrides for clarity in --help + aria_group = parser.add_argument_group('Aria2c RPC Downloader Overrides', 'Shortcuts for common --set options for the aria2c_rpc downloader.') + aria_group.add_argument('--auto-merge-fragments', action=argparse.BooleanOptionalAction, default=None, help='Shortcut to enable/disable download_policy.auto_merge_fragments.') + aria_group.add_argument('--remove-fragments-after-merge', action=argparse.BooleanOptionalAction, default=None, help='Shortcut to enable/disable download_policy.remove_fragments_after_merge.') + aria_group.add_argument('--fragments-dir', help='Shortcut for --set download_policy.aria_fragments_dir=PATH.') + aria_group.add_argument('--remote-dir', help='Shortcut for --set download_policy.aria_remote_dir=PATH.') + aria_group.add_argument('--cleanup', action=argparse.BooleanOptionalAction, default=None, help='Shortcut to enable/disable download_policy.cleanup.') + + parser.add_argument('--verbose', action='store_true', help='Enable verbose output for the orchestrator and underlying scripts.') + parser.add_argument('--print-downloader-log', action='store_true', help='Stream the live stdout/stderr from the download subprocess to the console.') + parser.add_argument('--dry-run', action='store_true', help='Print the effective policy and exit without running the test.') + parser.add_argument('--dummy', action='store_true', help='Simulate auth and download without running external commands. Used to test profile management logic.\nDummy behavior (e.g., failure rates, durations) can be configured in the policy file under settings.dummy_simulation_settings.') + parser.add_argument('--dummy-auth-failure-rate', type=float, default=0.0, help='[Dummy Mode] The probability (0.0 to 1.0) of a simulated auth request failing fatally.') + parser.add_argument('--dummy-auth-skipped-failure-rate', type=float, default=0.0, help='[Dummy Mode] The probability (0.0 to 1.0) of a simulated auth request having a tolerated failure (e.g., 429).') + parser.add_argument('--disable-log-writing', action='store_true', help='Disable writing state, stats, and log files. By default, files are created for each run.') + + # Add a group for download-specific utilities + download_util_group = parser.add_argument_group('Download Mode Utilities') + download_util_group.add_argument('--pre-cleanup-media', nargs='?', const='.', default=None, + help='Before running, delete media files (.mp4, .m4a, .webm, etc.) from a directory. ' + 'If a path is provided, cleans that directory. ' + 'If used without a path, cleans the directory specified in download_policy.output_dir or direct_docker_cli_policy.docker_host_download_path. ' + 'If no output_dir is set, it fails.') + download_util_group.add_argument('--reset-local-cache-folder', nargs='?', const='.', default=None, + help="Before running, delete the contents of the local cache folder used by direct_docker_cli mode. " + "The cache folder is defined by 'direct_docker_cli_policy.docker_host_cache_path' in the policy. " + "This is useful for forcing a fresh start for cookies, user-agents, etc. " + "If a path is provided, cleans that directory instead of the one from the policy.") + download_util_group.add_argument('--reset-infojson', action='store_true', + help="Before running, reset all '.processed' and '.LOCKED' info.json files in the source directory " + "back to '.json', allowing them to be re-processed.") + + # Add a group for Redis connection settings + redis_group = parser.add_argument_group('Redis Connection Overrides (for profile locking mode)') + redis_group.add_argument('--env-file', help='Path to a .env file to load environment variables from.') + redis_group.add_argument('--redis-host', default=None, help='Redis host. Defaults to REDIS_HOST or MASTER_HOST_IP env var, or localhost.') + redis_group.add_argument('--redis-port', type=int, default=None, help='Redis port. Defaults to REDIS_PORT env var, or 6379.') + redis_group.add_argument('--redis-password', default=None, help='Redis password. Defaults to REDIS_PASSWORD env var.') + redis_group.add_argument('--env', default=None, help="Default environment name for Redis key prefix (e.g., 'stg', 'prod'). Used if --auth-env or --download-env are not specified. Overrides policy file setting.") + redis_group.add_argument('--auth-env', help="Override the environment for the Auth simulation. Overrides --env.") + redis_group.add_argument('--download-env', help="Override the environment for the Download simulation. Overrides --env.") + redis_group.add_argument('--key-prefix', default=None, help='Explicit key prefix for Redis. Overrides --env and any defaults.') + + return parser diff --git a/ytops_client/stress_policy/process_runners.py b/ytops_client/stress_policy/process_runners.py new file mode 100644 index 0000000..4b3186c --- /dev/null +++ b/ytops_client/stress_policy/process_runners.py @@ -0,0 +1,283 @@ +import logging +import os +import shlex +import signal +import subprocess +import sys +import threading +import time + +try: + import docker +except ImportError: + docker = None + +logger = logging.getLogger(__name__) + +# Worker ID tracking +worker_id_map = {} +worker_id_counter = 0 +worker_id_lock = threading.Lock() + +def get_worker_id(): + """Assigns a stable, sequential ID to each worker thread.""" + global worker_id_counter + thread_id = threading.get_ident() + with worker_id_lock: + if thread_id not in worker_id_map: + worker_id_map[thread_id] = worker_id_counter + worker_id_counter += 1 + return worker_id_map[thread_id] + +def run_command(cmd, running_processes, process_lock, input_data=None, binary_stdout=False, stream_output=False, stream_prefix="", env=None): + """ + Runs a command, captures its output, and returns status. + If binary_stdout is True, stdout is returned as bytes. Otherwise, both are decoded strings. + If stream_output is True, the command's stdout/stderr are printed to the console in real-time. + """ + logger.debug(f"Running command: {' '.join(shlex.quote(s) for s in cmd)}") + if env: + logger.debug(f"With custom environment: {env}") + process = None + try: + # Combine with os.environ to ensure PATH etc. are inherited. + process_env = os.environ.copy() + if env: + # Ensure all values in the custom env are strings + process_env.update({k: str(v) for k, v in env.items()}) + + # Always open in binary mode to handle both cases. We will decode later. + process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE if input_data else None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid, # Start in a new process group to isolate from terminal signals + env=process_env + ) + with process_lock: + running_processes.add(process) + + stdout_capture = [] + stderr_capture = [] + + def read_pipe(pipe, capture_list, display_pipe=None, prefix=""): + """Reads a pipe line by line (as bytes), appending to a list and optionally displaying.""" + for line in iter(pipe.readline, b''): + capture_list.append(line) + if display_pipe: + # Decode for display + display_line = line.decode('utf-8', errors='replace') + # Use print to ensure atomicity and proper handling of newlines + print(f"{prefix}{display_line.strip()}", file=display_pipe) + + stdout_display_pipe = sys.stdout if stream_output else None + stderr_display_pipe = sys.stderr if stream_output else None + + # We must read stdout and stderr in parallel to prevent deadlocks. + stdout_thread = threading.Thread(target=read_pipe, args=(process.stdout, stdout_capture, stdout_display_pipe, stream_prefix)) + stderr_thread = threading.Thread(target=read_pipe, args=(process.stderr, stderr_capture, stderr_display_pipe, stream_prefix)) + + stdout_thread.start() + stderr_thread.start() + + # Handle stdin after starting to read outputs to avoid deadlocks. + if input_data: + try: + process.stdin.write(input_data.encode('utf-8')) + process.stdin.close() + except (IOError, BrokenPipeError): + # This can happen if the process exits quickly or doesn't read stdin. + logger.debug(f"Could not write to stdin for command: {' '.join(cmd)}. Process may have already exited.") + + # Wait for the process to finish and for all output to be read. + # Add a timeout to prevent indefinite hangs. 15 minutes should be enough for any single download. + timeout_seconds = 15 * 60 + try: + retcode = process.wait(timeout=timeout_seconds) + except subprocess.TimeoutExpired: + logger.error(f"Command timed out after {timeout_seconds} seconds: {' '.join(cmd)}") + # Kill the entire process group to ensure child processes (like yt-dlp or ffmpeg) are also terminated. + try: + os.killpg(os.getpgid(process.pid), signal.SIGKILL) + except (ProcessLookupError, PermissionError): + pass # Process already finished or we lack permissions + retcode = -1 # Indicate failure + # Wait a moment for pipes to close after killing. + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + logger.warning("Process did not terminate gracefully after SIGKILL.") + + stdout_thread.join(timeout=5) + stderr_thread.join(timeout=5) + + stdout_bytes = b"".join(stdout_capture) + stderr_bytes = b"".join(stderr_capture) + + # If we timed out, create a synthetic stderr message to ensure the failure is reported upstream. + if retcode == -1 and not stderr_bytes.strip(): + stderr_bytes = f"Command timed out after {timeout_seconds} seconds".encode('utf-8') + + stdout = stdout_bytes if binary_stdout else stdout_bytes.decode('utf-8', errors='replace') + stderr = stderr_bytes.decode('utf-8', errors='replace') + + return retcode, stdout, stderr + + except FileNotFoundError: + logger.error(f"Command not found: {cmd[0]}. Make sure it's in your PATH.") + return -1, "", f"Command not found: {cmd[0]}" + except Exception as e: + logger.error(f"An error occurred while running command: {' '.join(cmd)}. Error: {e}") + return -1, "", str(e) + finally: + if process: + with process_lock: + running_processes.discard(process) + + +def run_docker_container(image_name, command, volumes, stream_prefix="", network_name=None, log_callback=None, profile_manager=None, profile_name=None, environment=None, log_command_override=None): + """ + Runs a command in a new, ephemeral Docker container using docker-py. + Streams logs in real-time, allows for live log processing, and ensures cleanup. + Can monitor a profile and stop the container if the profile is BANNED or RESTING. + Returns a tuple of (exit_code, stdout_str, stderr_str, stop_reason). + """ + if not docker: + # This should be caught earlier, but as a safeguard: + return -1, "", "Docker SDK for Python is not installed. Please run: pip install docker", None + + logger.debug(f"Running docker container. Image: {image_name}, Command: {command}, Volumes: {volumes}, Network: {network_name}") + + # --- Construct and log the equivalent CLI command for debugging --- + try: + user_id = f"{os.getuid()}:{os.getgid()}" if os.name != 'nt' else None + cli_cmd = ['docker', 'run', '--rm'] + if user_id: + cli_cmd.extend(['-u', user_id]) + if network_name: + cli_cmd.extend(['--network', network_name]) + if environment: + for k, v in sorted(environment.items()): + cli_cmd.extend(['-e', f"{k}={v}"]) + if volumes: + for host_path, container_config in sorted(volumes.items()): + bind = container_config.get('bind') + mode = container_config.get('mode', 'rw') + cli_cmd.extend(['-v', f"{os.path.abspath(host_path)}:{bind}:{mode}"]) + cli_cmd.append(image_name) + cli_cmd.extend(command) + logger.info(f"Full docker command: {' '.join(shlex.quote(s) for s in cli_cmd)}") + if log_command_override: + # Build a more comprehensive, runnable command for logging + env_prefix_parts = [] + if environment: + for k, v in sorted(environment.items()): + env_prefix_parts.append(f"{k}={shlex.quote(str(v))}") + + env_prefix = ' '.join(env_prefix_parts) + + equivalent_ytdlp_cmd = ' '.join(shlex.quote(s) for s in log_command_override) + + full_equivalent_cmd = f"{env_prefix} {equivalent_ytdlp_cmd}".strip() + logger.info(f"Equivalent host command: {full_equivalent_cmd}") + except Exception as e: + logger.warning(f"Could not construct equivalent docker command for logging: {e}") + # --- End of logging --- + + container = None + monitor_thread = None + stop_monitor_event = threading.Event() + # Use a mutable object (dict) to share the stop reason between threads + stop_reason_obj = {'reason': None} + try: + client = docker.from_env() + + # Run container as current host user to avoid permission issues with volume mounts + user_id = f"{os.getuid()}:{os.getgid()}" if os.name != 'nt' else None + + container = client.containers.run( + image_name, + command=command, + volumes=volumes, + detach=True, + network=network_name, + user=user_id, + environment=environment, + # We use `remove` in `finally` instead of `auto_remove` to ensure we can get logs + # even if the container fails to start. + ) + + # Thread to monitor profile status and stop container if BANNED or RESTING + def monitor_profile(): + while not stop_monitor_event.is_set(): + try: + profile_info = profile_manager.get_profile(profile_name) + if profile_info: + state = profile_info.get('state') + if state in ['BANNED', 'RESTING']: + logger.warning(f"Profile '{profile_name}' is {state}. Stopping container {container.short_id}.") + stop_reason_obj['reason'] = f"Profile became {state}" + try: + container.stop(timeout=5) + except docker.errors.APIError as e: + logger.warning(f"Could not stop container {container.short_id}: {e}") + break # Stop monitoring + except Exception as e: + logger.error(f"Error in profile monitor thread: {e}") + + # Wait for 2 seconds or until stop event is set + stop_monitor_event.wait(2) + + if profile_manager and profile_name: + monitor_thread = threading.Thread(target=monitor_profile, daemon=True) + monitor_thread.start() + + # Stream logs in a separate thread to avoid blocking. + log_stream = container.logs(stream=True, follow=True, stdout=True, stderr=True) + for line_bytes in log_stream: + line_str = line_bytes.decode('utf-8', errors='replace').strip() + # Use logger.info to ensure output is captured by all handlers + logger.info(f"{stream_prefix}{line_str}") + if log_callback: + # The callback can return True to signal an immediate stop. + if log_callback(line_str): + logger.warning(f"Log callback requested to stop container {container.short_id}.") + stop_reason_obj['reason'] = "Stopped by log callback (fatal error)" + try: + container.stop(timeout=5) + except docker.errors.APIError as e: + logger.warning(f"Could not stop container {container.short_id}: {e}") + break # Stop reading logs + + result = container.wait(timeout=15 * 60) + exit_code = result.get('StatusCode', -1) + + # Get final logs to separate stdout and stderr. + final_stdout = container.logs(stdout=True, stderr=False) + final_stderr = container.logs(stdout=False, stderr=True) + + stdout_str = final_stdout.decode('utf-8', errors='replace') + stderr_str = final_stderr.decode('utf-8', errors='replace') + + return exit_code, stdout_str, stderr_str, stop_reason_obj['reason'] + + except docker.errors.ImageNotFound: + logger.error(f"Docker image not found: '{image_name}'. Please pull it first.") + return -1, "", f"Docker image not found: {image_name}", None + except docker.errors.APIError as e: + logger.error(f"Docker API error: {e}") + return -1, "", str(e), None + except Exception as e: + logger.error(f"An unexpected error occurred while running docker container: {e}", exc_info=True) + return -1, "", str(e), None + finally: + if monitor_thread: + stop_monitor_event.set() + monitor_thread.join(timeout=1) + if container: + try: + container.remove(force=True) + logger.debug(f"Removed container {container.short_id}") + except docker.errors.APIError as e: + logger.warning(f"Could not remove container {container.short_id}: {e}") diff --git a/ytops_client/stress_policy/state_manager.py b/ytops_client/stress_policy/state_manager.py new file mode 100644 index 0000000..6605b69 --- /dev/null +++ b/ytops_client/stress_policy/state_manager.py @@ -0,0 +1,794 @@ +import collections +import collections.abc +import json +import logging +import re +import threading +import time +from datetime import datetime +from pathlib import Path + +from . import utils as sp_utils + +logger = logging.getLogger(__name__) + + +class StateManager: + """Tracks statistics, manages rate limits, and persists state across runs.""" + def __init__(self, policy_name, disable_log_writing=False, shutdown_event=None): + self.disable_log_writing = disable_log_writing + self.state_file_path = Path(f"{policy_name}_state.json") + self.stats_file_path = Path(f"{policy_name}_stats.jsonl") + self.lock = threading.RLock() + self.start_time = time.time() + self.shutdown_event = shutdown_event or threading.Event() + self.events = [] + self.state = { + 'global_request_count': 0, + 'rate_limit_trackers': {}, # e.g., {'per_ip': [ts1, ts2], 'profile_foo': [ts3, ts4]} + 'profile_request_counts': {}, # for client rotation + 'profile_last_refresh_time': {}, # for client rotation + 'proxy_last_finish_time': {}, # for per-proxy sleep + 'processed_files': [], # For continuous download_only mode + # For dynamic profile cooldown strategy + 'profile_cooldown_counts': {}, + 'profile_cooldown_sleep_until': {}, + 'profile_pool_size': 0, + 'profile_run_suffix': None, + 'worker_profile_generations': {}, + 'last_url_index': 0, + # For batch modes + 'total_batches_processed': 0, + 'successful_batches': 0, + 'failed_batches': 0, + 'total_videos_processed': 0, + } + self.stats_file_handle = None + self._load_state() + self.print_historical_summary() + self._open_stats_log() + + def _load_state(self): + if self.disable_log_writing: + logger.info("Log writing is disabled. State will not be loaded from disk.") + return + if not self.state_file_path.exists(): + logger.info(f"State file not found at '{self.state_file_path}', starting fresh.") + return + try: + with open(self.state_file_path, 'r', encoding='utf-8') as f: + self.state = json.load(f) + # Ensure keys exist + self.state.setdefault('global_request_count', 0) + self.state.setdefault('rate_limit_trackers', {}) + self.state.setdefault('profile_request_counts', {}) + self.state.setdefault('profile_last_refresh_time', {}) + self.state.setdefault('proxy_last_finish_time', {}) + self.state.setdefault('processed_files', []) + # For dynamic profile cooldown strategy + self.state.setdefault('profile_cooldown_counts', {}) + self.state.setdefault('profile_cooldown_sleep_until', {}) + self.state.setdefault('profile_pool_size', 0) + self.state.setdefault('profile_run_suffix', None) + self.state.setdefault('worker_profile_generations', {}) + self.state.setdefault('last_url_index', 0) + # For batch modes + self.state.setdefault('total_batches_processed', 0) + self.state.setdefault('successful_batches', 0) + self.state.setdefault('failed_batches', 0) + self.state.setdefault('total_videos_processed', 0) + logger.info(f"Loaded state from {self.state_file_path}") + except (IOError, json.JSONDecodeError) as e: + logger.error(f"Could not load or parse state file {self.state_file_path}: {e}. Starting fresh.") + + def _save_state(self): + if self.disable_log_writing: + return + with self.lock: + try: + with open(self.state_file_path, 'w', encoding='utf-8') as f: + json.dump(self.state, f, indent=2) + logger.info(f"Saved state to {self.state_file_path}") + except IOError as e: + logger.error(f"Could not save state to {self.state_file_path}: {e}") + + def _open_stats_log(self): + if self.disable_log_writing: + return + try: + self.stats_file_handle = open(self.stats_file_path, 'a', encoding='utf-8') + except IOError as e: + logger.error(f"Could not open stats file {self.stats_file_path}: {e}") + + def close(self): + """Saves state and closes file handles.""" + self._save_state() + if self.stats_file_handle: + self.stats_file_handle.close() + self.stats_file_handle = None + + def mark_file_as_processed(self, file_path): + """Adds a file path to the list of processed files in the state.""" + with self.lock: + # Using a list and checking for existence is fine for moderate numbers of files. + # A set isn't JSON serializable. + processed = self.state.setdefault('processed_files', []) + file_str = str(file_path) + if file_str not in processed: + processed.append(file_str) + + def get_last_url_index(self): + """Gets the last URL index to start from.""" + with self.lock: + return self.state.get('last_url_index', 0) + + def get_next_url_batch(self, count, urls_list): + """Gets the next batch of URLs to process, updating the state.""" + with self.lock: + start_index = self.state.get('last_url_index', 0) + if start_index >= len(urls_list): + return [], start_index # No more URLs + + end_index = start_index + count + batch = urls_list[start_index:end_index] + + # Update state with the index of the *next* URL to be processed. + self.state['last_url_index'] = end_index + return batch, start_index + + def update_last_url_index(self, index, force=False): + """Updates the last processed URL index in the state. + + Args: + index: The index of the *next* URL to process. + force: If True, sets the index regardless of the current value. + """ + with self.lock: + if force or index > self.state.get('last_url_index', 0): + self.state['last_url_index'] = index + + def get_processed_files(self): + """Returns a set of file paths that have been processed.""" + with self.lock: + return set(self.state.get('processed_files', [])) + + def record_batch_result(self, success, video_count, profile_name=None): + with self.lock: + self.state['total_batches_processed'] = self.state.get('total_batches_processed', 0) + 1 + self.state['total_videos_processed'] = self.state.get('total_videos_processed', 0) + video_count + if success: + self.state['successful_batches'] = self.state.get('successful_batches', 0) + 1 + else: + self.state['failed_batches'] = self.state.get('failed_batches', 0) + 1 + + # Print live counter + total = self.state['total_batches_processed'] + ok = self.state['successful_batches'] + fail = self.state['failed_batches'] + profile_log = f" [{profile_name}]" if profile_name else "" + logger.info(f"Batch #{total} complete.{profile_log} (Total OK: {ok}, Total Fail: {fail})") + + def print_historical_summary(self): + """Prints a summary based on the state loaded from disk, before new events.""" + with self.lock: + now = time.time() + rate_trackers = self.state.get('rate_limit_trackers', {}) + total_requests = self.state.get('global_request_count', 0) + + if not rate_trackers and not total_requests: + logger.info("No historical data found in state file.") + return + + logger.info("\n--- Summary From Previous Runs ---") + logger.info(f"Total info.json requests (all previous runs): {total_requests}") + + if rate_trackers: + for key, timestamps in sorted(rate_trackers.items()): + # Time windows in seconds + windows = { + 'last 10 min': 600, + 'last 60 min': 3600, + 'last 6 hours': 21600, + 'last 24 hours': 86400 + } + + rates_str_parts = [] + for name, seconds in windows.items(): + count = sum(1 for ts in timestamps if now - ts <= seconds) + # Calculate rate in requests per minute + rate_rpm = (count / seconds) * 60 if seconds > 0 else 0 + rates_str_parts.append(f"{count} req in {name} ({rate_rpm:.2f} rpm)") + + logger.info(f"Tracker '{key}': " + ", ".join(rates_str_parts)) + logger.info("------------------------------------") + + def log_event(self, event_data): + with self.lock: + event_data['timestamp'] = datetime.now().isoformat() + self.events.append(event_data) + if self.stats_file_handle: + self.stats_file_handle.write(json.dumps(event_data) + '\n') + self.stats_file_handle.flush() + + def get_request_count(self): + with self.lock: + return self.state.get('global_request_count', 0) + + def increment_request_count(self): + with self.lock: + self.state['global_request_count'] = self.state.get('global_request_count', 0) + 1 + + def check_cumulative_error_rate(self, max_errors, per_minutes, error_type=None): + """ + Checks if a cumulative error rate has been exceeded. + If error_type is None, checks for any failure. + Returns the number of errors found if the threshold is met, otherwise 0. + """ + with self.lock: + now = time.time() + window_seconds = per_minutes * 60 + + if error_type: + recent_errors = [ + e for e in self.events + if e.get('error_type') == error_type and (now - datetime.fromisoformat(e['timestamp']).timestamp()) <= window_seconds + ] + else: # Generic failure check + recent_errors = [ + e for e in self.events + # Only count failures that are not explicitly tolerated + if not e.get('success') and not e.get('is_tolerated_error') and (now - datetime.fromisoformat(e['timestamp']).timestamp()) <= window_seconds + ] + + if len(recent_errors) >= max_errors: + return len(recent_errors) + return 0 + + def check_quality_degradation_rate(self, max_triggers, per_minutes): + """ + Checks if the quality degradation trigger rate has been exceeded. + Returns the number of triggers found if the threshold is met, otherwise 0. + """ + with self.lock: + now = time.time() + window_seconds = per_minutes * 60 + + recent_triggers = [ + e for e in self.events + if e.get('quality_degradation_trigger') and (now - datetime.fromisoformat(e['timestamp']).timestamp()) <= window_seconds + ] + + if len(recent_triggers) >= max_triggers: + return len(recent_triggers) + return 0 + + def check_and_update_rate_limit(self, profile_name, policy): + """ + Checks if a request is allowed based on policy rate limits. + If allowed, updates the internal state. Returns True if allowed, False otherwise. + """ + with self.lock: + now = time.time() + gen_policy = policy.get('info_json_generation_policy', {}) + rate_limits = gen_policy.get('rate_limits', {}) + + # Check per-IP limit + ip_limit = rate_limits.get('per_ip') + if ip_limit: + tracker_key = 'per_ip' + max_req = ip_limit.get('max_requests') + period_min = ip_limit.get('per_minutes') + if max_req and period_min: + timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) + # Filter out old timestamps + timestamps = [ts for ts in timestamps if now - ts < period_min * 60] + if len(timestamps) >= max_req: + logger.warning("Per-IP rate limit reached. Skipping task.") + return False + self.state['rate_limit_trackers'][tracker_key] = timestamps + + # Check per-profile limit + profile_limit = rate_limits.get('per_profile') + if profile_limit and profile_name: + tracker_key = f"profile_{profile_name}" + max_req = profile_limit.get('max_requests') + period_min = profile_limit.get('per_minutes') + if max_req and period_min: + timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) + timestamps = [ts for ts in timestamps if now - ts < period_min * 60] + if len(timestamps) >= max_req: + logger.warning(f"Per-profile rate limit for '{profile_name}' reached. Skipping task.") + return False + self.state['rate_limit_trackers'][tracker_key] = timestamps + + # If all checks pass, record the new request timestamp for all relevant trackers + if ip_limit and ip_limit.get('max_requests'): + self.state['rate_limit_trackers'].setdefault('per_ip', []).append(now) + if profile_limit and profile_limit.get('max_requests') and profile_name: + self.state['rate_limit_trackers'].setdefault(f"profile_{profile_name}", []).append(now) + + return True + + def get_client_for_request(self, profile_name, gen_policy): + """ + Determines which client to use based on the client_rotation_policy. + Returns a tuple: (client_name, request_params_dict). + """ + with self.lock: + rotation_policy = gen_policy.get('client_rotation_policy') + + # If no rotation policy, use the simple 'client' key. + if not rotation_policy: + client = gen_policy.get('client') + logger.info(f"Using client '{client}' for profile '{profile_name}'.") + req_params = gen_policy.get('request_params') + return client, req_params + + # --- Rotation logic --- + now = time.time() + major_client = rotation_policy.get('major_client') + refresh_client = rotation_policy.get('refresh_client') + refresh_every = rotation_policy.get('refresh_every', {}) + + if not refresh_client or not refresh_every: + return major_client, rotation_policy.get('major_client_params') + + should_refresh = False + + # Check time-based refresh + refresh_minutes = refresh_every.get('minutes') + last_refresh_time = self.state['profile_last_refresh_time'].get(profile_name, 0) + if refresh_minutes and (now - last_refresh_time) > (refresh_minutes * 60): + should_refresh = True + + # Check request-count-based refresh + refresh_requests = refresh_every.get('requests') + request_count = self.state['profile_request_counts'].get(profile_name, 0) + if refresh_requests and request_count >= refresh_requests: + should_refresh = True + + if should_refresh: + logger.info(f"Profile '{profile_name}' is due for a refresh. Using refresh client '{refresh_client}'.") + self.state['profile_last_refresh_time'][profile_name] = now + self.state['profile_request_counts'][profile_name] = 0 # Reset counter + return refresh_client, rotation_policy.get('refresh_client_params') + else: + # Not refreshing, so increment request count for this profile + self.state['profile_request_counts'][profile_name] = request_count + 1 + return major_client, rotation_policy.get('major_client_params') + + def get_next_available_profile(self, policy): + """ + Finds or creates an available profile based on the dynamic cooldown policy. + Returns a profile name, or None if no profile is available. + """ + with self.lock: + now = time.time() + settings = policy.get('settings', {}) + pm_policy = settings.get('profile_management') + + if not pm_policy: + return None + + prefix = pm_policy.get('prefix') + if not prefix: + logger.error("Profile management policy requires 'prefix'.") + return None + + # Determine and persist the suffix for this run to ensure profile names are stable + run_suffix = self.state.get('profile_run_suffix') + if not run_suffix: + suffix_config = pm_policy.get('suffix') + if suffix_config == 'auto': + run_suffix = datetime.now().strftime('%Y%m%d%H%M') + else: + run_suffix = suffix_config or '' + self.state['profile_run_suffix'] = run_suffix + + # Initialize pool size from policy if not already in state + if self.state.get('profile_pool_size', 0) == 0: + self.state['profile_pool_size'] = pm_policy.get('initial_pool_size', 1) + + max_reqs = pm_policy.get('max_requests_per_profile') + sleep_mins = pm_policy.get('sleep_minutes_on_exhaustion') + + # Loop until a profile is found or we decide we can't find one + while True: + # Try to find an existing, available profile + for i in range(self.state['profile_pool_size']): + profile_name = f"{prefix}_{run_suffix}_{i}" if run_suffix else f"{prefix}_{i}" + + # Check if sleeping + sleep_until = self.state['profile_cooldown_sleep_until'].get(profile_name, 0) + if now < sleep_until: + continue # Still sleeping + + # Check if it needs to be put to sleep + req_count = self.state['profile_cooldown_counts'].get(profile_name, 0) + if max_reqs and req_count >= max_reqs: + sleep_duration_seconds = (sleep_mins or 0) * 60 + self.state['profile_cooldown_sleep_until'][profile_name] = now + sleep_duration_seconds + self.state['profile_cooldown_counts'][profile_name] = 0 # Reset count for next time + logger.info(f"Profile '{profile_name}' reached request limit ({req_count}/{max_reqs}). Putting to sleep for {sleep_mins} minutes.") + continue # Now sleeping, try next profile + + # This profile is available + logger.info(f"Selected available profile '{profile_name}' (request count: {req_count}/{max_reqs if max_reqs else 'unlimited'}).") + return profile_name + + # If we get here, no existing profile was available + if pm_policy.get('auto_expand_pool'): + new_profile_index = self.state['profile_pool_size'] + self.state['profile_pool_size'] += 1 + profile_name = f"{prefix}_{run_suffix}_{new_profile_index}" if run_suffix else f"{prefix}_{new_profile_index}" + logger.info(f"Profile pool exhausted. Expanding pool to size {self.state['profile_pool_size']}. New profile: '{profile_name}'") + return profile_name + else: + # No available profiles and pool expansion is disabled + return None + + def get_or_rotate_worker_profile(self, worker_id, policy): + """ + Gets the current profile for a worker, rotating to a new generation if the lifetime limit is met. + This is used by the 'per_worker_with_rotation' profile mode. + """ + with self.lock: + pm_policy = policy.get('settings', {}).get('profile_management', {}) + if not pm_policy: + logger.error("Profile mode 'per_worker_with_rotation' requires 'settings.profile_management' configuration in the policy.") + return f"error_profile_{worker_id}" + + prefix = pm_policy.get('prefix') + if not prefix: + logger.error("Profile management for 'per_worker_with_rotation' requires a 'prefix'.") + return f"error_profile_{worker_id}" + + max_reqs = pm_policy.get('max_requests_per_profile') + + generations = self.state.setdefault('worker_profile_generations', {}) + # worker_id is an int, but JSON keys must be strings + worker_id_str = str(worker_id) + current_gen = generations.get(worker_id_str, 0) + + profile_name = f"{prefix}_{worker_id}_{current_gen}" + + if not max_reqs: # No lifetime limit defined, so never rotate. + return profile_name + + req_count = self.state.get('profile_cooldown_counts', {}).get(profile_name, 0) + + if req_count >= max_reqs: + logger.info(f"Profile '{profile_name}' reached lifetime request limit ({req_count}/{max_reqs}). Rotating to new generation for worker {worker_id}.") + new_gen = current_gen + 1 + generations[worker_id_str] = new_gen + # The request counts for the old profile are implicitly left behind. + # The new profile will start with a count of 0. + profile_name = f"{prefix}_{worker_id}_{new_gen}" + + return profile_name + + def record_profile_request(self, profile_name): + """Increments the request counter for a profile for the cooldown policy.""" + with self.lock: + if not profile_name: + return + counts = self.state.setdefault('profile_cooldown_counts', {}) + counts[profile_name] = counts.get(profile_name, 0) + 1 + + def record_proxy_usage(self, proxy_url): + """Records a request timestamp for a given proxy URL for statistical purposes.""" + if not proxy_url: + return + with self.lock: + now = time.time() + # Use a prefix to avoid collisions with profile names or other keys + tracker_key = f"proxy_{proxy_url}" + self.state['rate_limit_trackers'].setdefault(tracker_key, []).append(now) + + def check_and_update_download_rate_limit(self, proxy_url, policy): + """Checks download rate limits. Returns True if allowed, False otherwise.""" + with self.lock: + now = time.time() + d_policy = policy.get('download_policy', {}) + rate_limits = d_policy.get('rate_limits', {}) + + # Check per-IP limit + ip_limit = rate_limits.get('per_ip') + if ip_limit: + tracker_key = 'download_per_ip' # Use a distinct key + max_req = ip_limit.get('max_requests') + period_min = ip_limit.get('per_minutes') + if max_req and period_min: + timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) + timestamps = [ts for ts in timestamps if now - ts < period_min * 60] + if len(timestamps) >= max_req: + logger.warning("Per-IP download rate limit reached. Skipping task.") + return False + self.state['rate_limit_trackers'][tracker_key] = timestamps + + # Check per-proxy limit + proxy_limit = rate_limits.get('per_proxy') + if proxy_limit and proxy_url: + tracker_key = f"download_proxy_{proxy_url}" + max_req = proxy_limit.get('max_requests') + period_min = proxy_limit.get('per_minutes') + if max_req and period_min: + timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) + timestamps = [ts for ts in timestamps if now - ts < period_min * 60] + if len(timestamps) >= max_req: + logger.warning(f"Per-proxy download rate limit for '{proxy_url}' reached. Skipping task.") + return False + self.state['rate_limit_trackers'][tracker_key] = timestamps + + # If all checks pass, record the new request timestamp for all relevant trackers + if ip_limit and ip_limit.get('max_requests'): + self.state['rate_limit_trackers'].setdefault('download_per_ip', []).append(now) + if proxy_limit and proxy_limit.get('max_requests') and proxy_url: + self.state['rate_limit_trackers'].setdefault(f"download_proxy_{proxy_url}", []).append(now) + + return True + + def wait_for_proxy_cooldown(self, proxy_url, policy): + """If a per-proxy sleep is defined, wait until the cooldown period has passed.""" + with self.lock: + d_policy = policy.get('download_policy', {}) + sleep_duration = d_policy.get('sleep_per_proxy_seconds', 0) + if not proxy_url or not sleep_duration > 0: + return + + last_finish = self.state.setdefault('proxy_last_finish_time', {}).get(proxy_url, 0) + elapsed = time.time() - last_finish + + if elapsed < sleep_duration: + time_to_sleep = sleep_duration - elapsed + logger.info(f"Proxy '{proxy_url}' was used recently. Sleeping for {time_to_sleep:.2f}s.") + # Interruptible sleep + sleep_end_time = time.time() + time_to_sleep + while time.time() < sleep_end_time: + if self.shutdown_event.is_set(): + logger.info("Shutdown requested during proxy cooldown sleep.") + break + time.sleep(0.2) + + def update_proxy_finish_time(self, proxy_url): + """Updates the last finish time for a proxy.""" + with self.lock: + if not proxy_url: + return + self.state.setdefault('proxy_last_finish_time', {})[proxy_url] = time.time() + + def print_summary(self, policy=None): + """Print a summary of the test run.""" + with self.lock: + # --- Cumulative Stats from State --- + now = time.time() + rate_trackers = self.state.get('rate_limit_trackers', {}) + if rate_trackers: + logger.info("\n--- Cumulative Rate Summary (All Runs, updated at end of run) ---") + logger.info("This shows the total number of requests/downloads over various time windows, including previous runs.") + + fetch_trackers = {k: v for k, v in rate_trackers.items() if not k.startswith('download_')} + download_trackers = {k: v for k, v in rate_trackers.items() if k.startswith('download_')} + + def print_tracker_stats(trackers, tracker_type): + if not trackers: + logger.info(f"No historical {tracker_type} trackers found.") + return + + logger.info(f"Historical {tracker_type} Trackers:") + for key, timestamps in sorted(trackers.items()): + windows = { + 'last 10 min': 600, 'last 60 min': 3600, + 'last 6 hours': 21600, 'last 24 hours': 86400 + } + rates_str_parts = [] + for name, seconds in windows.items(): + count = sum(1 for ts in timestamps if now - ts <= seconds) + rate_rpm = (count / seconds) * 60 if seconds > 0 else 0 + rates_str_parts.append(f"{count} in {name} ({rate_rpm:.2f}/min)") + + # Clean up key for display + display_key = key.replace('download_', '').replace('per_ip', 'all_proxies/ips') + logger.info(f" - Tracker '{display_key}': " + ", ".join(rates_str_parts)) + + print_tracker_stats(fetch_trackers, "Fetch Request") + print_tracker_stats(download_trackers, "Download Attempt") + + if not self.events: + logger.info("\nNo new events were recorded in this session.") + return + + duration = time.time() - self.start_time + fetch_events = [e for e in self.events if e.get('type') == 'fetch'] + batch_fetch_events = [e for e in self.events if e.get('type') == 'fetch_batch'] + download_events = [e for e in self.events if e.get('type') not in ['fetch', 'fetch_batch']] + + logger.info("\n--- Test Summary (This Run) ---") + logger.info(f"Total duration: {duration:.2f} seconds") + + # Check for batch mode stats from state + if self.state.get('total_batches_processed', 0) > 0: + logger.info(f"Total batches processed (cumulative): {self.state['total_batches_processed']}") + logger.info(f" - Successful: {self.state['successful_batches']}") + logger.info(f" - Failed: {self.state['failed_batches']}") + logger.info(f"Total videos processed (cumulative): {self.state['total_videos_processed']}") + else: + logger.info(f"Total info.json requests (cumulative): {self.get_request_count()}") + + if policy: + logger.info("\n--- Test Configuration ---") + settings = policy.get('settings', {}) + d_policy = policy.get('download_policy', {}) + + if settings.get('urls_file'): + logger.info(f"URL source file: {settings['urls_file']}") + if settings.get('info_json_dir'): + logger.info(f"Info.json source dir: {settings['info_json_dir']}") + + if d_policy: + logger.info(f"Download formats: {d_policy.get('formats', 'N/A')}") + if d_policy.get('downloader'): + logger.info(f"Downloader: {d_policy.get('downloader')}") + if d_policy.get('downloader_args'): + logger.info(f"Downloader args: {d_policy.get('downloader_args')}") + if d_policy.get('pause_before_download_seconds'): + logger.info(f"Pause before download: {d_policy.get('pause_before_download_seconds')}s") + if d_policy.get('sleep_between_formats'): + sleep_cfg = d_policy.get('sleep_between_formats') + logger.info(f"Sleep between formats: {sleep_cfg.get('min_seconds', 0)}-{sleep_cfg.get('max_seconds', 0)}s") + + if fetch_events: + total_fetches = len(fetch_events) + successful_fetches = sum(1 for e in fetch_events if e['success']) + cancelled_fetches = sum(1 for e in fetch_events if e.get('error_type') == 'Cancelled') + failed_fetches = total_fetches - successful_fetches - cancelled_fetches + + logger.info("\n--- Fetch Summary (This Run) ---") + logger.info(f"Total info.json fetch attempts: {total_fetches}") + logger.info(f" - Successful: {successful_fetches}") + logger.info(f" - Failed: {failed_fetches}") + if cancelled_fetches > 0: + logger.info(f" - Cancelled: {cancelled_fetches}") + + completed_fetches = successful_fetches + failed_fetches + if completed_fetches > 0: + success_rate = (successful_fetches / completed_fetches) * 100 + logger.info(f"Success rate (of completed): {success_rate:.2f}%") + elif total_fetches > 0: + logger.info("Success rate: N/A (no tasks completed)") + + if duration > 1 and total_fetches > 0: + rpm = (total_fetches / duration) * 60 + logger.info(f"Actual fetch rate: {rpm:.2f} requests/minute") + + if failed_fetches > 0: + error_counts = collections.Counter( + e.get('error_type', 'Unknown') + for e in fetch_events if not e['success'] and e.get('error_type') != 'Cancelled' + ) + logger.info("Failure breakdown:") + for error_type, count in sorted(error_counts.items()): + logger.info(f" - {error_type}: {count}") + + profile_counts = collections.Counter(e.get('profile') for e in fetch_events if e.get('profile')) + if profile_counts: + logger.info("Requests per profile:") + for profile, count in sorted(profile_counts.items()): + logger.info(f" - {profile}: {count}") + + proxy_counts = collections.Counter(e.get('proxy_url') for e in fetch_events if e.get('proxy_url')) + if proxy_counts: + logger.info("Requests per proxy:") + for proxy, count in sorted(proxy_counts.items()): + logger.info(f" - {proxy}: {count}") + + if batch_fetch_events: + total_batches = len(batch_fetch_events) + successful_batches = sum(1 for e in batch_fetch_events if e['success']) + failed_batches = total_batches - successful_batches + total_videos_this_run = sum(e.get('video_count', 0) for e in batch_fetch_events) + + logger.info("\n--- Batch Fetch Summary (This Run) ---") + logger.info(f"Total batches processed: {total_batches}") + logger.info(f"Total videos processed: {total_videos_this_run}") + logger.info(f" - Successful batches: {successful_batches}") + logger.info(f" - Failed batches: {failed_batches}") + + profile_counts = collections.Counter(e.get('profile') for e in batch_fetch_events if e.get('profile')) + if profile_counts: + logger.info("Batches per profile:") + for profile, count in sorted(profile_counts.items()): + logger.info(f" - {profile}: {count}") + + proxy_counts = collections.Counter(e.get('proxy_url') for e in batch_fetch_events if e.get('proxy_url')) + if proxy_counts: + logger.info("Batches per proxy:") + for proxy, count in sorted(proxy_counts.items()): + logger.info(f" - {proxy}: {count}") + + if download_events: + total_attempts = len(download_events) + successes = sum(1 for e in download_events if e['success']) + cancelled = sum(1 for e in download_events if e.get('error_type') == 'Cancelled') + failures = total_attempts - successes - cancelled + + # --- Profile Association for Download Events --- + download_profiles = [e.get('profile') for e in download_events] + + # For download_only mode, we might need to fall back to regex extraction + # if the profile wasn't passed down (e.g., no profile grouping). + profile_regex = None + if policy: + settings = policy.get('settings', {}) + if settings.get('mode') == 'download_only': + profile_regex = settings.get('profile_extraction_regex') + + if profile_regex: + for i, e in enumerate(download_events): + if not download_profiles[i]: # If profile wasn't set in the event + path = Path(e.get('path', '')) + match = re.search(profile_regex, path.name) + if match and match.groups(): + download_profiles[i] = match.group(1) + + # Replace any remaining Nones with 'unknown_profile' + download_profiles = [p or 'unknown_profile' for p in download_profiles] + + num_profiles_used = len(set(p for p in download_profiles if p != 'unknown_profile')) + + logger.info("\n--- Download Summary (This Run) ---") + if policy: + workers = policy.get('execution_control', {}).get('workers', 'N/A') + logger.info(f"Workers configured: {workers}") + + logger.info(f"Profiles utilized for downloads: {num_profiles_used}") + logger.info(f"Total download attempts: {total_attempts}") + logger.info(f" - Successful: {successes}") + logger.info(f" - Failed: {failures}") + if cancelled > 0: + logger.info(f" - Cancelled: {cancelled}") + + completed_downloads = successes + failures + if completed_downloads > 0: + success_rate = (successes / completed_downloads) * 100 + logger.info(f"Success rate (of completed): {success_rate:.2f}%") + elif total_attempts > 0: + logger.info("Success rate: N/A (no tasks completed)") + + duration_hours = duration / 3600.0 + if duration > 1 and total_attempts > 0: + dpm = (total_attempts / duration) * 60 + logger.info(f"Actual overall download rate: {dpm:.2f} attempts/minute") + + total_bytes = sum(e.get('downloaded_bytes', 0) for e in download_events if e['success']) + if total_bytes > 0: + logger.info(f"Total data downloaded: {sp_utils.format_size(total_bytes)}") + + if failures > 0: + error_counts = collections.Counter( + e.get('error_type', 'Unknown') + for e in download_events if not e['success'] and e.get('error_type') != 'Cancelled' + ) + logger.info("Failure breakdown:") + for error_type, count in sorted(error_counts.items()): + logger.info(f" - {error_type}: {count}") + + # Add profile to each download event for easier counting + for i, e in enumerate(download_events): + e['profile'] = download_profiles[i] + + profile_counts = collections.Counter(e.get('profile') for e in download_events if e.get('profile')) + if profile_counts: + logger.info("Downloads per profile:") + for profile, count in sorted(profile_counts.items()): + rate_per_hour = (count / duration_hours) if duration_hours > 0 else 0 + logger.info(f" - {profile}: {count} attempts (avg this run: {rate_per_hour:.2f}/hour)") + + proxy_counts = collections.Counter(e.get('proxy_url') for e in download_events if e.get('proxy_url')) + if proxy_counts: + logger.info("Downloads per proxy:") + for proxy, count in sorted(proxy_counts.items()): + rate_per_hour = (count / duration_hours) if duration_hours > 0 else 0 + logger.info(f" - {proxy}: {count} attempts (avg this run: {rate_per_hour:.2f}/hour)") + + logger.info("--------------------") diff --git a/ytops_client/stress_policy/utils.py b/ytops_client/stress_policy/utils.py new file mode 100644 index 0000000..a9f1932 --- /dev/null +++ b/ytops_client/stress_policy/utils.py @@ -0,0 +1,507 @@ +import collections.abc +import json +import logging +import os +import random +import re +import shlex +import sys +import time +from copy import deepcopy +from pathlib import Path +from urllib.parse import urlparse, parse_qs + +try: + import yaml +except ImportError: + print("PyYAML is not installed. Please install it with: pip install PyYAML", file=sys.stderr) + sys.exit(1) + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +# This makes the project root the parent directory of 'ytops_client' +_PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) + + +def get_video_id(url: str) -> str: + """Extracts a YouTube video ID from a URL.""" + match = re.search(r"v=([0-9A-Za-z_-]{11})", url) + if match: + return match.group(1) + match = re.search(r"youtu\.be\/([0-9A-Za-z_-]{11})", url) + if match: + return match.group(1) + if re.fullmatch(r'[0-9A-Za-z_-]{11}', url): + return url + return "unknown_video_id" + + +def get_display_name(path_or_url): + """Returns a clean name for logging, either a filename or a video ID.""" + if isinstance(path_or_url, Path): + return path_or_url.name + + path_str = str(path_or_url) + video_id = get_video_id(path_str) + if video_id != "unknown_video_id": + return video_id + + return Path(path_str).name + + +def format_size(b): + """Format size in bytes to human-readable string.""" + if b is None: + return 'N/A' + if b < 1024: + return f"{b}B" + elif b < 1024**2: + return f"{b/1024:.2f}KiB" + elif b < 1024**3: + return f"{b/1024**2:.2f}MiB" + else: + return f"{b/1024**3:.2f}GiB" + + +def flatten_dict(d, parent_key='', sep='.'): + """Flattens a nested dictionary.""" + items = {} + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.abc.MutableMapping): + items.update(flatten_dict(v, new_key, sep=sep)) + else: + items[new_key] = v + return items + + +def print_policy_overrides(policy): + """Prints all policy values as a single-line of --set arguments.""" + # We don't want to include the 'name' key in the overrides. + policy_copy = deepcopy(policy) + policy_copy.pop('name', None) + + flat_policy = flatten_dict(policy_copy) + + set_args = [] + for key, value in sorted(flat_policy.items()): + if value is None: + value_str = 'null' + elif isinstance(value, bool): + value_str = str(value).lower() + elif isinstance(value, (list, dict)): + # Use compact JSON for lists/dicts + value_str = json.dumps(value, separators=(',', ':')) + else: + value_str = str(value) + + # Use shlex.quote to handle spaces and special characters safely + set_args.append(f"--set {shlex.quote(f'{key}={value_str}')}") + + print(' '.join(set_args)) + + +def _config_dict_to_flags_file_content(config_dict: dict) -> str: + """Converts a dictionary of yt-dlp options to a string for a config file.""" + config_lines = [] + for key, value in config_dict.items(): + flag = f'--{key.replace("_", "-")}' + if isinstance(value, bool): + if value: + config_lines.append(flag) + elif isinstance(value, list): + # Special case for --use-extractors which takes a comma-separated list + if key == 'use-extractors': + config_lines.append(flag) + config_lines.append(','.join(map(str, value))) + else: # Assume other lists mean repeated flags + for item in value: + config_lines.append(flag) + config_lines.append(str(item)) + elif isinstance(value, dict): # Primarily for extractor-args + for sub_key, sub_value in value.items(): + if isinstance(sub_value, str) and ';' in sub_value: + # Support user-friendly format: semicolon-separated values + items = [item.strip() for item in sub_value.split(';')] + for item in items: + if item: # Avoid empty strings + config_lines.append(flag) + config_lines.append(f"{sub_key}:{item}") + elif isinstance(sub_value, list): + for item in sub_value: + config_lines.append(flag) + config_lines.append(f"{sub_key}:{item}") + else: + config_lines.append(flag) + config_lines.append(f"{sub_key}:{sub_value}") + else: + config_lines.append(flag) + value_str = str(value) + # yt-dlp config files support quoting arguments. + # Let's quote any string that contains spaces to be safe. + if isinstance(value, str) and ' ' in value_str: + value_str = f'"{value_str}"' + config_lines.append(value_str) + return '\n'.join(config_lines) + + +def _config_dict_to_cli_flags(config_dict: dict) -> list: + """Converts a dictionary of yt-dlp options to a list of command-line arguments.""" + args = [] + for key, value in config_dict.items(): + flag = f'--{key.replace("_", "-")}' + if isinstance(value, bool): + if value: + args.append(flag) + elif isinstance(value, list): + if key == 'use-extractors': + args.append(flag) + args.append(','.join(map(str, value))) + else: + for item in value: + args.append(flag) + args.append(str(item)) + elif isinstance(value, dict): + for sub_key, sub_value in value.items(): + if isinstance(sub_value, str) and ';' in sub_value: + items = [item.strip() for item in sub_value.split(';')] + for item in items: + if item: + args.append(flag) + args.append(f"{sub_key}:{item}") + elif isinstance(sub_value, list): + for item in sub_value: + args.append(flag) + args.append(f"{sub_key}:{item}") + else: + args.append(flag) + args.append(f"{sub_key}:{sub_value}") + else: + args.append(flag) + args.append(str(value)) + return args + + +def _parse_config_file_to_cli_args(content: str) -> list: + """ + Parses yt-dlp config file content into a list of command-line arguments. + This is a best-effort parser for logging purposes. + """ + args = [] + lines = content.splitlines() + for line in lines: + line = line.strip() + if not line or line.startswith('#'): + continue + + # yt-dlp config files can have options and values on separate lines. + # This simple parser assumes one argument per line (e.g., '--proxy', 'http://...'). + # shlex.split is good for handling quoted arguments on a single line. + try: + parts = shlex.split(line) + args.extend(parts) + except ValueError: + # Fallback for unterminated quotes or other shlex errors + args.extend(line.split()) + return args + + +def check_url_expiry(url: str, time_shift_minutes: int): + """ + Checks a single URL for expiration, considering a time shift. + Returns a tuple: (status, time_left_seconds) + status can be 'valid', 'expired', or 'no_expiry_info'. + A URL is considered 'expired' if it has expired or will expire within the time_shift_minutes. + """ + now = time.time() + parsed = urlparse(url) + query_params = parse_qs(parsed.query) + expire_ts_str = query_params.get('expire', [None])[0] + + if not expire_ts_str or not expire_ts_str.isdigit(): + return 'no_expiry_info', float('inf') + + expire_ts = int(expire_ts_str) + time_left = expire_ts - now + + if time_left <= time_shift_minutes * 60: + return 'expired', time_left + + return 'valid', time_left + + +def generate_user_agent_from_policy(policy): + """ + Generates a User-Agent string based on settings in the policy. + Checks 'direct_docker_cli_policy' and 'direct_batch_cli_policy'. + Falls back to a default if no policy is provided. + """ + # Check both possible policy keys for the settings. + direct_policy = policy.get('direct_docker_cli_policy', {}) or policy.get('direct_batch_cli_policy', {}) + template = direct_policy.get('user_agent_template') + version_range = direct_policy.get('user_agent_version_range') + + if template and version_range and isinstance(version_range, list) and len(version_range) == 2: + major_version = random.randint(version_range[0], version_range[1]) + return template.format(major_version=major_version) + + # Fallback to a generic UA if policy is not configured + return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36' + + +def update_dict(d, u): + """Recursively update a dictionary.""" + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = update_dict(d.get(k, {}), v) + else: + d[k] = v + return d + + +def load_policy(policy_file, policy_name=None): + """Load a policy from a YAML file.""" + logger = logging.getLogger(__name__) + try: + with open(policy_file, 'r', encoding='utf-8') as f: + # If a policy name is given, look for that specific document + if policy_name: + docs = list(yaml.safe_load_all(f)) + for doc in docs: + if isinstance(doc, dict) and doc.get('name') == policy_name: + return doc + raise ValueError(f"Policy '{policy_name}' not found in {policy_file}") + # Otherwise, load the first document + return yaml.safe_load(f) + except (IOError, yaml.YAMLError, ValueError) as e: + logger.error(f"Failed to load policy file {policy_file}: {e}") + sys.exit(1) + + +def apply_overrides(policy, overrides): + """Apply command-line overrides to the policy.""" + logger = logging.getLogger(__name__) + for override in overrides: + try: + key, value = override.split('=', 1) + keys = key.split('.') + + # Try to parse as JSON/YAML if it looks like a list or dict, otherwise treat as scalar + if (value.startswith('[') and value.endswith(']')) or \ + (value.startswith('{') and value.endswith('}')): + try: + value = yaml.safe_load(value) + except yaml.YAMLError: + logger.warning(f"Could not parse override value '{value}' as YAML. Treating as a string.") + else: + # Try to auto-convert scalar value type + if value.lower() == 'true': + value = True + elif value.lower() == 'false': + value = False + elif value.lower() == 'null': + value = None + else: + try: + value = int(value) + except ValueError: + try: + value = float(value) + except ValueError: + pass # Keep as string + + d = policy + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + except ValueError: + logger.error(f"Invalid override format: '{override}'. Use 'key.subkey=value'.") + sys.exit(1) + return policy + + +def display_effective_policy(policy, name, sources=None, profile_names=None, original_workers_setting=None): + """Prints a human-readable summary of the effective policy.""" + logger = logging.getLogger(__name__) + logger.info(f"--- Effective Policy: {name} ---") + settings = policy.get('settings', {}) + exec_control = policy.get('execution_control', {}) + orchestration_mode = settings.get('orchestration_mode') + + logger.info(f"Mode: {settings.get('mode', 'full_stack')}") + if profile_names: + num_profiles = len(profile_names) + logger.info(f"Profiles found: {num_profiles}") + if num_profiles > 0: + # Sort profiles for consistent display, show top 10 + sorted_profiles = sorted(profile_names) + profiles_to_show = sorted_profiles[:10] + logger.info(f" (e.g., {', '.join(profiles_to_show)}{'...' if num_profiles > 10 else ''})") + + workers_display = str(exec_control.get('workers', 1)) + if original_workers_setting == 'auto': + workers_display = f"auto (calculated: {workers_display})" + logger.info(f"Workers: {workers_display}") + + sleep_cfg = exec_control.get('sleep_between_tasks', {}) + sleep_min = sleep_cfg.get('min_seconds') + if sleep_min is not None: + sleep_max = sleep_cfg.get('max_seconds') + if sleep_max is None: + sleep_max = sleep_min + + if sleep_max < sleep_min: + logger.info(f"Sleep between tasks (per worker): {sleep_max}s (fixed; max < min)") + elif sleep_max > sleep_min: + logger.info(f"Sleep between tasks (per worker): {sleep_min}-{sleep_max}s (random)") + else: + logger.info(f"Sleep between tasks (per worker): {sleep_min}s") + + run_until = exec_control.get('run_until', {}) + run_conditions = [] + if 'minutes' in run_until: + run_conditions.append(f"for {run_until['minutes']} minutes") + if 'requests' in run_until: + run_conditions.append(f"until {run_until['requests']} total requests") + if 'cycles' in run_until: + run_conditions.append(f"for {run_until['cycles']} cycles") + + if run_conditions: + logger.info(f"Run condition: Stop after running {' or '.join(run_conditions)}.") + if 'minutes' in run_until and 'cycles' not in run_until: + logger.info("Will continuously cycle through sources until time limit is reached.") + elif orchestration_mode in ['direct_batch_cli', 'direct_download_cli', 'direct_docker_cli']: + logger.info("Run condition: Stop after all source URLs/tasks have been processed once.") + else: + logger.warning("WARNING: No 'run_until' condition is set. This test will run forever unless stopped manually.") + logger.info("Run condition: No stop condition defined, will run indefinitely (until Ctrl+C).") + + # --- Rate Calculation --- + if sources: + workers = exec_control.get('workers', 1) + num_sources = len(profile_names) if profile_names else len(sources) + + min_sleep = sleep_cfg.get('min_seconds', 0) + max_sleep = sleep_cfg.get('max_seconds') or min_sleep + avg_sleep_per_task = (min_sleep + max_sleep) / 2 + + # Assume an average task duration. This is a major assumption. + mode = settings.get('mode', 'full_stack') + assumptions = exec_control.get('assumptions', {}) + + assumed_fetch_duration = 0 + if mode in ['full_stack', 'fetch_only']: + assumed_fetch_duration = assumptions.get('fetch_task_duration', 12 if mode == 'full_stack' else 3) + + assumed_download_duration = 0 + if mode in ['full_stack', 'download_only']: + # This assumes the total time to download all formats for a single source. + assumed_download_duration = assumptions.get('download_task_duration', 60) + + total_assumed_task_duration = assumed_fetch_duration + assumed_download_duration + + if workers > 0 and total_assumed_task_duration > 0: + total_time_per_task = total_assumed_task_duration + avg_sleep_per_task + tasks_per_minute_per_worker = 60 / total_time_per_task + total_tasks_per_minute = tasks_per_minute_per_worker * workers + + logger.info("--- Rate Estimation ---") + logger.info(f"Source count: {num_sources}") + if mode in ['full_stack', 'fetch_only']: + logger.info(f"Est. fetch time per source: {assumed_fetch_duration}s (override via execution_control.assumptions.fetch_task_duration)") + if mode in ['full_stack', 'download_only']: + logger.info(f"Est. download time per source: {assumed_download_duration}s (override via execution_control.assumptions.download_task_duration)") + logger.info(" (Note: This assumes total time for all formats per source)") + + logger.info(f"Est. sleep per task: {avg_sleep_per_task:.1f}s") + logger.info(f"==> Expected task rate: ~{total_tasks_per_minute:.2f} tasks/minute ({workers} workers * {tasks_per_minute_per_worker:.2f} tasks/min/worker)") + + target_rate_cfg = exec_control.get('target_rate', {}) + target_reqs = target_rate_cfg.get('requests') + target_mins = target_rate_cfg.get('per_minutes') + if target_reqs and target_mins: + target_rpm = target_reqs / target_mins + logger.info(f"Target rate: {target_rpm:.2f} tasks/minute") + if total_tasks_per_minute < target_rpm * 0.8: + logger.warning("Warning: Expected rate is significantly lower than target rate.") + logger.warning("Consider increasing workers, reducing sleep, or checking task performance.") + + logger.info("---------------------------------") + time.sleep(2) # Give user time to read + + +def list_policies(): + """Scans the policies directory and prints a list of available policies.""" + policies_dir = os.path.join(_PROJECT_ROOT, 'policies') + + if not os.path.isdir(policies_dir): + print(f"Error: Policies directory not found at '{policies_dir}'", file=sys.stderr) + return 1 + + print("Available Policies:") + print("=" * 20) + + policy_files = sorted(Path(policies_dir).glob('*.yaml')) + if not policy_files: + print("No policy files (.yaml) found.") + return 0 + + for policy_file in policy_files: + print(f"\n--- File: {policy_file.relative_to(_PROJECT_ROOT)} ---") + try: + with open(policy_file, 'r', encoding='utf-8') as f: + content = f.read() + + # Split into documents. The separator is a line that is exactly '---'. + documents = re.split(r'^\-\-\-$', content, flags=re.MULTILINE) + + found_any_in_file = False + for doc in documents: + doc = doc.strip() + if not doc: + continue + + lines = doc.split('\n') + policy_name = None + description_lines = [] + + # Find name and description + for i, line in enumerate(lines): + if line.strip().startswith('name:'): + policy_name = line.split(':', 1)[1].strip() + + # Look backwards for comments + j = i - 1 + current_desc_block = [] + while j >= 0 and lines[j].strip().startswith('#'): + comment = lines[j].strip().lstrip('#').strip() + current_desc_block.insert(0, comment) + j -= 1 + + if current_desc_block: + description_lines = current_desc_block + break + + if policy_name: + found_any_in_file = True + print(f" - Name: {policy_name}") + if description_lines: + # Heuristic to clean up "Policy: " prefix + if description_lines[0].lower().startswith('policy:'): + description_lines[0] = description_lines[0][len('policy:'):].strip() + + print(f" Description: {description_lines[0]}") + for desc_line in description_lines[1:]: + print(f" {desc_line}") + else: + print(" Description: (No description found)") + + relative_path = policy_file.relative_to(_PROJECT_ROOT) + print(f" Usage: --policy {relative_path} --policy-name {policy_name}") + + if not found_any_in_file: + print(" (No named policies found in this file)") + + except Exception as e: + print(f" Error parsing {policy_file.name}: {e}") + + return 0 diff --git a/ytops_client/stress_policy/workers.py b/ytops_client/stress_policy/workers.py new file mode 100644 index 0000000..42cb176 --- /dev/null +++ b/ytops_client/stress_policy/workers.py @@ -0,0 +1,2615 @@ +import collections +import json +import logging +import os +import random +import re +import shlex +import sys +import tempfile +import shutil +import threading +import time +from copy import deepcopy +from datetime import datetime, timezone +from pathlib import Path + +from . import utils as sp_utils +from .process_runners import run_command, run_docker_container, get_worker_id +from ..profile_manager_tool import ProfileManager + + +logger = logging.getLogger(__name__) + +# --- Auth Profile Manager Cache --- +# This is a cache to hold a ProfileManager instance for the auth simulation. +# It's needed so that download workers can decrement the correct pending download counter. +_auth_manager_cache = {} +_auth_manager_lock = threading.Lock() + +def _get_auth_manager(current_manager, auth_env: str): + """ + Gets a ProfileManager instance for a specific auth simulation environment. + It uses the auth_env provided from the info.json metadata. + """ + with _auth_manager_lock: + if not auth_env: + return None + + if auth_env in _auth_manager_cache: + return _auth_manager_cache[auth_env] + + logger.info(f"Creating new ProfileManager for auth simulation env: '{auth_env}'") + try: + # Re-use connection settings from the current manager + redis_conn_kwargs = current_manager.redis.connection_pool.connection_kwargs + auth_key_prefix = f"{auth_env}_profile_mgmt_" + + auth_manager = ProfileManager( + redis_host=redis_conn_kwargs.get('host'), + redis_port=redis_conn_kwargs.get('port'), + redis_password=redis_conn_kwargs.get('password'), + key_prefix=auth_key_prefix + ) + _auth_manager_cache[auth_env] = auth_manager + return auth_manager + except Exception as e: + logger.error(f"Failed to create ProfileManager for auth env '{auth_env}': {e}") + return None + +def _run_download_logic(source, info_json_content, policy, state_manager, args, running_processes, process_lock, profile_name=None, profile_manager_instance=None): + """Shared download logic for a single info.json.""" + proxy_url = None + if info_json_content: + try: + info_data = json.loads(info_json_content) + proxy_url = info_data.get('_proxy_url') + except (json.JSONDecodeError, AttributeError): + logger.warning(f"[{sp_utils.get_display_name(source)}] Could not parse info.json to get proxy for download controls.") + + d_policy = policy.get('download_policy', {}) + temp_download_dir = None + local_policy = policy + + if d_policy.get('output_to_airflow_ready_dir'): + local_policy = deepcopy(policy) + temp_download_dir = tempfile.mkdtemp(prefix='stress-dl-video-') + # This modification is safe because it's on a deep copy + local_policy.setdefault('download_policy', {})['output_dir'] = temp_download_dir + logger.info(f"[{sp_utils.get_display_name(source)}] Using temporary download directory: {temp_download_dir}") + + try: + if not state_manager.check_and_update_download_rate_limit(proxy_url, local_policy): + return [] + + state_manager.wait_for_proxy_cooldown(proxy_url, local_policy) + results = process_info_json_cycle(source, info_json_content, local_policy, state_manager, args, running_processes, process_lock, proxy_url=proxy_url, profile_name=profile_name, +profile_manager_instance=profile_manager_instance) + state_manager.update_proxy_finish_time(proxy_url) + + # --- Post-download logic for Airflow dir --- + if d_policy.get('output_to_airflow_ready_dir'): + for result in results: + if result.get('success') and result.get('downloaded_filepath'): + try: + video_id = result.get('video_id') + if not video_id: + # Fallback: extract from info.json content + try: + info_data = json.loads(info_json_content) + video_id = info_data.get('id') + except (json.JSONDecodeError, AttributeError): + video_id = None + + if not video_id: + logger.error(f"[{sp_utils.get_display_name(source)}] Could not find video ID in result for moving file.") + continue + + now = datetime.now() + rounded_minute = (now.minute // 10) * 10 + timestamp_str = now.strftime('%Y%m%dT%H') + f"{rounded_minute:02d}" + + base_path = d_policy.get('airflow_ready_dir_base_path', 'downloadfiles/videos/ready') + if not os.path.isabs(base_path): + base_path = os.path.join(sp_utils._PROJECT_ROOT, base_path) + final_dir_base = os.path.join(base_path, timestamp_str) + final_dir_path = os.path.join(final_dir_base, video_id) + + os.makedirs(final_dir_path, exist_ok=True) + + downloaded_file = result['downloaded_filepath'] + if os.path.exists(downloaded_file): + shutil.move(downloaded_file, final_dir_path) + logger.info(f"[{sp_utils.get_display_name(source)}] Moved media file to {final_dir_path}") + + # The source is the path to the task/info.json file. + if isinstance(source, Path) and source.exists(): + new_info_json_name = f"info_{video_id}.json" + dest_info_json_path = os.path.join(final_dir_path, new_info_json_name) + shutil.copy(source, dest_info_json_path) + logger.info(f"[{sp_utils.get_display_name(source)}] Copied info.json to {dest_info_json_path}") + except Exception as e: + logger.error(f"[{sp_utils.get_display_name(source)}] Failed to move downloaded file to Airflow ready directory: {e}") + + return results + finally: + if temp_download_dir and os.path.exists(temp_download_dir): + shutil.rmtree(temp_download_dir) + logger.info(f"[{sp_utils.get_display_name(source)}] Cleaned up temporary directory: {temp_download_dir}") + + +def process_profile_task(profile_name, file_list, policy, state_manager, cycle_num, args, running_processes, process_lock, profile_manager_instance=None): + """Worker task for a profile, processing its files sequentially.""" + logger.info(f"Worker {get_worker_id()} starting task for profile '{profile_name}' with {len(file_list)} files.") + all_results = [] + for i, file_path in enumerate(file_list): + if state_manager.shutdown_event.is_set(): + logger.info(f"Shutdown requested, stopping task for profile '{profile_name}'.") + break + + try: + with open(file_path, 'r', encoding='utf-8') as f: + info_json_content = f.read() + except (IOError, FileNotFoundError) as e: + logger.error(f"[{sp_utils.get_display_name(file_path)}] Could not read info.json file: {e}") + continue # Skip this file + + results_for_file = _run_download_logic(file_path, info_json_content, policy, state_manager, args, running_processes, process_lock, profile_name=profile_name, profile_manager_instance=profile_manager_instance) + all_results.extend(results_for_file) + + # Mark file as processed if configured. This works for both 'once' and 'continuous' modes. + settings = policy.get('settings', {}) + if settings.get('directory_scan_mode') == 'continuous': + state_manager.mark_file_as_processed(file_path) + + if settings.get('mark_processed_files'): + try: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + new_path = file_path.parent / f"{file_path.name}.{timestamp}.processed" + file_path.rename(new_path) + logger.info(f"Marked '{file_path.name}' as processed by renaming to '{new_path.name}'") + except (IOError, OSError) as e: + logger.error(f"Failed to rename processed file '{file_path.name}': {e}") + + # Check for stop conditions after processing each file + should_stop_profile = False + for result in results_for_file: + if not result['success']: + s_conditions = policy.get('stop_conditions', {}) + if s_conditions.get('on_failure') or \ + (s_conditions.get('on_http_403') and result['error_type'] == 'HTTP 403') or \ + (s_conditions.get('on_timeout') and result['error_type'] == 'Timeout'): + logger.info(f"Stopping further processing for profile '{profile_name}' due to failure.") + should_stop_profile = True + break + if should_stop_profile: + break + + # Apply sleep between tasks for this profile + if i < len(file_list) - 1: + exec_control = policy.get('execution_control', {}) + sleep_cfg = exec_control.get('sleep_between_tasks', {}) + sleep_min = sleep_cfg.get('min_seconds', 0) + sleep_max_val = sleep_cfg.get('max_seconds') + + if sleep_min > 0 or sleep_max_val is not None: + sleep_max = sleep_min if sleep_max_val is None else sleep_max_val + + sleep_duration = 0 + if sleep_max < sleep_min: + logger.warning(f"sleep_between_tasks: max_seconds ({sleep_max}s) is less than min_seconds ({sleep_min}s). Using max_seconds as fixed sleep duration.") + sleep_duration = sleep_max + elif sleep_max > sleep_min: + sleep_duration = random.uniform(sleep_min, sleep_max) + else: # equal + sleep_duration = sleep_min + + if sleep_duration > 0: + logger.debug(f"Profile '{profile_name}' sleeping for {sleep_duration:.2f}s before next file.") + # Interruptible sleep + sleep_end_time = time.time() + sleep_duration + while time.time() < sleep_end_time: + if state_manager.shutdown_event.is_set(): + break + time.sleep(0.2) + + return all_results + + +def run_download_worker(info_json_path, info_json_content, format_to_download, policy, args, running_processes, process_lock, state_manager, profile_name=None): + """ + Performs a single download attempt. Designed to be run in a worker thread. + """ + worker_id = get_worker_id() + display_name = sp_utils.get_display_name(info_json_path) + profile_log_part = f" [Profile: {profile_name}]" if profile_name else "" + log_prefix = f"[Worker {worker_id}]{profile_log_part} [{display_name} @ {format_to_download}]" + + download_policy = policy.get('download_policy', {}) + settings = policy.get('settings', {}) + downloader = download_policy.get('downloader') + + # Get script command from settings, with fallback to download_policy for old format. + script_cmd_str = settings.get('download_script') + if not script_cmd_str: + script_cmd_str = download_policy.get('script') + + if script_cmd_str: + download_cmd = shlex.split(script_cmd_str) + elif downloader == 'aria2c_rpc': + download_cmd = [sys.executable, '-m', 'ytops_client.cli', 'download', 'aria-rpc'] + elif downloader == 'native-cli': + download_cmd = [sys.executable, '-m', 'ytops_client.cli', 'download', 'cli'] + else: + # Default to the new native-py downloader if downloader is 'native-py' or not specified. + download_cmd = [sys.executable, '-m', 'ytops_client.cli', 'download', 'py'] + + download_cmd.extend(['-f', format_to_download]) + + if downloader == 'aria2c_rpc': + if download_policy.get('aria_host'): + download_cmd.extend(['--aria-host', str(download_policy['aria_host'])]) + if download_policy.get('aria_port'): + download_cmd.extend(['--aria-port', str(download_policy['aria_port'])]) + if download_policy.get('aria_secret'): + download_cmd.extend(['--aria-secret', str(download_policy['aria_secret'])]) + if download_policy.get('output_dir'): + download_cmd.extend(['--output-dir', str(download_policy['output_dir'])]) + if download_policy.get('aria_remote_dir'): + download_cmd.extend(['--remote-dir', str(download_policy['aria_remote_dir'])]) + if download_policy.get('aria_fragments_dir'): + download_cmd.extend(['--fragments-dir', str(download_policy['aria_fragments_dir'])]) + # For stress testing, waiting is the desired default to get a success/fail result. + # Allow disabling it by explicitly setting aria_wait: false in the policy. + if download_policy.get('aria_wait', True): + download_cmd.append('--wait') + + if download_policy.get('auto_merge_fragments'): + download_cmd.append('--auto-merge-fragments') + if download_policy.get('remove_fragments_after_merge'): + download_cmd.append('--remove-fragments-after-merge') + if download_policy.get('cleanup'): + download_cmd.append('--cleanup') + if download_policy.get('purge_on_complete'): + download_cmd.append('--purge-on-complete') + + downloader_args = download_policy.get('downloader_args') + proxy = download_policy.get('proxy') + if proxy: + # Note: proxy_rename is not supported for aria2c_rpc mode. + proxy_arg = f"--all-proxy {shlex.quote(str(proxy))}" + if downloader_args: + downloader_args = f"{downloader_args} {proxy_arg}" + else: + downloader_args = proxy_arg + + if downloader_args: + # For aria2c_rpc, the downloader_args value is passed directly to the script's --downloader-args option. + download_cmd.extend(['--downloader-args', downloader_args]) + elif downloader == 'native-cli': + # This is the logic for the legacy download_tool.py (yt-dlp CLI wrapper). + pause_seconds = download_policy.get('pause_before_download_seconds') + if pause_seconds and isinstance(pause_seconds, (int, float)) and pause_seconds > 0: + download_cmd.extend(['--pause', str(pause_seconds)]) + + if download_policy.get('continue_downloads'): + download_cmd.append('--download-continue') + + # Add proxy if specified directly in the policy + proxy = download_policy.get('proxy') + if proxy: + download_cmd.extend(['--proxy', str(proxy)]) + + proxy_rename = download_policy.get('proxy_rename') + if proxy_rename: + download_cmd.extend(['--proxy-rename', str(proxy_rename)]) + + extra_args = download_policy.get('extra_args') + if extra_args: + download_cmd.extend(shlex.split(extra_args)) + + # Note: 'downloader' here refers to yt-dlp's internal downloader, not our script. + # The policy key 'external_downloader' is more clear, but we support 'downloader' for backward compatibility. + ext_downloader = download_policy.get('external_downloader') or download_policy.get('downloader') + if ext_downloader and ext_downloader not in ['native-cli', 'native-py', 'aria2c_rpc']: + download_cmd.extend(['--downloader', str(ext_downloader)]) + + downloader_args = download_policy.get('downloader_args') + if downloader_args: + download_cmd.extend(['--downloader-args', str(downloader_args)]) + + if download_policy.get('merge_output_format'): + download_cmd.extend(['--merge-output-format', str(download_policy['merge_output_format'])]) + + if download_policy.get('merge_output_format'): + download_cmd.extend(['--merge-output-format', str(download_policy['merge_output_format'])]) + + if download_policy.get('cleanup'): + download_cmd.append('--cleanup') + else: + # This is the default logic for the new native-py downloader. + if download_policy.get('output_to_buffer'): + download_cmd.append('--output-buffer') + else: + # --output-dir is only relevant if not outputting to buffer. + if download_policy.get('output_dir'): + download_cmd.extend(['--output-dir', str(download_policy['output_dir'])]) + + if download_policy.get('config'): + download_cmd.extend(['--config', str(download_policy['config'])]) + + if download_policy.get('temp_path'): + download_cmd.extend(['--temp-path', str(download_policy['temp_path'])]) + if download_policy.get('continue_downloads'): + download_cmd.append('--download-continue') + + pause_seconds = download_policy.get('pause_before_download_seconds') + if pause_seconds and isinstance(pause_seconds, (int, float)) and pause_seconds > 0: + download_cmd.extend(['--pause', str(pause_seconds)]) + + proxy = download_policy.get('proxy') + if proxy: + download_cmd.extend(['--proxy', str(proxy)]) + + proxy_rename = download_policy.get('proxy_rename') + if proxy_rename: + download_cmd.extend(['--proxy-rename', str(proxy_rename)]) + + # The 'extra_args' from the policy are for the download script itself, not for yt-dlp. + # We need to split them and add them to the command. + extra_args = download_policy.get('extra_args') + if extra_args: + download_cmd.extend(shlex.split(extra_args)) + + # Pass through downloader settings for yt-dlp to use + # e.g. to tell yt-dlp to use aria2c as its backend + ext_downloader = download_policy.get('external_downloader') + if ext_downloader: + download_cmd.extend(['--downloader', str(ext_downloader)]) + + downloader_args = download_policy.get('downloader_args') + if downloader_args: + download_cmd.extend(['--downloader-args', str(downloader_args)]) + + if args.dummy: + # Create a copy to add the info.json path for logging, without modifying the original + log_cmd = list(download_cmd) + if isinstance(info_json_path, Path) and info_json_path.exists(): + log_cmd.extend(['--load-info-json', str(info_json_path)]) + else: + log_cmd.extend(['--load-info-json', '']) + + logger.info(f"{log_prefix} Dummy mode: simulating download...") + logger.info(f"{log_prefix} Dummy mode: Would run command: {' '.join(shlex.quote(s) for s in log_cmd)}") + + dummy_settings = policy.get('settings', {}).get('dummy_simulation_settings', {}) + min_seconds = dummy_settings.get('download_min_seconds', 1.0) + max_seconds = dummy_settings.get('download_max_seconds', 3.0) + failure_rate = dummy_settings.get('download_failure_rate', 0.0) + skipped_rate = dummy_settings.get('download_skipped_failure_rate', 0.0) + + sleep_duration = random.uniform(min_seconds, max_seconds) + logger.info(f"{log_prefix} Dummy mode: simulating download for {sleep_duration:.2f}s (from policy range {min_seconds}-{max_seconds}s).") + time.sleep(sleep_duration) # Simulate work + + rand_val = random.random() + should_fail_skipped = rand_val < skipped_rate + should_fail_fatal = not should_fail_skipped and rand_val < (skipped_rate + failure_rate) + + if should_fail_skipped: + logger.warning(f"{log_prefix} Dummy mode: Injecting simulated skipped download failure.") + return { + 'type': 'download', + 'path': str(info_json_path), + 'format': format_to_download, + 'success': False, + 'error_type': 'DummySkippedFailure', + 'details': 'FAIL (Dummy mode, skipped)', + 'downloaded_bytes': 0, + 'profile': profile_name, + 'downloaded_filepath': None, + 'is_tolerated_error': True + } + + if should_fail_fatal: + logger.warning(f"{log_prefix} Dummy mode: Injecting simulated fatal download failure.") + return { + 'type': 'download', + 'path': str(info_json_path), + 'format': format_to_download, + 'success': False, + 'error_type': 'DummyFailure', + 'details': 'FAIL (Dummy mode, fatal)', + 'downloaded_bytes': 0, + 'profile': profile_name, + 'downloaded_filepath': None + } + + downloaded_filepath = f'/dev/null/{display_name}.mp4' + if download_policy.get('output_to_airflow_ready_dir'): + output_dir = download_policy.get('output_dir') + if output_dir and os.path.isdir(output_dir): + try: + dummy_path_obj = Path(output_dir) / f"{display_name}.mp4" + dummy_path_obj.touch() + downloaded_filepath = str(dummy_path_obj) + logger.info(f"{log_prefix} Dummy mode: created dummy file for Airflow move: {downloaded_filepath}") + except OSError as e: + logger.error(f"{log_prefix} Dummy mode: failed to create dummy file in '{output_dir}': {e}") + + return { + 'type': 'download', + 'path': str(info_json_path), + 'format': format_to_download, + 'success': True, + 'error_type': None, + 'details': 'OK (Dummy mode)', + 'downloaded_bytes': random.randint(100000, 5000000), + 'profile': profile_name, + 'downloaded_filepath': downloaded_filepath + } + + logger.info(f"{log_prefix} Kicking off download process...") + + temp_info_file_path = None + try: + if isinstance(info_json_path, Path) and info_json_path.exists(): + # The info.json is already in a file, pass its path directly. + download_cmd.extend(['--load-info-json', str(info_json_path)]) + else: + # The info.json content is in memory, so write it to a temporary file. + import tempfile + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json', encoding='utf-8') as temp_f: + temp_f.write(info_json_content) + temp_info_file_path = temp_f.name + download_cmd.extend(['--load-info-json', temp_info_file_path]) + + cmd_str_for_log = ' '.join(shlex.quote(s) for s in download_cmd) + logger.info(f"{log_prefix} Running download command: {cmd_str_for_log}") + output_to_buffer = download_policy.get('output_to_buffer', False) + retcode, stdout, stderr = run_command( + download_cmd, + running_processes, + process_lock, + binary_stdout=output_to_buffer, + stream_output=getattr(args, 'print_downloader_log', False), + stream_prefix=f"{log_prefix} | " + ) + finally: + if temp_info_file_path and os.path.exists(temp_info_file_path): + os.unlink(temp_info_file_path) + + is_403_error = "HTTP Error 403" in stderr + is_timeout_error = "Read timed out" in stderr + output_to_buffer = download_policy.get('output_to_buffer', False) + + # Parse stdout to find the downloaded file path. + # The download scripts print the final path, sometimes with a prefix. + downloaded_filepath = None + if stdout and not output_to_buffer: + lines = stdout.strip().split('\n') + if lines: + last_line = lines[-1].strip() + + # Handle aria-rpc output format: "Download... successful: " + aria_match = re.search(r'successful: (.+)', last_line) + if aria_match: + path_from_aria = aria_match.group(1).strip() + if os.path.exists(path_from_aria): + downloaded_filepath = path_from_aria + else: + logger.warning(f"[{display_name}] Path from aria-rpc output does not exist: '{path_from_aria}'") + # Handle native-py/cli output format (just the path) + elif os.path.exists(last_line): + downloaded_filepath = last_line + + result = { + 'type': 'download', + 'path': str(info_json_path), + 'format': format_to_download, + 'success': retcode == 0, + 'error_type': None, + 'details': '', + 'downloaded_bytes': 0, + 'profile': profile_name, + 'downloaded_filepath': downloaded_filepath + } + + if retcode == 0: + details_str = "OK" + size_in_bytes = 0 + if output_to_buffer: + # The most accurate size is the length of the stdout buffer. + size_in_bytes = len(stdout) # stdout is bytes + details_str += f" (Buffered {sp_utils.format_size(size_in_bytes)})" + else: + size_match = re.search(r'\[download\]\s+100%\s+of\s+~?([0-9.]+)(B|KiB|MiB|GiB)', stderr) + if size_match: + value = float(size_match.group(1)) + unit = size_match.group(2) + multipliers = {"B": 1, "KiB": 1024, "MiB": 1024**2, "GiB": 1024**3} + size_in_bytes = int(value * multipliers.get(unit, 1)) + details_str += f" ({size_match.group(1)}{unit})" + + result['downloaded_bytes'] = size_in_bytes + result['details'] = details_str + else: + # Check for shutdown first. This is the most likely cause for an abrupt non-zero exit. + if state_manager.shutdown_event.is_set(): + result['error_type'] = 'Cancelled' + result['details'] = 'Task cancelled during shutdown.' + else: + # Check both stdout and stderr for error messages, as logging might be directed to stdout. + full_output = f"{stdout}\n{stderr}" + # Look for common error indicators from both yt-dlp and our scripts. + error_lines = [ + line for line in full_output.strip().split('\n') + if 'ERROR:' in line or ' - ERROR - ' in line or 'DownloadError:' in line or 'Traceback' in line + ] + if error_lines: + # Try to get the most specific part of the error. + details = error_lines[-1].strip() + # Our log format is "timestamp - logger - LEVEL - message" + if ' - ERROR - ' in details: + details = details.split(' - ERROR - ', 1)[-1] + result['details'] = details + else: + # Fallback to last non-empty line of stderr if no explicit "ERROR" line found + stderr_lines = [line for line in stderr.strip().split('\n') if line.strip()] + result['details'] = stderr_lines[-1].strip() if stderr_lines else "Unknown error" + + if is_403_error: + result['error_type'] = 'HTTP 403' + elif is_timeout_error: + result['error_type'] = 'Timeout' + else: + result['error_type'] = f'Exit Code {retcode}' + + return result + + +def process_info_json_cycle(path, content, policy, state_manager, args, running_processes, process_lock, proxy_url=None, profile_name=None, profile_manager_instance=None): + """ + Processes one info.json file for one cycle, downloading selected formats. + """ + results = [] + display_name = sp_utils.get_display_name(path) + d_policy = policy.get('download_policy', {}) + s_conditions = policy.get('stop_conditions', {}) + + try: + info_data = json.loads(content) + + # If the task file specifies a format, use it instead of the policy's format list. + # This is for granular tasks generated by the task-generator tool. + if '_ytops_download_format' in info_data: + format_selection = info_data['_ytops_download_format'] + logger.info(f"[{display_name}] Using format '{format_selection}' from task file.") + else: + format_selection = d_policy.get('formats', '') + available_formats = [f['format_id'] for f in info_data.get('formats', [])] + if not available_formats: + logger.warning(f"[{display_name}] No formats found in info.json. Skipping.") + return [] + + formats_to_test = [] + if format_selection == 'all': + formats_to_test = available_formats + elif format_selection.startswith('random:'): + percent = float(format_selection.split(':')[1].rstrip('%')) + count = max(1, int(len(available_formats) * (percent / 100.0))) + formats_to_test = random.sample(available_formats, k=count) + elif format_selection.startswith('random_from:'): + choices = [f.strip() for f in format_selection.split(':', 1)[1].split(',')] + valid_choices = [f for f in choices if f in available_formats] + if valid_choices: + formats_to_test = [random.choice(valid_choices)] + else: + # If the format selection contains complex selector characters (other than comma), + # treat the entire string as a single format selector for yt-dlp to interpret. + # Otherwise, split by comma to test each specified format ID individually. + if any(c in format_selection for c in '/+[]()'): + requested_formats = [format_selection] + else: + requested_formats = [f.strip() for f in format_selection.split(',') if f.strip()] + + formats_to_test = [] + selector_keywords = ('best', 'worst', 'bestvideo', 'bestaudio') + + for req_fmt in requested_formats: + # Treat as a selector and pass through if it contains special characters + # or starts with a known selector keyword. + if any(c in req_fmt for c in '/+[]()') or req_fmt.startswith(selector_keywords): + formats_to_test.append(req_fmt) + continue + + # Otherwise, treat as a specific format ID that must exist. + # Check for exact match first. + if req_fmt in available_formats: + formats_to_test.append(req_fmt) + continue + + # If no exact match, check for formats that start with this ID + '-' and then digits + # e.g., req_fmt '140' should match '140-0' but not '140-something'. + prefix_match_re = re.compile(rf'^{re.escape(req_fmt)}-\d+$') + first_match = next((af for af in available_formats if prefix_match_re.match(af)), None) + + if first_match: + logger.info(f"[{display_name}] Requested format '{req_fmt}' not found. Using first available match: '{first_match}'.") + formats_to_test.append(first_match) + else: + logger.warning(f"[{display_name}] Requested format '{req_fmt}' not found in available formats and is not a recognized selector. Skipping this format.") + + except json.JSONDecodeError: + logger.error(f"[{display_name}] Failed to parse info.json. Skipping.") + return [] + + for i, format_id in enumerate(formats_to_test): + if state_manager.shutdown_event.is_set(): + logger.info(f"Shutdown requested, stopping further format tests for {display_name}.") + break + + # Check if the format URL is expired before attempting to download + format_details = next((f for f in info_data.get('formats', []) if f.get('format_id') == format_id), None) + + # If format_id is a complex selector, it won't be found directly. As a heuristic, + # check the expiration of the first format URL, as they typically share the same expiration. + if not format_details and any(c in format_id for c in '/+[]()'): + available_formats_list = info_data.get('formats', []) + if available_formats_list: + format_details = available_formats_list[0] + + # The check is enabled by default and can be disabled via policy. + if d_policy.get('check_url_expiration', True) and format_details and 'url' in format_details: + url_to_check = format_details['url'] + time_shift_minutes = d_policy.get('expire_time_shift_minutes', 0) + status, time_left_seconds = sp_utils.check_url_expiry(url_to_check, time_shift_minutes) + + logger.debug(f"[{display_name}] URL expiration check for format '{format_id}': status={status}, time_left={time_left_seconds:.0f}s") + + if status == 'expired': + details = "Download URL is expired" + if time_shift_minutes > 0 and time_left_seconds > 0: + logger.warning(f"[{display_name}] Skipping format '{format_id}' because its URL will expire in {time_left_seconds/60:.1f}m (within {time_shift_minutes}m time-shift).") + details = f"URL will expire within {time_shift_minutes}m time-shift" + else: + logger.warning(f"[{display_name}] Skipping format '{format_id}' because its URL is expired.") + + result = { + 'type': 'download', 'path': str(path), 'format': format_id, + 'success': False, 'error_type': 'Skipped (Expired URL)', + 'details': details, 'downloaded_bytes': 0, 'is_tolerated_error': True + } + if proxy_url: + result['proxy_url'] = proxy_url + + if profile_manager_instance and profile_name: + profile_manager_instance.record_activity(profile_name, 'tolerated_error') + + state_manager.log_event(result) + results.append(result) + continue # Move to the next format + + elif status == 'no_expiry_info': + logger.debug(f"[{display_name}] No valid 'expire' parameter found in format URL for '{format_id}'. Skipping expiration check.") + + result = run_download_worker(path, content, format_id, policy, args, running_processes, process_lock, state_manager, profile_name=profile_name) + if 'id' in info_data: + result['video_id'] = info_data['id'] + if proxy_url: + result['proxy_url'] = proxy_url + + # Record download attempt/error if a profile is being used. + if profile_manager_instance and profile_name: + if result.get('success'): + profile_manager_instance.record_activity(profile_name, 'download') + elif result.get('error_type') == 'Cancelled': + pass # Do not record cancellations + elif result.get('is_tolerated_error'): + profile_manager_instance.record_activity(profile_name, 'tolerated_error') + else: + profile_manager_instance.record_activity(profile_name, 'download_error') + + state_manager.log_event(result) + results.append(result) + + worker_id = get_worker_id() + status = "SUCCESS" if result['success'] else f"FAILURE ({result['error_type']})" + profile_log_part = f" [Profile: {profile_name}]" if profile_name else "" + logger.info(f"[Worker {worker_id}]{profile_log_part} Result for {display_name} (format {format_id}): {status} - {result.get('details', 'OK')}") + + if not result['success']: + if s_conditions.get('on_failure') or \ + (s_conditions.get('on_http_403') and result['error_type'] == 'HTTP 403') or \ + (s_conditions.get('on_timeout') and result['error_type'] == 'Timeout'): + logger.info(f"Stopping further format tests for {display_name} in this cycle due to failure.") + break + + sleep_cfg = d_policy.get('sleep_between_formats', {}) + sleep_min = sleep_cfg.get('min_seconds', 0) + if sleep_min > 0 and i < len(formats_to_test) - 1: + sleep_max = sleep_cfg.get('max_seconds') or sleep_min + if sleep_max > sleep_min: + sleep_duration = random.uniform(sleep_min, sleep_max) + else: + sleep_duration = sleep_min + + logger.debug(f"Sleeping for {sleep_duration:.2f}s between formats for {display_name}.") + # Interruptible sleep + sleep_end_time = time.time() + sleep_duration + while time.time() < sleep_end_time: + if state_manager.shutdown_event.is_set(): + break + time.sleep(0.2) + + return results + + +def run_throughput_worker(worker_id, policy, state_manager, args, profile_manager_instance, running_processes, process_lock): + """A persistent worker for the 'throughput' orchestration mode.""" + owner_id = f"throughput-worker-{worker_id}" + settings = policy.get('settings', {}) + exec_control = policy.get('execution_control', {}) + d_policy = policy.get('download_policy', {}) + + profile_prefix = d_policy.get('profile_prefix') + if not profile_prefix: + logger.error(f"[Worker {worker_id}] Throughput mode requires 'download_policy.profile_prefix'. Worker exiting.") + return [] + + no_task_streak = 0 + + while not state_manager.shutdown_event.is_set(): + locked_profile = None + claimed_task_path = None + try: + # 0. If no tasks were found previously, pause briefly. + if no_task_streak > 0: + polling_interval = exec_control.get('worker_polling_interval_seconds', 1) + logger.info(f"[Worker {worker_id}] No tasks found in previous attempt(s). Pausing for {polling_interval}s. (Streak: {no_task_streak})") + time.sleep(polling_interval) + if state_manager.shutdown_event.is_set(): continue + + # 1. Find a task and lock its associated profile + locked_profile, claimed_task_path = find_task_and_lock_profile( + profile_manager_instance, owner_id, profile_prefix, policy, worker_id + ) + + if not locked_profile: + # No task/profile combo was available. + no_task_streak += 1 + polling_interval = exec_control.get('worker_polling_interval_seconds', 1) + logger.info(f"[Worker {worker_id}] No available tasks found for any active profiles. Pausing for {polling_interval}s.") + time.sleep(polling_interval) + continue + + profile_name = locked_profile['name'] + + # We have a task and a lock. + if claimed_task_path: + no_task_streak = 0 # Reset streak + # 3. Process the task + try: + with open(claimed_task_path, 'r', encoding='utf-8') as f: + info_json_content = f.read() + except (IOError, FileNotFoundError) as e: + logger.error(f"[{sp_utils.get_display_name(claimed_task_path)}] Could not read claimed task file: {e}") + # Unlock profile and continue, file might be corrupted + profile_manager_instance.unlock_profile(profile_name, owner=owner_id) + locked_profile = None + # Clean up the bad file + try: claimed_task_path.unlink() + except OSError: pass + continue + + # The locked profile's proxy MUST be used for the download. + local_policy = deepcopy(policy) + local_policy.setdefault('download_policy', {})['proxy'] = locked_profile['proxy'] + + _run_download_logic( + source=claimed_task_path, + info_json_content=info_json_content, + policy=local_policy, + state_manager=state_manager, + args=args, + running_processes=running_processes, + process_lock=process_lock, + profile_name=profile_name, + profile_manager_instance=profile_manager_instance + ) + + # 4. Clean up the processed task file + try: + os.remove(claimed_task_path) + logger.debug(f"[{sp_utils.get_display_name(claimed_task_path)}] Removed processed task file.") + except OSError as e: + logger.error(f"Failed to remove processed task file '{claimed_task_path}': {e}") + else: + # This case should not be reached with the new task-first locking logic. + # If it is, it means find_task_and_lock_profile returned a profile but no task. + logger.warning(f"[Worker {worker_id}] Inconsistent state: locked profile '{profile_name}' but no task was claimed. Unlocking and continuing.") + + except Exception as e: + logger.error(f"[Worker {worker_id}] An unexpected error occurred in the worker loop: {e}", exc_info=True) + time.sleep(5) # Pause before retrying to avoid spamming errors + finally: + if locked_profile: + # 5. Unlock the profile. Only apply cooldown if a task was processed. + cooldown = None + if claimed_task_path: + # Enforcer is the only point where we configure to apply different policies, + # since we might restart enforcer, but won't restart stress-policy working on auth and downloads simultaneously. + # This is like applying a policy across multiple workers/machines without needing to restart each of them. + # DESIGN: The cooldown duration is not configured in the worker's policy. + # Instead, it is read from a central Redis key. This key is set by the + # policy-enforcer, making the enforcer the single source of truth for + # this policy. This allows changing the cooldown behavior without + # restarting the workers. + cooldown_config = profile_manager_instance.get_config('unlock_cooldown_seconds') + if cooldown_config: + try: + val = json.loads(cooldown_config) + if isinstance(val, list) and len(val) == 2 and val[0] < val[1]: + cooldown = random.randint(val[0], val[1]) + elif isinstance(val, int): + cooldown = val + except (json.JSONDecodeError, TypeError): + if cooldown_config.isdigit(): + cooldown = int(cooldown_config) + + if cooldown: + logger.info(f"[Worker {worker_id}] Putting profile '{locked_profile['name']}' into COOLDOWN for {cooldown}s.") + + profile_manager_instance.unlock_profile( + locked_profile['name'], + owner=owner_id, + rest_for_seconds=cooldown + ) + locked_profile = None + + # 6. Throughput is now controlled by the enforcer via the profile's + # 'unlock_cooldown_seconds' policy, which puts the profile into a + # RESTING state. The worker does not need to sleep here and can + # immediately try to lock a new profile to maximize throughput. + + logger.info(f"[Worker {worker_id}] Worker loop finished.") + return [] # This function doesn't return results directly + + +def _post_process_and_move_info_json(file_path, profile_name, proxy_url, policy, worker_id, profile_manager_instance=None): + """Helper to post-process a single info.json file and move it to the final directory.""" + direct_policy = policy.get('direct_docker_cli_policy', {}) + settings = policy.get('settings', {}) + save_dir = settings.get('save_info_json_dir') + if not save_dir: + return False + + video_id = "unknown" + try: + # Use a short delay and retry mechanism to handle cases where the file is not yet fully written. + for attempt in range(3): + try: + with open(file_path, 'r+', encoding='utf-8') as f: + info_data = json.load(f) + video_id = info_data.get('id', 'unknown') + env_name = profile_manager_instance.key_prefix.replace('_profile_mgmt_', '') if profile_manager_instance else 'unknown' + info_data['_ytops_metadata'] = { + 'profile_name': profile_name, + 'proxy_url': proxy_url, + 'generation_timestamp_utc': datetime.now(timezone.utc).isoformat(), + 'auth_env': env_name + } + f.seek(0) + json.dump(info_data, f, indent=2) + f.truncate() + break # Success + except (json.JSONDecodeError, IOError) as e: + if attempt < 2: + time.sleep(0.2) + else: + raise e + + final_path = Path(save_dir) / file_path.name + rename_template = direct_policy.get('rename_file_template') + if rename_template: + sanitized_proxy = re.sub(r'[:/]', '_', proxy_url) + new_name = rename_template.format( + video_id=video_id, profile_name=profile_name, proxy=sanitized_proxy + ) + final_path = Path(save_dir) / new_name + + # Use rename for atomic move + os.rename(str(file_path), str(final_path)) + logger.info(f"[Worker {worker_id}] Post-processed and moved info.json to '{final_path}'") + return True + except (IOError, json.JSONDecodeError, OSError) as e: + logger.error(f"[Worker {worker_id}] Error post-processing '{file_path.name}' (video: {video_id}): {e}") + return False + + +def find_task_and_lock_profile(profile_manager, owner_id, profile_prefix, policy, worker_id): + """ + Scans for an available task and locks the specific ACTIVE profile that generated it. + This preserves a 1-to-1 relationship between a profile and its tasks. + Returns a tuple of (locked_profile_dict, claimed_task_path_obj) or (None, None). + """ + settings = policy.get('settings', {}) + info_json_dir = settings.get('info_json_dir') + if not info_json_dir: + return None, None + + # 1. Get ACTIVE profiles from Redis. + active_profiles = profile_manager.list_profiles(state_filter='ACTIVE') + active_profile_names = {p['name'] for p in active_profiles if p['name'].startswith(profile_prefix)} + + if not active_profile_names: + return None, None + + # 2. Get all available task files. + try: + task_files = list(Path(info_json_dir).glob('*.json')) + except FileNotFoundError: + logger.warning(f"Info JSON directory not found during scan: {info_json_dir}") + return None, None + + if not task_files: + return None, None + + profile_regex_str = settings.get('profile_extraction_regex') + if not profile_regex_str: + logger.error(f"[Worker {worker_id}] The task-locking strategy requires 'settings.profile_extraction_regex' to be defined in the policy.") + return None, None + + try: + profile_regex = re.compile(profile_regex_str) + except re.error as e: + logger.error(f"Invalid profile_extraction_regex in policy: '{profile_regex_str}'. Error: {e}") + return None, None + + # 3. Shuffle tasks to distribute load if multiple workers are looking. + random.shuffle(task_files) + + # 4. Iterate through tasks and try to lock their corresponding ACTIVE profile. + for task_path in task_files: + match = profile_regex.search(task_path.name) + if not (match and match.groups()): + continue + + profile_name = match.group(1) + if profile_name in active_profile_names: + # Found a task for an active profile. Try to lock it. + locked_profile = profile_manager.lock_profile(owner=owner_id, specific_profile_name=profile_name) + if locked_profile: + # Success! Claim the file. + locked_path = task_path.with_suffix(f"{task_path.suffix}.LOCKED.{worker_id}") + try: + task_path.rename(locked_path) + logger.info(f"[Worker {worker_id}] Locked profile '{profile_name}' and claimed its task '{task_path.name}'.") + return locked_profile, locked_path + except FileNotFoundError: + logger.warning(f"[Worker {worker_id}] Task '{task_path.name}' was claimed by another worker. Unlocking '{profile_name}'.") + profile_manager.unlock_profile(profile_name, owner=owner_id) + continue # Try next task + except OSError as e: + logger.error(f"[Worker {worker_id}] Error claiming task file '{task_path.name}': {e}") + profile_manager.unlock_profile(profile_name, owner=owner_id) + continue + + # No suitable task/profile combo found. + logger.debug("Found task files, but none correspond to any currently ACTIVE profiles.") + return None, None + + +def run_direct_batch_worker(worker_id, policy, state_manager, args, profile_manager_instance, urls_list, running_processes, process_lock): + """A worker for the 'direct_batch_cli' orchestration mode.""" + owner_id = f"direct-batch-worker-{worker_id}" + settings = policy.get('settings', {}) + exec_control = policy.get('execution_control', {}) + gen_policy = policy.get('info_json_generation_policy', {}) + direct_policy = policy.get('direct_batch_cli_policy', {}) + + profile_prefix = gen_policy.get('profile_prefix') + if not profile_prefix: + logger.error(f"[Worker {worker_id}] Direct batch mode requires 'info_json_generation_policy.profile_prefix'. Worker exiting.") + return [] + + batch_size = direct_policy.get('batch_size') + if not batch_size: + logger.error(f"[Worker {worker_id}] Direct batch mode requires 'direct_batch_cli_policy.batch_size'. Worker exiting.") + return [] + + save_dir = settings.get('save_info_json_dir') + if not save_dir: + logger.error(f"[Worker {worker_id}] Direct batch mode requires 'settings.save_info_json_dir'. Worker exiting.") + return [] + + os.makedirs(save_dir, exist_ok=True) + + last_used_profile_name = None + while not state_manager.shutdown_event.is_set(): + locked_profile = None + temp_batch_file = None + # --- Variables for robust finalization --- + files_created = 0 + url_batch_len = 0 + batch_started = False + # --- + try: + # 1. Lock a profile + locked_profile = profile_manager_instance.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + + # --- New logic to avoid immediate reuse --- + avoid_reuse = direct_policy.get('avoid_immediate_profile_reuse', False) + if avoid_reuse and locked_profile and last_used_profile_name and locked_profile['name'] == last_used_profile_name: + logger.info(f"[Worker {worker_id}] Re-locked same profile '{locked_profile['name']}'. Unlocking and pausing to allow for rotation.") + profile_manager_instance.unlock_profile(locked_profile['name'], owner=owner_id) + + wait_seconds = direct_policy.get('avoid_reuse_max_wait_seconds', 5) + time.sleep(wait_seconds) + + # After waiting, try to lock again. + logger.info(f"[Worker {worker_id}] Attempting to lock a new profile after waiting.") + locked_profile = profile_manager_instance.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + + if locked_profile and locked_profile['name'] == last_used_profile_name: + logger.warning(f"[Worker {worker_id}] Still locking the same profile '{locked_profile['name']}' after waiting. Proceeding to use it to avoid getting stuck.") + elif locked_profile: + logger.info(f"[Worker {worker_id}] Switched to a different profile after waiting: '{locked_profile['name']}'.") + # --- End new logic --- + + if not locked_profile: + polling_interval = exec_control.get('worker_polling_interval_seconds', 1) + # --- Add diagnostic logging --- + all_profiles_in_pool = profile_manager_instance.list_profiles() + profiles_in_prefix = [p for p in all_profiles_in_pool if p['name'].startswith(profile_prefix)] + if profiles_in_prefix: + state_counts = collections.Counter(p['state'] for p in profiles_in_prefix) + states_summary = ', '.join(f"{count} {state}" for state, count in sorted(state_counts.items())) + logger.info(f"[Worker {worker_id}] No auth profiles available to lock. Pool status ({profile_prefix}*): {states_summary}. Pausing for {polling_interval}s.") + else: + logger.info(f"[Worker {worker_id}] No auth profiles available to lock. No profiles found with prefix '{profile_prefix}'. Pausing for {polling_interval}s.") + # --- End diagnostic logging --- + time.sleep(polling_interval) + continue + + profile_name = locked_profile['name'] + proxy_url = locked_profile['proxy'] + + # 2. Get a batch of URLs from the shared list + url_batch, start_idx = state_manager.get_next_url_batch(batch_size, urls_list) + if not url_batch: + logger.info(f"[Worker {worker_id}] No more URLs to process. Worker exiting.") + break # Exit the while loop + + url_batch_len = len(url_batch) + batch_started = True + + # Preemptively increment the counter to avoid race conditions with download workers. + profile_manager_instance.increment_pending_downloads(profile_name, url_batch_len) + logger.info(f"[Worker {worker_id}] [{profile_name}] Preemptively incremented pending downloads by {url_batch_len} for the upcoming batch.") + + end_idx = start_idx + len(url_batch) + logger.info(f"[Worker {worker_id}] [{profile_name}] Processing batch of {len(url_batch)} URLs (lines {start_idx + 1}-{end_idx} from source).") + + video_ids_in_batch = {sp_utils.get_video_id(u) for u in url_batch} + + # 3. Write URLs to a temporary batch file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8') as f: + temp_batch_file = f.name + f.write('\n'.join(url_batch)) + + # 4. Construct and run the command + ytdlp_cmd_str = direct_policy.get('ytdlp_command') + if not ytdlp_cmd_str: + logger.error(f"[Worker {worker_id}] Direct batch mode requires 'direct_batch_cli_policy.ytdlp_command'.") + break + + cmd = shlex.split(ytdlp_cmd_str) + cmd.extend(['--batch-file', temp_batch_file]) + cmd.extend(['--proxy', proxy_url]) + + # The output template should not include the .info.json extension, as + # yt-dlp adds it automatically when --write-info-json is used. + output_template_str = direct_policy.get('ytdlp_output_template', '%(id)s') + ytdlp_args = direct_policy.get('ytdlp_args') + + custom_env = direct_policy.get('env_vars', {}).copy() + + # --- PYTHONPATH for custom yt-dlp module --- + ytdlp_module_path = direct_policy.get('ytdlp_module_path') + if ytdlp_module_path: + existing_pythonpath = custom_env.get('PYTHONPATH', os.environ.get('PYTHONPATH', '')) + # Prepend the custom path to PYTHONPATH to give it precedence + custom_env['PYTHONPATH'] = f"{ytdlp_module_path}{os.pathsep}{existing_pythonpath}".strip(os.pathsep) + logger.debug(f"[Worker {worker_id}] Using custom PYTHONPATH: {custom_env['PYTHONPATH']}") + + custom_env['YTDLP_PROFILE_NAME'] = profile_name + custom_env['YTDLP_PROXY_URL'] = proxy_url + env_name = profile_manager_instance.key_prefix.replace('_profile_mgmt_', '') + custom_env['YTDLP_SIM_MODE'] = env_name + + # Create a per-profile cache directory and set XDG_CACHE_HOME + cache_dir_base = direct_policy.get('cache_dir_base', '.cache') + profile_cache_dir = os.path.join(cache_dir_base, profile_name) + try: + os.makedirs(profile_cache_dir, exist_ok=True) + custom_env['XDG_CACHE_HOME'] = profile_cache_dir + except OSError as e: + logger.error(f"[Worker {worker_id}] Failed to create cache directory '{profile_cache_dir}': {e}") + + # --- Manage User-Agent --- + # Use a consistent User-Agent per profile, storing it in the profile's cache directory. + user_agent = None + user_agent_file = os.path.join(profile_cache_dir, 'user_agent.txt') + try: + if os.path.exists(user_agent_file): + with open(user_agent_file, 'r', encoding='utf-8') as f: + user_agent = f.read().strip() + + if not user_agent: # File doesn't exist or is empty + user_agent = sp_utils.generate_user_agent_from_policy(policy) + with open(user_agent_file, 'w', encoding='utf-8') as f: + f.write(user_agent) + logger.info(f"[{profile_name}] Generated and saved new User-Agent: '{user_agent}'") + else: + logger.info(f"[{profile_name}] Using existing User-Agent from cache: '{user_agent}'") + except IOError as e: + logger.error(f"[Worker {worker_id}] Error accessing User-Agent file '{user_agent_file}': {e}. Using generated UA for this run.") + user_agent = sp_utils.generate_user_agent_from_policy(policy) # fallback + + # Add proxy rename from policy if specified, for custom yt-dlp forks + proxy_rename = direct_policy.get('ytdlp_proxy_rename') + if proxy_rename: + custom_env['YTDLP_PROXY_RENAME'] = proxy_rename + + if user_agent: + cmd.extend(['--user-agent', user_agent]) + + if ytdlp_args: + cmd.extend(shlex.split(ytdlp_args)) + + if args.verbose and '--verbose' not in cmd: + cmd.append('--verbose') + + if args.dummy: + # In dummy mode, we replace the real yt-dlp command with our dummy script. + # The dummy script will handle Redis interactions (checking for bans, recording activity). + + # For logging, construct what the real command would have been + log_cmd = list(cmd) # cmd has most args now + log_cmd.extend(['-o', os.path.join('temp_dir', output_template_str)]) + logger.info(f"[Worker {worker_id}] [{profile_name}] DUMMY MODE: Would run real command: {' '.join(shlex.quote(s) for s in log_cmd)}") + logger.info(f"[Worker {worker_id}] [{profile_name}] DUMMY MODE: With environment for real command: {custom_env}") + + cmd = [ + sys.executable, '-m', 'ytops_client.cli', + 'yt-dlp-dummy' + ] + + # The orchestrator is still responsible for managing temp directories and post-processing. + with tempfile.TemporaryDirectory(prefix=f"ytdlp-dummy-batch-{worker_id}-") as temp_output_dir: + output_template = os.path.join(temp_output_dir, output_template_str) + cmd.extend(['--batch-file', temp_batch_file]) + cmd.extend(['-o', output_template]) + if args.verbose: + cmd.append('--verbose') + + # Pass failure rates and Redis connection info to the dummy script via environment + dummy_settings = policy.get('settings', {}).get('dummy_simulation_settings', {}) + auth_failure_rate = dummy_settings.get('auth_failure_rate', 0.0) + auth_skipped_rate = dummy_settings.get('auth_skipped_failure_rate', 0.0) + custom_env['YTDLP_DUMMY_FAILURE_RATE'] = auth_failure_rate + custom_env['YTDLP_DUMMY_SKIPPED_FAILURE_RATE'] = auth_skipped_rate + custom_env['REDIS_HOST'] = profile_manager_instance.redis.connection_pool.connection_kwargs.get('host') + custom_env['REDIS_PORT'] = profile_manager_instance.redis.connection_pool.connection_kwargs.get('port') + redis_password = profile_manager_instance.redis.connection_pool.connection_kwargs.get('password') + if redis_password: + custom_env['REDIS_PASSWORD'] = redis_password + + logger.info(f"[Worker {worker_id}] [{profile_name}] DUMMY MODE: Running dummy yt-dlp script with updated environment: {custom_env}") + retcode, stdout, stderr = run_command( + cmd, running_processes, process_lock, env=custom_env, stream_output=args.verbose, + stream_prefix=f"[Worker {worker_id} | yt-dlp-dummy] " + ) + + # --- Post-processing is the same as in non-dummy mode --- + processed_files = list(Path(temp_output_dir).glob('*.json')) + + for temp_path in processed_files: + files_created += 1 + video_id = "unknown" + try: + # The orchestrator injects its own metadata after the fact. + with open(temp_path, 'r+', encoding='utf-8') as f: + info_data = json.load(f) + video_id = info_data.get('id', 'unknown') + env_name = profile_manager_instance.key_prefix.replace('_profile_mgmt_', '') + info_data['_ytops_metadata'] = { + 'profile_name': profile_name, + 'proxy_url': proxy_url, + 'generation_timestamp_utc': datetime.now(timezone.utc).isoformat(), + 'auth_env': env_name + } + f.seek(0) + json.dump(info_data, f, indent=2) + f.truncate() + + final_path = Path(save_dir) / temp_path.name + rename_template = direct_policy.get('rename_file_template') + if rename_template: + sanitized_proxy = re.sub(r'[:/]', '_', proxy_url) + new_name = rename_template.format( + video_id=video_id, profile_name=profile_name, proxy=sanitized_proxy + ) + final_path = Path(save_dir) / new_name + + shutil.move(str(temp_path), str(final_path)) + logger.info(f"[Worker {worker_id}] Post-processed and moved info.json to '{final_path}'") + except (IOError, json.JSONDecodeError, OSError) as e: + logger.error(f"[Worker {worker_id}] DUMMY MODE: Error post-processing '{temp_path.name}' (video: {video_id}): {e}") + + # The orchestrator still determines overall batch success and logs its own event. + # It does NOT call record_activity, as the dummy script did that per-URL. + success = (retcode == 0 and files_created > 0) + + if not success: + reason = f"exit code was {retcode}" if retcode != 0 else f"0 files created" + logger.warning(f"[Worker {worker_id}] [{profile_name}] DUMMY MODE: Marking batch as FAILED. Reason: {reason}.") + + # Record batch stats + state_manager.record_batch_result(success, len(url_batch), profile_name=profile_name) + + event_details = f"Dummy batch completed. Files created: {files_created}/{len(url_batch)}." + if not success and stderr: + event_details += f" Stderr: {stderr.strip().splitlines()[-1]}" + event = { 'type': 'fetch_batch', 'profile': profile_name, 'proxy_url': proxy_url, 'success': success, 'details': event_details, 'video_count': len(url_batch) } + state_manager.log_event(event) + + else: + with tempfile.TemporaryDirectory(prefix=f"ytdlp-batch-{worker_id}-") as temp_output_dir: + output_template = os.path.join(temp_output_dir, output_template_str) + cmd.extend(['-o', output_template]) + + logger.info(f"[Worker {worker_id}] [{profile_name}] Processing batch of {len(url_batch)} URLs...") + logger.info(f"[Worker {worker_id}] [{profile_name}] Running command: {' '.join(shlex.quote(s) for s in cmd)}") + logger.info(f"[Worker {worker_id}] [{profile_name}] With environment: {custom_env}") + retcode, stdout, stderr = run_command( + cmd, running_processes, process_lock, env=custom_env, stream_output=args.verbose, + stream_prefix=f"[Worker {worker_id} | yt-dlp] " + ) + + is_bot_error = "Sign in to confirm you're not a bot" in stderr + if is_bot_error: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Bot detection occurred during batch. Marking as failure.") + + processed_files = list(Path(temp_output_dir).glob('*.json')) + + for temp_path in processed_files: + files_created += 1 + video_id = "unknown" + try: + with open(temp_path, 'r+', encoding='utf-8') as f: + info_data = json.load(f) + video_id = info_data.get('id', 'unknown') + env_name = profile_manager_instance.key_prefix.replace('_profile_mgmt_', '') + info_data['_ytops_metadata'] = { + 'profile_name': profile_name, + 'proxy_url': proxy_url, + 'generation_timestamp_utc': datetime.now(timezone.utc).isoformat(), + 'auth_env': env_name + } + f.seek(0) + json.dump(info_data, f, indent=2) + f.truncate() + + final_path = Path(save_dir) / temp_path.name + rename_template = direct_policy.get('rename_file_template') + if rename_template: + sanitized_proxy = re.sub(r'[:/]', '_', proxy_url) + new_name = rename_template.format( + video_id=video_id, profile_name=profile_name, proxy=sanitized_proxy + ) + final_path = Path(save_dir) / new_name + + shutil.move(str(temp_path), str(final_path)) + logger.info(f"[Worker {worker_id}] Post-processed and moved info.json to '{final_path}'") + + except (IOError, json.JSONDecodeError, OSError) as e: + logger.error(f"[Worker {worker_id}] Error post-processing '{temp_path.name}' (video: {video_id}): {e}") + + # The orchestrator records per-URL success/failure for the profile. + # A batch is considered an overall success for logging if it had no fatal errors + # and produced at least one file. + success = (files_created > 0 and not is_bot_error) + + if not success: + reason = "bot detection occurred" if is_bot_error else f"0 files created out of {len(url_batch)}" + logger.warning(f"[Worker {worker_id}] [{profile_name}] Marking batch as FAILED. Reason: {reason}.") + + # Record batch stats for overall orchestrator health + state_manager.record_batch_result(success, len(url_batch), profile_name=profile_name) + + # In this mode, the custom yt-dlp script is responsible for recording + # per-URL activity ('success', 'failure', 'tolerated_error') directly into Redis. + # The orchestrator does not record activity here to avoid double-counting. + logger.info(f"[Worker {worker_id}] [{profile_name}] Batch finished. Per-URL activity was recorded by the yt-dlp script.") + + event_details = f"Batch completed. Exit: {retcode}. Files created: {files_created}/{len(url_batch)}." + if not success and stderr: + if is_bot_error: + event_details += " Stderr: Bot detection occurred." + else: + event_details += f" Stderr: {stderr.strip().splitlines()[-1]}" + + event = { 'type': 'fetch_batch', 'profile': profile_name, 'proxy_url': proxy_url, 'success': success, 'details': event_details, 'video_count': len(url_batch) } + state_manager.log_event(event) + + except Exception as e: + logger.error(f"[Worker {worker_id}] Unexpected error in worker loop: {e}", exc_info=True) + if locked_profile: + profile_manager_instance.record_activity(locked_profile['name'], 'failure') + finally: + if locked_profile and batch_started: + # --- Reconcile pending downloads counter --- + # This is in the finally block to guarantee it runs even if post-processing fails. + adjustment = files_created - url_batch_len + if adjustment != 0: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Reconciling pending downloads. Batch created {files_created}/{url_batch_len} files. Adjusting by {adjustment}.") + profile_manager_instance.increment_pending_downloads(locked_profile['name'], adjustment) + + if locked_profile: + last_used_profile_name = locked_profile['name'] + cooldown = None + # DESIGN: The cooldown duration is not configured in the worker's policy. + # Instead, it is read from a central Redis key. This key is set by the + # policy-enforcer, making the enforcer the single source of truth for + # this policy. This allows changing the cooldown behavior without + # restarting the workers. + cooldown_config = profile_manager_instance.get_config('unlock_cooldown_seconds') + if cooldown_config: + try: + val = json.loads(cooldown_config) + if isinstance(val, list) and len(val) == 2 and val[0] < val[1]: + cooldown = random.randint(val[0], val[1]) + elif isinstance(val, int): + cooldown = val + except (json.JSONDecodeError, TypeError): + if cooldown_config.isdigit(): + cooldown = int(cooldown_config) + + if cooldown: + logger.info(f"[Worker {worker_id}] Putting profile '{locked_profile['name']}' into COOLDOWN for {cooldown}s.") + + profile_manager_instance.unlock_profile( + locked_profile['name'], + owner=owner_id, + rest_for_seconds=cooldown + ) + if temp_batch_file and os.path.exists(temp_batch_file): + os.unlink(temp_batch_file) + + logger.info(f"[Worker {worker_id}] Worker loop finished.") + return [] + + +def run_direct_docker_worker(worker_id, policy, state_manager, args, profile_manager_instance, urls_list, running_processes, process_lock): + """A worker for the 'direct_docker_cli' orchestration mode (fetch_only).""" + owner_id = f"direct-docker-worker-{worker_id}" + settings = policy.get('settings', {}) + exec_control = policy.get('execution_control', {}) + gen_policy = policy.get('info_json_generation_policy', {}) + direct_policy = policy.get('direct_docker_cli_policy', {}) + + profile_prefix = gen_policy.get('profile_prefix') + if not profile_prefix: + logger.error(f"[Worker {worker_id}] Direct docker mode requires 'info_json_generation_policy.profile_prefix'. Worker exiting.") + return [] + + batch_size = direct_policy.get('batch_size') + if not batch_size: + logger.error(f"[Worker {worker_id}] Direct docker mode requires 'direct_docker_cli_policy.batch_size'. Worker exiting.") + return [] + + save_dir = settings.get('save_info_json_dir') + if not save_dir: + logger.error(f"[Worker {worker_id}] Direct docker mode requires 'settings.save_info_json_dir'. Worker exiting.") + return [] + os.makedirs(save_dir, exist_ok=True) + + # --- Docker specific config --- + image_name = direct_policy.get('docker_image_name') + host_mount_path = os.path.abspath(direct_policy.get('docker_host_mount_path')) + container_mount_path = direct_policy.get('docker_container_mount_path') + host_cache_path = direct_policy.get('docker_host_cache_path') + if host_cache_path: host_cache_path = os.path.abspath(host_cache_path) + container_cache_path = direct_policy.get('docker_container_cache_path') + network_name = direct_policy.get('docker_network_name') + + if not all([image_name, host_mount_path, container_mount_path]): + logger.error(f"[Worker {worker_id}] Direct docker mode requires 'docker_image_name', 'docker_host_mount_path', and 'docker_container_mount_path'. Worker exiting.") + return [] + + try: + os.makedirs(host_mount_path, exist_ok=True) + except OSError as e: + logger.error(f"[Worker {worker_id}] Could not create docker_host_mount_path '{host_mount_path}': {e}. Worker exiting.") + return [] + + last_used_profile_name = None + while not state_manager.shutdown_event.is_set(): + locked_profile = None + temp_task_dir_host = None + # --- Variables for robust finalization --- + live_success_count = 0 + url_batch_len = 0 + batch_started = False + # --- + try: + # 1. Lock a profile + locked_profile = profile_manager_instance.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + + # --- New logic to avoid immediate reuse --- + avoid_reuse = direct_policy.get('avoid_immediate_profile_reuse', False) + if avoid_reuse and locked_profile and last_used_profile_name and locked_profile['name'] == last_used_profile_name: + logger.info(f"[Worker {worker_id}] Re-locked same profile '{locked_profile['name']}'. Unlocking and pausing to allow for rotation.") + profile_manager_instance.unlock_profile(locked_profile['name'], owner=owner_id) + + wait_seconds = direct_policy.get('avoid_reuse_max_wait_seconds', 5) + time.sleep(wait_seconds) + + # After waiting, try to lock again. + logger.info(f"[Worker {worker_id}] Attempting to lock a new profile after waiting.") + locked_profile = profile_manager_instance.lock_profile(owner=owner_id, profile_prefix=profile_prefix) + + if locked_profile and locked_profile['name'] == last_used_profile_name: + logger.warning(f"[Worker {worker_id}] Still locking the same profile '{locked_profile['name']}' after waiting. Proceeding to use it to avoid getting stuck.") + elif locked_profile: + logger.info(f"[Worker {worker_id}] Switched to a different profile after waiting: '{locked_profile['name']}'.") + # --- End new logic --- + + if not locked_profile: + polling_interval = exec_control.get('worker_polling_interval_seconds', 1) + # --- Add diagnostic logging --- + all_profiles_in_pool = profile_manager_instance.list_profiles() + profiles_in_prefix = [p for p in all_profiles_in_pool if p['name'].startswith(profile_prefix)] + if profiles_in_prefix: + state_counts = collections.Counter(p['state'] for p in profiles_in_prefix) + states_summary = ', '.join(f"{count} {state}" for state, count in sorted(state_counts.items())) + logger.info(f"[Worker {worker_id}] No auth profiles available to lock. Pool status ({profile_prefix}*): {states_summary}. Pausing for {polling_interval}s.") + else: + logger.info(f"[Worker {worker_id}] No auth profiles available to lock. No profiles found with prefix '{profile_prefix}'. Pausing for {polling_interval}s.") + # --- End diagnostic logging --- + time.sleep(polling_interval) + continue + + profile_name = locked_profile['name'] + proxy_url = locked_profile['proxy'] + + # --- Manage User-Agent, Visitor ID, and Cache Directory --- + user_agent = None + visitor_id = None + environment = {} + profile_cache_dir_host = None + if host_cache_path: + profile_cache_dir_host = os.path.join(host_cache_path, profile_name) + try: + os.makedirs(profile_cache_dir_host, exist_ok=True) + if container_cache_path: + environment['XDG_CACHE_HOME'] = container_cache_path + + # --- User-Agent --- + user_agent_file = os.path.join(profile_cache_dir_host, 'user_agent.txt') + if os.path.exists(user_agent_file): + with open(user_agent_file, 'r', encoding='utf-8') as f: + user_agent = f.read().strip() + + if not user_agent: + user_agent = sp_utils.generate_user_agent_from_policy(policy) + with open(user_agent_file, 'w', encoding='utf-8') as f: + f.write(user_agent) + logger.info(f"[{profile_name}] Generated and saved new User-Agent: '{user_agent}'") + else: + logger.info(f"[{profile_name}] Using existing User-Agent from cache: '{user_agent}'") + + # --- Visitor ID --- + if direct_policy.get('track_visitor_id'): + visitor_id_file = os.path.join(profile_cache_dir_host, 'visitor_id.txt') + if os.path.exists(visitor_id_file): + with open(visitor_id_file, 'r', encoding='utf-8') as f: + visitor_id = f.read().strip() + if visitor_id: + logger.info(f"[{profile_name}] Using existing Visitor ID from cache: '{visitor_id}'") + + except IOError as e: + logger.error(f"[Worker {worker_id}] Error accessing cache file in '{profile_cache_dir_host}': {e}. Using generated UA for this run.") + user_agent = sp_utils.generate_user_agent_from_policy(policy) # Fallback for UA + else: + # Fallback if no cache is configured for auth simulation + user_agent = sp_utils.generate_user_agent_from_policy(policy) + + # 2. Get a batch of URLs + url_batch, start_idx = state_manager.get_next_url_batch(batch_size, urls_list) + if not url_batch: + logger.info(f"[Worker {worker_id}] No more URLs to process. Worker exiting.") + break + + url_batch_len = len(url_batch) + batch_started = True + + # Preemptively increment the counter to avoid race conditions with download workers. + profile_manager_instance.increment_pending_downloads(profile_name, url_batch_len) + logger.info(f"[Worker {worker_id}] [{profile_name}] Preemptively incremented pending downloads by {url_batch_len} for the upcoming batch.") + + end_idx = start_idx + len(url_batch) + logger.info(f"[Worker {worker_id}] [{profile_name}] Processing batch of {len(url_batch)} URLs (lines {start_idx + 1}-{end_idx} from source).") + + # 3. Prepare files on the host + temp_task_dir_host = tempfile.mkdtemp(prefix=f"docker-task-{worker_id}-", dir=host_mount_path) + task_dir_name = os.path.basename(temp_task_dir_host) + task_dir_container = os.path.join(container_mount_path, task_dir_name) + + # Set XDG_CONFIG_HOME for yt-dlp to find the config automatically + environment['XDG_CONFIG_HOME'] = task_dir_container + + # Write batch file + temp_batch_file_host = os.path.join(temp_task_dir_host, 'batch.txt') + with open(temp_batch_file_host, 'w', encoding='utf-8') as f: + f.write('\n'.join(url_batch)) + + # Write yt-dlp config file + base_config_content = "" + base_config_file = direct_policy.get('ytdlp_config_file') + if base_config_file: + # Try path as-is first, then relative to project root. + config_path_to_read = Path(base_config_file) + if not config_path_to_read.exists(): + config_path_to_read = Path(sp_utils._PROJECT_ROOT) / base_config_file + + if config_path_to_read.exists(): + try: + with open(config_path_to_read, 'r', encoding='utf-8') as f: + base_config_content = f.read() + logger.info(f"[Worker {worker_id}] [{profile_name}] Loaded base config from '{config_path_to_read}'") + except IOError as e: + logger.error(f"[Worker {worker_id}] Could not read ytdlp_config_file '{config_path_to_read}': {e}") + else: + logger.error(f"[Worker {worker_id}] Could not find ytdlp_config_file: '{base_config_file}'") + + config_overrides = direct_policy.get('ytdlp_config_overrides', {}).copy() + + if direct_policy.get('use_cookies') and host_cache_path and container_cache_path: + try: + cookie_file_host = os.path.join(profile_cache_dir_host, 'cookies.txt') + # Ensure the file exists and has the Netscape header if it's empty. + if not os.path.exists(cookie_file_host) or os.path.getsize(cookie_file_host) == 0: + with open(cookie_file_host, 'w', encoding='utf-8') as f: + f.write("# Netscape HTTP Cookie File\n") + logger.info(f"[{profile_name}] Created/initialized cookie file with header: {cookie_file_host}") + + cookie_file_container = os.path.join(container_cache_path, 'cookies.txt') + config_overrides['cookies'] = cookie_file_container + logger.info(f"[{profile_name}] Using persistent cookie jar: {cookie_file_host}") + except (IOError, OSError) as e: + logger.error(f"[Worker {worker_id}] Could not create cookie file in '{profile_cache_dir_host}': {e}") + + # Inject per-task values into overrides + config_overrides['proxy'] = proxy_url + config_overrides['batch-file'] = os.path.join(task_dir_container, 'batch.txt') + # The output template should not include the .info.json extension, as + # yt-dlp adds it automatically when --write-info-json is used. + config_overrides['output'] = os.path.join(task_dir_container, '%(id)s') + if user_agent: + config_overrides['user-agent'] = user_agent + + overrides_content = sp_utils._config_dict_to_flags_file_content(config_overrides) + raw_args_from_policy = direct_policy.get('ytdlp_raw_args', []) + + # --- Inject visitor_id into raw args if available --- + if visitor_id: + # Start with a copy of the raw args from policy + new_raw_args = list(raw_args_from_policy) + + # --- Handle youtube extractor args --- + youtube_arg_index = -1 + original_youtube_value = None + for i, arg in enumerate(new_raw_args): + if arg.startswith('--extractor-args') and 'youtube:' in arg: + youtube_arg_index = i + try: + parts = shlex.split(arg) + if len(parts) == 2 and parts[1].startswith('youtube:'): + original_youtube_value = parts[1] + except ValueError: + logger.warning(f"Could not parse extractor-arg, will not modify: {arg}") + break # Found it, stop searching + + if youtube_arg_index != -1 and original_youtube_value: + # Modify existing youtube arg + new_value = f'{original_youtube_value.rstrip()};visitor_data={visitor_id}' + if 'skip=' in new_value: + new_value = re.sub(r'skip=([^;\'"]*)', r'skip=\1,webpage,configs', new_value) + else: + new_value += ';skip=webpage,configs' + new_raw_args[youtube_arg_index] = f'--extractor-args "{new_value}"' + else: + # Add new youtube arg + logger.warning(f"[{profile_name}] No existing '--extractor-args youtube:...' found. Adding a new one for visitor_id.") + new_raw_args.append(f'--extractor-args "youtube:visitor_data={visitor_id};skip=webpage,configs"') + + # --- Handle youtubetab extractor args --- + youtubetab_arg_index = -1 + for i, arg in enumerate(new_raw_args): + if arg.startswith('--extractor-args') and 'youtubetab:' in arg: + youtubetab_arg_index = i + break + + # The request is to set/replace this argument + new_youtubetab_arg = '--extractor-args "youtubetab:skip=webpage"' + if youtubetab_arg_index != -1: + # Replace existing + new_raw_args[youtubetab_arg_index] = new_youtubetab_arg + else: + # Add new + new_raw_args.append(new_youtubetab_arg) + + raw_args_from_policy = new_raw_args + # --- End visitor_id injection --- + + raw_args_content = '\n'.join(raw_args_from_policy) + + config_content = f"{base_config_content.strip()}\n\n# --- Overrides from policy ---\n{overrides_content}" + if raw_args_content: + config_content += f"\n\n# --- Raw args from policy ---\n{raw_args_content}" + + logger.info(f"[Worker {worker_id}] [{profile_name}] Generated yt-dlp config file content:\n---config---\n{config_content}\n------------") + + # Create the directory structure yt-dlp expects inside the temp task dir + ytdlp_config_dir_host = os.path.join(temp_task_dir_host, 'yt-dlp') + os.makedirs(ytdlp_config_dir_host, exist_ok=True) + temp_config_file_host = os.path.join(ytdlp_config_dir_host, 'config') + with open(temp_config_file_host, 'w', encoding='utf-8') as f: + f.write(config_content) + + # 4. Construct and run the 'docker run' command + volumes = { + host_mount_path: { + 'bind': container_mount_path, + 'mode': 'rw' + } + } + if host_cache_path and container_cache_path: + profile_cache_dir_host = os.path.join(host_cache_path, profile_name) + os.makedirs(profile_cache_dir_host, exist_ok=True) + volumes[profile_cache_dir_host] = { + 'bind': container_cache_path, + 'mode': 'rw' + } + + # The command tells yt-dlp where to find the config file we created. + # We still set XDG_CONFIG_HOME for any other config it might look for. + command = ['yt-dlp', '--config-locations', os.path.join(task_dir_container, 'yt-dlp/config')] + logger.info(f"[Worker {worker_id}] [{profile_name}] Running docker command: {' '.join(shlex.quote(s) for s in command)}") + + # For logging purposes, construct the full equivalent command line with host paths + log_config_overrides_for_host = config_overrides.copy() + log_config_overrides_for_host['batch-file'] = temp_batch_file_host + log_config_overrides_for_host['output'] = os.path.join(temp_task_dir_host, '%(id)s') + if 'cookies' in log_config_overrides_for_host and host_cache_path: + log_config_overrides_for_host['cookies'] = os.path.join(profile_cache_dir_host, 'cookies.txt') + + log_command_override = ['yt-dlp'] + if base_config_content: + log_command_override.extend(sp_utils._parse_config_file_to_cli_args(base_config_content)) + log_command_override.extend(sp_utils._config_dict_to_cli_flags(log_config_overrides_for_host)) + for raw_arg in raw_args_from_policy: + log_command_override.extend(shlex.split(raw_arg)) + + # --- Live log parsing and activity recording --- + live_failure_count = 0 + live_tolerated_count = 0 + activity_lock = threading.Lock() + + # Get error patterns from policy for live parsing + tolerated_error_patterns = direct_policy.get('tolerated_error_patterns', []) + fatal_error_patterns = direct_policy.get('fatal_error_patterns', []) + + def log_parser_callback(line): + nonlocal live_success_count, live_failure_count, live_tolerated_count + + # --- Visitor ID Extraction --- + if direct_policy.get('track_visitor_id') and profile_cache_dir_host: + # e.g., [debug] [youtube] [pot:cache] TRACE: Retrieved cache spec PoTokenCacheSpec(key_bindings={'t': 'webpo', 'cb': '...', 'cbt': 'visitor_id', ... + match = re.search(r"'cb': '([^']*)', 'cbt': 'visitor_id'", line) + if match: + new_visitor_id = match.group(1) + logger.info(f"[Worker {worker_id}] [{profile_name}] Detected new Visitor ID: {new_visitor_id}") + try: + visitor_id_file = os.path.join(profile_cache_dir_host, 'visitor_id.txt') + with open(visitor_id_file, 'w', encoding='utf-8') as f: + f.write(new_visitor_id) + logger.info(f"[{profile_name}] Saved new Visitor ID to cache.") + except IOError as e: + logger.error(f"[{profile_name}] Failed to save new Visitor ID to cache: {e}") + + # Success is the highest priority check + if '[info] Writing video metadata as JSON to:' in line: + with activity_lock: + live_success_count += 1 + logger.info(f"[Worker {worker_id}] [{profile_name}] Live success #{live_success_count} detected from log.") + profile_manager_instance.record_activity(profile_name, 'success') + + # --- Immediate post-processing --- + try: + path_match = re.search(r"Writing video metadata as JSON to: '?([^']+)'?$", line) + if not path_match: + path_match = re.search(r"Writing video metadata as JSON to: (.*)$", line) + + if path_match: + container_file_path = path_match.group(1).strip() + + if container_file_path.startswith(container_mount_path): + relative_path = os.path.relpath(container_file_path, container_mount_path) + host_file_path = os.path.join(host_mount_path, relative_path) + + # The file might not exist immediately. + for _ in range(5): # Retry for up to 0.5s + if os.path.exists(host_file_path): + break + time.sleep(0.1) + + if os.path.exists(host_file_path): + _post_process_and_move_info_json( + Path(host_file_path), profile_name, proxy_url, policy, worker_id, + profile_manager_instance=profile_manager_instance + ) + else: + logger.warning(f"File from log not found on host for immediate processing: {host_file_path}") + except Exception as e: + logger.error(f"Error during immediate post-processing from log line: {e}") + # --- End immediate post-processing --- + return False + + # Check for fatal patterns (e.g., bot detection) which might not start with ERROR: + for pattern in fatal_error_patterns: + if re.search(pattern, line, re.IGNORECASE): + with activity_lock: + live_failure_count += 1 + logger.error(f"[Worker {worker_id}] [{profile_name}] Live FATAL error #{live_failure_count} detected from log: {line}") + profile_manager_instance.record_activity(profile_name, 'failure') + if direct_policy.get('ban_on_fatal_error_in_batch'): + logger.warning(f"Banning profile '{profile_name}' immediately due to fatal error to stop container.") + profile_manager_instance.update_profile_state(profile_name, 'BANNED', 'Fatal error during batch') + return True # Signal to stop container + return False # Do not stop if ban_on_fatal_error_in_batch is false + + # Only process lines that contain ERROR: for tolerated/generic failures + if 'ERROR:' not in line: + return False + + # Check if it's a tolerated error + for pattern in tolerated_error_patterns: + if re.search(pattern, line, re.IGNORECASE): + with activity_lock: + live_tolerated_count += 1 + logger.warning(f"[Worker {worker_id}] [{profile_name}] Live TOLERATED error #{live_tolerated_count} detected from log: {line}") + profile_manager_instance.record_activity(profile_name, 'tolerated_error') + return False + + # If it's an ERROR: line and not tolerated, it's a failure + with activity_lock: + live_failure_count += 1 + logger.warning(f"[Worker {worker_id}] [{profile_name}] Live failure #{live_failure_count} detected from log: {line}") + profile_manager_instance.record_activity(profile_name, 'failure') + + return False + + retcode, stdout, stderr, stop_reason = run_docker_container( + image_name=image_name, + command=command, + volumes=volumes, + stream_prefix=f"[Worker {worker_id} | docker-ytdlp] ", + network_name=network_name, + log_callback=log_parser_callback, + profile_manager=profile_manager_instance, + profile_name=profile_name, + environment=environment, + log_command_override=log_command_override + ) + + # 5. Post-process results + logger.info(f"[Worker {worker_id}] [{profile_name}] Docker container finished. Post-processing results...") + full_output = f"{stdout}\n{stderr}" + is_bot_error = "Sign in to confirm you're not a bot" in full_output + if is_bot_error: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Bot detection occurred during batch. Marking as failure.") + + # Fallback post-processing for any files missed by the live parser. + # The live parser moves files, so this loop should only find leftovers. + processed_files = list(Path(temp_task_dir_host).glob('*.json')) + if processed_files: + logger.info(f"[Worker {worker_id}] Found {len(processed_files)} leftover file(s) to process after live parsing.") + for temp_path in processed_files: + _post_process_and_move_info_json(temp_path, profile_name, proxy_url, policy, worker_id, profile_manager_instance=profile_manager_instance) + + # A batch is considered an overall success for logging if it had no fatal errors. + # The per-URL activity has already been recorded live. + # We use live_success_count for a more accurate success metric. + success = (live_success_count > 0 and not is_bot_error) + + if not success: + reason = "bot detection occurred" if is_bot_error else f"0 successful files created out of {len(url_batch)}" + logger.warning(f"[Worker {worker_id}] [{profile_name}] Marking batch as FAILED. Reason: {reason}.") + + # Record batch stats for overall orchestrator health + state_manager.record_batch_result(success, len(url_batch), profile_name=profile_name) + + # If live parsing didn't catch all activity (e.g., yt-dlp exits before printing logs), + # we reconcile the counts here based on files created. + with activity_lock: + processed_count = live_success_count + live_failure_count + live_tolerated_count + + # Failures are harder to reconcile from file counts alone. + # We assume live parsing caught them. The total number of failures is + # (batch_size - files_created), but we don't know if they were already recorded. + # The current live parsing is the most reliable source for failures. + unaccounted_failures = len(url_batch) - processed_count + if unaccounted_failures > 0: + logger.info(f"[Worker {worker_id}] [{profile_name}] Reconciling activity: {unaccounted_failures} unaccounted failure(s).") + for _ in range(unaccounted_failures): + profile_manager_instance.record_activity(profile_name, 'failure') + + + if stop_reason: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Batch aborted due to: {stop_reason}. Adjusting URL index.") + # The batch was from start_idx to end_idx. + # We processed `processed_count` URLs. + # The next batch should start from `start_idx + processed_count`. + # `get_next_url_batch` updated `last_url_index` to `end_idx`. We need to rewind it. + with activity_lock: + processed_count = live_success_count + live_failure_count + live_tolerated_count + next_start_index = start_idx + processed_count + state_manager.update_last_url_index(next_start_index, force=True) + logger.info(f"[Worker {worker_id}] Rewound URL index to {next_start_index} for next worker.") + + event_details = f"Docker batch completed. Exit: {retcode}. Files created: {live_success_count}/{len(url_batch)}. (Live successes: {live_success_count}, Live failures: {live_failure_count}, Live tolerated: {live_tolerated_count})" + if not success and stderr: + event_details += f" Stderr: {stderr.strip().splitlines()[-1] if stderr.strip() else 'N/A'}" + if stop_reason: + event_details += f" Aborted: {stop_reason}." + + event = { 'type': 'fetch_batch', 'profile': profile_name, 'proxy_url': proxy_url, 'success': success, 'details': event_details, 'video_count': len(url_batch) } + state_manager.log_event(event) + + logger.info(f"[Worker {worker_id}] [{profile_name}] Batch processing complete. Worker will now unlock profile and attempt next batch.") + + except Exception as e: + logger.error(f"[Worker {worker_id}] Unexpected error in worker loop: {e}", exc_info=True) + if locked_profile: + profile_manager_instance.record_activity(locked_profile['name'], 'failure') + finally: + if locked_profile and batch_started: + # --- Reconcile pending downloads counter --- + # This is in the finally block to guarantee it runs even if post-processing fails. + adjustment = live_success_count - url_batch_len + if adjustment != 0: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Reconciling pending downloads. Batch created {live_success_count}/{url_batch_len} files. Adjusting by {adjustment}.") + profile_manager_instance.increment_pending_downloads(locked_profile['name'], adjustment) + + if locked_profile: + last_used_profile_name = locked_profile['name'] + profile_manager_instance.unlock_profile(locked_profile['name'], owner=owner_id) + if temp_task_dir_host and os.path.exists(temp_task_dir_host): + # If shutdown is requested, a batch might have been interrupted after files were + # created but before they were post-processed. We preserve the temp directory + # to allow for manual recovery of the info.json files. + if state_manager.shutdown_event.is_set() and any(Path(temp_task_dir_host).iterdir()): + logger.warning(f"Shutdown requested. Preserving temporary task directory for manual recovery: {temp_task_dir_host}") + else: + shutil.rmtree(temp_task_dir_host) + + logger.info(f"[Worker {worker_id}] Worker loop finished.") + return [] + + +def run_direct_docker_download_worker(worker_id, policy, state_manager, args, profile_manager_instance, running_processes, process_lock): + """A worker for the 'direct_docker_cli' orchestration mode with `mode: download_only`.""" + owner_id = f"direct-docker-dl-worker-{worker_id}" + settings = policy.get('settings', {}) + exec_control = policy.get('execution_control', {}) + d_policy = policy.get('download_policy', {}) + direct_policy = policy.get('direct_docker_cli_policy', {}) + + profile_prefix = d_policy.get('profile_prefix') + if not profile_prefix: + logger.error(f"[Worker {worker_id}] Direct docker download mode requires 'download_policy.profile_prefix'. Worker exiting.") + return [] + + # --- Docker specific config --- + image_name = direct_policy.get('docker_image_name') + host_mount_path = direct_policy.get('docker_host_mount_path') + container_mount_path = direct_policy.get('docker_container_mount_path') + host_download_path = direct_policy.get('docker_host_download_path') + container_download_path = direct_policy.get('docker_container_download_path') + network_name = direct_policy.get('docker_network_name') + + if not all([image_name, host_mount_path, container_mount_path, host_download_path, container_download_path]): + logger.error(f"[Worker {worker_id}] Direct docker download mode requires all docker_* keys in 'direct_docker_cli_policy'. Worker exiting.") + return [] + + try: + os.makedirs(host_mount_path, exist_ok=True) + os.makedirs(host_download_path, exist_ok=True) + except OSError as e: + logger.error(f"[Worker {worker_id}] Could not create required host directories: {e}. Worker exiting.") + return [] + + no_task_streak = 0 + last_used_profile_name = None + while not state_manager.shutdown_event.is_set(): + locked_profile = None + claimed_task_path_host = None + temp_config_dir_host = None + was_banned_by_parser = False + try: + if no_task_streak > 0: + polling_interval = exec_control.get('worker_polling_interval_seconds', 1) + # --- Add diagnostic logging --- + all_profiles_in_pool = profile_manager_instance.list_profiles() + profiles_in_prefix = [p for p in all_profiles_in_pool if p['name'].startswith(profile_prefix)] + if profiles_in_prefix: + state_counts = collections.Counter(p['state'] for p in profiles_in_prefix) + states_summary = ', '.join(f"{count} {state}" for state, count in sorted(state_counts.items())) + logger.info(f"[Worker {worker_id}] No tasks found or profiles available. Pool status ({profile_prefix}*): {states_summary}. Pausing for {polling_interval}s. (Streak: {no_task_streak})") + else: + logger.info(f"[Worker {worker_id}] No tasks found or profiles available. No profiles found with prefix '{profile_prefix}'. Pausing for {polling_interval}s. (Streak: {no_task_streak})") + # --- End diagnostic logging --- + time.sleep(polling_interval) + if state_manager.shutdown_event.is_set(): continue + + # 1. Find a task and lock its associated profile + locked_profile, claimed_task_path_host = find_task_and_lock_profile( + profile_manager_instance, owner_id, profile_prefix, policy, worker_id + ) + + if not locked_profile: + no_task_streak += 1 + # The main loop will pause if the streak continues. + continue + + profile_name = locked_profile['name'] + # We have a task and a lock. + + # User-Agent is not used for download simulation. + user_agent = None + + if claimed_task_path_host: + no_task_streak = 0 + auth_profile_name, auth_env = None, None + info_data = None + + # --- Read info.json content and metadata first --- + try: + with open(claimed_task_path_host, 'r', encoding='utf-8') as f: + info_data = json.load(f) + # This is critical for decrementing the counter in the finally block + metadata = info_data.get('_ytops_metadata', {}) + auth_profile_name = metadata.get('profile_name') + auth_env = metadata.get('auth_env') + except (IOError, json.JSONDecodeError) as e: + logger.error(f"CRITICAL: Could not read or parse task file '{claimed_task_path_host.name}': {e}. This task will be skipped, but the pending downloads counter CANNOT be decremented.") + continue # Skip to finally block to unlock profile + + # --- Check for URL expiration before running Docker --- + if d_policy.get('check_url_expiration', True): + # Heuristic: check the first available format URL + first_format = next((f for f in info_data.get('formats', []) if 'url' in f), None) + if first_format: + url_to_check = first_format['url'] + time_shift_minutes = d_policy.get('expire_time_shift_minutes', 0) + status, time_left_seconds = sp_utils.check_url_expiry(url_to_check, time_shift_minutes) + + logger.debug(f"[Worker {worker_id}] [{profile_name}] URL expiration check for task '{claimed_task_path_host.name}': status={status}, time_left={time_left_seconds:.0f}s") + + if status == 'expired': + details = "Download URL is expired" + if time_shift_minutes > 0 and time_left_seconds > 0: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Skipping task '{claimed_task_path_host.name}' because its URL will expire in {time_left_seconds/60:.1f}m (within {time_shift_minutes}m time-shift).") + details = f"URL will expire within {time_shift_minutes}m time-shift" + else: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Skipping task '{claimed_task_path_host.name}' because its URL is expired.") + + profile_manager_instance.record_activity(profile_name, 'tolerated_error') + + event = { + 'type': 'direct_docker_download', 'profile': profile_name, + 'proxy_url': locked_profile['proxy'], 'success': False, + 'error_type': 'Skipped (Expired URL)', 'details': details, + 'is_tolerated_error': True + } + state_manager.log_event(event) + + try: + base_path_str = str(claimed_task_path_host).rsplit('.LOCKED.', 1)[0] + processed_path = Path(f"{base_path_str}.processed") + claimed_task_path_host.rename(processed_path) + logger.debug(f"Renamed expired task file to '{processed_path.name}'.") + except (OSError, IndexError) as e: + logger.error(f"Failed to rename expired task file '{claimed_task_path_host}': {e}") + + continue # Skip to the finally block + + # The path to the task file inside the container needs to be relative to the host mount root. + # We must make the task path absolute first to correctly calculate the relative path from the absolute mount path. + relative_task_path = os.path.relpath(os.path.abspath(claimed_task_path_host), host_mount_path) + task_path_container = os.path.join(container_mount_path, relative_task_path) + + # 3. Prepare config file on host in a temporary directory + temp_config_dir_host = tempfile.mkdtemp(prefix=f"docker-dl-config-{worker_id}-", dir=host_mount_path) + config_dir_name = os.path.basename(temp_config_dir_host) + config_dir_container = os.path.join(container_mount_path, config_dir_name) + + environment = {'XDG_CONFIG_HOME': config_dir_container} + + base_config_content = "" + base_config_file = direct_policy.get('ytdlp_config_file') + if base_config_file: + config_path_to_read = Path(base_config_file) + if not config_path_to_read.exists(): + config_path_to_read = Path(sp_utils._PROJECT_ROOT) / base_config_file + if config_path_to_read.exists(): + try: + with open(config_path_to_read, 'r', encoding='utf-8') as base_f: + base_config_content = base_f.read() + except IOError as e: + logger.error(f"[Worker {worker_id}] Could not read ytdlp_config_file '{config_path_to_read}': {e}") + + config_overrides = direct_policy.get('ytdlp_config_overrides', {}).copy() + config_overrides['proxy'] = locked_profile['proxy'] + config_overrides['load-info-json'] = task_path_container + config_overrides['output'] = os.path.join(container_download_path, '%(id)s.f%(format_id)s.%(ext)s') + + # Prevent yt-dlp from using a cache directory. + config_overrides['no-cache-dir'] = True + + overrides_content = sp_utils._config_dict_to_flags_file_content(config_overrides) + raw_args_from_policy = direct_policy.get('ytdlp_raw_args', []) + raw_args_content = '\n'.join(raw_args_from_policy) + + config_content = f"{base_config_content.strip()}\n\n# --- Overrides from policy ---\n{overrides_content}" + if raw_args_content: + config_content += f"\n\n# --- Raw args from policy ---\n{raw_args_content}" + + logger.info(f"[Worker {worker_id}] [{profile_name}] Generated yt-dlp config:\n---config---\n{config_content}\n------------") + + ytdlp_config_dir_host = os.path.join(temp_config_dir_host, 'yt-dlp') + os.makedirs(ytdlp_config_dir_host, exist_ok=True) + temp_config_file_host = os.path.join(ytdlp_config_dir_host, 'config') + with open(temp_config_file_host, 'w', encoding='utf-8') as f: + f.write(config_content) + + # 4. Construct and run docker run command + volumes = { + os.path.abspath(host_mount_path): {'bind': container_mount_path, 'mode': 'ro'}, + os.path.abspath(host_download_path): {'bind': container_download_path, 'mode': 'rw'} + } + # The command tells yt-dlp where to find the config file we created. + # We still set XDG_CONFIG_HOME for any other config it might look for. + command = ['yt-dlp', '--config-locations', os.path.join(config_dir_container, 'yt-dlp/config')] + logger.info(f"[Worker {worker_id}] [{profile_name}] Running docker command: {' '.join(shlex.quote(s) for s in command)}") + + # For logging purposes, construct the full equivalent command line with host paths + log_config_overrides_for_host = config_overrides.copy() + log_config_overrides_for_host['load-info-json'] = str(claimed_task_path_host) + log_config_overrides_for_host['output'] = os.path.join(host_download_path, '%(id)s.f%(format_id)s.%(ext)s') + + log_command_override = ['yt-dlp'] + if base_config_content: + log_command_override.extend(sp_utils._parse_config_file_to_cli_args(base_config_content)) + log_command_override.extend(sp_utils._config_dict_to_cli_flags(log_config_overrides_for_host)) + raw_args_from_policy = direct_policy.get('ytdlp_raw_args', []) + for raw_arg in raw_args_from_policy: + log_command_override.extend(shlex.split(raw_arg)) + + # --- Live log parsing and activity recording --- + live_success_count = 0 + live_failure_count = 0 + live_tolerated_count = 0 + activity_lock = threading.Lock() + + tolerated_error_patterns = direct_policy.get('tolerated_error_patterns', []) + fatal_error_patterns = direct_policy.get('fatal_error_patterns', []) + + def log_parser_callback(line): + nonlocal live_success_count, live_failure_count, live_tolerated_count, was_banned_by_parser + + # Success is a high-priority check. Only record one success per task. + if '[download] 100% of' in line or 'has already been downloaded' in line: + with activity_lock: + # Only count one success per task + if live_success_count == 0: + live_success_count += 1 + logger.info(f"[Worker {worker_id}] [{profile_name}] Live download success detected from log.") + profile_manager_instance.record_activity(profile_name, 'download') + return False + + # Check for fatal patterns + for pattern in fatal_error_patterns: + if re.search(pattern, line, re.IGNORECASE): + with activity_lock: + live_failure_count += 1 + logger.error(f"[Worker {worker_id}] [{profile_name}] Live FATAL download error #{live_failure_count} detected from log: {line}") + profile_manager_instance.record_activity(profile_name, 'download_error') + if direct_policy.get('ban_on_fatal_error_in_batch'): + logger.warning(f"Banning profile '{profile_name}' immediately due to fatal download error to stop container.") + profile_manager_instance.update_profile_state(profile_name, 'BANNED', 'Fatal error during download') + was_banned_by_parser = True + return True # Signal to stop container + return False # Do not stop if ban_on_fatal_error_in_batch is false + + # Only process lines that contain ERROR: for tolerated/generic failures + if 'ERROR:' not in line: + return False + + # Check if it's a tolerated error + for pattern in tolerated_error_patterns: + if re.search(pattern, line, re.IGNORECASE): + with activity_lock: + live_tolerated_count += 1 + logger.warning(f"[Worker {worker_id}] [{profile_name}] Live TOLERATED download error #{live_tolerated_count} detected from log: {line}") + profile_manager_instance.record_activity(profile_name, 'tolerated_error') + return False + + # If it's an ERROR: line and not tolerated, it's a failure + with activity_lock: + live_failure_count += 1 + logger.warning(f"[Worker {worker_id}] [{profile_name}] Live download failure #{live_failure_count} detected from log: {line}") + profile_manager_instance.record_activity(profile_name, 'download_error') + + return False + + retcode, stdout, stderr, stop_reason = run_docker_container( + image_name=image_name, + command=command, + volumes=volumes, + stream_prefix=f"[Worker {worker_id} | docker-ytdlp] ", + network_name=network_name, + log_callback=log_parser_callback, + profile_manager=profile_manager_instance, + profile_name=profile_name, + environment=environment, + log_command_override=log_command_override + ) + + # 5. Post-process and record activity + full_output = f"{stdout}\n{stderr}" + is_bot_error = "Sign in to confirm you're not a bot" in full_output + if is_bot_error: + logger.warning(f"[Worker {worker_id}] [{profile_name}] Bot detection occurred during download. Marking as failure.") + + # --- Final Outcome Determination --- + # Activity is now recorded live by the log parser. This block just determines + # the overall success/failure for logging and event reporting. + success = False + final_outcome = "unknown" + with activity_lock: + if live_success_count > 0: + success = True + final_outcome = "download" + elif live_failure_count > 0 or is_bot_error: + final_outcome = "download_error" + elif live_tolerated_count > 0: + final_outcome = "tolerated_error" + elif retcode == 0: + # Fallback if no logs were matched but exit was clean. + success = True + final_outcome = "download" + logger.warning(f"[Worker {worker_id}] [{profile_name}] No specific success/error log line matched, but exit code is 0. Assuming success, but this may indicate a parsing issue.") + # We record a success here as a fallback, in case the log parser missed it. + profile_manager_instance.record_activity(profile_name, 'download') + else: + final_outcome = "download_error" + logger.warning(f"[Worker {worker_id}] [{profile_name}] No specific error log line matched, but exit code was {retcode}. Recording a generic download_error.") + profile_manager_instance.record_activity(profile_name, 'download_error') + + # --- Airflow Directory Logic --- + if success and d_policy.get('output_to_airflow_ready_dir'): + # Find the downloaded file path from yt-dlp's output + downloaded_filename = None + # Order of checks is important: Merger -> VideoConvertor -> Destination + merge_match = re.search(r'\[Merger\] Merging formats into "([^"]+)"', stdout) + if merge_match: + downloaded_filename = os.path.basename(merge_match.group(1)) + else: + convertor_match = re.search(r'\[VideoConvertor\].*?; Destination: (.*)', stdout) + if convertor_match: + downloaded_filename = os.path.basename(convertor_match.group(1).strip()) + else: + dest_match = re.search(r'\[download\] Destination: (.*)', stdout) + if dest_match: + downloaded_filename = os.path.basename(dest_match.group(1).strip()) + + if downloaded_filename: + try: + # Get video_id from the info.json + with open(claimed_task_path_host, 'r', encoding='utf-8') as f: + info_data = json.load(f) + video_id = info_data.get('id') + + if not video_id: + logger.error(f"[{profile_name}] Could not find video ID in '{claimed_task_path_host.name}' for moving file.") + else: + now = datetime.now() + rounded_minute = (now.minute // 10) * 10 + timestamp_str = now.strftime('%Y%m%dT%H') + f"{rounded_minute:02d}" + + base_path = d_policy.get('airflow_ready_dir_base_path', 'downloadfiles/videos/ready') + if not os.path.isabs(base_path): + base_path = os.path.join(sp_utils._PROJECT_ROOT, base_path) + final_dir_base = os.path.join(base_path, timestamp_str) + final_dir_path = os.path.join(final_dir_base, video_id) + + os.makedirs(final_dir_path, exist_ok=True) + + downloaded_file_host_path = os.path.join(host_download_path, downloaded_filename) + if os.path.exists(downloaded_file_host_path): + shutil.move(downloaded_file_host_path, final_dir_path) + logger.info(f"[{profile_name}] Moved media file to {final_dir_path}") + + new_info_json_name = f"info_{video_id}.json" + dest_info_json_path = os.path.join(final_dir_path, new_info_json_name) + shutil.copy(claimed_task_path_host, dest_info_json_path) + logger.info(f"[{profile_name}] Copied info.json to {dest_info_json_path}") + except Exception as e: + logger.error(f"[{profile_name}] Failed to move downloaded file to Airflow ready directory: {e}") + else: + logger.warning(f"[{profile_name}] Download succeeded, but could not parse final filename from output to move to Airflow dir.") + + event_details = f"Docker download finished. Exit: {retcode}. Final Outcome: {final_outcome}. (Live successes: {live_success_count}, Live failures: {live_failure_count}, Live tolerated: {live_tolerated_count})" + if not success and stderr: + event_details += f" Stderr: {stderr.strip().splitlines()[-1] if stderr.strip() else 'N/A'}" + if stop_reason: + event_details += f" Aborted: {stop_reason}." + + event = { 'type': 'direct_docker_download', 'profile': profile_name, 'proxy_url': locked_profile['proxy'], 'success': success, 'details': event_details } + state_manager.log_event(event) + + logger.info(f"[Worker {worker_id}] [{profile_name}] Task processing complete. Worker will now unlock profile and attempt next task.") + + # 6. Clean up task file by renaming to .processed + try: + # The claimed_task_path_host has a .LOCKED suffix, remove it before adding .processed + base_path_str = str(claimed_task_path_host).rsplit('.LOCKED.', 1)[0] + processed_path = Path(f"{base_path_str}.processed") + claimed_task_path_host.rename(processed_path) + logger.debug(f"[{sp_utils.get_display_name(claimed_task_path_host)}] Renamed processed task file to '{processed_path.name}'.") + except (OSError, IndexError) as e: + logger.error(f"Failed to rename processed task file '{claimed_task_path_host}': {e}") + # After this point, claimed_task_path_host is no longer valid. + # The metadata has already been read into auth_profile_name and auth_env. + else: + # This case should not be reached with the new task-first locking logic. + logger.warning(f"[Worker {worker_id}] Inconsistent state: locked profile '{profile_name}' but no task was claimed. Unlocking and continuing.") + + except Exception as e: + logger.error(f"[Worker {worker_id}] An unexpected error occurred in the worker loop: {e}", exc_info=True) + if locked_profile: + profile_manager_instance.record_activity(locked_profile['name'], 'failure') # Generic failure + time.sleep(5) + finally: + if locked_profile: + if claimed_task_path_host: + # The auth_profile_name and auth_env variables were populated in the `try` block + # before the task file was renamed or deleted. + if auth_profile_name and auth_env: + auth_manager = _get_auth_manager(profile_manager_instance, auth_env) + if auth_manager: + auth_manager.decrement_pending_downloads(auth_profile_name) + else: + logger.error(f"Could not get auth profile manager for env '{auth_env}'. Pending downloads counter will not be decremented.") + else: + logger.warning(f"Could not find auth profile name and/or auth_env in info.json metadata. Pending downloads counter will not be decremented. (Profile: {auth_profile_name}, Env: {auth_env})") + + if was_banned_by_parser: + logger.info(f"[Worker {worker_id}] Profile '{locked_profile['name']}' was already banned by the log parser. Skipping unlock/cooldown.") + else: + last_used_profile_name = locked_profile['name'] + cooldown = None + # Only apply cooldown if a task was actually claimed and processed. + if claimed_task_path_host: + # Enforcer is the only point where we configure to apply different policies, + # since we might restart enforcer, but won't restart stress-policy working on auth and downloads simultaneously. + # This is like applying a policy across multiple workers/machines without needing to restart each of them. + # DESIGN: The cooldown duration is not configured in the worker's policy. + # Instead, it is read from a central Redis key. This key is set by the + # policy-enforcer, making the enforcer the single source of truth for + # this policy. This allows changing the cooldown behavior without + # restarting the workers. + cooldown_source_value = profile_manager_instance.get_config('unlock_cooldown_seconds') + source_description = "Redis config" + + if cooldown_source_value is None: + cooldown_source_value = d_policy.get('default_unlock_cooldown_seconds') + source_description = "local policy" + + if cooldown_source_value is not None: + try: + # If from Redis, it's a string that needs parsing. + # If from local policy, it's already an int or list. + val = cooldown_source_value + if isinstance(val, str): + val = json.loads(val) + + if isinstance(val, list) and len(val) == 2 and val[0] < val[1]: + cooldown = random.randint(val[0], val[1]) + elif isinstance(val, int): + cooldown = val + + if cooldown is not None: + logger.debug(f"Determined cooldown from {source_description}: {cooldown_source_value}") + + except (json.JSONDecodeError, TypeError): + if isinstance(cooldown_source_value, str) and cooldown_source_value.isdigit(): + cooldown = int(cooldown_source_value) + logger.debug(f"Determined cooldown from {source_description}: {cooldown_source_value}") + + if cooldown: + logger.info(f"[Worker {worker_id}] Putting profile '{locked_profile['name']}' into COOLDOWN for {cooldown}s.") + + profile_manager_instance.unlock_profile( + locked_profile['name'], + owner=owner_id, + rest_for_seconds=cooldown + ) + if claimed_task_path_host and os.path.exists(claimed_task_path_host): + try: os.remove(claimed_task_path_host) + except OSError: pass + if temp_config_dir_host and os.path.exists(temp_config_dir_host): + try: + shutil.rmtree(temp_config_dir_host) + except OSError: pass + + logger.info(f"[Worker {worker_id}] Worker loop finished.") + return [] + + +def run_direct_download_worker(worker_id, policy, state_manager, args, profile_manager_instance, running_processes, process_lock): + """A persistent worker for the 'direct_download_cli' orchestration mode.""" + owner_id = f"direct-dl-worker-{worker_id}" + settings = policy.get('settings', {}) + exec_control = policy.get('execution_control', {}) + d_policy = policy.get('download_policy', {}) + direct_policy = policy.get('direct_download_cli_policy', {}) + + profile_prefix = d_policy.get('profile_prefix') + if not profile_prefix: + logger.error(f"[Worker {worker_id}] Direct download mode requires 'download_policy.profile_prefix'. Worker exiting.") + return [] + + output_dir = direct_policy.get('output_dir') + if not output_dir: + logger.error(f"[Worker {worker_id}] Direct download mode requires 'direct_download_cli_policy.output_dir'. Worker exiting.") + return [] + + os.makedirs(output_dir, exist_ok=True) + no_task_streak = 0 + + while not state_manager.shutdown_event.is_set(): + locked_profile = None + claimed_task_path = None + try: + # 0. If no tasks were found, pause briefly. + if no_task_streak > 0: + polling_interval = exec_control.get('worker_polling_interval_seconds', 1) + # --- Add diagnostic logging --- + all_profiles_in_pool = profile_manager_instance.list_profiles() + profiles_in_prefix = [p for p in all_profiles_in_pool if p['name'].startswith(profile_prefix)] + if profiles_in_prefix: + state_counts = collections.Counter(p['state'] for p in profiles_in_prefix) + states_summary = ', '.join(f"{count} {state}" for state, count in sorted(state_counts.items())) + logger.info(f"[Worker {worker_id}] No tasks found for available profiles. Pool status ({profile_prefix}*): {states_summary}. Pausing for {polling_interval}s. (Streak: {no_task_streak})") + else: + logger.info(f"[Worker {worker_id}] No tasks found for available profiles. No profiles found with prefix '{profile_prefix}'. Pausing for {polling_interval}s. (Streak: {no_task_streak})") + # --- End diagnostic logging --- + time.sleep(polling_interval) + if state_manager.shutdown_event.is_set(): continue + + # 1. Find a task and lock its associated profile + locked_profile, claimed_task_path = find_task_and_lock_profile( + profile_manager_instance, owner_id, profile_prefix, policy, worker_id + ) + + if not locked_profile: + no_task_streak += 1 + # The main loop will pause if the streak continues. + continue + + profile_name = locked_profile['name'] + # We have a task and a lock. + + if claimed_task_path: + no_task_streak = 0 # Reset streak + auth_profile_name, auth_env = None, None + + # --- Read metadata before processing/deleting file --- + try: + with open(claimed_task_path, 'r', encoding='utf-8') as f: + info_data = json.load(f) + metadata = info_data.get('_ytops_metadata', {}) + auth_profile_name = metadata.get('profile_name') + auth_env = metadata.get('auth_env') + except (IOError, json.JSONDecodeError) as e: + logger.error(f"Could not read info.json to get auth profile for decrementing counter: {e}") + + # 3. Construct and run the command + ytdlp_cmd_str = direct_policy.get('ytdlp_command') + if not ytdlp_cmd_str: + logger.error(f"[Worker {worker_id}] Direct download mode requires 'direct_download_cli_policy.ytdlp_command'.") + break + + proxy_url = locked_profile['proxy'] + proxy_rename = direct_policy.get('proxy_rename') + if proxy_rename: + rename_rule = proxy_rename.strip("'\"") + if rename_rule.startswith('s/') and rename_rule.count('/') >= 2: + try: + parts = rename_rule.split('/') + proxy_url = re.sub(parts[1], parts[2], proxy_url) + except (re.error, IndexError): + logger.error(f"[Worker {worker_id}] Invalid proxy_rename rule: {proxy_rename}") + + output_template = os.path.join(output_dir, '%(title)s - %(id)s.%(ext)s') + + cmd = shlex.split(ytdlp_cmd_str) + cmd.extend(['--load-info-json', str(claimed_task_path)]) + cmd.extend(['--proxy', proxy_url]) + cmd.extend(['-o', output_template]) + + ytdlp_args = direct_policy.get('ytdlp_args') + if ytdlp_args: + cmd.extend(shlex.split(ytdlp_args)) + + if args.verbose and '--verbose' not in cmd: + cmd.append('--verbose') + + custom_env = direct_policy.get('env_vars', {}).copy() + + # --- PYTHONPATH for custom yt-dlp module --- + ytdlp_module_path = direct_policy.get('ytdlp_module_path') + if ytdlp_module_path: + existing_pythonpath = custom_env.get('PYTHONPATH', os.environ.get('PYTHONPATH', '')) + custom_env['PYTHONPATH'] = f"{ytdlp_module_path}{os.pathsep}{existing_pythonpath}".strip(os.pathsep) + logger.debug(f"[Worker {worker_id}] Using custom PYTHONPATH: {custom_env['PYTHONPATH']}") + + # Pass profile info to the custom yt-dlp process + custom_env['YTDLP_PROFILE_NAME'] = profile_name + custom_env['YTDLP_PROXY_URL'] = locked_profile['proxy'] # Original proxy + env_name = profile_manager_instance.key_prefix.replace('_profile_mgmt_', '') + custom_env['YTDLP_SIM_MODE'] = env_name + + # Create a per-profile cache directory and set XDG_CACHE_HOME + cache_dir_base = direct_policy.get('cache_dir_base', '.cache') + profile_cache_dir = os.path.join(cache_dir_base, profile_name) + try: + os.makedirs(profile_cache_dir, exist_ok=True) + custom_env['XDG_CACHE_HOME'] = profile_cache_dir + except OSError as e: + logger.error(f"[Worker {worker_id}] Failed to create cache directory '{profile_cache_dir}': {e}") + + logger.info(f"[Worker {worker_id}] [{profile_name}] Processing task '{claimed_task_path.name}'...") + if args.dummy: + logger.info(f"========== [Worker {worker_id}] BEGIN DUMMY DIRECT DOWNLOAD ==========") + logger.info(f"[Worker {worker_id}] Profile: {profile_name} | Task: {claimed_task_path.name}") + logger.info(f"[Worker {worker_id}] Would run command: {' '.join(shlex.quote(s) for s in cmd)}") + logger.info(f"[Worker {worker_id}] With environment: {custom_env}") + + dummy_settings = policy.get('settings', {}).get('dummy_simulation_settings', {}) + min_seconds = dummy_settings.get('download_min_seconds', 0.5) + max_seconds = dummy_settings.get('download_max_seconds', 1.5) + failure_rate = dummy_settings.get('download_failure_rate', 0.0) + skipped_rate = dummy_settings.get('download_skipped_failure_rate', 0.0) + + time.sleep(random.uniform(min_seconds, max_seconds)) + + rand_val = random.random() + should_fail_skipped = rand_val < skipped_rate + should_fail_fatal = not should_fail_skipped and rand_val < (skipped_rate + failure_rate) + + if should_fail_skipped: + logger.warning(f"[Worker {worker_id}] DUMMY: Simulating skipped download failure.") + # A skipped/tolerated failure in yt-dlp usually results in exit code 0. + # The orchestrator will see this as a success but the stderr can be used for context. + retcode = 0 + stderr = "Dummy skipped failure" + elif should_fail_fatal: + logger.warning(f"[Worker {worker_id}] DUMMY: Simulating fatal download failure.") + retcode = 1 + stderr = "Dummy fatal failure" + else: + logger.info(f"[Worker {worker_id}] DUMMY: Simulating download success.") + retcode = 0 + stderr = "" + logger.info(f"========== [Worker {worker_id}] END DUMMY DIRECT DOWNLOAD ==========") + else: + logger.info(f"[Worker {worker_id}] [{profile_name}] Running command: {' '.join(shlex.quote(s) for s in cmd)}") + logger.info(f"[Worker {worker_id}] [{profile_name}] With environment: {custom_env}") + retcode, stdout, stderr = run_command( + cmd, running_processes, process_lock, env=custom_env, stream_output=args.verbose, + stream_prefix=f"[Worker {worker_id} | yt-dlp] " + ) + + # 4. Record activity + success = (retcode == 0) + activity_type = 'download' if success else 'download_error' + logger.info(f"[Worker {worker_id}] Recording '{activity_type}' for profile '{profile_name}'.") + profile_manager_instance.record_activity(profile_name, activity_type) + + event_details = f"Download finished. Exit code: {retcode}." + if not success and stderr: + event_details += f" Stderr: {stderr.strip().splitlines()[-1]}" + + event = {'type': 'direct_download', 'profile': profile_name, 'proxy_url': proxy_url, 'success': success, 'details': event_details} + state_manager.log_event(event) + + # 5. Clean up the processed task file + try: + os.remove(claimed_task_path) + logger.debug(f"[{sp_utils.get_display_name(claimed_task_path)}] Removed processed task file.") + except OSError as e: + logger.error(f"Failed to remove processed task file '{claimed_task_path}': {e}") + else: + no_task_streak += 1 + logger.info(f"[Worker {worker_id}] No tasks found for profile '{profile_name}'.") + + except Exception as e: + logger.error(f"[Worker {worker_id}] An unexpected error occurred in the worker loop: {e}", exc_info=True) + if locked_profile: + profile_manager_instance.record_activity(locked_profile['name'], 'failure') # Generic failure + time.sleep(5) + finally: + if locked_profile: + if claimed_task_path: + # The auth_profile_name and auth_env variables were populated in the `try` block + # before the task file was deleted. + if auth_profile_name and auth_env: + auth_manager = _get_auth_manager(profile_manager_instance, auth_env) + if auth_manager: + auth_manager.decrement_pending_downloads(auth_profile_name) + else: + logger.error(f"Could not get auth profile manager for env '{auth_env}'. Pending downloads counter will not be decremented.") + else: + logger.warning(f"Could not find auth profile name and/or auth_env in info.json metadata. Pending downloads counter will not be decremented. (Profile: {auth_profile_name}, Env: {auth_env})") + + cooldown = None + if claimed_task_path: + # Enforcer is the only point where we configure to apply different policies, + # since we might restart enforcer, but won't restart stress-policy working on auth and downloads simultaneously. + # This is like applying a policy across multiple workers/machines without needing to restart each of them. + # DESIGN: The cooldown duration is not configured in the worker's policy. + # Instead, it is read from a central Redis key. This key is set by the + # policy-enforcer, making the enforcer the single source of truth for + # this policy. This allows changing the cooldown behavior without + # restarting the workers. + cooldown_config = profile_manager_instance.get_config('unlock_cooldown_seconds') + if cooldown_config: + try: + val = json.loads(cooldown_config) + if isinstance(val, list) and len(val) == 2 and val[0] < val[1]: + cooldown = random.randint(val[0], val[1]) + elif isinstance(val, int): + cooldown = val + except (json.JSONDecodeError, TypeError): + if cooldown_config.isdigit(): + cooldown = int(cooldown_config) + + if cooldown: + logger.info(f"[Worker {worker_id}] Putting profile '{locked_profile['name']}' into COOLDOWN for {cooldown}s.") + + profile_manager_instance.unlock_profile( + locked_profile['name'], + owner=owner_id, + rest_for_seconds=cooldown + ) + locked_profile = None + + logger.info(f"[Worker {worker_id}] Worker loop finished.") + return [] diff --git a/ytops_client/stress_policy_tool.py b/ytops_client/stress_policy_tool.py index 069282a..1741298 100644 --- a/ytops_client/stress_policy_tool.py +++ b/ytops_client/stress_policy_tool.py @@ -1,11 +1,77 @@ #!/usr/bin/env python3 """ + +Architectural Overview for the Stress Policy Tool: + + + +This file, stress_policy_tool.py, is the main entry point and orchestrator. It is responsible for: + +- Parsing command-line arguments. + +- Setting up logging and the main shutdown handler. + +- Initializing the StateManager and ProfileManager. + +- Running the main execution loop (ThreadPoolExecutor) based on the chosen orchestration mode. + +- Delegating the actual work to functions in the `workers.py` module. + + + +The core logic has been refactored into the following modules within `ytops_client/stress_policy/`: + + + +- arg_parser.py: Defines the command-line interface for the 'stress-policy' command using argparse. + +- workers.py: Contains all core worker functions that are executed by the ThreadPoolExecutor, such as `process_task`, `run_direct_batch_worker`, and their helpers. This is where the main logic for fetching info.json +and running downloads resides. + +- state_manager.py: Manages run state, statistics, rate limits, and persistence between runs (e.g., `_state.json`, `_stats.jsonl`). + +- process_runners.py: A low-level module that handles the execution of external subprocesses (`run_command`) and Docker containers (`run_docker_container`). + +- utils.py: Provides stateless utility functions shared across the tool, such as loading YAML policies, applying overrides, and formatting. + +""" +""" Policy-driven stress-testing orchestrator for video format downloads. + +This tool orchestrates complex, multi-stage stress tests based on a YAML policy file. +It supports several modes of operation: + +- full_stack: A complete workflow that first fetches an info.json for a given URL + using a profile, and then uses that info.json to perform one or more downloads. + +- fetch_only: Only performs the info.json generation step. This is useful for + simulating user authentication and browsing behavior. + +- download_only: Only performs the download step, using a directory of pre-existing + info.json files as its source. + +- direct_batch_cli (fetch_only): A high-throughput mode for generating info.json files + by calling a custom, Redis-aware yt-dlp command-line tool directly in batch mode. + This mode bypasses the get-info Thrift service. The workflow is as follows: + 1. The orchestrator worker locks a profile from the auth pool. + 2. It takes a 'batch' of URLs from the source file. + 3. It invokes the configured yt-dlp command, passing the profile name and proxy via + environment variables. + 4. The custom yt-dlp process then does the following for each URL in the batch: + a. Checks Redis to ensure the profile has not been externally BANNED. + b. Fetches the info.json. + c. Records 'success', 'failure', or 'tolerated_error' for the profile in Redis. + 5. After the yt-dlp process finishes, the orchestrator worker post-processes the + generated info.json files to inject metadata (profile name, proxy). + 6. The worker unlocks the profile. + 7. The worker repeats this cycle with a new profile and the next batch of URLs. + +The tool uses a profile management system (v2) based on Redis for coordinating +state between multiple workers and enforcing policies (e.g., rate limits, cooldowns). """ import argparse import collections -import collections.abc import concurrent.futures import json import logging @@ -14,20 +80,37 @@ import random import re import shlex import signal -import subprocess import sys +import tempfile +import shutil import threading import time from copy import deepcopy from datetime import datetime, timezone from pathlib import Path -from urllib.parse import urlparse, parse_qs try: - import yaml + from dotenv import load_dotenv except ImportError: - print("PyYAML is not installed. Please install it with: pip install PyYAML", file=sys.stderr) - sys.exit(1) + load_dotenv = None + +try: + import docker +except ImportError: + docker = None + + +from .profile_manager_tool import ProfileManager +from .stress_policy.state_manager import StateManager +from .stress_policy.process_runners import run_command, run_docker_container, get_worker_id +from .stress_policy import utils as sp_utils +from .stress_policy.workers import ( + _run_download_logic, process_profile_task, run_download_worker, process_info_json_cycle, + run_throughput_worker, _post_process_and_move_info_json, run_direct_batch_worker, + run_direct_docker_worker, find_task_and_lock_profile, run_direct_docker_download_worker, + run_direct_download_worker +) +from .stress_policy.arg_parser import add_stress_policy_parser # Add a global event for graceful shutdown shutdown_event = threading.Event() @@ -36,1746 +119,14 @@ shutdown_event = threading.Event() running_processes = set() process_lock = threading.Lock() -# Globals for assigning a stable ID to each worker thread -worker_id_map = {} -worker_id_counter = 0 -worker_id_lock = threading.Lock() - # Configure logging logger = logging.getLogger('stress_policy_tool') -def get_worker_id(): - """Assigns a stable, sequential ID to each worker thread.""" - global worker_id_counter - thread_id = threading.get_ident() - with worker_id_lock: - if thread_id not in worker_id_map: - worker_id_map[thread_id] = worker_id_counter - worker_id_counter += 1 - return worker_id_map[thread_id] - - -def get_video_id(url: str) -> str: - """Extracts a YouTube video ID from a URL.""" - match = re.search(r"v=([0-9A-Za-z_-]{11})", url) - if match: - return match.group(1) - match = re.search(r"youtu\.be\/([0-9A-Za-z_-]{11})", url) - if match: - return match.group(1) - if re.fullmatch(r'[0-9A-Za-z_-]{11}', url): - return url - return "unknown_video_id" - - -def get_display_name(path_or_url): - """Returns a clean name for logging, either a filename or a video ID.""" - if isinstance(path_or_url, Path): - return path_or_url.name - - path_str = str(path_or_url) - video_id = get_video_id(path_str) - if video_id != "unknown_video_id": - return video_id - - return Path(path_str).name - - -def format_size(b): - """Format size in bytes to human-readable string.""" - if b is None: - return 'N/A' - if b < 1024: - return f"{b}B" - elif b < 1024**2: - return f"{b/1024:.2f}KiB" - elif b < 1024**3: - return f"{b/1024**2:.2f}MiB" - else: - return f"{b/1024**3:.2f}GiB" - - -def flatten_dict(d, parent_key='', sep='.'): - """Flattens a nested dictionary.""" - items = {} - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, collections.abc.MutableMapping): - items.update(flatten_dict(v, new_key, sep=sep)) - else: - items[new_key] = v - return items - - -def print_policy_overrides(policy): - """Prints all policy values as a single-line of --set arguments.""" - # We don't want to include the 'name' key in the overrides. - policy_copy = deepcopy(policy) - policy_copy.pop('name', None) - - flat_policy = flatten_dict(policy_copy) - - set_args = [] - for key, value in sorted(flat_policy.items()): - if value is None: - value_str = 'null' - elif isinstance(value, bool): - value_str = str(value).lower() - elif isinstance(value, (list, dict)): - # Use compact JSON for lists/dicts - value_str = json.dumps(value, separators=(',', ':')) - else: - value_str = str(value) - - # Use shlex.quote to handle spaces and special characters safely - set_args.append(f"--set {shlex.quote(f'{key}={value_str}')}") - - print(' '.join(set_args)) - - -def get_profile_from_filename(path, regex_pattern): - """Extracts a profile name from a filename using a regex.""" - if not regex_pattern: - return None - match = re.search(regex_pattern, path.name) - if match: - # Assume the first capturing group is the profile name - if match.groups(): - return match.group(1) - return None - - -class StateManager: - """Tracks statistics, manages rate limits, and persists state across runs.""" - def __init__(self, policy_name, disable_log_writing=False): - self.disable_log_writing = disable_log_writing - self.state_file_path = Path(f"{policy_name}_state.json") - self.stats_file_path = Path(f"{policy_name}_stats.jsonl") - self.lock = threading.RLock() - self.start_time = time.time() - self.events = [] - self.state = { - 'global_request_count': 0, - 'rate_limit_trackers': {}, # e.g., {'per_ip': [ts1, ts2], 'profile_foo': [ts3, ts4]} - 'profile_request_counts': {}, # for client rotation - 'profile_last_refresh_time': {}, # for client rotation - 'proxy_last_finish_time': {}, # for per-proxy sleep - 'processed_files': [], # For continuous download_only mode - # For dynamic profile cooldown strategy - 'profile_cooldown_counts': {}, - 'profile_cooldown_sleep_until': {}, - 'profile_pool_size': 0, - 'profile_run_suffix': None, - 'worker_profile_generations': {} - } - self.stats_file_handle = None - self._load_state() - self.print_historical_summary() - self._open_stats_log() - - def _load_state(self): - if self.disable_log_writing: - logger.info("Log writing is disabled. State will not be loaded from disk.") - return - if not self.state_file_path.exists(): - logger.info(f"State file not found at '{self.state_file_path}', starting fresh.") - return - try: - with open(self.state_file_path, 'r', encoding='utf-8') as f: - self.state = json.load(f) - # Ensure keys exist - self.state.setdefault('global_request_count', 0) - self.state.setdefault('rate_limit_trackers', {}) - self.state.setdefault('profile_request_counts', {}) - self.state.setdefault('profile_last_refresh_time', {}) - self.state.setdefault('proxy_last_finish_time', {}) - self.state.setdefault('processed_files', []) - # For dynamic profile cooldown strategy - self.state.setdefault('profile_cooldown_counts', {}) - self.state.setdefault('profile_cooldown_sleep_until', {}) - self.state.setdefault('profile_pool_size', 0) - self.state.setdefault('profile_run_suffix', None) - self.state.setdefault('worker_profile_generations', {}) - logger.info(f"Loaded state from {self.state_file_path}") - except (IOError, json.JSONDecodeError) as e: - logger.error(f"Could not load or parse state file {self.state_file_path}: {e}. Starting fresh.") - - def _save_state(self): - if self.disable_log_writing: - return - with self.lock: - try: - with open(self.state_file_path, 'w', encoding='utf-8') as f: - json.dump(self.state, f, indent=2) - logger.info(f"Saved state to {self.state_file_path}") - except IOError as e: - logger.error(f"Could not save state to {self.state_file_path}: {e}") - - def _open_stats_log(self): - if self.disable_log_writing: - return - try: - self.stats_file_handle = open(self.stats_file_path, 'a', encoding='utf-8') - except IOError as e: - logger.error(f"Could not open stats file {self.stats_file_path}: {e}") - - def close(self): - """Saves state and closes file handles.""" - self._save_state() - if self.stats_file_handle: - self.stats_file_handle.close() - self.stats_file_handle = None - - def mark_file_as_processed(self, file_path): - """Adds a file path to the list of processed files in the state.""" - with self.lock: - # Using a list and checking for existence is fine for moderate numbers of files. - # A set isn't JSON serializable. - processed = self.state.setdefault('processed_files', []) - file_str = str(file_path) - if file_str not in processed: - processed.append(file_str) - - def get_processed_files(self): - """Returns a set of file paths that have been processed.""" - with self.lock: - return set(self.state.get('processed_files', [])) - - def print_historical_summary(self): - """Prints a summary based on the state loaded from disk, before new events.""" - with self.lock: - now = time.time() - rate_trackers = self.state.get('rate_limit_trackers', {}) - total_requests = self.state.get('global_request_count', 0) - - if not rate_trackers and not total_requests: - logger.info("No historical data found in state file.") - return - - logger.info("\n--- Summary From Previous Runs ---") - logger.info(f"Total info.json requests (all previous runs): {total_requests}") - - if rate_trackers: - for key, timestamps in sorted(rate_trackers.items()): - # Time windows in seconds - windows = { - 'last 10 min': 600, - 'last 60 min': 3600, - 'last 6 hours': 21600, - 'last 24 hours': 86400 - } - - rates_str_parts = [] - for name, seconds in windows.items(): - count = sum(1 for ts in timestamps if now - ts <= seconds) - # Calculate rate in requests per minute - rate_rpm = (count / seconds) * 60 if seconds > 0 else 0 - rates_str_parts.append(f"{count} req in {name} ({rate_rpm:.2f} rpm)") - - logger.info(f"Tracker '{key}': " + ", ".join(rates_str_parts)) - logger.info("------------------------------------") - - def log_event(self, event_data): - with self.lock: - event_data['timestamp'] = datetime.now().isoformat() - self.events.append(event_data) - if self.stats_file_handle: - self.stats_file_handle.write(json.dumps(event_data) + '\n') - self.stats_file_handle.flush() - - def get_request_count(self): - with self.lock: - return self.state.get('global_request_count', 0) - - def increment_request_count(self): - with self.lock: - self.state['global_request_count'] = self.state.get('global_request_count', 0) + 1 - - def check_cumulative_error_rate(self, max_errors, per_minutes, error_type=None): - """ - Checks if a cumulative error rate has been exceeded. - If error_type is None, checks for any failure. - Returns the number of errors found if the threshold is met, otherwise 0. - """ - with self.lock: - now = time.time() - window_seconds = per_minutes * 60 - - if error_type: - recent_errors = [ - e for e in self.events - if e.get('error_type') == error_type and (now - datetime.fromisoformat(e['timestamp']).timestamp()) <= window_seconds - ] - else: # Generic failure check - recent_errors = [ - e for e in self.events - if not e.get('success') and (now - datetime.fromisoformat(e['timestamp']).timestamp()) <= window_seconds - ] - - if len(recent_errors) >= max_errors: - return len(recent_errors) - return 0 - - def check_quality_degradation_rate(self, max_triggers, per_minutes): - """ - Checks if the quality degradation trigger rate has been exceeded. - Returns the number of triggers found if the threshold is met, otherwise 0. - """ - with self.lock: - now = time.time() - window_seconds = per_minutes * 60 - - recent_triggers = [ - e for e in self.events - if e.get('quality_degradation_trigger') and (now - datetime.fromisoformat(e['timestamp']).timestamp()) <= window_seconds - ] - - if len(recent_triggers) >= max_triggers: - return len(recent_triggers) - return 0 - - def check_and_update_rate_limit(self, profile_name, policy): - """ - Checks if a request is allowed based on policy rate limits. - If allowed, updates the internal state. Returns True if allowed, False otherwise. - """ - with self.lock: - now = time.time() - gen_policy = policy.get('info_json_generation_policy', {}) - rate_limits = gen_policy.get('rate_limits', {}) - - # Check per-IP limit - ip_limit = rate_limits.get('per_ip') - if ip_limit: - tracker_key = 'per_ip' - max_req = ip_limit.get('max_requests') - period_min = ip_limit.get('per_minutes') - if max_req and period_min: - timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) - # Filter out old timestamps - timestamps = [ts for ts in timestamps if now - ts < period_min * 60] - if len(timestamps) >= max_req: - logger.warning("Per-IP rate limit reached. Skipping task.") - return False - self.state['rate_limit_trackers'][tracker_key] = timestamps - - # Check per-profile limit - profile_limit = rate_limits.get('per_profile') - if profile_limit and profile_name: - tracker_key = f"profile_{profile_name}" - max_req = profile_limit.get('max_requests') - period_min = profile_limit.get('per_minutes') - if max_req and period_min: - timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) - timestamps = [ts for ts in timestamps if now - ts < period_min * 60] - if len(timestamps) >= max_req: - logger.warning(f"Per-profile rate limit for '{profile_name}' reached. Skipping task.") - return False - self.state['rate_limit_trackers'][tracker_key] = timestamps - - # If all checks pass, record the new request timestamp for all relevant trackers - if ip_limit and ip_limit.get('max_requests'): - self.state['rate_limit_trackers'].setdefault('per_ip', []).append(now) - if profile_limit and profile_limit.get('max_requests') and profile_name: - self.state['rate_limit_trackers'].setdefault(f"profile_{profile_name}", []).append(now) - - return True - - def get_client_for_request(self, profile_name, gen_policy): - """ - Determines which client to use based on the client_rotation_policy. - Returns a tuple: (client_name, request_params_dict). - """ - with self.lock: - rotation_policy = gen_policy.get('client_rotation_policy') - - # If no rotation policy, use the simple 'client' key. - if not rotation_policy: - client = gen_policy.get('client') - logger.info(f"Using client '{client}' for profile '{profile_name}'.") - req_params = gen_policy.get('request_params') - return client, req_params - - # --- Rotation logic --- - now = time.time() - major_client = rotation_policy.get('major_client') - refresh_client = rotation_policy.get('refresh_client') - refresh_every = rotation_policy.get('refresh_every', {}) - - if not refresh_client or not refresh_every: - return major_client, rotation_policy.get('major_client_params') - - should_refresh = False - - # Check time-based refresh - refresh_minutes = refresh_every.get('minutes') - last_refresh_time = self.state['profile_last_refresh_time'].get(profile_name, 0) - if refresh_minutes and (now - last_refresh_time) > (refresh_minutes * 60): - should_refresh = True - - # Check request-count-based refresh - refresh_requests = refresh_every.get('requests') - request_count = self.state['profile_request_counts'].get(profile_name, 0) - if refresh_requests and request_count >= refresh_requests: - should_refresh = True - - if should_refresh: - logger.info(f"Profile '{profile_name}' is due for a refresh. Using refresh client '{refresh_client}'.") - self.state['profile_last_refresh_time'][profile_name] = now - self.state['profile_request_counts'][profile_name] = 0 # Reset counter - return refresh_client, rotation_policy.get('refresh_client_params') - else: - # Not refreshing, so increment request count for this profile - self.state['profile_request_counts'][profile_name] = request_count + 1 - return major_client, rotation_policy.get('major_client_params') - - def get_next_available_profile(self, policy): - """ - Finds or creates an available profile based on the dynamic cooldown policy. - Returns a profile name, or None if no profile is available. - """ - with self.lock: - now = time.time() - settings = policy.get('settings', {}) - pm_policy = settings.get('profile_management') - - if not pm_policy: - return None - - prefix = pm_policy.get('prefix') - if not prefix: - logger.error("Profile management policy requires 'prefix'.") - return None - - # Determine and persist the suffix for this run to ensure profile names are stable - run_suffix = self.state.get('profile_run_suffix') - if not run_suffix: - suffix_config = pm_policy.get('suffix') - if suffix_config == 'auto': - run_suffix = datetime.now().strftime('%Y%m%d%H%M') - else: - run_suffix = suffix_config or '' - self.state['profile_run_suffix'] = run_suffix - - # Initialize pool size from policy if not already in state - if self.state.get('profile_pool_size', 0) == 0: - self.state['profile_pool_size'] = pm_policy.get('initial_pool_size', 1) - - max_reqs = pm_policy.get('max_requests_per_profile') - sleep_mins = pm_policy.get('sleep_minutes_on_exhaustion') - - # Loop until a profile is found or we decide we can't find one - while True: - # Try to find an existing, available profile - for i in range(self.state['profile_pool_size']): - profile_name = f"{prefix}_{run_suffix}_{i}" if run_suffix else f"{prefix}_{i}" - - # Check if sleeping - sleep_until = self.state['profile_cooldown_sleep_until'].get(profile_name, 0) - if now < sleep_until: - continue # Still sleeping - - # Check if it needs to be put to sleep - req_count = self.state['profile_cooldown_counts'].get(profile_name, 0) - if max_reqs and req_count >= max_reqs: - sleep_duration_seconds = (sleep_mins or 0) * 60 - self.state['profile_cooldown_sleep_until'][profile_name] = now + sleep_duration_seconds - self.state['profile_cooldown_counts'][profile_name] = 0 # Reset count for next time - logger.info(f"Profile '{profile_name}' reached request limit ({req_count}/{max_reqs}). Putting to sleep for {sleep_mins} minutes.") - continue # Now sleeping, try next profile - - # This profile is available - logger.info(f"Selected available profile '{profile_name}' (request count: {req_count}/{max_reqs if max_reqs else 'unlimited'}).") - return profile_name - - # If we get here, no existing profile was available - if pm_policy.get('auto_expand_pool'): - new_profile_index = self.state['profile_pool_size'] - self.state['profile_pool_size'] += 1 - profile_name = f"{prefix}_{run_suffix}_{new_profile_index}" if run_suffix else f"{prefix}_{new_profile_index}" - logger.info(f"Profile pool exhausted. Expanding pool to size {self.state['profile_pool_size']}. New profile: '{profile_name}'") - return profile_name - else: - # No available profiles and pool expansion is disabled - return None - - def get_or_rotate_worker_profile(self, worker_id, policy): - """ - Gets the current profile for a worker, rotating to a new generation if the lifetime limit is met. - This is used by the 'per_worker_with_rotation' profile mode. - """ - with self.lock: - pm_policy = policy.get('settings', {}).get('profile_management', {}) - if not pm_policy: - logger.error("Profile mode 'per_worker_with_rotation' requires 'settings.profile_management' configuration in the policy.") - return f"error_profile_{worker_id}" - - prefix = pm_policy.get('prefix') - if not prefix: - logger.error("Profile management for 'per_worker_with_rotation' requires a 'prefix'.") - return f"error_profile_{worker_id}" - - max_reqs = pm_policy.get('max_requests_per_profile') - - generations = self.state.setdefault('worker_profile_generations', {}) - # worker_id is an int, but JSON keys must be strings - worker_id_str = str(worker_id) - current_gen = generations.get(worker_id_str, 0) - - profile_name = f"{prefix}_{worker_id}_{current_gen}" - - if not max_reqs: # No lifetime limit defined, so never rotate. - return profile_name - - req_count = self.state.get('profile_cooldown_counts', {}).get(profile_name, 0) - - if req_count >= max_reqs: - logger.info(f"Profile '{profile_name}' reached lifetime request limit ({req_count}/{max_reqs}). Rotating to new generation for worker {worker_id}.") - new_gen = current_gen + 1 - generations[worker_id_str] = new_gen - # The request counts for the old profile are implicitly left behind. - # The new profile will start with a count of 0. - profile_name = f"{prefix}_{worker_id}_{new_gen}" - - return profile_name - - def record_profile_request(self, profile_name): - """Increments the request counter for a profile for the cooldown policy.""" - with self.lock: - if not profile_name: - return - counts = self.state.setdefault('profile_cooldown_counts', {}) - counts[profile_name] = counts.get(profile_name, 0) + 1 - - def record_proxy_usage(self, proxy_url): - """Records a request timestamp for a given proxy URL for statistical purposes.""" - if not proxy_url: - return - with self.lock: - now = time.time() - # Use a prefix to avoid collisions with profile names or other keys - tracker_key = f"proxy_{proxy_url}" - self.state['rate_limit_trackers'].setdefault(tracker_key, []).append(now) - - def check_and_update_download_rate_limit(self, proxy_url, policy): - """Checks download rate limits. Returns True if allowed, False otherwise.""" - with self.lock: - now = time.time() - d_policy = policy.get('download_policy', {}) - rate_limits = d_policy.get('rate_limits', {}) - - # Check per-IP limit - ip_limit = rate_limits.get('per_ip') - if ip_limit: - tracker_key = 'download_per_ip' # Use a distinct key - max_req = ip_limit.get('max_requests') - period_min = ip_limit.get('per_minutes') - if max_req and period_min: - timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) - timestamps = [ts for ts in timestamps if now - ts < period_min * 60] - if len(timestamps) >= max_req: - logger.warning("Per-IP download rate limit reached. Skipping task.") - return False - self.state['rate_limit_trackers'][tracker_key] = timestamps - - # Check per-proxy limit - proxy_limit = rate_limits.get('per_proxy') - if proxy_limit and proxy_url: - tracker_key = f"download_proxy_{proxy_url}" - max_req = proxy_limit.get('max_requests') - period_min = proxy_limit.get('per_minutes') - if max_req and period_min: - timestamps = self.state['rate_limit_trackers'].get(tracker_key, []) - timestamps = [ts for ts in timestamps if now - ts < period_min * 60] - if len(timestamps) >= max_req: - logger.warning(f"Per-proxy download rate limit for '{proxy_url}' reached. Skipping task.") - return False - self.state['rate_limit_trackers'][tracker_key] = timestamps - - # If all checks pass, record the new request timestamp for all relevant trackers - if ip_limit and ip_limit.get('max_requests'): - self.state['rate_limit_trackers'].setdefault('download_per_ip', []).append(now) - if proxy_limit and proxy_limit.get('max_requests') and proxy_url: - self.state['rate_limit_trackers'].setdefault(f"download_proxy_{proxy_url}", []).append(now) - - return True - - def wait_for_proxy_cooldown(self, proxy_url, policy): - """If a per-proxy sleep is defined, wait until the cooldown period has passed.""" - with self.lock: - d_policy = policy.get('download_policy', {}) - sleep_duration = d_policy.get('sleep_per_proxy_seconds', 0) - if not proxy_url or not sleep_duration > 0: - return - - last_finish = self.state.setdefault('proxy_last_finish_time', {}).get(proxy_url, 0) - elapsed = time.time() - last_finish - - if elapsed < sleep_duration: - time_to_sleep = sleep_duration - elapsed - logger.info(f"Proxy '{proxy_url}' was used recently. Sleeping for {time_to_sleep:.2f}s.") - # Interruptible sleep - sleep_end_time = time.time() + time_to_sleep - while time.time() < sleep_end_time: - if shutdown_event.is_set(): - logger.info("Shutdown requested during proxy cooldown sleep.") - break - time.sleep(0.2) - - def update_proxy_finish_time(self, proxy_url): - """Updates the last finish time for a proxy.""" - with self.lock: - if not proxy_url: - return - self.state.setdefault('proxy_last_finish_time', {})[proxy_url] = time.time() - - def print_summary(self, policy=None): - """Print a summary of the test run.""" - with self.lock: - # --- Cumulative Stats from State --- - now = time.time() - rate_trackers = self.state.get('rate_limit_trackers', {}) - if rate_trackers: - logger.info("\n--- Cumulative Rate Summary (All Runs, updated at end of run) ---") - logger.info("This shows the total number of requests/downloads over various time windows, including previous runs.") - - fetch_trackers = {k: v for k, v in rate_trackers.items() if not k.startswith('download_')} - download_trackers = {k: v for k, v in rate_trackers.items() if k.startswith('download_')} - - def print_tracker_stats(trackers, tracker_type): - if not trackers: - logger.info(f"No historical {tracker_type} trackers found.") - return - - logger.info(f"Historical {tracker_type} Trackers:") - for key, timestamps in sorted(trackers.items()): - windows = { - 'last 10 min': 600, 'last 60 min': 3600, - 'last 6 hours': 21600, 'last 24 hours': 86400 - } - rates_str_parts = [] - for name, seconds in windows.items(): - count = sum(1 for ts in timestamps if now - ts <= seconds) - rate_rpm = (count / seconds) * 60 if seconds > 0 else 0 - rates_str_parts.append(f"{count} in {name} ({rate_rpm:.2f}/min)") - - # Clean up key for display - display_key = key.replace('download_', '').replace('per_ip', 'all_proxies/ips') - logger.info(f" - Tracker '{display_key}': " + ", ".join(rates_str_parts)) - - print_tracker_stats(fetch_trackers, "Fetch Request") - print_tracker_stats(download_trackers, "Download Attempt") - - if not self.events: - logger.info("\nNo new events were recorded in this session.") - return - - duration = time.time() - self.start_time - fetch_events = [e for e in self.events if e.get('type') == 'fetch'] - download_events = [e for e in self.events if e.get('type') != 'fetch'] - - logger.info("\n--- Test Summary (This Run) ---") - logger.info(f"Total duration: {duration:.2f} seconds") - logger.info(f"Total info.json requests (cumulative): {self.get_request_count()}") - - if policy: - logger.info("\n--- Test Configuration ---") - settings = policy.get('settings', {}) - d_policy = policy.get('download_policy', {}) - - if settings.get('urls_file'): - logger.info(f"URL source file: {settings['urls_file']}") - if settings.get('info_json_dir'): - logger.info(f"Info.json source dir: {settings['info_json_dir']}") - - if d_policy: - logger.info(f"Download formats: {d_policy.get('formats', 'N/A')}") - if d_policy.get('downloader'): - logger.info(f"Downloader: {d_policy.get('downloader')}") - if d_policy.get('downloader_args'): - logger.info(f"Downloader args: {d_policy.get('downloader_args')}") - if d_policy.get('pause_before_download_seconds'): - logger.info(f"Pause before download: {d_policy.get('pause_before_download_seconds')}s") - if d_policy.get('sleep_between_formats'): - sleep_cfg = d_policy.get('sleep_between_formats') - logger.info(f"Sleep between formats: {sleep_cfg.get('min_seconds', 0)}-{sleep_cfg.get('max_seconds', 0)}s") - - if fetch_events: - total_fetches = len(fetch_events) - successful_fetches = sum(1 for e in fetch_events if e['success']) - cancelled_fetches = sum(1 for e in fetch_events if e.get('error_type') == 'Cancelled') - failed_fetches = total_fetches - successful_fetches - cancelled_fetches - - logger.info("\n--- Fetch Summary (This Run) ---") - logger.info(f"Total info.json fetch attempts: {total_fetches}") - logger.info(f" - Successful: {successful_fetches}") - logger.info(f" - Failed: {failed_fetches}") - if cancelled_fetches > 0: - logger.info(f" - Cancelled: {cancelled_fetches}") - - completed_fetches = successful_fetches + failed_fetches - if completed_fetches > 0: - success_rate = (successful_fetches / completed_fetches) * 100 - logger.info(f"Success rate (of completed): {success_rate:.2f}%") - elif total_fetches > 0: - logger.info("Success rate: N/A (no tasks completed)") - - if duration > 1 and total_fetches > 0: - rpm = (total_fetches / duration) * 60 - logger.info(f"Actual fetch rate: {rpm:.2f} requests/minute") - - if failed_fetches > 0: - error_counts = collections.Counter( - e.get('error_type', 'Unknown') - for e in fetch_events if not e['success'] and e.get('error_type') != 'Cancelled' - ) - logger.info("Failure breakdown:") - for error_type, count in sorted(error_counts.items()): - logger.info(f" - {error_type}: {count}") - - profile_counts = collections.Counter(e.get('profile') for e in fetch_events if e.get('profile')) - if profile_counts: - logger.info("Requests per profile:") - for profile, count in sorted(profile_counts.items()): - logger.info(f" - {profile}: {count}") - - proxy_counts = collections.Counter(e.get('proxy_url') for e in fetch_events if e.get('proxy_url')) - if proxy_counts: - logger.info("Requests per proxy:") - for proxy, count in sorted(proxy_counts.items()): - logger.info(f" - {proxy}: {count}") - - profile_counts = collections.Counter(e.get('profile') for e in fetch_events if e.get('profile')) - if profile_counts: - logger.info("Requests per profile:") - for profile, count in sorted(profile_counts.items()): - logger.info(f" - {profile}: {count}") - - proxy_counts = collections.Counter(e.get('proxy_url') for e in fetch_events if e.get('proxy_url')) - if proxy_counts: - logger.info("Requests per proxy:") - for proxy, count in sorted(proxy_counts.items()): - logger.info(f" - {proxy}: {count}") - - if download_events: - total_attempts = len(download_events) - successes = sum(1 for e in download_events if e['success']) - cancelled = sum(1 for e in download_events if e.get('error_type') == 'Cancelled') - failures = total_attempts - successes - cancelled - - # --- Profile Association for Download Events --- - download_profiles = [e.get('profile') for e in download_events] - - # For download_only mode, we might need to fall back to regex extraction - # if the profile wasn't passed down (e.g., no profile grouping). - profile_regex = None - if policy: - settings = policy.get('settings', {}) - if settings.get('mode') == 'download_only': - profile_regex = settings.get('profile_extraction_regex') - - if profile_regex: - for i, e in enumerate(download_events): - if not download_profiles[i]: # If profile wasn't set in the event - path = Path(e.get('path', '')) - match = re.search(profile_regex, path.name) - if match and match.groups(): - download_profiles[i] = match.group(1) - - # Replace any remaining Nones with 'unknown_profile' - download_profiles = [p or 'unknown_profile' for p in download_profiles] - - num_profiles_used = len(set(p for p in download_profiles if p != 'unknown_profile')) - - logger.info("\n--- Download Summary (This Run) ---") - if policy: - workers = policy.get('execution_control', {}).get('workers', 'N/A') - logger.info(f"Workers configured: {workers}") - - logger.info(f"Profiles utilized for downloads: {num_profiles_used}") - logger.info(f"Total download attempts: {total_attempts}") - logger.info(f" - Successful: {successes}") - logger.info(f" - Failed: {failures}") - if cancelled > 0: - logger.info(f" - Cancelled: {cancelled}") - - completed_downloads = successes + failures - if completed_downloads > 0: - success_rate = (successes / completed_downloads) * 100 - logger.info(f"Success rate (of completed): {success_rate:.2f}%") - elif total_attempts > 0: - logger.info("Success rate: N/A (no tasks completed)") - - duration_hours = duration / 3600.0 - if duration > 1 and total_attempts > 0: - dpm = (total_attempts / duration) * 60 - logger.info(f"Actual overall download rate: {dpm:.2f} attempts/minute") - - total_bytes = sum(e.get('downloaded_bytes', 0) for e in download_events if e['success']) - if total_bytes > 0: - logger.info(f"Total data downloaded: {format_size(total_bytes)}") - - if failures > 0: - error_counts = collections.Counter( - e.get('error_type', 'Unknown') - for e in download_events if not e['success'] and e.get('error_type') != 'Cancelled' - ) - logger.info("Failure breakdown:") - for error_type, count in sorted(error_counts.items()): - logger.info(f" - {error_type}: {count}") - - # Add profile to each download event for easier counting - for i, e in enumerate(download_events): - e['profile'] = download_profiles[i] - - profile_counts = collections.Counter(e.get('profile') for e in download_events if e.get('profile')) - if profile_counts: - logger.info("Downloads per profile:") - for profile, count in sorted(profile_counts.items()): - rate_per_hour = (count / duration_hours) if duration_hours > 0 else 0 - logger.info(f" - {profile}: {count} attempts (avg this run: {rate_per_hour:.2f}/hour)") - - proxy_counts = collections.Counter(e.get('proxy_url') for e in download_events if e.get('proxy_url')) - if proxy_counts: - logger.info("Downloads per proxy:") - for proxy, count in sorted(proxy_counts.items()): - rate_per_hour = (count / duration_hours) if duration_hours > 0 else 0 - logger.info(f" - {proxy}: {count} attempts (avg this run: {rate_per_hour:.2f}/hour)") - - logger.info("--------------------") - - -def _run_download_logic(source, info_json_content, policy, state_manager, profile_name=None): - """Shared download logic for a single info.json.""" - proxy_url = None - if info_json_content: - try: - info_data = json.loads(info_json_content) - proxy_url = info_data.get('_proxy_url') - except (json.JSONDecodeError, AttributeError): - logger.warning(f"[{get_display_name(source)}] Could not parse info.json to get proxy for download controls.") - - if not state_manager.check_and_update_download_rate_limit(proxy_url, policy): - return [] - - state_manager.wait_for_proxy_cooldown(proxy_url, policy) - results = process_info_json_cycle(source, info_json_content, policy, state_manager, proxy_url=proxy_url, profile_name=profile_name) - state_manager.update_proxy_finish_time(proxy_url) - return results - - -def process_profile_task(profile_name, file_list, policy, state_manager, cycle_num): - """Worker task for a profile, processing its files sequentially.""" - logger.info(f"Worker {get_worker_id()} starting task for profile '{profile_name}' with {len(file_list)} files.") - all_results = [] - for i, file_path in enumerate(file_list): - if shutdown_event.is_set(): - logger.info(f"Shutdown requested, stopping task for profile '{profile_name}'.") - break - - try: - with open(file_path, 'r', encoding='utf-8') as f: - info_json_content = f.read() - except (IOError, FileNotFoundError) as e: - logger.error(f"[{get_display_name(file_path)}] Could not read info.json file: {e}") - continue # Skip this file - - results_for_file = _run_download_logic(file_path, info_json_content, policy, state_manager, profile_name=profile_name) - all_results.extend(results_for_file) - - # Check for stop conditions after processing each file - should_stop_profile = False - for result in results_for_file: - if not result['success']: - s_conditions = policy.get('stop_conditions', {}) - if s_conditions.get('on_failure') or \ - (s_conditions.get('on_http_403') and result['error_type'] == 'HTTP 403') or \ - (s_conditions.get('on_timeout') and result['error_type'] == 'Timeout'): - logger.info(f"Stopping further processing for profile '{profile_name}' due to failure.") - should_stop_profile = True - break - if should_stop_profile: - break - - # Apply sleep between tasks for this profile - if i < len(file_list) - 1: - exec_control = policy.get('execution_control', {}) - sleep_cfg = exec_control.get('sleep_between_tasks', {}) - sleep_min = sleep_cfg.get('min_seconds', 0) - - if sleep_min > 0: - sleep_max = sleep_cfg.get('max_seconds') or sleep_min - sleep_duration = random.uniform(sleep_min, sleep_max) if sleep_max > sleep_min else sleep_min - - logger.debug(f"Profile '{profile_name}' sleeping for {sleep_duration:.2f}s before next file.") - # Interruptible sleep - sleep_end_time = time.time() + sleep_duration - while time.time() < sleep_end_time: - if shutdown_event.is_set(): - break - time.sleep(0.2) - - return all_results - - -def run_command(cmd, input_data=None, binary_stdout=False): - """ - Runs a command, captures its output, and returns status. - If binary_stdout is True, stdout is returned as bytes. Otherwise, both are decoded strings. - """ - logger.debug(f"Running command: {' '.join(cmd)}") - process = None - try: - # Always open in binary mode to handle both cases. We will decode later. - process = subprocess.Popen( - cmd, - stdin=subprocess.PIPE if input_data else None, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - preexec_fn=os.setsid # Start in a new process group to isolate from terminal signals - ) - with process_lock: - running_processes.add(process) - - stdout_capture = [] - stderr_capture = [] - - def read_pipe(pipe, capture_list, display_pipe=None): - """Reads a pipe line by line (as bytes), appending to a list and optionally displaying.""" - for line in iter(pipe.readline, b''): - capture_list.append(line) - if display_pipe: - # Decode for display - display_line = line.decode('utf-8', errors='replace') - display_pipe.write(display_line) - display_pipe.flush() - - # We must read stdout and stderr in parallel to prevent deadlocks. - stdout_thread = threading.Thread(target=read_pipe, args=(process.stdout, stdout_capture)) - # Display stderr in real-time as it often contains progress info. - stderr_thread = threading.Thread(target=read_pipe, args=(process.stderr, stderr_capture, sys.stderr)) - - stdout_thread.start() - stderr_thread.start() - - # Handle stdin after starting to read outputs to avoid deadlocks. - if input_data: - try: - process.stdin.write(input_data.encode('utf-8')) - process.stdin.close() - except (IOError, BrokenPipeError): - # This can happen if the process exits quickly or doesn't read stdin. - logger.debug(f"Could not write to stdin for command: {' '.join(cmd)}. Process may have already exited.") - - # Wait for the process to finish and for all output to be read. - retcode = process.wait() - stdout_thread.join() - stderr_thread.join() - - stdout_bytes = b"".join(stdout_capture) - stderr_bytes = b"".join(stderr_capture) - - stdout = stdout_bytes if binary_stdout else stdout_bytes.decode('utf-8', errors='replace') - stderr = stderr_bytes.decode('utf-8', errors='replace') - - return retcode, stdout, stderr - - except FileNotFoundError: - logger.error(f"Command not found: {cmd[0]}. Make sure it's in your PATH.") - return -1, "", f"Command not found: {cmd[0]}" - except Exception as e: - logger.error(f"An error occurred while running command: {' '.join(cmd)}. Error: {e}") - return -1, "", str(e) - finally: - if process: - with process_lock: - running_processes.discard(process) - - -def run_download_worker(info_json_path, info_json_content, format_to_download, policy, profile_name=None): - """ - Performs a single download attempt. Designed to be run in a worker thread. - """ - download_policy = policy.get('download_policy', {}) - settings = policy.get('settings', {}) - downloader = download_policy.get('downloader') - - # Get script command from settings, with fallback to download_policy for old format. - script_cmd_str = settings.get('download_script') - if not script_cmd_str: - script_cmd_str = download_policy.get('script') - - if script_cmd_str: - download_cmd = shlex.split(script_cmd_str) - elif downloader == 'aria2c_rpc': - download_cmd = [sys.executable, '-m', 'ytops_client.cli', 'download', 'aria-rpc'] - elif downloader == 'native-cli': - download_cmd = [sys.executable, '-m', 'ytops_client.cli', 'download', 'cli'] - else: - # Default to the new native-py downloader if downloader is 'native-py' or not specified. - download_cmd = [sys.executable, '-m', 'ytops_client.cli', 'download', 'py'] - - download_cmd.extend(['-f', format_to_download]) - - if downloader == 'aria2c_rpc': - if download_policy.get('aria_host'): - download_cmd.extend(['--aria-host', str(download_policy['aria_host'])]) - if download_policy.get('aria_port'): - download_cmd.extend(['--aria-port', str(download_policy['aria_port'])]) - if download_policy.get('aria_secret'): - download_cmd.extend(['--aria-secret', str(download_policy['aria_secret'])]) - if download_policy.get('output_dir'): - download_cmd.extend(['--output-dir', str(download_policy['output_dir'])]) - if download_policy.get('aria_remote_dir'): - download_cmd.extend(['--remote-dir', str(download_policy['aria_remote_dir'])]) - if download_policy.get('aria_fragments_dir'): - download_cmd.extend(['--fragments-dir', str(download_policy['aria_fragments_dir'])]) - # For stress testing, waiting is the desired default to get a success/fail result. - # Allow disabling it by explicitly setting aria_wait: false in the policy. - if download_policy.get('aria_wait', True): - download_cmd.append('--wait') - - if download_policy.get('auto_merge_fragments'): - download_cmd.append('--auto-merge-fragments') - if download_policy.get('remove_fragments_after_merge'): - download_cmd.append('--remove-fragments-after-merge') - if download_policy.get('cleanup'): - download_cmd.append('--cleanup') - if download_policy.get('purge_on_complete'): - download_cmd.append('--purge-on-complete') - - downloader_args = download_policy.get('downloader_args') - proxy = download_policy.get('proxy') - if proxy: - # Note: proxy_rename is not supported for aria2c_rpc mode. - proxy_arg = f"--all-proxy {shlex.quote(str(proxy))}" - if downloader_args: - downloader_args = f"{downloader_args} {proxy_arg}" - else: - downloader_args = proxy_arg - - if downloader_args: - # For aria2c_rpc, the downloader_args value is passed directly to the script's --downloader-args option. - download_cmd.extend(['--downloader-args', downloader_args]) - elif downloader == 'native-cli': - # This is the logic for the legacy download_tool.py (yt-dlp CLI wrapper). - pause_seconds = download_policy.get('pause_before_download_seconds') - if pause_seconds and isinstance(pause_seconds, (int, float)) and pause_seconds > 0: - download_cmd.extend(['--pause', str(pause_seconds)]) - - if download_policy.get('continue_downloads'): - download_cmd.append('--download-continue') - - # Add proxy if specified directly in the policy - proxy = download_policy.get('proxy') - if proxy: - download_cmd.extend(['--proxy', str(proxy)]) - - proxy_rename = download_policy.get('proxy_rename') - if proxy_rename: - download_cmd.extend(['--proxy-rename', str(proxy_rename)]) - - extra_args = download_policy.get('extra_args') - if extra_args: - download_cmd.extend(shlex.split(extra_args)) - - # Note: 'downloader' here refers to yt-dlp's internal downloader, not our script. - # The policy key 'external_downloader' is more clear, but we support 'downloader' for backward compatibility. - ext_downloader = download_policy.get('external_downloader') or download_policy.get('downloader') - if ext_downloader and ext_downloader not in ['native-cli', 'native-py', 'aria2c_rpc']: - download_cmd.extend(['--downloader', str(ext_downloader)]) - - downloader_args = download_policy.get('downloader_args') - if downloader_args: - download_cmd.extend(['--downloader-args', str(downloader_args)]) - - if download_policy.get('merge_output_format'): - download_cmd.extend(['--merge-output-format', str(download_policy['merge_output_format'])]) - - if download_policy.get('merge_output_format'): - download_cmd.extend(['--merge-output-format', str(download_policy['merge_output_format'])]) - - if download_policy.get('cleanup'): - download_cmd.append('--cleanup') - else: - # This is the default logic for the new native-py downloader. - if download_policy.get('output_to_buffer'): - download_cmd.append('--output-buffer') - else: - # --output-dir is only relevant if not outputting to buffer. - if download_policy.get('output_dir'): - download_cmd.extend(['--output-dir', str(download_policy['output_dir'])]) - - if download_policy.get('temp_path'): - download_cmd.extend(['--temp-path', str(download_policy['temp_path'])]) - if download_policy.get('continue_downloads'): - download_cmd.append('--download-continue') - - pause_seconds = download_policy.get('pause_before_download_seconds') - if pause_seconds and isinstance(pause_seconds, (int, float)) and pause_seconds > 0: - download_cmd.extend(['--pause', str(pause_seconds)]) - - proxy = download_policy.get('proxy') - if proxy: - download_cmd.extend(['--proxy', str(proxy)]) - - proxy_rename = download_policy.get('proxy_rename') - if proxy_rename: - download_cmd.extend(['--proxy-rename', str(proxy_rename)]) - - # The 'extra_args' from the policy are for the download script itself, not for yt-dlp. - # We need to split them and add them to the command. - extra_args = download_policy.get('extra_args') - if extra_args: - download_cmd.extend(shlex.split(extra_args)) - - # Pass through downloader settings for yt-dlp to use - # e.g. to tell yt-dlp to use aria2c as its backend - ext_downloader = download_policy.get('external_downloader') - if ext_downloader: - download_cmd.extend(['--downloader', str(ext_downloader)]) - - downloader_args = download_policy.get('downloader_args') - if downloader_args: - download_cmd.extend(['--downloader-args', str(downloader_args)]) - - worker_id = get_worker_id() - display_name = get_display_name(info_json_path) - profile_log_part = f" [Profile: {profile_name}]" if profile_name else "" - log_prefix = f"[Worker {worker_id}]{profile_log_part} [{display_name} @ {format_to_download}]" - logger.info(f"{log_prefix} Kicking off download process...") - - temp_info_file_path = None - try: - if isinstance(info_json_path, Path) and info_json_path.exists(): - # The info.json is already in a file, pass its path directly. - download_cmd.extend(['--load-info-json', str(info_json_path)]) - else: - # The info.json content is in memory, so write it to a temporary file. - import tempfile - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json', encoding='utf-8') as temp_f: - temp_f.write(info_json_content) - temp_info_file_path = temp_f.name - download_cmd.extend(['--load-info-json', temp_info_file_path]) - - cmd_str_for_log = ' '.join(shlex.quote(s) for s in download_cmd) - logger.info(f"{log_prefix} Running download command: {cmd_str_for_log}") - output_to_buffer = download_policy.get('output_to_buffer', False) - retcode, stdout, stderr = run_command(download_cmd, binary_stdout=output_to_buffer) - finally: - if temp_info_file_path and os.path.exists(temp_info_file_path): - os.unlink(temp_info_file_path) - - is_403_error = "HTTP Error 403" in stderr - is_timeout_error = "Read timed out" in stderr - output_to_buffer = download_policy.get('output_to_buffer', False) - - result = { - 'type': 'download', - 'path': str(info_json_path), - 'format': format_to_download, - 'success': retcode == 0, - 'error_type': None, - 'details': '', - 'downloaded_bytes': 0, - 'profile': profile_name - } - - if retcode == 0: - details_str = "OK" - size_in_bytes = 0 - if output_to_buffer: - # The most accurate size is the length of the stdout buffer. - size_in_bytes = len(stdout) # stdout is bytes - details_str += f" (Buffered {format_size(size_in_bytes)})" - else: - size_match = re.search(r'\[download\]\s+100%\s+of\s+~?([0-9.]+)(B|KiB|MiB|GiB)', stderr) - if size_match: - value = float(size_match.group(1)) - unit = size_match.group(2) - multipliers = {"B": 1, "KiB": 1024, "MiB": 1024**2, "GiB": 1024**3} - size_in_bytes = int(value * multipliers.get(unit, 1)) - details_str += f" ({size_match.group(1)}{unit})" - - result['downloaded_bytes'] = size_in_bytes - result['details'] = details_str - else: - # Check both stdout and stderr for error messages, as logging might be directed to stdout. - full_output = f"{stdout}\n{stderr}" - error_lines = [line for line in full_output.strip().split('\n') if 'ERROR:' in line] - result['details'] = error_lines[-1].strip() if error_lines else "Unknown error" - - if is_403_error: - result['error_type'] = 'HTTP 403' - elif is_timeout_error: - result['error_type'] = 'Timeout' - else: - result['error_type'] = f'Exit Code {retcode}' - - return result - - -def process_info_json_cycle(path, content, policy, state_manager, proxy_url=None, profile_name=None): - """ - Processes one info.json file for one cycle, downloading selected formats. - """ - results = [] - display_name = get_display_name(path) - d_policy = policy.get('download_policy', {}) - s_conditions = policy.get('stop_conditions', {}) - format_selection = d_policy.get('formats', '') - - try: - info_data = json.loads(content) - available_formats = [f['format_id'] for f in info_data.get('formats', [])] - if not available_formats: - logger.warning(f"[{display_name}] No formats found in info.json. Skipping.") - return [] - - formats_to_test = [] - if format_selection == 'all': - formats_to_test = available_formats - elif format_selection.startswith('random:'): - percent = float(format_selection.split(':')[1].rstrip('%')) - count = max(1, int(len(available_formats) * (percent / 100.0))) - formats_to_test = random.sample(available_formats, k=count) - elif format_selection.startswith('random_from:'): - choices = [f.strip() for f in format_selection.split(':', 1)[1].split(',')] - valid_choices = [f for f in choices if f in available_formats] - if valid_choices: - formats_to_test = [random.choice(valid_choices)] - else: - requested_formats = [f.strip() for f in format_selection.split(',') if f.strip()] - formats_to_test = [] - for req_fmt in requested_formats: - # If it's a complex selector with slashes, don't try to validate it against available formats. - if '/' in req_fmt: - formats_to_test.append(req_fmt) - continue - - # Check for exact match first - if req_fmt in available_formats: - formats_to_test.append(req_fmt) - continue - - # If no exact match, check for formats that start with this ID + '-' - # e.g., req_fmt '140' should match '140-0' - prefix_match = f"{req_fmt}-" - first_match = next((af for af in available_formats if af.startswith(prefix_match)), None) - - if first_match: - logger.info(f"[{display_name}] Requested format '{req_fmt}' not found. Using first available match: '{first_match}'.") - formats_to_test.append(first_match) - else: - # This could be a complex selector like 'bestvideo' or '299/298', so keep it. - if req_fmt not in available_formats: - logger.warning(f"[{display_name}] Requested format '{req_fmt}' not found in available formats.") - formats_to_test.append(req_fmt) - - except json.JSONDecodeError: - logger.error(f"[{display_name}] Failed to parse info.json. Skipping.") - return [] - - for i, format_id in enumerate(formats_to_test): - if shutdown_event.is_set(): - logger.info(f"Shutdown requested, stopping further format tests for {display_name}.") - break - - # Check if the format URL is expired before attempting to download - format_details = next((f for f in info_data.get('formats', []) if f.get('format_id') == format_id), None) - if format_details and 'url' in format_details: - parsed_url = urlparse(format_details['url']) - query_params = parse_qs(parsed_url.query) - expire_ts_str = query_params.get('expire', [None])[0] - if expire_ts_str and expire_ts_str.isdigit(): - expire_ts = int(expire_ts_str) - if expire_ts < time.time(): - logger.warning(f"[{display_name}] Skipping format '{format_id}' because its URL is expired.") - result = { - 'type': 'download', 'path': str(path), 'format': format_id, - 'success': True, 'error_type': 'Skipped', - 'details': 'Download URL is expired', 'downloaded_bytes': 0 - } - if proxy_url: - result['proxy_url'] = proxy_url - state_manager.log_event(result) - results.append(result) - continue # Move to the next format - - result = run_download_worker(path, content, format_id, policy, profile_name=profile_name) - if proxy_url: - result['proxy_url'] = proxy_url - state_manager.log_event(result) - results.append(result) - - worker_id = get_worker_id() - status = "SUCCESS" if result['success'] else f"FAILURE ({result['error_type']})" - profile_log_part = f" [Profile: {profile_name}]" if profile_name else "" - logger.info(f"[Worker {worker_id}]{profile_log_part} Result for {display_name} (format {format_id}): {status} - {result.get('details', 'OK')}") - - if not result['success']: - if s_conditions.get('on_failure') or \ - (s_conditions.get('on_http_403') and result['error_type'] == 'HTTP 403') or \ - (s_conditions.get('on_timeout') and result['error_type'] == 'Timeout'): - logger.info(f"Stopping further format tests for {display_name} in this cycle due to failure.") - break - - sleep_cfg = d_policy.get('sleep_between_formats', {}) - sleep_min = sleep_cfg.get('min_seconds', 0) - if sleep_min > 0 and i < len(formats_to_test) - 1: - sleep_max = sleep_cfg.get('max_seconds') or sleep_min - if sleep_max > sleep_min: - sleep_duration = random.uniform(sleep_min, sleep_max) - else: - sleep_duration = sleep_min - - logger.debug(f"Sleeping for {sleep_duration:.2f}s between formats for {display_name}.") - # Interruptible sleep - sleep_end_time = time.time() + sleep_duration - while time.time() < sleep_end_time: - if shutdown_event.is_set(): - break - time.sleep(0.2) - - return results - - -def update_dict(d, u): - """Recursively update a dictionary.""" - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = update_dict(d.get(k, {}), v) - else: - d[k] = v - return d - - -def load_policy(policy_file, policy_name=None): - """Load a policy from a YAML file.""" - try: - with open(policy_file, 'r', encoding='utf-8') as f: - # If a policy name is given, look for that specific document - if policy_name: - docs = list(yaml.safe_load_all(f)) - for doc in docs: - if isinstance(doc, dict) and doc.get('name') == policy_name: - return doc - raise ValueError(f"Policy '{policy_name}' not found in {policy_file}") - # Otherwise, load the first document - return yaml.safe_load(f) - except (IOError, yaml.YAMLError, ValueError) as e: - logger.error(f"Failed to load policy file {policy_file}: {e}") - sys.exit(1) - - -def apply_overrides(policy, overrides): - """Apply command-line overrides to the policy.""" - for override in overrides: - try: - key, value = override.split('=', 1) - keys = key.split('.') - - # Try to parse as JSON/YAML if it looks like a list or dict, otherwise treat as scalar - if (value.startswith('[') and value.endswith(']')) or \ - (value.startswith('{') and value.endswith('}')): - try: - value = yaml.safe_load(value) - except yaml.YAMLError: - logger.warning(f"Could not parse override value '{value}' as YAML. Treating as a string.") - else: - # Try to auto-convert scalar value type - if value.lower() == 'true': - value = True - elif value.lower() == 'false': - value = False - elif value.lower() == 'null': - value = None - else: - try: - value = int(value) - except ValueError: - try: - value = float(value) - except ValueError: - pass # Keep as string - - d = policy - for k in keys[:-1]: - d = d.setdefault(k, {}) - d[keys[-1]] = value - except ValueError: - logger.error(f"Invalid override format: '{override}'. Use 'key.subkey=value'.") - sys.exit(1) - return policy - - -def display_effective_policy(policy, name, sources=None, profile_names=None, original_workers_setting=None): - """Prints a human-readable summary of the effective policy.""" - logger.info(f"--- Effective Policy: {name} ---") - settings = policy.get('settings', {}) - exec_control = policy.get('execution_control', {}) - - logger.info(f"Mode: {settings.get('mode', 'full_stack')}") - if profile_names: - num_profiles = len(profile_names) - logger.info(f"Profiles found: {num_profiles}") - if num_profiles > 0: - # Sort profiles for consistent display, show top 10 - sorted_profiles = sorted(profile_names) - profiles_to_show = sorted_profiles[:10] - logger.info(f" (e.g., {', '.join(profiles_to_show)}{'...' if num_profiles > 10 else ''})") - - workers_display = str(exec_control.get('workers', 1)) - if original_workers_setting == 'auto': - workers_display = f"auto (calculated: {workers_display})" - logger.info(f"Workers: {workers_display}") - - sleep_cfg = exec_control.get('sleep_between_tasks', {}) - sleep_min = sleep_cfg.get('min_seconds') - if sleep_min is not None: - sleep_max = sleep_cfg.get('max_seconds') or sleep_min - if sleep_max > sleep_min: - logger.info(f"Sleep between tasks (per worker): {sleep_min}-{sleep_max}s (random)") - else: - logger.info(f"Sleep between tasks (per worker): {sleep_min}s") - - run_until = exec_control.get('run_until', {}) - run_conditions = [] - if 'minutes' in run_until: - run_conditions.append(f"for {run_until['minutes']} minutes") - if 'requests' in run_until: - run_conditions.append(f"until {run_until['requests']} total requests") - if 'cycles' in run_until: - run_conditions.append(f"for {run_until['cycles']} cycles") - - if run_conditions: - logger.info(f"Run condition: Stop after running {' or '.join(run_conditions)}.") - if 'minutes' in run_until and 'cycles' not in run_until: - logger.info("Will continuously cycle through sources until time limit is reached.") - else: - logger.warning("WARNING: No 'run_until' condition is set. This test will run forever unless stopped manually.") - logger.info("Run condition: No stop condition defined, will run indefinitely (until Ctrl+C).") - - # --- Rate Calculation --- - if sources: - workers = exec_control.get('workers', 1) - num_sources = len(profile_names) if profile_names else len(sources) - - min_sleep = sleep_cfg.get('min_seconds', 0) - max_sleep = sleep_cfg.get('max_seconds') or min_sleep - avg_sleep_per_task = (min_sleep + max_sleep) / 2 - - # Assume an average task duration. This is a major assumption. - mode = settings.get('mode', 'full_stack') - assumptions = exec_control.get('assumptions', {}) - - assumed_fetch_duration = 0 - if mode in ['full_stack', 'fetch_only']: - assumed_fetch_duration = assumptions.get('fetch_task_duration', 12 if mode == 'full_stack' else 3) - - assumed_download_duration = 0 - if mode in ['full_stack', 'download_only']: - # This assumes the total time to download all formats for a single source. - assumed_download_duration = assumptions.get('download_task_duration', 60) - - total_assumed_task_duration = assumed_fetch_duration + assumed_download_duration - - if workers > 0 and total_assumed_task_duration > 0: - total_time_per_task = total_assumed_task_duration + avg_sleep_per_task - tasks_per_minute_per_worker = 60 / total_time_per_task - total_tasks_per_minute = tasks_per_minute_per_worker * workers - - logger.info("--- Rate Estimation ---") - logger.info(f"Source count: {num_sources}") - if mode in ['full_stack', 'fetch_only']: - logger.info(f"Est. fetch time per source: {assumed_fetch_duration}s (override via execution_control.assumptions.fetch_task_duration)") - if mode in ['full_stack', 'download_only']: - logger.info(f"Est. download time per source: {assumed_download_duration}s (override via execution_control.assumptions.download_task_duration)") - logger.info(" (Note: This assumes total time for all formats per source)") - - logger.info(f"Est. sleep per task: {avg_sleep_per_task:.1f}s") - logger.info(f"==> Expected task rate: ~{total_tasks_per_minute:.2f} tasks/minute ({workers} workers * {tasks_per_minute_per_worker:.2f} tasks/min/worker)") - - target_rate_cfg = exec_control.get('target_rate', {}) - target_reqs = target_rate_cfg.get('requests') - target_mins = target_rate_cfg.get('per_minutes') - if target_reqs and target_mins: - target_rpm = target_reqs / target_mins - logger.info(f"Target rate: {target_rpm:.2f} tasks/minute") - if total_tasks_per_minute < target_rpm * 0.8: - logger.warning("Warning: Expected rate is significantly lower than target rate.") - logger.warning("Consider increasing workers, reducing sleep, or checking task performance.") - - logger.info("---------------------------------") - time.sleep(2) # Give user time to read - - -def add_stress_policy_parser(subparsers): - """Add the parser for the 'stress-policy' command.""" - parser = subparsers.add_parser( - 'stress-policy', - description="The primary, policy-driven stress-testing orchestrator.\nIt runs complex, multi-stage stress tests based on a YAML policy file.\nUse '--list-policies' to see available pre-configured scenarios.\n\nModes supported:\n- full_stack: Generate info.json and then download from it.\n- fetch_only: Only generate info.json files.\n- download_only: Only download from existing info.json files.", - formatter_class=argparse.RawTextHelpFormatter, - help='Run advanced, policy-driven stress tests (recommended).', - epilog=""" -Examples: - -1. Fetch info.jsons for a TV client with a single profile and a rate limit: - ytops-client stress-policy --policy policies/1_fetch_only_policies.yaml \\ - --policy-name tv_downgraded_single_profile \\ - --set settings.urls_file=my_urls.txt \\ - --set execution_control.run_until.minutes=30 - # This runs a 'fetch_only' test using the 'tv_downgraded' client. It uses a single, - # static profile for all requests and enforces a safety limit of 450 requests per hour. - -2. Fetch info.jsons for an Android client using cookies for authentication: - ytops-client stress-policy --policy policies/1_fetch_only_policies.yaml \\ - --policy-name android_sdkless_with_cookies \\ - --set settings.urls_file=my_urls.txt \\ - --set info_json_generation_policy.request_params.cookies_file_path=/path/to/my_cookies.txt - # This demonstrates an authenticated 'fetch_only' test. It passes the path to a - # Netscape cookie file, which the server will use for the requests. - -3. Download from a folder of info.jsons, grouped by profile, with auto-workers: - ytops-client stress-policy --policy policies/2_download_only_policies.yaml \\ - --policy-name basic_profile_aware_download \\ - --set settings.info_json_dir=/path/to/my/infojsons - # This runs a 'download_only' test. It scans a directory, extracts profile names from - # the filenames (e.g., 'tv_user_1' from '...-VIDEOID-tv_user_1.json'), and groups - # them. 'workers=auto' sets the number of workers to the number of unique profiles found. - -4. Full-stack test with multiple workers and profile rotation: - ytops-client stress-policy --policy policies/3_full_stack_policies.yaml \\ - --policy-name tv_simply_profile_rotation \\ - --set settings.urls_file=my_urls.txt \\ - --set execution_control.workers=4 \\ - --set settings.profile_management.max_requests_per_profile=500 - # This runs a 'full_stack' test with 4 parallel workers. Each worker gets a unique - # profile (e.g., tv_simply_user_0_0, tv_simply_user_1_0, etc.). After a profile is - # used 500 times, it is retired, and a new "generation" is created (e.g., tv_simply_user_0_1). - -5. Full-stack authenticated test with a pool of profiles and corresponding cookie files: - ytops-client stress-policy --policy policies/3_full_stack_policies.yaml \\ - --policy-name mweb_multi_profile_with_cookies \\ - --set settings.urls_file=my_urls.txt \\ - --set settings.profile_management.cookie_files='["/path/c1.txt","/path/c2.txt"]' - # This runs a 'full_stack' test using a pool of profiles (e.g., mweb_user_0, mweb_user_1). - # It uses the 'cookie_files' list to assign a specific cookie file to each profile in the - # pool, enabling multi-account authenticated testing. Note the JSON/YAML list format for the override. - -6. Full-stack test submitting downloads to an aria2c RPC server: - ytops-client stress-policy --policy policies/3_full_stack_policies.yaml \\ - --policy-name tv_simply_profile_rotation_aria2c_rpc \\ - --set settings.urls_file=my_urls.txt \\ - --set download_policy.aria_host=192.168.1.100 \\ - --set download_policy.aria_port=6801 - # This runs a test where downloads are not performed by the worker itself, but are - # sent to a remote aria2c daemon. The policy specifies 'downloader: aria2c_rpc' - # and provides connection details. This is useful for offloading download traffic. - --------------------------------------------------------------------------------- -Overridable Policy Parameters via --set: - - Key Description - -------------------------------------- ------------------------------------------------ - [settings] - settings.mode Test mode: 'full_stack', 'fetch_only', or 'download_only'. - settings.urls_file Path to file with URLs/video IDs. - settings.info_json_dir Path to directory with existing info.json files. - settings.profile_extraction_regex For 'download_only' mode, a regex to extract profile names from info.json filenames. The first capture group is used as the profile name. E.g., '.*-(.*?).json'. This enables profile-aware sequential downloading. - settings.info_json_dir_sample_percent Randomly sample this %% of files from the directory (for 'once' scan mode). - settings.directory_scan_mode For 'download_only': 'once' (default) or 'continuous' to watch for new files. - settings.mark_processed_files For 'continuous' scan mode: if true, rename processed files to '*..processed' to avoid reprocessing. - settings.max_files_per_cycle For 'continuous' scan mode: max new files to process per cycle. - settings.sleep_if_no_new_files_seconds For 'continuous' scan mode: seconds to sleep if no new files are found (default: 10). - settings.profile_prefix (Legacy) Prefix for profile names (e.g., 'test_user'). - settings.profile_pool (Legacy) Size of the profile pool. - settings.profile_mode Profile strategy. 'per_request' (legacy), 'per_worker' (legacy), or 'per_worker_with_rotation' (requires profile_management). - settings.info_json_script Command to run the info.json generation script (e.g., 'bin/ytops-client get-info'). - settings.save_info_json_dir If set, save all successfully generated info.json files to this directory. - - [settings.profile_management] (New, preferred method for profile control) - profile_management.prefix Prefix for profile names (e.g., 'dyn_user'). - profile_management.suffix Suffix for profile names. Set to 'auto' for a timestamp, or provide a string. - profile_management.initial_pool_size The number of profiles to start with. - profile_management.auto_expand_pool If true, create new profiles when the initial pool is exhausted (all sleeping). - profile_management.max_requests_per_profile Max requests a profile can make before it must 'sleep'. - profile_management.sleep_minutes_on_exhaustion How many minutes a profile 'sleeps' after hitting its request limit. - profile_management.cookie_files A list of paths to cookie files. Used to assign a unique cookie file to each profile in a pool. - - [execution_control] - execution_control.workers Number of parallel worker threads. Set to "auto" to calculate from target_rate or number of profiles. - execution_control.auto_workers_max The maximum number of workers to use when 'workers' is 'auto' in profile-aware download mode (default: 8). - execution_control.target_rate.requests Target requests for 'auto' workers calculation. - execution_control.target_rate.per_minutes Period in minutes for target_rate. - execution_control.run_until.minutes Stop test after N minutes. Will continuously cycle through sources. - execution_control.run_until.cycles Stop test after N cycles. A cycle is one full pass through all sources. - execution_control.run_until.requests Stop test after N total info.json requests (cumulative across runs). - execution_control.sleep_between_tasks.min_seconds Min sleep time between tasks, per worker. - - [info_json_generation_policy] - info_json_generation_policy.client Client to use (e.g., 'mweb', 'tv_camoufox'). - info_json_generation_policy.auth_host Host for the auth/Thrift service. - info_json_generation_policy.auth_port Port for the auth/Thrift service. - info_json_generation_policy.assigned_proxy_url A specific proxy to use for a request, overriding the server's proxy pool. - info_json_generation_policy.proxy_rename Regex substitution for the assigned proxy URL (e.g., 's/old/new/'). - info_json_generation_policy.command_template A full command template for the info.json script. Overrides other keys. - info_json_generation_policy.rate_limits.per_ip.max_requests Max requests for the given time period from one IP. - info_json_generation_policy.rate_limits.per_ip.per_minutes Time period in minutes for the per_ip rate limit. - info_json_generation_policy.rate_limits.per_profile.max_requests Max requests for a single profile in a time period. - info_json_generation_policy.rate_limits.per_profile.per_minutes Time period in minutes for the per_profile rate limit. - info_json_generation_policy.client_rotation_policy.major_client The primary client to use for most requests. - info_json_generation_policy.client_rotation_policy.refresh_client The client to use periodically to refresh context. - info_json_generation_policy.client_rotation_policy.refresh_every.requests Trigger refresh client after N requests for a profile. - - [download_policy] - download_policy.formats Formats to download (e.g., '18,140', 'random:50%%'). - download_policy.downloader Orchestrator script to use: 'native-py' (default, Python lib), 'native-cli' (legacy CLI wrapper), or 'aria2c_rpc'. - download_policy.external_downloader For 'native-py' or default, the backend yt-dlp should use (e.g., 'aria2c', 'native'). - download_policy.downloader_args Arguments for the external_downloader. For yt-dlp, e.g., 'aria2c:-x 8'. - download_policy.merge_output_format Container to merge to (e.g., 'mkv'). Defaults to 'mp4' via cli.config. - download_policy.temp_path For 'native-py', path to a directory for temporary files (e.g., a RAM disk like /dev/shm). - download_policy.output_to_buffer For 'native-py', download to an in-memory buffer and pipe to stdout instead of saving to a file (true/false). Best for single-file formats. - download_policy.proxy Proxy for direct downloads (e.g., "socks5://127.0.0.1:1080"). - download_policy.proxy_rename Regex substitution for the proxy URL (e.g., 's/old/new/'). - download_policy.pause_before_download_seconds Pause for N seconds before starting each download attempt. - download_policy.continue_downloads Enable download continuation (true/false). - download_policy.cleanup After success: for native downloaders, rename and truncate file to 0 bytes; for 'aria2c_rpc', remove file(s) from filesystem. - download_policy.extra_args A string of extra arguments for the download script (e.g., "--limit-rate 5M"). - download_policy.sleep_per_proxy_seconds Cooldown in seconds between downloads on the same proxy. - download_policy.rate_limits.per_proxy.max_requests Max downloads for a single proxy in a time period. - download_policy.rate_limits.per_proxy.per_minutes Time period in minutes for the per_proxy download rate limit. - # For downloader: 'aria2c_rpc' - download_policy.aria_host Hostname of the aria2c RPC server. - download_policy.aria_port Port of the aria2c RPC server. - download_policy.aria_secret Secret token for the aria2c RPC server. - download_policy.aria_wait Wait for aria2c downloads to complete (true/false). - download_policy.cleanup Remove downloaded file(s) from the filesystem on success. Requires script access to the download directory. - download_policy.purge_on_complete On success, purge ALL completed/failed downloads from aria2c history. Use as a workaround for older aria2c versions where targeted removal fails. - download_policy.output_dir Output directory for downloads. - download_policy.aria_remote_dir The absolute download path on the remote aria2c host. - download_policy.aria_fragments_dir The local path to find fragments for merging (if different from output_dir). - download_policy.auto_merge_fragments For fragmented downloads, automatically merge parts after download (true/false). Requires aria_wait=true. - download_policy.remove_fragments_after_merge For fragmented downloads, delete fragment files after a successful merge (true/false). Requires auto_merge_fragments=true. - - [stop_conditions] - stop_conditions.on_failure Stop on any download failure (true/false). - stop_conditions.on_http_403 Stop on any HTTP 403 error (true/false). - stop_conditions.on_error_rate.max_errors Stop test if more than N errors (of any type) occur within the time period. - stop_conditions.on_error_rate.per_minutes Time period in minutes for the error rate calculation. - stop_conditions.on_cumulative_403.max_errors Stop test if more than N HTTP 403 errors occur within the time period. - stop_conditions.on_cumulative_403.per_minutes Time period in minutes for the cumulative 403 calculation. - stop_conditions.on_quality_degradation.trigger_if_missing_formats A format ID or comma-separated list of IDs. Triggers if any are missing. - stop_conditions.on_quality_degradation.max_triggers Stop test if quality degradation is detected N times. - stop_conditions.on_quality_degradation.per_minutes Time period in minutes for the quality degradation calculation. --------------------------------------------------------------------------------- -""" - ) - parser.add_argument('--policy', help='Path to the YAML policy file. Required unless --list-policies is used.') - parser.add_argument('--policy-name', help='Name of the policy to run from a multi-policy file (if it contains "---" separators).') - parser.add_argument('--list-policies', action='store_true', help='List all available policies from the default policies directory and exit.') - parser.add_argument('--show-overrides', action='store_true', help='Load the specified policy and print all its defined values as a single-line of --set arguments, then exit.') - parser.add_argument('--set', action='append', default=[], help="Override a policy setting using 'key.subkey=value' format.\n(e.g., --set execution_control.workers=5)") - - # Add a group for aria2c-specific overrides for clarity in --help - aria_group = parser.add_argument_group('Aria2c RPC Downloader Overrides', 'Shortcuts for common --set options for the aria2c_rpc downloader.') - aria_group.add_argument('--auto-merge-fragments', action=argparse.BooleanOptionalAction, default=None, help='Shortcut to enable/disable download_policy.auto_merge_fragments.') - aria_group.add_argument('--remove-fragments-after-merge', action=argparse.BooleanOptionalAction, default=None, help='Shortcut to enable/disable download_policy.remove_fragments_after_merge.') - aria_group.add_argument('--fragments-dir', help='Shortcut for --set download_policy.aria_fragments_dir=PATH.') - aria_group.add_argument('--remote-dir', help='Shortcut for --set download_policy.aria_remote_dir=PATH.') - aria_group.add_argument('--cleanup', action=argparse.BooleanOptionalAction, default=None, help='Shortcut to enable/disable download_policy.cleanup.') - - parser.add_argument('--verbose', action='store_true', help='Enable verbose output for the orchestrator and underlying scripts.') - parser.add_argument('--dry-run', action='store_true', help='Print the effective policy and exit without running the test.') - parser.add_argument('--disable-log-writing', action='store_true', help='Disable writing state, stats, and log files. By default, files are created for each run.') - return parser - - -def list_policies(): - """Scans the policies directory and prints a list of available policies.""" - script_dir = os.path.dirname(os.path.abspath(__file__)) - project_root = os.path.abspath(os.path.join(script_dir, '..')) - policies_dir = os.path.join(project_root, 'policies') - - if not os.path.isdir(policies_dir): - print(f"Error: Policies directory not found at '{policies_dir}'", file=sys.stderr) - return 1 - - print("Available Policies:") - print("=" * 20) - - policy_files = sorted(Path(policies_dir).glob('*.yaml')) - if not policy_files: - print("No policy files (.yaml) found.") - return 0 - - for policy_file in policy_files: - print(f"\n--- File: {policy_file.relative_to(project_root)} ---") - try: - with open(policy_file, 'r', encoding='utf-8') as f: - content = f.read() - - # Split into documents. The separator is a line that is exactly '---'. - documents = re.split(r'^\-\-\-$', content, flags=re.MULTILINE) - - found_any_in_file = False - for doc in documents: - doc = doc.strip() - if not doc: - continue - - lines = doc.split('\n') - policy_name = None - description_lines = [] - - # Find name and description - for i, line in enumerate(lines): - if line.strip().startswith('name:'): - policy_name = line.split(':', 1)[1].strip() - - # Look backwards for comments - j = i - 1 - current_desc_block = [] - while j >= 0 and lines[j].strip().startswith('#'): - comment = lines[j].strip().lstrip('#').strip() - current_desc_block.insert(0, comment) - j -= 1 - - if current_desc_block: - description_lines = current_desc_block - break - - if policy_name: - found_any_in_file = True - print(f" - Name: {policy_name}") - if description_lines: - # Heuristic to clean up "Policy: " prefix - if description_lines[0].lower().startswith('policy:'): - description_lines[0] = description_lines[0][len('policy:'):].strip() - - print(f" Description: {description_lines[0]}") - for desc_line in description_lines[1:]: - print(f" {desc_line}") - else: - print(" Description: (No description found)") - - relative_path = policy_file.relative_to(project_root) - print(f" Usage: --policy {relative_path} --policy-name {policy_name}") - - if not found_any_in_file: - print(" (No named policies found in this file)") - - except Exception as e: - print(f" Error parsing {policy_file.name}: {e}") - - return 0 - - def main_stress_policy(args): """Main logic for the 'stress-policy' command.""" if args.list_policies: - return list_policies() + return sp_utils.list_policies() if not args.policy: print("Error: --policy is required unless using --list-policies.", file=sys.stderr) @@ -1783,15 +134,65 @@ def main_stress_policy(args): # Handle --show-overrides early, as it doesn't run the test. if args.show_overrides: - policy = load_policy(args.policy, args.policy_name) + policy = sp_utils.load_policy(args.policy, args.policy_name) if not policy: return 1 # load_policy prints its own error - print_policy_overrides(policy) + sp_utils.print_policy_overrides(policy) return 0 - policy = load_policy(args.policy, args.policy_name) - policy = apply_overrides(policy, args.set) + policy = sp_utils.load_policy(args.policy, args.policy_name) + policy = sp_utils.apply_overrides(policy, args.set) + + # If orchestrator is verbose, make downloaders verbose too by passing it through. + if args.verbose: + d_policy = policy.setdefault('download_policy', {}) + extra_args = d_policy.get('extra_args', '') + if '--verbose' not in extra_args: + d_policy['extra_args'] = f"{extra_args} --verbose".strip() + + # --- Set safe defaults --- + settings = policy.get('settings', {}) + mode = settings.get('mode', 'full_stack') + # For continuous download mode, it is almost always desired to mark files as + # processed to avoid an infinite loop on the same files. We make this the + # default and issue a warning if it's not explicitly set. + if mode == 'download_only' and settings.get('directory_scan_mode') == 'continuous': + if 'mark_processed_files' not in settings: + # Use print because logger is not yet configured. + print("WARNING: In 'continuous' download mode, 'settings.mark_processed_files' was not set.", file=sys.stderr) + print(" Defaulting to 'true' to prevent reprocessing files.", file=sys.stderr) + print(" Set it to 'false' explicitly in your policy to disable this behavior.", file=sys.stderr) + settings['mark_processed_files'] = True + + # Load .env file *after* loading policy to respect env_file from policy. + if load_dotenv: + sim_params = policy.get('simulation_parameters', {}) + # Coalesce from CLI, then policy. An explicit CLI arg takes precedence. + env_file = args.env_file or sim_params.get('env_file') + + if not env_file and args.env and '.env' in args.env and os.path.exists(args.env): + # Use print because logger is not yet configured. + print(f"Warning: --env should be an environment name (e.g., 'sim'), not a file path. Treating '{args.env}' as --env-file. The environment name will default to 'sim'.", file=sys.stderr) + env_file = args.env + args.env = 'sim' + + was_loaded = load_dotenv(env_file) + if was_loaded: + # Use print because logger is not yet configured. + print(f"Loaded environment variables from {env_file or '.env file'}", file=sys.stderr) + elif args.env_file: # Only error if user explicitly passed it + print(f"Error: The specified --env-file was not found: {args.env_file}", file=sys.stderr) + return 1 + + if args.profile_prefix: + # This shortcut overrides the profile_prefix for all relevant stages. + # Useful for simple fetch_only or download_only runs. + policy.setdefault('info_json_generation_policy', {})['profile_prefix'] = args.profile_prefix + policy.setdefault('download_policy', {})['profile_prefix'] = args.profile_prefix + # Use print because logger is not yet configured. + print(f"Overriding profile_prefix for all stages with CLI arg: {args.profile_prefix}", file=sys.stderr) + # Apply direct CLI overrides after --set, so they have final precedence. if args.auto_merge_fragments is not None: policy.setdefault('download_policy', {})['auto_merge_fragments'] = args.auto_merge_fragments @@ -1804,6 +205,9 @@ def main_stress_policy(args): if args.cleanup is not None: policy.setdefault('download_policy', {})['cleanup'] = args.cleanup + if args.expire_time_shift_minutes is not None: + policy.setdefault('download_policy', {})['expire_time_shift_minutes'] = args.expire_time_shift_minutes + policy_name = policy.get('name', args.policy_name or Path(args.policy).stem) # --- Logging Setup --- @@ -1813,6 +217,9 @@ def main_stress_policy(args): root_logger = logging.getLogger() root_logger.setLevel(log_level) + + # Silence noisy loggers from dependencies like docker-py + logging.getLogger('urllib3.connectionpool').setLevel(logging.INFO if args.verbose else logging.WARNING) # Remove any existing handlers to avoid duplicate logs for handler in root_logger.handlers[:]: @@ -1824,10 +231,11 @@ def main_stress_policy(args): root_logger.addHandler(console_handler) if not args.disable_log_writing: - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') log_filename = f"stress-policy-{timestamp}-{policy_name}.log" try: - file_handler = logging.FileHandler(log_filename, encoding='utf-8') + # Open in append mode to be safe, though timestamp should be unique. + file_handler = logging.FileHandler(log_filename, mode='a', encoding='utf-8') file_handler.setFormatter(logging.Formatter(log_format, datefmt=date_format)) root_logger.addHandler(file_handler) # Use print because logger is just being set up. @@ -1835,7 +243,111 @@ def main_stress_policy(args): except IOError as e: print(f"Error: Could not open log file {log_filename}: {e}", file=sys.stderr) - state_manager = StateManager(policy_name, disable_log_writing=args.disable_log_writing) + state_manager = StateManager(policy_name, disable_log_writing=args.disable_log_writing, shutdown_event=shutdown_event) + + if args.reset_infojson: + info_json_dir = settings.get('info_json_dir') + if not info_json_dir: + logger.error("--reset-infojson requires 'settings.info_json_dir' to be set in the policy.") + return 1 + + logger.info(f"--- Resetting info.json files in '{info_json_dir}' ---") + source_dir = Path(info_json_dir) + if not source_dir.is_dir(): + logger.warning(f"Source directory for reset does not exist: {source_dir}. Skipping reset.") + else: + processed_files = list(source_dir.rglob('*.json.processed')) + locked_files = list(source_dir.rglob('*.json.LOCKED.*')) + files_to_reset = processed_files + locked_files + + if not files_to_reset: + logger.info("No processed or locked files found to reset.") + else: + reset_count = 0 + for file_to_reset in files_to_reset: + original_path = None + if file_to_reset.name.endswith('.processed'): + original_path_str = str(file_to_reset).removesuffix('.processed') + original_path = Path(original_path_str) + elif '.LOCKED.' in file_to_reset.name: + original_path_str = str(file_to_reset).split('.LOCKED.')[0] + original_path = Path(original_path_str) + + if original_path: + try: + if original_path.exists(): + logger.warning(f"Original file '{original_path.name}' already exists. Deleting '{file_to_reset.name}' instead of renaming.") + file_to_reset.unlink() + else: + file_to_reset.rename(original_path) + logger.debug(f"Reset '{file_to_reset.name}' to '{original_path.name}'") + reset_count += 1 + except (IOError, OSError) as e: + logger.error(f"Failed to reset '{file_to_reset.name}': {e}") + logger.info(f"Reset {reset_count} info.json file(s).") + + if args.pre_cleanup_media is not None: + cleanup_path_str = args.pre_cleanup_media + d_policy = policy.get('download_policy', {}) + direct_docker_policy = policy.get('direct_docker_cli_policy', {}) + + if cleanup_path_str == '.': # Special value from `const` + # Determine path from policy + if direct_docker_policy.get('docker_host_download_path'): + cleanup_path_str = direct_docker_policy['docker_host_download_path'] + elif d_policy.get('output_dir'): + cleanup_path_str = d_policy['output_dir'] + else: + logger.error("--pre-cleanup-media was used without a path, but could not determine a download directory from the policy.") + return 1 + + cleanup_path = Path(cleanup_path_str) + if not cleanup_path.is_dir(): + logger.warning(f"Directory for media cleanup does not exist, skipping: {cleanup_path}") + else: + logger.info(f"--- Cleaning up media files in '{cleanup_path}' ---") + media_extensions = ['.mp4', '.m4a', '.webm', '.mkv', '.part', '.ytdl'] + files_deleted = 0 + for ext in media_extensions: + for media_file in cleanup_path.rglob(f'*{ext}'): + try: + media_file.unlink() + logger.debug(f"Deleted {media_file}") + files_deleted += 1 + except OSError as e: + logger.error(f"Failed to delete media file '{media_file}': {e}") + logger.info(f"Deleted {files_deleted} media file(s).") + + if args.reset_local_cache_folder is not None: + cache_path_str = args.reset_local_cache_folder + direct_docker_policy = policy.get('direct_docker_cli_policy', {}) + + if cache_path_str == '.': # Special value from `const` + if direct_docker_policy.get('docker_host_cache_path'): + cache_path_str = direct_docker_policy['docker_host_cache_path'] + else: + logger.error("--reset-local-cache-folder was used without a path, but 'direct_docker_cli_policy.docker_host_cache_path' is not set in the policy.") + return 1 + + cache_path = Path(cache_path_str) + if not cache_path.is_dir(): + logger.warning(f"Local cache directory for reset does not exist, skipping: {cache_path}") + else: + logger.info(f"--- Resetting local cache folder '{cache_path}' ---") + try: + shutil.rmtree(cache_path) + os.makedirs(cache_path) + logger.info(f"Successfully deleted and recreated cache folder '{cache_path}'.") + except OSError as e: + logger.error(f"Failed to reset cache folder '{cache_path}': {e}") + + if policy.get('name') in ['continuous_auth_simulation', 'continuous_download_simulation']: + logger.warning("This policy is part of a multi-stage simulation.") + if 'auth' in policy.get('name', ''): + logger.warning("It is recommended to run this auth policy using: ./bin/run-profile-simulation") + if 'download' in policy.get('name', ''): + logger.warning("It is recommended to run this download policy using: ./bin/run-download-simulation") + time.sleep(2) # --- Graceful shutdown handler --- def shutdown_handler(signum, frame): @@ -1846,20 +358,19 @@ def main_stress_policy(args): # Save state immediately to prevent loss on interrupt. logger.info("Attempting to save state before shutdown...") state_manager.close() - - # Kill running subprocesses to unblock workers + logger.info("Shutdown requested. Allowing in-progress tasks to complete. No new tasks will be started. Press Ctrl+C again to force exit.") + else: + logger.info("Second signal received, forcing exit.") + # On second signal, forcefully terminate subprocesses. with process_lock: if running_processes: - logger.info(f"Terminating {len(running_processes)} running subprocess(es)...") + logger.info(f"Forcefully terminating {len(running_processes)} running subprocess(es)...") for p in running_processes: try: # Kill the entire process group to ensure child processes (like yt-dlp) are terminated. os.killpg(os.getpgid(p.pid), signal.SIGKILL) except (ProcessLookupError, PermissionError): pass # Process already finished or we lack permissions - logger.info("Subprocesses terminated. Waiting for workers to finish. Press Ctrl+C again to force exit.") - else: - logger.info("Second signal received, forcing exit.") # Use os._exit for a hard exit that doesn't run cleanup handlers, # which can deadlock if locks are held. os._exit(1) @@ -1868,9 +379,362 @@ def main_stress_policy(args): signal.signal(signal.SIGTERM, shutdown_handler) settings = policy.get('settings', {}) - - # --- Load sources based on mode --- + exec_control = policy.get('execution_control', {}) mode = settings.get('mode', 'full_stack') + orchestration_mode = settings.get('orchestration_mode') + + # --- Profile Manager Setup for Locking Mode --- + profile_manager = None + profile_managers = {} + if settings.get('profile_mode') == 'from_pool_with_lock': + logger.info("--- Profile Locking Mode Enabled ---") + logger.info("This mode requires profiles to be set up and managed by the policy enforcer.") + logger.info("1. Ensure you have run: bin/setup-profiles-from-policy") + logger.info("2. Ensure the policy enforcer is running in the background: bin/ytops-client policy-enforcer --live") + logger.info(" (e.g. using policies/8_unified_simulation_enforcer.yaml)") + logger.info("3. To monitor profiles, use: bin/ytops-client profile list --live") + logger.info("------------------------------------") + + # Coalesce Redis settings from CLI args, .env file, and defaults + redis_host = args.redis_host or os.getenv('REDIS_HOST', os.getenv('MASTER_HOST_IP', 'localhost')) + redis_port = args.redis_port if args.redis_port is not None else int(os.getenv('REDIS_PORT', 6379)) + redis_password = args.redis_password or os.getenv('REDIS_PASSWORD') + + sim_params = policy.get('simulation_parameters', {}) + + def setup_manager(sim_type, env_cli_arg, env_policy_key): + # Determine the effective environment name with correct precedence: + # 1. Specific CLI arg (e.g., --auth-env) + # 2. General CLI arg (--env) + # 3. Specific policy setting (e.g., simulation_parameters.auth_env) + # 4. General policy setting (simulation_parameters.env) + # 5. Hardcoded default ('sim') + policy_env = sim_params.get(env_policy_key) + default_policy_env = sim_params.get('env') + effective_env = env_cli_arg or args.env or policy_env or default_policy_env or 'sim' + + logger.info(f"Setting up ProfileManager for {sim_type} simulation using env: '{effective_env}'") + + if args.key_prefix: + key_prefix = args.key_prefix + else: + key_prefix = f"{effective_env}_profile_mgmt_" + + return ProfileManager( + redis_host=redis_host, redis_port=redis_port, + redis_password=redis_password, key_prefix=key_prefix + ) + + # Determine which managers are needed based on mode and orchestration mode + needs_auth = False + needs_download = False + + if mode in ['full_stack', 'fetch_only']: + needs_auth = True + if mode in ['full_stack', 'download_only']: + needs_download = True + + if orchestration_mode == 'direct_batch_cli': + direct_policy = policy.get('direct_batch_cli_policy', {}) + use_env = direct_policy.get('use_profile_env', 'auth') + if use_env == 'download': + needs_download = True + else: # auth is default + needs_auth = True + + if needs_auth: + # For backward compatibility, policy might have 'env' instead of 'auth_env' + auth_env_key = 'auth_env' if 'auth_env' in sim_params else 'env' + profile_managers['auth'] = setup_manager('Auth', args.auth_env, auth_env_key) + + if needs_download: + download_env_key = 'download_env' if 'download_env' in sim_params else 'env' + profile_managers['download'] = setup_manager('Download', args.download_env, download_env_key) + + # For modes with only one manager, set the legacy `profile_manager` variable + # for components that haven't been updated to use the `profile_managers` dict. + if len(profile_managers) == 1: + profile_manager = list(profile_managers.values())[0] + + # --- Throughput Orchestration Mode --- + if orchestration_mode == 'throughput': + logger.info("--- Throughput Orchestration Mode Enabled ---") + if mode != 'download_only' or settings.get('profile_mode') != 'from_pool_with_lock': + logger.error("Orchestration mode 'throughput' is only compatible with 'download_only' mode and 'from_pool_with_lock' profile mode.") + return 1 + + download_manager = profile_managers.get('download') + if not download_manager: + logger.error("Throughput mode requires a download profile manager.") + return 1 + + original_workers_setting = exec_control.get('workers') + if original_workers_setting == 'auto': + d_policy = policy.get('download_policy', {}) + profile_prefix = d_policy.get('profile_prefix') + if not profile_prefix: + logger.error("Cannot calculate 'auto' workers for throughput mode without 'download_policy.profile_prefix'.") + return 1 + + all_profiles = download_manager.list_profiles() + matching_profiles = [p for p in all_profiles if p['name'].startswith(profile_prefix)] + calculated_workers = len(matching_profiles) + + if calculated_workers == 0: + logger.error(f"Cannot use 'auto' workers: No profiles found with prefix '{profile_prefix}'. Please run setup-profiles.") + return 1 + + exec_control['workers'] = calculated_workers + logger.info(f"Calculated 'auto' workers for throughput mode: {calculated_workers} (based on {len(matching_profiles)} profiles with prefix '{profile_prefix}').") + + sp_utils.display_effective_policy(policy, policy_name, sources=[], original_workers_setting=original_workers_setting) + if args.dry_run: return 0 + + workers = exec_control.get('workers', 1) + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [ + executor.submit(run_throughput_worker, i, policy, state_manager, args, download_manager, running_processes, process_lock) + for i in range(workers) + ] + # Wait for shutdown signal + shutdown_event.wait() + logger.info("Shutdown signal received, waiting for throughput workers to finish current tasks...") + # The workers will exit their loops upon seeing the shutdown_event. + # We don't need complex shutdown logic here; the main `finally` block will handle summary. + concurrent.futures.wait(futures) + + # In this mode, the main loop is handled by workers. So we return here. + state_manager.print_summary(policy) + state_manager.close() + return 0 + + # --- Direct Batch CLI Orchestration Mode --- + elif orchestration_mode == 'direct_batch_cli': + logger.info("--- Direct Batch CLI Orchestration Mode Enabled ---") + if mode != 'fetch_only' or settings.get('profile_mode') != 'from_pool_with_lock': + logger.error("Orchestration mode 'direct_batch_cli' is only compatible with 'fetch_only' mode and 'from_pool_with_lock' profile mode.") + return 1 + + direct_policy = policy.get('direct_batch_cli_policy', {}) + use_env = direct_policy.get('use_profile_env', 'auth') # Default to auth for backward compatibility + + profile_manager_instance = profile_managers.get(use_env) + if not profile_manager_instance: + logger.error(f"Direct batch CLI mode requires a '{use_env}' profile manager, but it was not configured.") + logger.error("Check 'simulation_parameters' in your policy and the 'mode' setting.") + return 1 + + urls_file = settings.get('urls_file') + if not urls_file: + logger.error("Direct batch CLI mode requires 'settings.urls_file'.") + return 1 + + try: + with open(urls_file, 'r', encoding='utf-8') as f: + urls_list = [line.strip() for line in f if line.strip()] + except IOError as e: + logger.error(f"Could not read urls_file '{urls_file}': {e}") + return 1 + + if not urls_list: + logger.error(f"URL file '{urls_file}' is empty. Nothing to do.") + return 1 + + # Handle starting from a specific index + start_index = 0 + if args.start_from_url_index is not None: + start_index = max(0, args.start_from_url_index - 1) + state_manager.update_last_url_index(start_index, force=True) + else: + start_index = state_manager.get_last_url_index() + + if start_index >= len(urls_list) and len(urls_list) > 0: + logger.warning("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + logger.warning("!!! ALL URLS HAVE BEEN PROCESSED IN PREVIOUS RUNS (based on state file) !!!") + logger.warning(f"!!! State file indicates start index {start_index + 1}, but URL file has only {len(urls_list)} URLs. !!!") + logger.warning("!!! Deleting state file and stopping. Please run the command again to start from the beginning. !!!") + logger.warning("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + if not args.dry_run and not args.disable_log_writing: + state_manager.close() # ensure it's closed before deleting + try: + os.remove(state_manager.state_file_path) + logger.info(f"Deleted state file: {state_manager.state_file_path}") + except OSError as e: + logger.error(f"Failed to delete state file: {e}") + else: + logger.info("[Dry Run] Would have deleted state file and stopped.") + + return 0 # Stop execution. + + if start_index > 0: + logger.info(f"Starting/resuming from URL index {start_index + 1}.") + # The worker's get_next_url_batch will respect this starting index. + + sp_utils.display_effective_policy(policy, policy_name, sources=urls_list) + if args.dry_run: return 0 + + workers = exec_control.get('workers', 1) + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [ + executor.submit(run_direct_batch_worker, i, policy, state_manager, args, profile_manager_instance, urls_list, running_processes, process_lock) + for i in range(workers) + ] + # Wait for all workers to complete. They will exit their loops when no URLs are left. + concurrent.futures.wait(futures) + if shutdown_event.is_set(): + logger.info("Shutdown signal received, workers have finished.") + + state_manager.print_summary(policy) + state_manager.close() + return 0 + + # --- Direct Docker CLI Orchestration Mode --- + elif orchestration_mode == 'direct_docker_cli': + logger.info("--- Direct Docker CLI Orchestration Mode Enabled ---") + if not docker: + logger.error("The 'direct_docker_cli' orchestration mode requires the Docker SDK for Python.") + logger.error("Please install it with: pip install docker") + return 1 + + if mode not in ['fetch_only', 'download_only'] or settings.get('profile_mode') != 'from_pool_with_lock': + logger.error("Orchestration mode 'direct_docker_cli' is only compatible with 'fetch_only' or 'download_only' modes and 'from_pool_with_lock' profile mode.") + return 1 + + direct_policy = policy.get('direct_docker_cli_policy', {}) + use_env = direct_policy.get('use_profile_env', 'auth' if mode == 'fetch_only' else 'download') + + profile_manager_instance = profile_managers.get(use_env) + if not profile_manager_instance: + logger.error(f"Direct docker CLI mode requires a '{use_env}' profile manager, but it was not configured.") + return 1 + + workers = exec_control.get('workers', 1) + + if mode == 'fetch_only': + urls_file = settings.get('urls_file') + if not urls_file: + logger.error("Direct docker CLI (fetch) mode requires 'settings.urls_file'.") + return 1 + + try: + with open(urls_file, 'r', encoding='utf-8') as f: + urls_list = [line.strip() for line in f if line.strip()] + except IOError as e: + logger.error(f"Could not read urls_file '{urls_file}': {e}") + return 1 + + if not urls_list: + logger.error(f"URL file '{urls_file}' is empty. Nothing to do.") + return 1 + + start_index = 0 + if args.start_from_url_index is not None: + start_index = max(0, args.start_from_url_index - 1) + state_manager.update_last_url_index(start_index, force=True) + else: + start_index = state_manager.get_last_url_index() + + if start_index >= len(urls_list) and len(urls_list) > 0: + logger.warning("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + logger.warning("!!! ALL URLS HAVE BEEN PROCESSED IN PREVIOUS RUNS (based on state file) !!!") + logger.warning(f"!!! State file indicates start index {start_index + 1}, but URL file has only {len(urls_list)} URLs. !!!") + logger.warning("!!! Deleting state file and stopping. Please run the command again to start from the beginning. !!!") + logger.warning("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + if not args.dry_run and not args.disable_log_writing: + state_manager.close() + try: + os.remove(state_manager.state_file_path) + logger.info(f"Deleted state file: {state_manager.state_file_path}") + except OSError as e: + logger.error(f"Failed to delete state file: {e}") + else: + logger.info("[Dry Run] Would have deleted state file and stopped.") + return 0 + + if start_index > 0: + logger.info(f"Starting/resuming from URL index {start_index + 1}.") + + sp_utils.display_effective_policy(policy, policy_name, sources=urls_list) + if args.dry_run: return 0 + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [ + executor.submit(run_direct_docker_worker, i, policy, state_manager, args, profile_manager_instance, urls_list, running_processes, process_lock) + for i in range(workers) + ] + concurrent.futures.wait(futures) + if shutdown_event.is_set(): + logger.info("Shutdown signal received, workers have finished.") + + elif mode == 'download_only': + info_json_dir = settings.get('info_json_dir') + if not info_json_dir: + logger.error("Direct docker CLI (download) mode requires 'settings.info_json_dir'.") + return 1 + try: + os.makedirs(info_json_dir, exist_ok=True) + except OSError as e: + logger.error(f"Failed to create info.json directory '{info_json_dir}': {e}") + return 1 + + sp_utils.display_effective_policy(policy, policy_name, sources=[]) + if args.dry_run: return 0 + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [ + executor.submit(run_direct_docker_download_worker, i, policy, state_manager, args, profile_manager_instance, running_processes, process_lock) + for i in range(workers) + ] + # This worker runs until shutdown + shutdown_event.wait() + logger.info("Shutdown signal received, waiting for direct docker download workers to finish...") + concurrent.futures.wait(futures) + + state_manager.print_summary(policy) + state_manager.close() + return 0 + + # --- Direct Download CLI Orchestration Mode --- + elif orchestration_mode == 'direct_download_cli': + logger.info("--- Direct Download CLI Orchestration Mode Enabled ---") + if mode != 'download_only' or settings.get('profile_mode') != 'from_pool_with_lock': + logger.error("Orchestration mode 'direct_download_cli' is only compatible with 'download_only' mode and 'from_pool_with_lock' profile mode.") + return 1 + + download_manager = profile_managers.get('download') + if not download_manager: + logger.error("Direct download CLI mode requires a download profile manager.") + return 1 + + info_json_dir = settings.get('info_json_dir') + if not info_json_dir: + logger.error("Direct download CLI mode requires 'settings.info_json_dir'.") + return 1 + + try: + os.makedirs(info_json_dir, exist_ok=True) + except OSError as e: + logger.error(f"Failed to create info.json directory '{info_json_dir}': {e}") + return 1 + + sp_utils.display_effective_policy(policy, policy_name, sources=[]) + if args.dry_run: return 0 + + workers = exec_control.get('workers', 1) + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [ + executor.submit(run_direct_download_worker, i, policy, state_manager, args, download_manager, running_processes, process_lock) + for i in range(workers) + ] + shutdown_event.wait() + logger.info("Shutdown signal received, waiting for direct download workers to finish...") + concurrent.futures.wait(futures) + + state_manager.print_summary(policy) + state_manager.close() + return 0 + + # --- Default (Task-First) Orchestration Mode --- sources = [] # This will be a list of URLs or Path objects if mode in ['full_stack', 'fetch_only']: urls_file = settings.get('urls_file') @@ -1933,71 +797,65 @@ def main_stress_policy(args): logger.error("No sources (URLs or info.json files) to process. Exiting.") return 1 - # Grouping of sources by profile is now handled inside the main loop to support continuous mode. - profile_extraction_regex = settings.get('profile_extraction_regex') + start_index = 0 + if mode in ['full_stack', 'fetch_only']: + if args.start_from_url_index is not None: + # User provided a 1-based index via CLI + start_index = max(0, args.start_from_url_index - 1) + logger.info(f"Starting from URL index {start_index + 1} as requested by --start-from-url-index.") + # When user specifies it, we should overwrite the saved state. + state_manager.update_last_url_index(start_index, force=True) + else: + start_index = state_manager.get_last_url_index() + if start_index > 0: + logger.info(f"Resuming from URL index {start_index + 1} based on saved state.") - # For 'auto' worker calculation and initial display, we need to group sources once. - # This will be re-calculated inside the loop for continuous mode. - profile_tasks = None - if mode == 'download_only' and profile_extraction_regex: - profile_tasks = collections.defaultdict(list) - for source_path in sources: - profile_name = get_profile_from_filename(source_path, profile_extraction_regex) - if profile_name: - profile_tasks[profile_name].append(source_path) - else: - profile_tasks['unmatched_profile'].append(source_path) + if start_index >= len(sources): + logger.warning(f"Start index ({start_index + 1}) is beyond the end of the URL list ({len(sources)}). Nothing to process.") + sources = [] # --- Auto-calculate workers if needed --- - exec_control = policy.get('execution_control', {}) original_workers_setting = exec_control.get('workers') if original_workers_setting == 'auto': - if mode == 'download_only' and profile_tasks is not None: - num_profiles = len(profile_tasks) - # Use auto_workers_max from policy, with a default of 8. - max_workers = exec_control.get('auto_workers_max', 8) - num_workers = min(num_profiles, max_workers) - exec_control['workers'] = max(1, num_workers) - logger.info(f"Calculated 'auto' workers based on {num_profiles} profiles (max: {max_workers}): {exec_control['workers']}") - else: - target_rate_cfg = exec_control.get('target_rate', {}) - target_reqs = target_rate_cfg.get('requests') - target_mins = target_rate_cfg.get('per_minutes') - if target_reqs and target_mins and sources: - target_rpm = target_reqs / target_mins - num_sources = len(sources) - sleep_cfg = exec_control.get('sleep_between_tasks', {}) - avg_sleep = (sleep_cfg.get('min_seconds', 0) + sleep_cfg.get('max_seconds', 0)) / 2 - assumed_task_duration = 12 # Must match assumption in display_effective_policy + # In this simplified model, 'auto' is based on target rate, not profiles. + target_rate_cfg = exec_control.get('target_rate', {}) + target_reqs = target_rate_cfg.get('requests') + target_mins = target_rate_cfg.get('per_minutes') + if target_reqs and target_mins and sources: + target_rpm = target_reqs / target_mins + num_sources = len(sources) + sleep_cfg = exec_control.get('sleep_between_tasks', {}) + avg_sleep = (sleep_cfg.get('min_seconds', 0) + sleep_cfg.get('max_seconds', 0)) / 2 + assumed_task_duration = 12 # Must match assumption in display_effective_policy - # Formula: workers = (total_work_seconds) / (total_time_for_work) - # total_time_for_work is derived from the target rate: - # (total_cycle_time) = (60 * num_sources) / target_rpm - # total_time_for_work = total_cycle_time - avg_sleep - work_time_available = (60 * num_sources / target_rpm) - avg_sleep + # Formula: workers = (total_work_seconds) / (total_time_for_work) + # total_time_for_work is derived from the target rate: + # (total_cycle_time) = (60 * num_sources) / target_rpm + # total_time_for_work = total_cycle_time - avg_sleep + work_time_available = (60 * num_sources / target_rpm) - avg_sleep - if work_time_available <= 0: - # The sleep time alone makes the target rate impossible. - # Set workers to max parallelism as a best-effort. - num_workers = num_sources - logger.warning(f"Target rate of {target_rpm} req/min is likely unachievable due to sleep time of {avg_sleep}s.") - logger.warning(f"Setting workers to max parallelism ({num_workers}) as a best effort.") - else: - total_work_seconds = num_sources * assumed_task_duration - num_workers = total_work_seconds / work_time_available - - calculated_workers = max(1, int(num_workers + 0.99)) # Ceiling - exec_control['workers'] = calculated_workers - logger.info(f"Calculated 'auto' workers based on target rate: {calculated_workers}") + if work_time_available <= 0: + # The sleep time alone makes the target rate impossible. + # Set workers to max parallelism as a best-effort. + num_workers = num_sources + logger.warning(f"Target rate of {target_rpm} req/min is likely unachievable due to sleep time of {avg_sleep}s.") + logger.warning(f"Setting workers to max parallelism ({num_workers}) as a best effort.") else: - logger.warning("Cannot calculate 'auto' workers: 'target_rate' or sources are not defined. Defaulting to 1 worker.") - exec_control['workers'] = 1 + total_work_seconds = num_sources * assumed_task_duration + num_workers = total_work_seconds / work_time_available + + calculated_workers = max(1, int(num_workers + 0.99)) # Ceiling + exec_control['workers'] = calculated_workers + logger.info(f"Calculated 'auto' workers based on target rate: {calculated_workers}") + else: + logger.warning("Cannot calculate 'auto' workers: 'target_rate' or sources are not defined. Defaulting to 1 worker.") + exec_control['workers'] = 1 - display_effective_policy( + sp_utils.display_effective_policy( policy, policy_name, sources=sources, - profile_names=list(profile_tasks.keys()) if profile_tasks is not None else None, + profile_names=None, # Profile grouping is removed original_workers_setting=original_workers_setting ) @@ -2015,294 +873,6 @@ def main_stress_policy(args): # --- Main test loop --- cycles = 0 try: - def process_task(source, source_index, cycle_num): - """Worker task for one source (URL or file path).""" - try: - if shutdown_event.is_set(): - return [] # Shutdown initiated, do not start new work - - # --- Step 1: Get info.json content --- - info_json_content = None - profile_name = None - if mode in ['full_stack', 'fetch_only']: - gen_policy = policy.get('info_json_generation_policy', {}) - cmd_template = gen_policy.get('command_template') - - # --- Profile Generation --- - profile_mode = settings.get('profile_mode') - pm_policy = settings.get('profile_management') - - if profile_mode == 'per_worker_with_rotation': - if not pm_policy: - logger.error("Profile mode 'per_worker_with_rotation' requires 'settings.profile_management' configuration.") - # Log a failure event and skip - event = {'type': 'fetch', 'path': str(source), 'success': False, 'error_type': 'ConfigError', 'details': 'Missing profile_management section'} - state_manager.log_event(event) - return [] - worker_id = get_worker_id() - profile_name = state_manager.get_or_rotate_worker_profile(worker_id, policy) - elif pm_policy: - # This is the existing dynamic cooldown logic - profile_name = state_manager.get_next_available_profile(policy) - if not profile_name: - logger.warning("No available profiles to run task. Skipping.") - return [] - else: - # This is the legacy logic - profile_prefix = settings.get('profile_prefix') - if profile_prefix: - if profile_mode == 'per_request': - timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f') - profile_name = f"{profile_prefix}_{timestamp}_{source_index}" - elif profile_mode == 'per_worker': - worker_index = get_worker_id() - profile_name = f"{profile_prefix}_{worker_index}" - else: # Default to pool logic - profile_pool = settings.get('profile_pool') - if profile_pool: - profile_name = f"{profile_prefix}_{source_index % profile_pool}" - else: - profile_name = "default" # A final fallback - - # --- Rate Limit Check --- - if not state_manager.check_and_update_rate_limit(profile_name, policy): - return [] # Rate limited, skip this task - - # --- Command Generation --- - gen_cmd = [] - save_dir = settings.get('save_info_json_dir') - save_path = None - - if cmd_template: - # Low-level template mode. The user is responsible for output. - video_id = get_video_id(source) - - # A heuristic to add '--' if the video ID looks like an option. - # We split the template, find the standalone '{url}' placeholder, - # and insert '--' before it. This assumes it's a positional argument. - template_parts = shlex.split(cmd_template) - try: - # Find from the end, in case it's used in an option value earlier. - url_index = len(template_parts) - 1 - template_parts[::-1].index('{url}') - if video_id.startswith('-'): - template_parts.insert(url_index, '--') - except ValueError: - # '{url}' not found as a standalone token, do nothing special. - pass - - # Rejoin and then format the whole string. - gen_cmd_str = ' '.join(template_parts) - gen_cmd_str = gen_cmd_str.format(url=video_id, profile=profile_name) - gen_cmd = shlex.split(gen_cmd_str) - if args.verbose and '--verbose' not in gen_cmd: - gen_cmd.append('--verbose') - else: - # High-level policy mode. Orchestrator builds the command. - script_cmd_str = settings.get('info_json_script') - if not script_cmd_str: - logger.error("High-level policy requires 'settings.info_json_script'.") - return [] - gen_cmd = shlex.split(script_cmd_str) - video_id = get_video_id(source) - - client_to_use, request_params = state_manager.get_client_for_request(profile_name, gen_policy) - - # --- Multi-Cookie File Logic --- - if pm_policy: - cookie_files = pm_policy.get('cookie_files') - if cookie_files and isinstance(cookie_files, list) and len(cookie_files) > 0: - profile_index = -1 - # Extract index from profile name. Matches _ or __ - match = re.search(r'_(\d+)(?:_(\d+))?$', profile_name) - if match: - # For rotation mode, the first group is worker_id. For pool mode, it's the profile index. - profile_index = int(match.group(1)) - - if profile_index != -1: - cookie_file_path = cookie_files[profile_index % len(cookie_files)] - if not request_params: - request_params = {} - request_params['cookies_file_path'] = cookie_file_path - logger.info(f"[{source}] Assigned cookie file '{os.path.basename(cookie_file_path)}' to profile '{profile_name}'") - else: - logger.warning(f"[{source}] Could not determine index for profile '{profile_name}' to assign cookie file.") - - if client_to_use: - gen_cmd.extend(['--client', str(client_to_use)]) - if gen_policy.get('auth_host'): - gen_cmd.extend(['--auth-host', str(gen_policy.get('auth_host'))]) - if gen_policy.get('auth_port'): - gen_cmd.extend(['--auth-port', str(gen_policy.get('auth_port'))]) - if profile_name != "default": - gen_cmd.extend(['--profile', profile_name]) - - # Add --print-proxy so we can track it for stats - if '--print-proxy' not in gen_cmd: - gen_cmd.append('--print-proxy') - - if request_params: - gen_cmd.extend(['--request-params-json', json.dumps(request_params)]) - if gen_policy.get('assigned_proxy_url'): - gen_cmd.extend(['--assigned-proxy-url', str(gen_policy.get('assigned_proxy_url'))]) - if gen_policy.get('proxy_rename'): - gen_cmd.extend(['--proxy-rename', str(gen_policy.get('proxy_rename'))]) - - if args.verbose: - gen_cmd.append('--verbose') - - # If saving is enabled, delegate saving to the client script. - if save_dir: - try: - os.makedirs(save_dir, exist_ok=True) - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - # Note: Using a timestamped filename to avoid race conditions. - filename = f"{timestamp}-{video_id}-{profile_name}.json" - save_path = Path(save_dir) / filename - gen_cmd.extend(['--output', str(save_path)]) - # No longer need to suppress, it's the default. - except IOError as e: - logger.error(f"[{source}] Could not prepare save path in '{save_dir}': {e}") - # Continue without saving - save_path = None - - # If not saving to a file, we need the output on stdout for the download step. - if not save_dir: - gen_cmd.append('--print-info-out') - - # The positional video_id argument must come after all options. - # Use '--' to ensure it's not parsed as an option if it starts with a dash. - if video_id.startswith('-'): - gen_cmd.append('--') - gen_cmd.append(video_id) - - worker_id = get_worker_id() - profile_log_part = f" [Profile: {profile_name}]" if profile_name else "" - logger.info(f"[Worker {worker_id}]{profile_log_part} [{source}] Running info.json command: {' '.join(shlex.quote(s) for s in gen_cmd)}") - retcode, stdout, stderr = run_command(gen_cmd) - info_json_content = stdout - - # --- Extract proxy from stderr and record it for stats --- - proxy_url = None - proxy_match = re.search(r"Proxy used: (.*)", stderr) - if proxy_match: - proxy_url = proxy_match.group(1).strip() - state_manager.record_proxy_usage(proxy_url) - - if retcode == 0: - # If the client script saved the file, stdout will be empty. - # If we need the content for a download step, we must read it back. - if not info_json_content.strip(): - # Check stderr for the success message to confirm save. - saved_path_match = re.search(r"Successfully saved info.json to (.*)", stderr) - if saved_path_match: - output_file_str = saved_path_match.group(1).strip().strip("'\"") - logger.info(f"[{source}] -> {saved_path_match.group(0).strip()}") - - # If this is a full_stack test, we need the content for the download worker. - if mode == 'full_stack': - try: - with open(output_file_str, 'r', encoding='utf-8') as f: - info_json_content = f.read() - except IOError as e: - logger.error(f"Could not read back info.json from '{output_file_str}': {e}") - retcode = -1 # Treat as failure - elif save_path: - # Command was told to save, but didn't confirm. Assume it worked if exit code is 0. - logger.info(f"[{source}] -> Client script exited 0, assuming info.json was saved to '{save_path}'") - if mode == 'full_stack': - try: - with open(save_path, 'r', encoding='utf-8') as f: - info_json_content = f.read() - except IOError as e: - logger.error(f"Could not read back info.json from '{save_path}': {e}") - retcode = -1 - # If stdout is empty and we weren't saving, it's an issue. - elif not save_path and not cmd_template: - logger.error(f"[{source}] info.json generation gave no stdout and was not asked to save to a file.") - retcode = -1 - else: - logger.info(f"[{source}] -> Successfully fetched info.json to memory/stdout.") - - event = {'type': 'fetch', 'path': str(source), 'profile': profile_name} - if proxy_url: - event['proxy_url'] = proxy_url - - if retcode != 0: - error_lines = [line for line in stderr.strip().split('\n') if 'error' in line.lower()] - error_msg = error_lines[-1] if error_lines else stderr.strip().split('\n')[-1] - logger.error(f"[{source}] Failed to generate info.json: {error_msg}") - event.update({'success': False, 'error_type': 'GetInfoJsonFail', 'details': error_msg}) - state_manager.log_event(event) - return [] - - # Check for quality degradation before logging success - s_conditions = policy.get('stop_conditions', {}) - quality_policy = s_conditions.get('on_quality_degradation') - if quality_policy and info_json_content: - try: - info_data = json.loads(info_json_content) - available_formats = {f.get('format_id') for f in info_data.get('formats', [])} - - required_formats = quality_policy.get('trigger_if_missing_formats') - if required_formats: - # Can be a single string, a comma-separated string, or a list of strings. - if isinstance(required_formats, str): - required_formats = [f.strip() for f in required_formats.split(',')] - - missing_formats = [f for f in required_formats if f not in available_formats] - - if missing_formats: - logger.warning(f"[{source}] Quality degradation detected. Missing required formats: {', '.join(missing_formats)}.") - event['quality_degradation_trigger'] = True - event['missing_formats'] = missing_formats - except (json.JSONDecodeError, TypeError): - logger.warning(f"[{source}] Could not parse info.json or find formats to check for quality degradation.") - - # Record request for profile cooldown policy if active - if pm_policy: - state_manager.record_profile_request(profile_name) - - state_manager.increment_request_count() - event.update({'success': True, 'details': 'OK'}) - state_manager.log_event(event) - - # Saving is now delegated to the client script when a save_dir is provided. - # The orchestrator no longer saves the file itself. - - elif mode == 'download_only': - # This path is for non-profile-grouped download_only mode. - try: - with open(source, 'r', encoding='utf-8') as f: - info_json_content = f.read() - except (IOError, FileNotFoundError) as e: - logger.error(f"[{get_display_name(source)}] Could not read info.json file: {e}") - return [] - - if mode != 'fetch_only': - return _run_download_logic(source, info_json_content, policy, state_manager, profile_name=profile_name) - - return [] - finally: - # Sleep after the task is completed to space out requests from this worker. - exec_control = policy.get('execution_control', {}) - sleep_cfg = exec_control.get('sleep_between_tasks', {}) - sleep_min = sleep_cfg.get('min_seconds', 0) - - if sleep_min > 0: - sleep_max = sleep_cfg.get('max_seconds') or sleep_min - if sleep_max > sleep_min: - sleep_duration = random.uniform(sleep_min, sleep_max) - else: - sleep_duration = sleep_min - - logger.debug(f"Worker sleeping for {sleep_duration:.2f}s after task for {get_display_name(source)}.") - # Interruptible sleep - sleep_end_time = time.time() + sleep_duration - while time.time() < sleep_end_time: - if shutdown_event.is_set(): - break - time.sleep(0.2) - while not shutdown_event.is_set(): if duration_seconds and (time.time() - start_time) > duration_seconds: logger.info("Reached duration limit. Stopping.") @@ -2351,16 +921,6 @@ def main_stress_policy(args): # --- Group sources for this cycle --- task_items = sources - profile_tasks = None - if mode == 'download_only' and profile_extraction_regex: - profile_tasks = collections.defaultdict(list) - for source_path in sources: - profile_name = get_profile_from_filename(source_path, profile_extraction_regex) - if profile_name: - profile_tasks[profile_name].append(source_path) - else: - profile_tasks['unmatched_profile'].append(source_path) - task_items = list(profile_tasks.items()) # If there's nothing to do this cycle, skip. if not task_items: @@ -2379,21 +939,17 @@ def main_stress_policy(args): logger.info(f"--- Cycle #{cycles} (Total Requests: {state_manager.get_request_count()}) ---") with concurrent.futures.ThreadPoolExecutor(max_workers=exec_control.get('workers', 1)) as executor: - if mode == 'download_only' and profile_tasks is not None: - # New: submit profile tasks - future_to_source = { - executor.submit(process_profile_task, profile_name, file_list, policy, state_manager, cycles): profile_name - for profile_name, file_list in task_items - } - else: - # Old: submit individual file/url tasks - future_to_source = { - executor.submit(process_task, source, i, cycles): source - for i, source in enumerate(task_items) + # Submit one task per source URL or info.json file + future_to_task_info = { + executor.submit(process_task, source, i, cycles, policy, state_manager, args, profile_managers, running_processes, process_lock): { + 'source': source, + 'abs_index': i } + for i, source in enumerate(task_items) if i >= start_index + } should_stop = False - pending_futures = set(future_to_source.keys()) + pending_futures = set(future_to_task_info.keys()) while pending_futures and not should_stop: done, pending_futures = concurrent.futures.wait( @@ -2405,14 +961,24 @@ def main_stress_policy(args): should_stop = True break - source = future_to_source[future] + task_info = future_to_task_info[future] + source = task_info['source'] + abs_index = task_info.get('abs_index') + try: results = future.result() - - # Mark file as processed in continuous download mode - if mode == 'download_only' and settings.get('directory_scan_mode') == 'continuous': - state_manager.mark_file_as_processed(source) + if abs_index is not None and mode in ['full_stack', 'fetch_only']: + # Update state to resume from the *next* URL. + state_manager.update_last_url_index(abs_index + 1) + + # --- Mark file as processed --- + # This is the central place to mark a source as complete for download_only mode. + if mode == 'download_only': + # In continuous mode, we add to state file to prevent re-picking in same run. + if settings.get('directory_scan_mode') == 'continuous': + state_manager.mark_file_as_processed(source) + # If marking by rename is on, do that. if settings.get('mark_processed_files'): try: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') @@ -2421,6 +987,10 @@ def main_stress_policy(args): logger.info(f"Marked '{source.name}' as processed by renaming to '{new_path.name}'") except (IOError, OSError) as e: logger.error(f"Failed to rename processed file '{source.name}': {e}") + + # When using profile-aware mode, the file processing (including marking as + # processed) is handled inside process_profile_task. + # For non-profile mode, this logic was incorrect and has been moved. for result in results: if not result['success']: @@ -2429,11 +999,11 @@ def main_stress_policy(args): if s_conditions.get('on_failure') or \ (s_conditions.get('on_http_403') and not is_cumulative_403_active and result['error_type'] == 'HTTP 403') or \ (s_conditions.get('on_timeout') and result['error_type'] == 'Timeout'): - logger.info(f"!!! STOP CONDITION MET: Immediate stop on failure '{result['error_type']}' for {get_display_name(source)}. Shutting down all workers. !!!") + logger.info(f"!!! STOP CONDITION MET: Immediate stop on failure '{result['error_type']}' for {sp_utils.get_display_name(source)}. Shutting down all workers. !!!") should_stop = True break except concurrent.futures.CancelledError: - logger.info(f"Task for {get_display_name(source)} was cancelled during shutdown.") + logger.info(f"Task for {sp_utils.get_display_name(source)} was cancelled during shutdown.") event = { 'type': 'fetch' if mode != 'download_only' else 'download', 'path': str(source), @@ -2443,51 +1013,58 @@ def main_stress_policy(args): } state_manager.log_event(event) except Exception as exc: - logger.error(f'{get_display_name(source)} generated an exception: {exc}') + logger.error(f'{sp_utils.get_display_name(source)} generated an exception: {exc}') if should_stop: break - # Check for cumulative error rate stop conditions + # Check for all stop conditions after each task completes. + + # 1. Max requests limit + if not should_stop and max_requests > 0 and state_manager.get_request_count() >= max_requests: + logger.info(f"!!! STOP CONDITION MET: Reached request limit ({max_requests}). Shutting down. !!!") + should_stop = True + + # 2. Duration limit + if not should_stop and duration_seconds and (time.time() - start_time) > duration_seconds: + logger.info(f"!!! STOP CONDITION MET: Reached duration limit ({run_until_cfg.get('minutes')} minutes). Shutting down. !!!") + should_stop = True + + # 3. Cumulative error rate limits s_conditions = policy.get('stop_conditions', {}) error_rate_policy = s_conditions.get('on_error_rate') - if error_rate_policy and not should_stop: + if not should_stop and error_rate_policy: max_errors = error_rate_policy.get('max_errors') per_minutes = error_rate_policy.get('per_minutes') if max_errors and per_minutes: error_count = state_manager.check_cumulative_error_rate(max_errors, per_minutes) if error_count > 0: - logger.info(f"!!! STOP CONDITION MET: Error rate exceeded: {error_count} errors in the last {per_minutes} minute(s). Shutting down. !!!") + logger.info(f"!!! STOP CONDITION MET: Error rate exceeded ({error_count} errors in last {per_minutes}m). Shutting down. !!!") should_stop = True cumulative_403_policy = s_conditions.get('on_cumulative_403') - if cumulative_403_policy and not should_stop: + if not should_stop and cumulative_403_policy: max_errors = cumulative_403_policy.get('max_errors') per_minutes = cumulative_403_policy.get('per_minutes') if max_errors and per_minutes: error_count = state_manager.check_cumulative_error_rate(max_errors, per_minutes, error_type='HTTP 403') if error_count > 0: - logger.info(f"!!! STOP CONDITION MET: Cumulative 403 error rate exceeded: {error_count} errors in the last {per_minutes} minute(s). Shutting down. !!!") + logger.info(f"!!! STOP CONDITION MET: Cumulative 403 rate exceeded ({error_count} in last {per_minutes}m). Shutting down. !!!") should_stop = True quality_degradation_policy = s_conditions.get('on_quality_degradation') - if quality_degradation_policy and not should_stop: + if not should_stop and quality_degradation_policy: max_triggers = quality_degradation_policy.get('max_triggers') per_minutes = quality_degradation_policy.get('per_minutes') if max_triggers and per_minutes: trigger_count = state_manager.check_quality_degradation_rate(max_triggers, per_minutes) if trigger_count > 0: - logger.info(f"!!! STOP CONDITION MET: Quality degradation triggered {trigger_count} times in the last {per_minutes} minute(s). Shutting down. !!!") + logger.info(f"!!! STOP CONDITION MET: Quality degradation triggered {trigger_count} times in last {per_minutes}m. Shutting down. !!!") should_stop = True if should_stop: break - # Check for duration limit after each task completes - if duration_seconds and (time.time() - start_time) > duration_seconds: - logger.info("Reached duration limit. Cancelling remaining tasks.") - should_stop = True - if should_stop and pending_futures: logger.info(f"Cancelling {len(pending_futures)} outstanding task(s).") for future in pending_futures: @@ -2498,12 +1075,109 @@ def main_stress_policy(args): if max_cycles > 0 and cycles >= max_cycles: break + # If the run is not time-based (i.e., it's limited by cycles or requests) + # and it's not a continuous directory scan, we should stop after one pass. + # This makes the behavior of --set run_until.requests=N more intuitive: it acts + # as an upper limit for a single pass, not a trigger for multiple passes. + if settings.get('directory_scan_mode') != 'continuous' and not duration_seconds: + logger.info("Run is not time-based. Halting after one full pass through sources.") + break + logger.info("Cycle complete.") except KeyboardInterrupt: logger.info("\nForceful shutdown requested...") finally: + # --- Graceful Shutdown URL Reporting --- + if shutdown_event.is_set(): + orchestration_mode = settings.get('orchestration_mode') + if orchestration_mode in ['direct_batch_cli', 'direct_docker_cli'] and mode == 'fetch_only': + urls_file = settings.get('urls_file') + # Check if urls_list was loaded for the relevant mode + if urls_file and 'urls_list' in locals() and urls_list: + last_index = state_manager.get_last_url_index() + # The index points to the *next* URL to be processed. + # If a batch was aborted, it might have been rewound. + # We should save all URLs from this index onwards. + if last_index < len(urls_list): + unprocessed_urls = urls_list[last_index:] + unprocessed_filename = f"unprocessed_urls_{policy_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" + try: + with open(unprocessed_filename, 'w', encoding='utf-8') as f: + f.write('\n'.join(unprocessed_urls)) + logger.warning(f"--- GRACEFUL SHUTDOWN ---") + logger.warning(f"Saved {len(unprocessed_urls)} unprocessed URLs to '{unprocessed_filename}'.") + logger.warning(f"Last processed URL index was {last_index}. Next run should start from index {last_index + 1}.") + logger.warning(f"-------------------------") + except IOError as e: + logger.error(f"Could not save unprocessed URLs: {e}") + state_manager.print_summary(policy) state_manager.close() return 0 + + +def process_task(source, index, cycle_num, policy, state_manager, args, profile_managers, running_processes, process_lock): + """ + Worker task for a single source (URL or info.json path). + This function is the main entry point for the 'task-first' orchestration mode. + """ + settings = policy.get('settings', {}) + mode = settings.get('mode', 'full_stack') + profile_mode = settings.get('profile_mode') + + auth_manager = profile_managers.get('auth') + download_manager = profile_managers.get('download') + + # --- Full Stack Mode --- + if mode == 'full_stack': + # 1. Fetch info.json + if not auth_manager: + logger.error("Full-stack mode requires an 'auth' profile manager.") + return [] + + # This part of the logic is simplified and does not exist in the provided codebase. + # It would involve locking an auth profile, fetching info.json, and then unlocking. + # For now, we'll assume a placeholder logic. + logger.error("Full-stack mode (task-first) is not fully implemented in this version.") + return [] + + # --- Fetch Only Mode --- + elif mode == 'fetch_only': + if not auth_manager: + logger.error("Fetch-only mode requires an 'auth' profile manager.") + return [] + logger.error("Fetch-only mode (task-first) is not fully implemented in this version.") + return [] + + # --- Download Only Mode --- + elif mode == 'download_only': + if profile_mode == 'from_pool_with_lock': + if not download_manager: + logger.error("Download-only with locking requires a 'download' profile manager.") + return [] + # In this mode, we process one file per profile. + return process_profile_task( + profile_name=None, # Profile is locked inside the task + file_list=[source], + policy=policy, + state_manager=state_manager, + cycle_num=cycle_num, + args=args, + running_processes=running_processes, + process_lock=process_lock, + profile_manager_instance=download_manager + ) + else: + # Legacy mode without profile locking + try: + with open(source, 'r', encoding='utf-8') as f: + info_json_content = f.read() + except (IOError, FileNotFoundError) as e: + logger.error(f"[{sp_utils.get_display_name(source)}] Could not read info.json file: {e}") + return [] + + return _run_download_logic(source, info_json_content, policy, state_manager, args, running_processes, process_lock) + + return [] diff --git a/ytops_client/task_generator_tool.py b/ytops_client/task_generator_tool.py new file mode 100644 index 0000000..cec933c --- /dev/null +++ b/ytops_client/task_generator_tool.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +""" +CLI tool to generate granular download task files from a directory of info.json files. +""" +import argparse +import json +import logging +import os +import re +import sys +from pathlib import Path + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def sanitize_format_for_filename(format_str: str) -> str: + """Sanitizes a format selector string to be filesystem-friendly.""" + # Replace common problematic characters with underscores + sanitized = re.sub(r'[\\/+:\[\]\s]', '_', format_str) + # Remove any trailing characters that might be problematic + sanitized = sanitized.strip('._-') + return sanitized + +def add_task_generator_parser(subparsers): + """Adds the parser for the 'task-generator' command.""" + parser = subparsers.add_parser( + 'task-generator', + description="Generate granular download task files from info.jsons.", + formatter_class=argparse.RawTextHelpFormatter, + help="Generate granular download task files." + ) + + # All functionality is under subcommands for extensibility. + generate_subparsers = parser.add_subparsers(dest='task_generator_command', help='Action to perform', required=True) + + gen_parser = generate_subparsers.add_parser( + 'generate', + help='Generate task files from a source directory.', + description='Reads info.json files from a source directory and creates one task file per format in an output directory.' + ) + gen_parser.add_argument('--source-dir', required=True, help='Directory containing the source info.json files.') + gen_parser.add_argument('--output-dir', required=True, help='Directory where the generated task files will be saved.') + gen_parser.add_argument('--formats', required=True, help='A comma-separated list of format IDs or selectors to generate tasks for (e.g., "18,140,bestvideo").') + gen_parser.add_argument('--verbose', action='store_true', help='Enable verbose logging.') + + reset_parser = generate_subparsers.add_parser( + 'reset', + help='Reset processed source files.', + description='Finds all *.processed files in the source directory and renames them back to *.json to allow re-generation.' + ) + reset_parser.add_argument('--source-dir', required=True, help='Directory containing the source info.json files to reset.') + reset_parser.add_argument('--verbose', action='store_true', help='Enable verbose logging.') + + +def _main_task_generator_reset(args): + """Main logic for the 'reset' command.""" + source_dir = Path(args.source_dir) + if not source_dir.is_dir(): + logger.error(f"Source directory does not exist or is not a directory: {source_dir}") + return 1 + + logger.info(f"Scanning for *.processed and *.LOCKED.* files in '{source_dir}' (recursively) to reset...") + # Use rglob for recursive search + processed_files = list(source_dir.rglob('*.json.processed')) + locked_files = list(source_dir.rglob('*.json.LOCKED.*')) + files_to_reset = processed_files + locked_files + + if not files_to_reset: + logger.info("No processed or locked files found to reset.") + return 0 + + reset_count = 0 + for file_to_reset in files_to_reset: + original_path = None + if file_to_reset.name.endswith('.processed'): + # Handles cases like file.json.processed + original_path_str = str(file_to_reset).removesuffix('.processed') + original_path = Path(original_path_str) + elif '.LOCKED.' in file_to_reset.name: + # Handles cases like file.json.LOCKED.0 + original_path_str = str(file_to_reset).split('.LOCKED.')[0] + original_path = Path(original_path_str) + + if original_path: + try: + if original_path.exists(): + logger.warning(f"Original file '{original_path.name}' already exists. Deleting '{file_to_reset.name}' instead of renaming.") + file_to_reset.unlink() + else: + file_to_reset.rename(original_path) + logger.debug(f"Reset '{file_to_reset.name}' to '{original_path.name}'") + reset_count += 1 + except (IOError, OSError) as e: + logger.error(f"Failed to reset '{file_to_reset.name}': {e}") + else: + logger.warning(f"Could not determine original filename for '{file_to_reset.name}'. Skipping.") + + logger.info(f"Successfully reset {reset_count} file(s).") + return 0 + + +def main_task_generator(args): + """Main logic for the 'task-generator' tool.""" + if args.task_generator_command == 'generate': + return _main_task_generator_generate(args) + elif args.task_generator_command == 'reset': + return _main_task_generator_reset(args) + return 1 + + +def _main_task_generator_generate(args): + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + source_dir = Path(args.source_dir) + output_dir = Path(args.output_dir) + formats_to_generate = [f.strip() for f in args.formats.split(',') if f.strip()] + + if not source_dir.is_dir(): + logger.error(f"Source directory does not exist or is not a directory: {source_dir}") + return 1 + + try: + output_dir.mkdir(parents=True, exist_ok=True) + except OSError as e: + logger.error(f"Could not create output directory '{output_dir}': {e}") + return 1 + + logger.info(f"Scanning for info.json files in '{source_dir}' (recursively)...") + source_files = list(source_dir.rglob('*.json')) + + if not source_files: + logger.info(f"No .json files found in '{source_dir}'. Nothing to do.") + return 0 + + logger.info(f"Found {len(source_files)} source file(s). Generating tasks for formats: {', '.join(formats_to_generate)}...") + + total_tasks_generated = 0 + for source_file in source_files: + try: + with open(source_file, 'r', encoding='utf-8') as f: + info_json_content = json.load(f) + except (IOError, json.JSONDecodeError) as e: + logger.warning(f"Skipping file '{source_file.name}' due to read/parse error: {e}") + continue + + try: + tasks_generated_this_run = 0 + + # Use metadata to create a profile-specific subdirectory for better organization. + profile_name_from_meta = info_json_content.get('_ytops_metadata', {}).get('profile_name') + final_output_dir = output_dir + if profile_name_from_meta: + final_output_dir = output_dir / profile_name_from_meta + # Ensure subdirectory exists. This is done once per source file. + try: + final_output_dir.mkdir(parents=True, exist_ok=True) + except OSError as e: + logger.error(f"Could not create profile subdirectory '{final_output_dir}': {e}. Skipping tasks for this source file.") + continue + + for format_str in formats_to_generate: + task_data = info_json_content.copy() + # Add the target format to the task data itself. This makes the task file self-contained. + task_data['_ytops_download_format'] = format_str + + # Create a unique filename for the task + original_stem = source_file.stem + safe_format_str = sanitize_format_for_filename(format_str) + task_filename = f"{original_stem}-format-{safe_format_str}.json" + output_path = final_output_dir / task_filename + + # Check if this specific task file already exists to avoid re-writing + if output_path.exists(): + logger.debug(f"Task file already exists, skipping generation: {output_path}") + continue + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(task_data, f, indent=2) + logger.debug(f"Generated task file: {output_path}") + tasks_generated_this_run += 1 + + if tasks_generated_this_run > 0: + total_tasks_generated += tasks_generated_this_run + + # Mark source file as processed by renaming + try: + processed_path = source_file.with_suffix(f"{source_file.suffix}.processed") + source_file.rename(processed_path) + logger.debug(f"Marked '{source_file.name}' as processed.") + except (IOError, OSError) as e: + logger.error(f"Failed to mark source file '{source_file.name}' as processed: {e}") + + except IOError as e: + logger.error(f"An I/O error occurred while generating tasks for '{source_file.name}': {e}. It will be retried on the next run.") + # The file is not renamed, so it will be picked up again + + logger.info(f"Successfully generated {total_tasks_generated} new task file(s) in '{output_dir}'.") + return 0 diff --git a/ytops_client/youtube-dl/Dockerfile b/ytops_client/youtube-dl/Dockerfile new file mode 100644 index 0000000..fbf65a1 --- /dev/null +++ b/ytops_client/youtube-dl/Dockerfile @@ -0,0 +1,71 @@ +# https://github.com/Jeeaaasus/youtube-dl/blob/master/Dockerfile based on, excluded services +FROM debian:12-slim + +ENV PATH="/opt/yt-dlp-venv/bin:$PATH" \ + HOME="/config" \ + PUID="911" \ + PGID="911" \ + UMASK="022" \ + OPENSSL_CONF= + +RUN set -x && \ + addgroup --gid "$PGID" abc && \ + adduser \ + --gecos "" \ + --disabled-password \ + --uid "$PUID" \ + --ingroup abc \ + --shell /bin/bash \ + abc + +RUN set -x && \ + apt update && \ + apt install -y \ + brotli \ + file \ + wget \ + unzip \ + python3 \ + python3-venv \ + python3-pip && \ + apt clean && \ + python3 -m venv /opt/yt-dlp-venv && \ + rm -rf \ + /var/lib/apt/lists/* \ + /tmp/* + +RUN set -x && \ + arch=`uname -m` && \ + if [ "$arch" = "x86_64" ]; then \ + wget -q 'https://github.com/yt-dlp/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-linux64-gpl.tar.xz' -O - | tar -xJ -C /tmp/ --one-top-level=ffmpeg && \ + chmod -R a+x /tmp/ffmpeg/* && \ + mv $(find /tmp/ffmpeg/* -name ffmpeg) /usr/local/bin/ && \ + mv $(find /tmp/ffmpeg/* -name ffprobe) /usr/local/bin/ && \ + mv $(find /tmp/ffmpeg/* -name ffplay) /usr/local/bin/ && \ + rm -rf /tmp/* ; \ + else \ + if [ "$arch" = "aarch64" ]; then arch='arm64'; fi && \ + wget -q "https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-${arch}-static.tar.xz" -O - | tar -xJ -C /tmp/ --one-top-level=ffmpeg && \ + chmod -R a+x /tmp/ffmpeg/* && \ + mv $(find /tmp/ffmpeg/* -name ffmpeg) /usr/local/bin/ && \ + mv $(find /tmp/ffmpeg/* -name ffprobe) /usr/local/bin/ && \ + rm -rf /tmp/* ; \ + fi + +RUN set -x && \ + arch=`uname -m` && \ + wget -q "https://github.com/denoland/deno/releases/latest/download/deno-${arch}-unknown-linux-gnu.zip" -O /tmp/deno.zip && \ + unzip /tmp/deno.zip -d /tmp/deno/ && \ + chmod -R a+x /tmp/deno/* && \ + mv $(find /tmp/deno/* -name deno) /usr/local/bin/ && \ + rm -rf /tmp/* + +RUN set -x && \ + /opt/yt-dlp-venv/bin/pip --no-cache-dir install -U --pre yt-dlp[default] bgutil-ytdlp-pot-provider && \ + chmod -R a+rx /opt/yt-dlp-venv + +VOLUME /config /downloads + +WORKDIR /config + +CMD ["yt-dlp", "--version"] diff --git a/ytops_client/youtube-dl/README.md b/ytops_client/youtube-dl/README.md new file mode 100644 index 0000000..28189b1 --- /dev/null +++ b/ytops_client/youtube-dl/README.md @@ -0,0 +1,33 @@ +# yt-dlp Docker Image + +**A yt-dlp Docker image for downloading YouTube subscriptions and for use with yt-ops-client.** + +yt-dlp documentation [here](https://github.com/yt-dlp/yt-dlp). + +# Building the Image + +A helper script is provided to build the Docker image. + +```bash +# From the root of the repository +./bin/build-yt-dlp-image +``` + +This will build the image and tag it based on the version in `ytops_client/youtube-dl/release-versions/latest.txt`. For example, if the file contains `2025.12.08`, the script will create the tags `ytops/yt-dlp:2025.12.08` and `ytops/yt-dlp:latest`. + +You can also specify a custom image name: +```bash +./bin/build-yt-dlp-image my-registry/my-yt-dlp +``` + +## Updating yt-dlp + +To update the version of `yt-dlp` used in the image: + +1. Modify the version string in `ytops_client/youtube-dl/release-versions/latest.txt`. +2. Rebuild the image using the build script: + ```bash + ./bin/build-yt-dlp-image + ``` +3. If you have a running container, you will need to stop, remove, and recreate it to use the new image. + diff --git a/ytops_client/youtube-dl/release-versions/latest.txt b/ytops_client/youtube-dl/release-versions/latest.txt new file mode 100644 index 0000000..e4e2387 --- /dev/null +++ b/ytops_client/youtube-dl/release-versions/latest.txt @@ -0,0 +1 @@ +2025.12.08 diff --git a/ytops_client/yt_dlp_dummy_tool.py b/ytops_client/yt_dlp_dummy_tool.py new file mode 100644 index 0000000..78e54b6 --- /dev/null +++ b/ytops_client/yt_dlp_dummy_tool.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +""" +(Internal) A dummy yt-dlp script that simulates Redis interactions for testing. +""" + +import argparse +import json +import logging +import os +import random +import re +import sys +import time +from datetime import datetime, timezone +from pathlib import Path + +# Add project root to path to import ProfileManager and other utils +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(script_dir, '..')) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from ytops_client.profile_manager_tool import ProfileManager +from ytops_client.stress_policy import utils as sp_utils + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def add_yt_dlp_dummy_parser(subparsers): + """Adds the parser for the 'yt-dlp-dummy' command.""" + parser = subparsers.add_parser( + 'yt-dlp-dummy', + description='(Internal) A dummy yt-dlp script that simulates Redis interactions for testing.', + formatter_class=argparse.RawTextHelpFormatter, + help='(Internal) Dummy yt-dlp for simulation.' + ) + # Mimic a subset of yt-dlp's arguments required by the orchestrator + parser.add_argument('--batch-file', required=True, help='File containing URLs to process.') + parser.add_argument('-o', '--output', dest='output_template', required=True, help='Output template for info.json files.') + parser.add_argument('--proxy', help='Proxy URL to use (for logging purposes).') + parser.add_argument('--verbose', action='store_true', help='Enable verbose logging.') + # Note: Other yt-dlp args passed by the orchestrator are safely ignored. + + +def main_yt_dlp_dummy(args): + """Main logic for the 'yt-dlp-dummy' tool.""" + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # --- Read configuration from environment variables --- + profile_name = os.environ.get('YTDLP_PROFILE_NAME') + sim_mode_env = os.environ.get('YTDLP_SIM_MODE') + drop_on_banned = os.environ.get('YTDLP_DROP_BATCH_ON_BANNED_PROFILE_EVEN_EXTERNALLY_MANAGED') == '1' + + failure_rate = float(os.environ.get('YTDLP_DUMMY_FAILURE_RATE', 0.0)) + tolerated_rate = float(os.environ.get('YTDLP_DUMMY_SKIPPED_FAILURE_RATE', 0.0)) + + # These env vars are set by the orchestrator for Redis connection + redis_host = os.getenv('REDIS_HOST', 'localhost') + redis_port = int(os.getenv('REDIS_PORT', 6379)) + redis_password = os.getenv('REDIS_PASSWORD') + + if not profile_name or not sim_mode_env: + logger.error("Missing required environment variables: YTDLP_PROFILE_NAME and YTDLP_SIM_MODE") + return 1 + + # --- Connect to Redis --- + key_prefix = f"{sim_mode_env}_profile_mgmt_" + manager = ProfileManager( + redis_host=redis_host, redis_port=redis_port, + redis_password=redis_password, key_prefix=key_prefix + ) + + # --- Read URLs from batch file --- + try: + with open(args.batch_file, 'r', encoding='utf-8') as f: + urls = [line.strip() for line in f if line.strip()] + except IOError as e: + logger.error(f"Failed to read batch file '{args.batch_file}': {e}") + return 1 + + logger.info(f"Dummy yt-dlp starting batch for profile '{profile_name}'. Processing {len(urls)} URLs.") + + files_created = 0 + hard_failures = 0 + + for url in urls: + time.sleep(random.uniform(0.1, 0.3)) # Simulate work per URL + + # 1. Check if profile has been banned externally + if drop_on_banned: + profile_data = manager.get_profile(profile_name) + if profile_data and profile_data.get('state') == manager.STATE_BANNED: + logger.warning(f"Profile '{profile_name}' is BANNED. Stopping batch as per policy.") + return 1 + + # 2. Simulate success/failure and record activity in Redis + rand_val = random.random() + + if rand_val < failure_rate: + logger.warning(f"Simulating HARD failure for URL '{sp_utils.get_video_id(url)}'.") + logger.info(f"Recording 'failure' for profile '{profile_name}' in Redis.") + manager.record_activity(profile_name, 'failure') + hard_failures += 1 + continue + elif rand_val < (failure_rate + tolerated_rate): + logger.warning(f"Simulating TOLERATED failure for URL '{sp_utils.get_video_id(url)}'.") + logger.info(f"Recording 'tolerated_error' for profile '{profile_name}' in Redis.") + manager.record_activity(profile_name, 'tolerated_error') + continue + else: + # Success + logger.info(f"Simulating SUCCESS for URL '{sp_utils.get_video_id(url)}'. Recording 'success' for profile '{profile_name}' in Redis.") + manager.record_activity(profile_name, 'success') + + # 3. Create dummy info.json file + video_id = sp_utils.get_video_id(url) + dummy_formats = [ + '18', '140', '299-dashy', '298-dashy', '137-dashy', + '136-dashy', '135-dashy', '134-dashy', '133-dashy' + ] + info_data = { + 'id': video_id, + 'formats': [ + {'format_id': f_id, 'url': f'http://dummy.url/{video_id}/{f_id}'} + for f_id in dummy_formats + ], + } + + # This is a simplified version of yt-dlp's output template handling + output_path_str = args.output_template.replace('%(id)s', video_id) + # Real yt-dlp adds .info.json when --write-info-json is used, so we do too. + if not output_path_str.endswith('.info.json'): + output_path_str += '.info.json' + output_path = Path(output_path_str) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + try: + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(info_data, f, indent=2) + logger.debug(f"Created dummy info.json: {output_path}") + files_created += 1 + except (IOError, OSError) as e: + logger.error(f"Failed to write dummy info.json to '{output_path}': {e}") + hard_failures += 1 + + logger.info(f"Dummy yt-dlp finished batch. Created {files_created} files. Had {hard_failures} hard failures.") + + # yt-dlp exits 0 with --ignore-errors. Our dummy does the same. + # The orchestrator judges batch success based on files_created. + return 0