# -*- coding: utf-8 -*- # vim:fenc=utf-8 # # Copyright © 2024 rl # # Distributed under terms of the MIT license. """ DAG for processing a single YouTube URL passed via DAG run configuration. This is the "Worker" part of a Sensor/Worker pattern. This DAG has been refactored to use the TaskFlow API to implement worker affinity, ensuring all tasks for a single URL run on the same machine. """ from __future__ import annotations from airflow.decorators import task, task_group from airflow.exceptions import AirflowException, AirflowSkipException from airflow.models import Variable from airflow.models.dag import DAG from airflow.models.param import Param from airflow.models.xcom_arg import XComArg from airflow.operators.dummy import DummyOperator from airflow.utils.dates import days_ago from airflow.api.common.trigger_dag import trigger_dag from datetime import datetime, timedelta import json import logging import os import random import re import socket import time import traceback import uuid # Import utility functions and Thrift modules from utils.redis_utils import _get_redis_client from pangramia.yt.common.ttypes import TokenUpdateMode from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException from pangramia.yt.tokens_ops import YTTokenOpService from thrift.protocol import TBinaryProtocol from thrift.transport import TSocket, TTransport from thrift.transport.TTransport import TTransportException # Configure logging logger = logging.getLogger(__name__) # Default settings from Airflow Variables or hardcoded fallbacks DEFAULT_QUEUE_NAME = 'video_queue' DEFAULT_REDIS_CONN_ID = 'redis_default' DEFAULT_TIMEOUT = 600 DEFAULT_YT_AUTH_SERVICE_IP = Variable.get("YT_AUTH_SERVICE_IP", default_var="172.17.0.1") DEFAULT_YT_AUTH_SERVICE_PORT = Variable.get("YT_AUTH_SERVICE_PORT", default_var=9080) # The queue is set to a fallback here. The actual worker-specific queue is # assigned just-in-time by the task_instance_mutation_hook in airflow_local_settings.py, # which reads the 'worker_queue' from the DAG run configuration. DEFAULT_ARGS = { 'owner': 'airflow', 'retries': 0, 'queue': 'queue-dl', # Fallback queue. Will be overridden by the policy hook. } # --- Helper Functions --- def _get_thrift_client(host, port, timeout): """Helper to create and connect a Thrift client.""" transport = TSocket.TSocket(host, port) transport.setTimeout(timeout * 1000) transport = TTransport.TFramedTransport(transport) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = YTTokenOpService.Client(protocol) transport.open() logger.info(f"Connected to Thrift server at {host}:{port}") return client, transport def _extract_video_id(url): """Extracts YouTube video ID from URL.""" if not url or not isinstance(url, str): return None patterns = [r'v=([a-zA-Z0-9_-]{11})', r'youtu\.be/([a-zA-Z0-9_-]{11})'] for pattern in patterns: match = re.search(pattern, url) if match: return match.group(1) return None def _get_account_pool(params: dict) -> list: """ Gets the list of accounts to use for processing, filtering out banned/resting accounts. Supports explicit list, prefix-based generation, and single account modes. """ account_pool_str = params.get('account_pool', 'default_account') accounts = [] is_prefix_mode = False if ',' in account_pool_str: accounts = [acc.strip() for acc in account_pool_str.split(',') if acc.strip()] else: prefix = account_pool_str pool_size_param = params.get('account_pool_size') if pool_size_param is not None: is_prefix_mode = True pool_size = int(pool_size_param) accounts = [f"{prefix}_{i:02d}" for i in range(1, pool_size + 1)] else: accounts = [prefix] if not accounts: raise AirflowException("Initial account pool is empty.") redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID) try: redis_client = _get_redis_client(redis_conn_id) active_accounts = [] for account in accounts: status_bytes = redis_client.hget(f"account_status:{account}", "status") status = status_bytes.decode('utf-8') if status_bytes else "ACTIVE" if status not in ['BANNED'] and 'RESTING' not in status: active_accounts.append(account) if not active_accounts and accounts: auto_create = params.get('auto_create_new_accounts_on_exhaustion', False) if auto_create and is_prefix_mode: new_account_id = f"{account_pool_str}-auto-{str(uuid.uuid4())[:8]}" logger.warning(f"Account pool exhausted. Auto-creating new account: '{new_account_id}'") active_accounts.append(new_account_id) else: raise AirflowException("All accounts in the configured pool are currently exhausted.") accounts = active_accounts except Exception as e: logger.error(f"Could not filter accounts from Redis. Using unfiltered pool. Error: {e}", exc_info=True) if not accounts: raise AirflowException("Account pool is empty after filtering.") logger.info(f"Final active account pool with {len(accounts)} accounts.") return accounts # ============================================================================= # TASK DEFINITIONS (TaskFlow API) # ============================================================================= @task def get_url_and_assign_account(**context): """ Gets the URL to process from the DAG run configuration and assigns an active account. This is the first task in the pinned-worker DAG. """ params = context['params'] # The URL is passed by the dispatcher DAG. url_to_process = params.get('url_to_process') if not url_to_process: raise AirflowException("'url_to_process' was not found in the DAG run configuration.") logger.info(f"Received URL '{url_to_process}' to process.") # Account assignment logic is the same as before. account_id = random.choice(_get_account_pool(params)) logger.info(f"Selected account '{account_id}' for this run.") return { 'url_to_process': url_to_process, 'account_id': account_id, 'accounts_tried': [account_id], } @task def get_token(initial_data: dict, **context): """Makes a single attempt to get a token from the Thrift service.""" ti = context['task_instance'] params = context['params'] account_id = initial_data['account_id'] url = initial_data['url_to_process'] info_json_dir = Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles') host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT)) machine_id = params.get('machine_id') or socket.gethostname() logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' ---") client, transport = None, None try: client, transport = _get_thrift_client(host, port, timeout) token_data = client.getOrRefreshToken(accountId=account_id, updateType=TokenUpdateMode.AUTO, url=url, clients=params.get('clients'), machineId=machine_id) info_json = getattr(token_data, 'infoJson', None) if not (info_json and json.loads(info_json)): raise AirflowException("Service returned success but info.json was empty or invalid.") video_id = _extract_video_id(url) os.makedirs(info_json_dir, exist_ok=True) # Use a readable timestamp for a unique filename on each attempt. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") info_json_path = os.path.join(info_json_dir, f"info_{video_id or 'unknown'}_{account_id}_{timestamp}.json") with open(info_json_path, 'w', encoding='utf-8') as f: f.write(info_json) proxy_attr = next((attr for attr in ['socks5Proxy', 'socksProxy', 'socks'] if hasattr(token_data, attr)), None) return { 'info_json_path': info_json_path, 'socks_proxy': getattr(token_data, proxy_attr) if proxy_attr else None, 'ytdlp_command': getattr(token_data, 'ytdlpCommand', None), 'successful_account_id': account_id, 'original_url': url, # Include original URL for fallback } except (PBServiceException, PBUserException, TTransportException) as e: error_context = getattr(e, 'context', None) if isinstance(error_context, str): try: error_context = json.loads(error_context.replace("'", "\"")) except: pass error_details = { 'error_message': getattr(e, 'message', str(e)), 'error_code': getattr(e, 'errorCode', 'TRANSPORT_ERROR'), 'proxy_url': error_context.get('proxy_url') if isinstance(error_context, dict) else None } logger.error(f"Thrift call failed for account '{account_id}'. Exception: {error_details['error_message']}") ti.xcom_push(key='error_details', value=error_details) raise AirflowException(f"Thrift call failed: {error_details['error_message']}") finally: if transport and transport.isOpen(): transport.close() @task.branch def handle_bannable_error_branch(task_id_to_check: str, **context): """Inspects a failed task and routes to retry logic if the error is bannable.""" ti = context['task_instance'] params = context['params'] error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details') if not error_details: return None # Let DAG fail for unexpected errors error_code = error_details.get('error_code', '').strip() policy = params.get('on_bannable_failure', 'retry_with_new_account') is_bannable = error_code in ["SOCKS5_CONNECTION_FAILED", "SOCKET_TIMEOUT", "BOT_DETECTED", "BOT_DETECTION_SIGN_IN_REQUIRED"] logger.info(f"Handling failure from '{task_id_to_check}'. Error code: '{error_code}', Policy: '{policy}'") if is_bannable and policy in ['retry_with_new_account', 'retry_and_ban_account_only']: return 'ban_account_and_prepare_for_retry' if is_bannable and policy in ['retry_on_connection_error', 'retry_without_ban']: return 'assign_new_account_for_retry' if is_bannable: # stop_loop return 'ban_account_and_fail' return None # Not a bannable error, let DAG fail @task def ban_account(initial_data: dict, reason: str, **context): """Bans a single account via the Thrift service.""" params = context['params'] account_id = initial_data['account_id'] client, transport = None, None try: host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT)) client, transport = _get_thrift_client(host, port, timeout) logger.warning(f"Banning account '{account_id}'. Reason: {reason}") client.banAccount(accountId=account_id, reason=reason) except Exception as e: logger.error(f"Failed to issue ban for account '{account_id}': {e}", exc_info=True) finally: if transport and transport.isOpen(): transport.close() @task def assign_new_account_for_retry(initial_data: dict, **context): """Selects a new, unused account for the retry attempt.""" params = context['params'] accounts_tried = initial_data['accounts_tried'] account_pool = _get_account_pool(params) available_for_retry = [acc for acc in account_pool if acc not in accounts_tried] if not available_for_retry: raise AirflowException("No other accounts available in the pool for a retry.") new_account_id = random.choice(available_for_retry) accounts_tried.append(new_account_id) logger.info(f"Selected new account for retry: '{new_account_id}'") # Return updated initial_data with new account return { 'url_to_process': initial_data['url_to_process'], 'account_id': new_account_id, 'accounts_tried': accounts_tried, } @task def ban_and_fail(initial_data: dict, reason: str, **context): """Bans an account and then intentionally fails the task to stop the DAG.""" ban_account(initial_data, reason, **context) raise AirflowException(f"Failing task as per policy. Reason: {reason}") @task def download_and_probe(token_data: dict, **context): """ Uses the retrieved token data to download and probe the media file. This version uses subprocess directly with an argument list for better security and clarity. """ import subprocess params = context['params'] info_json_path = token_data.get('info_json_path') proxy = token_data.get('socks_proxy') original_url = token_data.get('original_url') download_dir = Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles/video') download_format = params.get('download_format', 'ba[ext=m4a]/bestaudio/best') output_template = params.get('output_path_template', "%(title)s [%(id)s].%(ext)s") full_output_path = os.path.join(download_dir, output_template) retry_on_probe_failure = params.get('retry_on_probe_failure', False) if not (info_json_path and os.path.exists(info_json_path)): raise AirflowException(f"Error: info.json path is missing or file does not exist ({info_json_path}).") def run_yt_dlp(): """Constructs and runs the yt-dlp command, returning the final filename.""" cmd = [ 'yt-dlp', '--load-info-json', info_json_path, '-f', download_format, '-o', full_output_path, '--print', 'filename', '--continue', '--no-progress', '--no-simulate', '--no-write-info-json', '--ignore-errors', '--no-playlist', ] if proxy: cmd.extend(['--proxy', proxy]) # Crucially, add the original URL to allow yt-dlp to refresh expired download links, # which is the most common cause of HTTP 403 errors. if original_url: cmd.append(original_url) logger.info(f"Executing yt-dlp command: {' '.join(cmd)}") process = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) if process.returncode != 0: logger.error(f"yt-dlp failed with exit code {process.returncode}") logger.error(f"STDOUT: {process.stdout}") logger.error(f"STDERR: {process.stderr}") raise AirflowException("yt-dlp command failed.") # Get the last line of stdout, which should be the filename final_filename = process.stdout.strip().split('\n')[-1] if not (final_filename and os.path.exists(final_filename)): logger.error(f"Download command finished but the output file does not exist: '{final_filename}'") logger.error(f"Full STDOUT:\n{process.stdout}") logger.error(f"Full STDERR:\n{process.stderr}") raise AirflowException(f"Download failed or did not produce a file: {final_filename}") logger.info(f"SUCCESS: Download complete. Final file at: {final_filename}") return final_filename def run_ffmpeg_probe(filename): """Probes the given file with ffmpeg to check for corruption.""" logger.info(f"Probing downloaded file: {filename}") try: subprocess.run(['ffmpeg', '-v', 'error', '-i', filename, '-f', 'null', '-'], check=True, capture_output=True, text=True) logger.info("SUCCESS: Probe confirmed valid media file.") except subprocess.CalledProcessError as e: logger.error(f"ffmpeg probe check failed for '{filename}'. The file might be corrupt.") logger.error(f"ffmpeg STDERR: {e.stderr}") raise AirflowException("ffmpeg probe failed.") # --- Main Execution Logic --- final_filename = run_yt_dlp() try: run_ffmpeg_probe(final_filename) return final_filename except AirflowException as e: if "probe failed" in str(e) and retry_on_probe_failure: logger.warning("Probe failed. Attempting one re-download...") try: # Rename the failed file to allow for a fresh download attempt part_file = f"{final_filename}.part" os.rename(final_filename, part_file) logger.info(f"Renamed corrupted file to {part_file}") except OSError as rename_err: logger.error(f"Could not rename corrupted file: {rename_err}") final_filename_retry = run_yt_dlp() run_ffmpeg_probe(final_filename_retry) return final_filename_retry else: # Re-raise the original exception if no retry is attempted raise @task def mark_url_as_success(initial_data: dict, downloaded_file_path: str, token_data: dict, **context): """Records the successful result in Redis.""" params = context['params'] url = initial_data['url_to_process'] result_data = { 'status': 'success', 'end_time': time.time(), 'url': url, 'downloaded_file_path': downloaded_file_path, **token_data, 'dag_run_id': context['dag_run'].run_id, } client = _get_redis_client(params['redis_conn_id']) client.hset(f"{params['queue_name']}_result", url, json.dumps(result_data)) logger.info(f"Stored success result for URL '{url}'.") @task(trigger_rule='one_failed') def handle_generic_failure(**context): """Handles any failure in the DAG by recording a detailed error report to Redis.""" # This task is simplified for brevity. The original's detailed logic can be ported here. logger.error("A failure occurred in the DAG. See previous task logs for details.") # In a real scenario, this would pull XComs and build a rich report like the original. raise AirflowException("Failing task to mark DAG run as failed after error.") @task(trigger_rule='one_success') def continue_processing_loop(**context): """ After a successful run, triggers a new dispatcher to continue the processing loop, effectively asking for the next URL to be processed. """ params = context['params'] dag_run = context['dag_run'] # Create a new unique run_id for the dispatcher, tied to this worker's run. new_dispatcher_run_id = f"retriggered_by_{dag_run.run_id}" # Pass all original parameters from the orchestrator through to the new dispatcher run. conf_to_pass = {k: v for k, v in params.items() if v is not None} # The new dispatcher will pull its own URL and determine its own queue, so we don't pass these. conf_to_pass.pop('url_to_process', None) conf_to_pass.pop('worker_queue', None) logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop.") trigger_dag( dag_id='ytdlp_ops_dispatcher', run_id=new_dispatcher_run_id, conf=conf_to_pass, replace_microseconds=False ) @task(trigger_rule='one_success') def coalesce_token_data(get_token_result=None, retry_get_token_result=None): """ Selects the successful token data from either the first attempt or the retry. The task that did not run or failed will have a result of None. """ if retry_get_token_result: logger.info("Using token data from retry attempt.") return retry_get_token_result if get_token_result: logger.info("Using token data from initial attempt.") return get_token_result # This should not be reached if trigger_rule='one_success' is working correctly. raise AirflowException("Could not find a successful token result from any attempt.") # ============================================================================= # DAG Definition # ============================================================================= with DAG( dag_id='ytdlp_ops_worker_per_url', default_args=DEFAULT_ARGS, schedule=None, start_date=days_ago(1), catchup=False, tags=['ytdlp', 'worker'], doc_md=__doc__, render_template_as_native_obj=True, params={ 'queue_name': Param(DEFAULT_QUEUE_NAME, type="string"), 'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string"), 'service_ip': Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string"), 'service_port': Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer"), 'account_pool': Param('default_account', type="string"), 'account_pool_size': Param(None, type=["integer", "null"]), 'machine_id': Param(None, type=["string", "null"]), 'clients': Param('mweb', type="string"), 'timeout': Param(DEFAULT_TIMEOUT, type="integer"), 'download_format': Param('ba[ext=m4a]/bestaudio/best', type="string"), 'output_path_template': Param("%(title)s [%(id)s].%(ext)s", type="string"), 'on_bannable_failure': Param('retry_with_new_account', type="string", enum=['stop_loop', 'retry_with_new_account', 'retry_without_ban', 'retry_and_ban_account_only', 'retry_on_connection_error']), 'retry_on_probe_failure': Param(False, type="boolean"), 'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean"), # Internal params passed from dispatcher 'url_to_process': Param(None, type=["string", "null"]), 'worker_queue': Param(None, type=["string", "null"]), } ) as dag: initial_data = get_url_and_assign_account() # First attempt at getting token first_token_attempt = get_token(initial_data) # Branch task to handle errors branch_task = handle_bannable_error_branch.override(trigger_rule='one_failed')( task_id_to_check=first_token_attempt.operator.task_id ) # Retry path tasks ban_task = ban_account.override(task_id='ban_account_and_prepare_for_retry')( initial_data=initial_data, reason="Banned by Airflow worker on first attempt" ) new_account_task = assign_new_account_for_retry.override()( initial_data=initial_data ) retry_token_task = get_token.override(task_id='retry_get_token')( initial_data=new_account_task ) # Stop path ban_and_fail_task = ban_and_fail.override()( initial_data=initial_data, reason="Banned by Airflow worker (policy is stop_loop)" ) # Set up dependencies for retry logic first_token_attempt >> branch_task branch_task >> ban_task >> new_account_task >> retry_token_task branch_task >> new_account_task # For policies that don't ban branch_task >> ban_and_fail_task # Coalesce results from the two possible token tasks token_data = coalesce_token_data( get_token_result=first_token_attempt, retry_get_token_result=retry_token_task ) download_task = download_and_probe(token_data=token_data) success_task = mark_url_as_success( initial_data=initial_data, downloaded_file_path=download_task, token_data=token_data ) failure_task = handle_generic_failure() # Main pipeline token_data >> download_task >> success_task # On success, trigger a new dispatcher to continue the loop. success_task >> continue_processing_loop() # Failure handling [first_token_attempt, retry_token_task, download_task] >> failure_task