yt-dlp-dags/airflow/dags/ytdlp_ops_worker_per_url.py

970 lines
42 KiB
Python

# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2024 rl <rl@rlmbp>
#
# Distributed under terms of the MIT license.
"""
DAG for processing a single YouTube URL passed via DAG run configuration.
This is the "Worker" part of a Sensor/Worker pattern.
This DAG has been refactored to use the TaskFlow API to implement worker affinity,
ensuring all tasks for a single URL run on the same machine.
"""
from __future__ import annotations
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import Variable
from airflow.models.dag import DAG
from airflow.models.param import Param
from airflow.models.xcom_arg import XComArg
from airflow.operators.dummy import DummyOperator
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup
from airflow.api.common.trigger_dag import trigger_dag
from datetime import datetime, timedelta
import json
import logging
import os
import random
import re
import socket
import time
import traceback
import uuid
# Import utility functions and Thrift modules
from utils.redis_utils import _get_redis_client
from pangramia.yt.common.ttypes import TokenUpdateMode
from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException
from pangramia.yt.tokens_ops import YTTokenOpService
from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport
from thrift.transport.TTransport import TTransportException
# Configure logging
logger = logging.getLogger(__name__)
# Default settings from Airflow Variables or hardcoded fallbacks
DEFAULT_QUEUE_NAME = 'video_queue'
DEFAULT_REDIS_CONN_ID = 'redis_default'
DEFAULT_TIMEOUT = 3600
DEFAULT_YT_AUTH_SERVICE_IP = Variable.get("YT_AUTH_SERVICE_IP", default_var="172.17.0.1")
DEFAULT_YT_AUTH_SERVICE_PORT = Variable.get("YT_AUTH_SERVICE_PORT", default_var=9080)
# The queue is set to a fallback here. The actual worker-specific queue is
# assigned just-in-time by the task_instance_mutation_hook (see: airflow/config/custom_task_hooks.py),
# which parses the target queue from the DAG run_id.
DEFAULT_ARGS = {
'owner': 'airflow',
'retries': 0,
'queue': 'queue-dl', # Fallback queue. Will be overridden by the policy hook.
}
# --- Helper Functions ---
def _get_thrift_client(host, port, timeout):
"""Helper to create and connect a Thrift client."""
transport = TSocket.TSocket(host, port)
transport.setTimeout(timeout * 1000)
transport = TTransport.TFramedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = YTTokenOpService.Client(protocol)
transport.open()
logger.info(f"Connected to Thrift server at {host}:{port}")
return client, transport
def _extract_video_id(url):
"""Extracts YouTube video ID from URL."""
if not url or not isinstance(url, str):
return None
patterns = [r'v=([a-zA-Z0-9_-]{11})', r'youtu\.be/([a-zA-Z0-9_-]{11})']
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None
def _get_account_pool(params: dict) -> list:
"""
Gets the list of accounts to use for processing, filtering out banned/resting accounts.
Supports explicit list, prefix-based generation, and single account modes.
"""
account_pool_str = params.get('account_pool', 'default_account')
accounts = []
is_prefix_mode = False
if ',' in account_pool_str:
accounts = [acc.strip() for acc in account_pool_str.split(',') if acc.strip()]
else:
prefix = account_pool_str
pool_size_param = params.get('account_pool_size')
if pool_size_param is not None:
is_prefix_mode = True
pool_size = int(pool_size_param)
accounts = [f"{prefix}_{i:02d}" for i in range(1, pool_size + 1)]
else:
accounts = [prefix]
if not accounts:
raise AirflowException("Initial account pool is empty.")
redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID)
try:
redis_client = _get_redis_client(redis_conn_id)
active_accounts = []
for account in accounts:
status_bytes = redis_client.hget(f"account_status:{account}", "status")
status = status_bytes.decode('utf-8') if status_bytes else "ACTIVE"
if status not in ['BANNED'] and 'RESTING' not in status:
active_accounts.append(account)
if not active_accounts and accounts:
auto_create = params.get('auto_create_new_accounts_on_exhaustion', False)
if auto_create and is_prefix_mode:
new_account_id = f"{account_pool_str}-auto-{str(uuid.uuid4())[:8]}"
logger.warning(f"Account pool exhausted. Auto-creating new account: '{new_account_id}'")
active_accounts.append(new_account_id)
else:
raise AirflowException("All accounts in the configured pool are currently exhausted.")
accounts = active_accounts
except Exception as e:
logger.error(f"Could not filter accounts from Redis. Using unfiltered pool. Error: {e}", exc_info=True)
if not accounts:
raise AirflowException("Account pool is empty after filtering.")
logger.info(f"Final active account pool with {len(accounts)} accounts.")
return accounts
# =============================================================================
# TASK DEFINITIONS (TaskFlow API)
# =============================================================================
@task
def get_url_and_assign_account(**context):
"""
Gets the URL to process from the DAG run configuration and assigns an active account.
This is the first task in the pinned-worker DAG.
"""
params = context['params']
ti = context['task_instance']
# --- Worker Pinning Verification ---
# This is a safeguard against a known Airflow issue where clearing a task
# can cause the task_instance_mutation_hook to be skipped, breaking pinning.
# See: https://github.com/apache/airflow/issues/20143
expected_queue = None
if ti.run_id and '_q_' in ti.run_id:
expected_queue = ti.run_id.split('_q_')[-1]
if not expected_queue:
# Fallback to conf if run_id parsing fails for some reason
expected_queue = params.get('worker_queue')
if expected_queue and ti.queue != expected_queue:
error_msg = (
f"WORKER PINNING FAILURE: Task is running on queue '{ti.queue}' but was expected on '{expected_queue}'. "
"This usually happens after manually clearing a task, which is not the recommended recovery method for this DAG. "
"To recover a failed URL, let the DAG run fail, use the 'ytdlp_mgmt_queues' DAG to requeue the URL, "
"and use the 'ytdlp_ops_orchestrator' to start a new worker loop if needed."
)
logger.error(error_msg)
raise AirflowException(error_msg)
elif expected_queue:
logger.info(f"Worker pinning verified. Task is correctly running on queue '{ti.queue}'.")
# --- End Verification ---
# The URL is passed by the dispatcher DAG.
url_to_process = params.get('url_to_process')
if not url_to_process:
raise AirflowException("'url_to_process' was not found in the DAG run configuration.")
logger.info(f"Received URL '{url_to_process}' to process.")
# Account assignment logic is the same as before.
account_id = random.choice(_get_account_pool(params))
logger.info(f"Selected account '{account_id}' for this run.")
return {
'url_to_process': url_to_process,
'account_id': account_id,
'accounts_tried': [account_id],
}
@task
def get_token(initial_data: dict, **context):
"""Makes a single attempt to get a token from the Thrift service."""
ti = context['task_instance']
params = context['params']
account_id = initial_data['account_id']
url = initial_data['url_to_process']
info_json_dir = Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles')
host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT))
machine_id = params.get('machine_id') or socket.gethostname()
logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' ---")
client, transport = None, None
try:
client, transport = _get_thrift_client(host, port, timeout)
token_data = client.getOrRefreshToken(accountId=account_id, updateType=TokenUpdateMode.AUTO, url=url, clients=params.get('clients'), machineId=machine_id)
info_json = getattr(token_data, 'infoJson', None)
if not (info_json and json.loads(info_json)):
raise AirflowException("Service returned success but info.json was empty or invalid.")
video_id = _extract_video_id(url)
os.makedirs(info_json_dir, exist_ok=True)
# Use a readable timestamp for a unique filename on each attempt.
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
info_json_path = os.path.join(info_json_dir, f"info_{video_id or 'unknown'}_{account_id}_{timestamp}.json")
with open(info_json_path, 'w', encoding='utf-8') as f:
f.write(info_json)
proxy_attr = next((attr for attr in ['socks5Proxy', 'socksProxy', 'socks'] if hasattr(token_data, attr)), None)
return {
'info_json_path': info_json_path,
'socks_proxy': getattr(token_data, proxy_attr) if proxy_attr else None,
'ytdlp_command': getattr(token_data, 'ytdlpCommand', None),
'successful_account_id': account_id,
'original_url': url, # Include original URL for fallback
}
except (PBServiceException, PBUserException, TTransportException) as e:
error_context = getattr(e, 'context', None)
if isinstance(error_context, str):
try: error_context = json.loads(error_context.replace("'", "\""))
except: pass
error_details = {
'error_message': getattr(e, 'message', str(e)),
'error_code': getattr(e, 'errorCode', 'TRANSPORT_ERROR'),
'proxy_url': error_context.get('proxy_url') if isinstance(error_context, dict) else None
}
logger.error(f"Thrift call failed for account '{account_id}'. Exception: {error_details['error_message']}")
ti.xcom_push(key='error_details', value=error_details)
raise AirflowException(f"Thrift call failed: {error_details['error_message']}")
finally:
if transport and transport.isOpen():
transport.close()
@task.branch
def handle_bannable_error_branch(task_id_to_check: str, **context):
"""
Inspects a failed task and routes to retry logic if the error is retryable.
Routes to a fatal error handler for non-retryable infrastructure issues.
"""
ti = context['task_instance']
params = context['params']
error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details')
if not error_details:
logger.error(f"Task {task_id_to_check} failed without error details. Marking as fatal.")
return 'handle_fatal_error'
error_message = error_details.get('error_message', '').strip()
error_code = error_details.get('error_code', '').strip()
policy = params.get('on_bannable_failure', 'retry_with_new_account')
# Check if this is an age confirmation error - should not stop the loop
if "Sign in to confirm your age" in error_message or "confirm your age" in error_message.lower():
logger.info(f"Age confirmation error detected for '{task_id_to_check}'. This is a content restriction, not a bot detection issue.")
return 'handle_age_restriction_error'
# Fatal Thrift connection errors that should stop all processing.
if error_code == 'TRANSPORT_ERROR':
logger.error(f"Fatal Thrift connection error from '{task_id_to_check}'. Stopping processing.")
return 'handle_fatal_error'
# Service-side connection errors that are potentially retryable.
connection_errors = ['SOCKS5_CONNECTION_FAILED', 'SOCKET_TIMEOUT', 'CAMOUFOX_TIMEOUT']
if error_code in connection_errors:
logger.info(f"Handling connection error '{error_code}' from '{task_id_to_check}'. Policy: '{policy}'")
if policy == 'stop_loop':
logger.warning(f"Connection error with 'stop_loop' policy. Marking as fatal.")
return 'handle_fatal_error'
else:
logger.info("Retrying with a new account without banning.")
return 'assign_new_account_for_direct_retry'
# Bannable errors (e.g., bot detection) that can be retried with a new account.
is_bannable = error_code in ["BOT_DETECTED", "BOT_DETECTION_SIGN_IN_REQUIRED"]
logger.info(f"Handling failure from '{task_id_to_check}'. Error code: '{error_code}', Policy: '{policy}'")
if is_bannable:
if policy in ['retry_with_new_account', 'retry_and_ban_account_only']:
return 'ban_account_and_prepare_for_retry'
if policy in ['retry_on_connection_error', 'retry_without_ban']:
return 'assign_new_account_for_direct_retry'
if policy == 'stop_loop':
return 'ban_and_report_immediately'
# Any other error is considered fatal for this run.
logger.error(f"Unhandled or non-retryable error '{error_code}' from '{task_id_to_check}'. Marking as fatal.")
return 'handle_fatal_error'
@task_group(group_id='ban_and_retry_logic')
def ban_and_retry_logic(initial_data: dict):
"""
Task group that checks for sliding window failures before banning an account.
If the account meets ban criteria, it's banned. Otherwise, the ban is skipped
but the retry proceeds.
"""
@task.branch
def check_sliding_window_for_ban(data: dict, **context):
"""
Checks Redis for recent failures. If thresholds are met, proceeds to ban.
Otherwise, proceeds to a dummy task to allow retry without ban.
"""
params = context['params']
account_id = data['account_id']
redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID)
# These thresholds should ideally be Airflow Variables to be configurable
failure_window_seconds = 3600 # 1 hour
failure_threshold_count = 5
failure_threshold_unique_proxies = 3
try:
redis_client = _get_redis_client(redis_conn_id)
failure_key = f"account_failures:{account_id}"
now = time.time()
window_start = now - failure_window_seconds
# 1. Remove old failures and get recent ones
redis_client.zremrangebyscore(failure_key, '-inf', window_start)
recent_failures = redis_client.zrange(failure_key, 0, -1)
if len(recent_failures) >= failure_threshold_count:
# Decode from bytes to string for processing
recent_failures_str = [f.decode('utf-8') for f in recent_failures]
# Failure format is "context:job_id:timestamp"
unique_proxies = {f.split(':')[0] for f in recent_failures_str}
if len(unique_proxies) >= failure_threshold_unique_proxies:
logger.warning(
f"Account {account_id} has failed {len(recent_failures)} times "
f"with {len(unique_proxies)} unique contexts in the last hour. Proceeding to ban."
)
return 'ban_account_task'
else:
logger.info(
f"Account {account_id} has {len(recent_failures)} failures, but only "
f"from {len(unique_proxies)} unique contexts (threshold is {failure_threshold_unique_proxies}). Skipping ban."
)
else:
logger.info(f"Account {account_id} has {len(recent_failures)} failures (threshold is {failure_threshold_count}). Skipping ban.")
except Exception as e:
logger.error(f"Error during sliding window check for account {account_id}: {e}. Skipping ban as a precaution.", exc_info=True)
return 'skip_ban_task'
@task(task_id='ban_account_task')
def ban_account_task(data: dict, **context):
"""Wrapper task to call the main ban_account function."""
ban_account(initial_data=data, reason="Banned by Airflow worker after sliding window check", **context)
@task(task_id='skip_ban_task')
def skip_ban_task():
"""Dummy task to represent the 'skip ban' path."""
pass
check_task = check_sliding_window_for_ban(data=initial_data)
ban_task_in_group = ban_account_task(data=initial_data)
skip_task = skip_ban_task()
check_task >> [ban_task_in_group, skip_task]
@task
def ban_account(initial_data: dict, reason: str, **context):
"""Bans a single account via the Thrift service."""
params = context['params']
account_id = initial_data['account_id']
client, transport = None, None
try:
host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT))
client, transport = _get_thrift_client(host, port, timeout)
logger.warning(f"Banning account '{account_id}'. Reason: {reason}")
client.banAccount(accountId=account_id, reason=reason)
except Exception as e:
logger.error(f"Failed to issue ban for account '{account_id}': {e}", exc_info=True)
finally:
if transport and transport.isOpen():
transport.close()
@task
def assign_new_account_for_direct_retry(initial_data: dict, **context):
"""Selects a new, unused account for a direct retry (e.g., after connection error)."""
params = context['params']
accounts_tried = initial_data['accounts_tried']
account_pool = _get_account_pool(params)
available_for_retry = [acc for acc in account_pool if acc not in accounts_tried]
if not available_for_retry:
raise AirflowException("No other accounts available in the pool for a retry.")
new_account_id = random.choice(available_for_retry)
accounts_tried.append(new_account_id)
logger.info(f"Selected new account for retry: '{new_account_id}'")
# Return updated initial_data with new account
return {
'url_to_process': initial_data['url_to_process'],
'account_id': new_account_id,
'accounts_tried': accounts_tried,
}
@task
def assign_new_account_after_ban_check(initial_data: dict, **context):
"""Selects a new, unused account for the retry attempt after a ban check."""
params = context['params']
accounts_tried = initial_data['accounts_tried']
account_pool = _get_account_pool(params)
available_for_retry = [acc for acc in account_pool if acc not in accounts_tried]
if not available_for_retry:
raise AirflowException("No other accounts available in the pool for a retry.")
new_account_id = random.choice(available_for_retry)
accounts_tried.append(new_account_id)
logger.info(f"Selected new account for retry: '{new_account_id}'")
# Return updated initial_data with new account
return {
'url_to_process': initial_data['url_to_process'],
'account_id': new_account_id,
'accounts_tried': accounts_tried,
}
@task
def ban_and_report_immediately(initial_data: dict, reason: str, **context):
"""Bans an account and prepares for failure reporting and continuing the loop."""
ban_account(initial_data, reason, **context)
logger.info(f"Account '{initial_data.get('account_id')}' banned. Proceeding to report failure.")
# This task is a leaf in its path and is followed by the failure reporting task.
return initial_data # Pass data along if needed by reporting
@task
def download_and_probe(token_data: dict, **context):
"""
Uses the retrieved token data to download and probe the media file.
This version uses subprocess directly with an argument list for better security and clarity.
"""
import subprocess
import shlex
params = context['params']
info_json_path = token_data.get('info_json_path')
proxy = token_data.get('socks_proxy')
original_url = token_data.get('original_url')
download_dir = Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles/video')
download_format = params.get('download_format', 'ba[ext=m4a]/bestaudio/best')
output_template = params.get('output_path_template', "%(title)s [%(id)s].%(ext)s")
full_output_path = os.path.join(download_dir, output_template)
retry_on_probe_failure = params.get('retry_on_probe_failure', False)
if not (info_json_path and os.path.exists(info_json_path)):
raise AirflowException(f"Error: info.json path is missing or file does not exist ({info_json_path}).")
def run_yt_dlp():
"""Constructs and runs the yt-dlp command, returning the final filename."""
cmd = [
'yt-dlp',
'--verbose',
'--load-info-json', info_json_path,
'-f', download_format,
'-o', full_output_path,
'--print', 'filename',
'--continue',
'--no-progress',
'--no-simulate',
'--no-write-info-json',
'--ignore-errors',
'--no-playlist',
]
if proxy:
cmd.extend(['--proxy', proxy])
# Crucially, add the original URL to allow yt-dlp to refresh expired download links,
# which is the most common cause of HTTP 403 errors.
if original_url:
cmd.append(original_url)
copy_paste_cmd = ' '.join(shlex.quote(arg) for arg in cmd)
logger.info(f"Executing yt-dlp command: {copy_paste_cmd}")
process = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
if process.returncode != 0:
logger.error(f"yt-dlp failed with exit code {process.returncode}")
logger.error(f"STDOUT: {process.stdout}")
logger.error(f"STDERR: {process.stderr}")
raise AirflowException("yt-dlp command failed.")
# Get the last line of stdout, which should be the filename
final_filename = process.stdout.strip().split('\n')[-1]
if not (final_filename and os.path.exists(final_filename)):
logger.error(f"Download command finished but the output file does not exist: '{final_filename}'")
logger.error(f"Full STDOUT:\n{process.stdout}")
logger.error(f"Full STDERR:\n{process.stderr}")
raise AirflowException(f"Download failed or did not produce a file: {final_filename}")
logger.info(f"SUCCESS: Download complete. Final file at: {final_filename}")
return final_filename
def run_ffmpeg_probe(filename):
"""Probes the given file with ffmpeg to check for corruption."""
logger.info(f"Probing downloaded file: {filename}")
try:
subprocess.run(['ffmpeg', '-v', 'error', '-i', filename, '-f', 'null', '-'], check=True, capture_output=True, text=True)
logger.info("SUCCESS: Probe confirmed valid media file.")
except subprocess.CalledProcessError as e:
logger.error(f"ffmpeg probe check failed for '{filename}'. The file might be corrupt.")
logger.error(f"ffmpeg STDERR: {e.stderr}")
raise AirflowException("ffmpeg probe failed.")
# --- Main Execution Logic ---
final_filename = run_yt_dlp()
try:
run_ffmpeg_probe(final_filename)
return final_filename
except AirflowException as e:
if "probe failed" in str(e) and retry_on_probe_failure:
logger.warning("Probe failed. Attempting one re-download...")
try:
# Rename the failed file to allow for a fresh download attempt
part_file = f"{final_filename}.part"
os.rename(final_filename, part_file)
logger.info(f"Renamed corrupted file to {part_file}")
except OSError as rename_err:
logger.error(f"Could not rename corrupted file: {rename_err}")
final_filename_retry = run_yt_dlp()
run_ffmpeg_probe(final_filename_retry)
return final_filename_retry
else:
# Re-raise the original exception if no retry is attempted
raise
@task
def mark_url_as_success(initial_data: dict, downloaded_file_path: str, token_data: dict, **context):
"""Records the successful result in Redis."""
params = context['params']
url = initial_data['url_to_process']
result_data = {
'status': 'success', 'end_time': time.time(), 'url': url,
'downloaded_file_path': downloaded_file_path, **token_data,
'dag_run_id': context['dag_run'].run_id,
}
client = _get_redis_client(params['redis_conn_id'])
client.hset(f"{params['queue_name']}_result", url, json.dumps(result_data))
logger.info(f"Stored success result for URL '{url}'.")
@task(trigger_rule='one_failed')
def report_failure_and_continue(**context):
"""
Handles a failed URL processing attempt by recording a detailed error report to Redis.
This is a common endpoint for various failure paths that should not stop the overall dispatcher loop.
"""
params = context['params']
ti = context['task_instance']
url = params.get('url_to_process', 'unknown')
# Collect error details from XCom
error_details = {}
# Check for error details from get_token tasks
first_token_task_id = 'get_token'
retry_token_task_id = 'retry_get_token'
first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details')
retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details')
# Use the most recent error details
if retry_token_error:
error_details = retry_token_error
elif first_token_error:
error_details = first_token_error
else:
# Check for other possible error sources
# This is a simplified approach - in a real implementation you might want to
# check more task IDs or use a more sophisticated error collection mechanism
pass
logger.error(f"A failure occurred while processing URL '{url}'. Reporting to Redis.")
result_data = {
'status': 'failed',
'end_time': time.time(),
'url': url,
'dag_run_id': context['dag_run'].run_id,
'error_details': error_details
}
try:
client = _get_redis_client(params['redis_conn_id'])
result_queue = f"{params['queue_name']}_result"
fail_queue = f"{params['queue_name']}_fail"
with client.pipeline() as pipe:
pipe.hset(result_queue, url, json.dumps(result_data))
pipe.hset(fail_queue, url, json.dumps(result_data))
pipe.execute()
logger.info(f"Stored failure result for URL '{url}' in '{result_queue}' and '{fail_queue}'.")
except Exception as e:
logger.error(f"Could not report failure to Redis: {e}", exc_info=True)
@task(trigger_rule='one_failed')
def handle_fatal_error(**context):
"""
Handles fatal, non-retryable errors (e.g., infrastructure issues).
This task reports the failure to Redis before failing the DAG run to ensure
failed URLs are queued for later reprocessing, then stops the processing loop.
"""
params = context['params']
ti = context['task_instance']
url = params.get('url_to_process', 'unknown')
# Collect error details
error_details = {}
first_token_task_id = 'get_token'
retry_token_task_id = 'retry_get_token'
first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details')
retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details')
# Use the most recent error details
if retry_token_error:
error_details = retry_token_error
elif first_token_error:
error_details = first_token_error
logger.error(f"A fatal, non-retryable error occurred for URL '{url}'. See previous task logs for details.")
# Report failure to Redis so the URL can be reprocessed later
try:
result_data = {
'status': 'failed',
'end_time': time.time(),
'url': url,
'dag_run_id': context['dag_run'].run_id,
'error': 'fatal_error',
'error_message': 'Fatal non-retryable error occurred',
'error_details': error_details
}
client = _get_redis_client(params['redis_conn_id'])
result_queue = f"{params['queue_name']}_result"
fail_queue = f"{params['queue_name']}_fail"
with client.pipeline() as pipe:
pipe.hset(result_queue, url, json.dumps(result_data))
pipe.hset(fail_queue, url, json.dumps(result_data))
pipe.execute()
logger.info(f"Stored fatal error result for URL '{url}' in '{result_queue}' and '{fail_queue}' for later reprocessing.")
except Exception as e:
logger.error(f"Could not report fatal error to Redis: {e}", exc_info=True)
# Fail the DAG run to prevent automatic continuation of the processing loop
raise AirflowException("Failing DAG due to fatal error. The dispatcher loop will stop.")
@task(trigger_rule='one_success')
def continue_processing_loop(**context):
"""
After a successful run, triggers a new dispatcher to continue the processing loop,
effectively asking for the next URL to be processed.
"""
params = context['params']
dag_run = context['dag_run']
# Create a new unique run_id for the dispatcher.
# Using a timestamp and UUID ensures the ID is unique and does not grow in length over time,
# preventing database errors.
new_dispatcher_run_id = f"retriggered_by_worker_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}"
# Pass all original parameters from the orchestrator through to the new dispatcher run.
conf_to_pass = {k: v for k, v in params.items() if v is not None}
# The new dispatcher will pull its own URL and determine its own queue, so we don't pass these.
conf_to_pass.pop('url_to_process', None)
conf_to_pass.pop('worker_queue', None)
logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop.")
trigger_dag(
dag_id='ytdlp_ops_dispatcher',
run_id=new_dispatcher_run_id,
conf=conf_to_pass,
replace_microseconds=False
)
@task.branch(trigger_rule='one_failed')
def handle_retry_failure_branch(task_id_to_check: str, **context):
"""
Inspects a failed retry attempt and decides on the final action.
On retry, most errors are considered fatal for the URL, but not for the system.
"""
ti = context['task_instance']
error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details')
if not error_details:
return 'handle_fatal_error'
error_message = error_details.get('error_message', '').strip()
error_code = error_details.get('error_code', '').strip()
# Check if this is an age confirmation error - should not stop the loop
if "Sign in to confirm your age" in error_message or "confirm your age" in error_message.lower():
logger.info(f"Age confirmation error detected on retry from '{task_id_to_check}'. Reporting failure and continuing loop.")
return 'report_failure_and_continue'
if error_code == 'TRANSPORT_ERROR':
logger.error(f"Fatal Thrift connection error on retry from '{task_id_to_check}'.")
return 'handle_fatal_error'
is_bannable = error_code in ["BOT_DETECTED", "BOT_DETECTION_SIGN_IN_REQUIRED"]
if is_bannable:
logger.warning(f"Bannable error '{error_code}' on retry. Banning account and reporting failure.")
return 'ban_and_report_after_retry'
logger.error(f"URL failed on retry with code '{error_code}'. Reporting failure and continuing loop.")
return 'report_failure_and_continue'
@task
def ban_and_report_after_retry(retry_data: dict, reason: str, **context):
"""Bans the account used in a failed retry and prepares for failure reporting."""
# The account to ban is the one from the retry attempt.
ban_account(retry_data, reason, **context)
logger.info(f"Account '{retry_data.get('account_id')}' banned after retry failed. Proceeding to report failure.")
return retry_data
@task.branch(trigger_rule='one_failed')
def handle_download_failure_branch(**context):
"""If download or probe fails, routes to the standard failure reporting."""
logger.warning("Download or probe failed. Reporting failure and continuing loop.")
return 'report_failure_and_continue'
@task(trigger_rule='one_success')
def coalesce_token_data(get_token_result=None, retry_get_token_result=None):
"""
Selects the successful token data from either the first attempt or the retry.
The task that did not run or failed will have a result of None.
"""
if retry_get_token_result:
logger.info("Using token data from retry attempt.")
return retry_get_token_result
if get_token_result:
logger.info("Using token data from initial attempt.")
return get_token_result
# This should not be reached if trigger_rule='one_success' is working correctly.
raise AirflowException("Could not find a successful token result from any attempt.")
@task(trigger_rule='one_failed')
def handle_age_restriction_error(**context):
"""
Handles age restriction errors specifically. These are content restrictions
that cannot be bypassed by using different accounts, so we report the failure
and continue the processing loop rather than stopping it.
"""
params = context['params']
ti = context['task_instance']
url = params.get('url_to_process', 'unknown')
# Collect error details
error_details = {}
first_token_task_id = 'get_token'
retry_token_task_id = 'retry_get_token'
first_token_error = ti.xcom_pull(task_ids=first_token_task_id, key='error_details')
retry_token_error = ti.xcom_pull(task_ids=retry_token_task_id, key='error_details')
# Use the most recent error details
if retry_token_error:
error_details = retry_token_error
elif first_token_error:
error_details = first_token_error
logger.error(f"Age restriction error for URL '{url}'. This content requires age confirmation and cannot be bypassed.")
# Report failure to Redis so the URL can be marked as failed
try:
result_data = {
'status': 'failed',
'end_time': time.time(),
'url': url,
'dag_run_id': context['dag_run'].run_id,
'error': 'age_restriction',
'error_message': 'Content requires age confirmation',
'error_details': error_details
}
client = _get_redis_client(params['redis_conn_id'])
result_queue = f"{params['queue_name']}_result"
fail_queue = f"{params['queue_name']}_fail"
with client.pipeline() as pipe:
pipe.hset(result_queue, url, json.dumps(result_data))
pipe.hset(fail_queue, url, json.dumps(result_data))
pipe.execute()
logger.info(f"Stored age restriction error for URL '{url}' in '{result_queue}' and '{fail_queue}'.")
except Exception as e:
logger.error(f"Could not report age restriction error to Redis: {e}", exc_info=True)
# This is NOT a fatal error for the processing loop - we just continue with the next URL
# =============================================================================
# DAG Definition with TaskGroups
# =============================================================================
with DAG(
dag_id='ytdlp_ops_worker_per_url',
default_args=DEFAULT_ARGS,
schedule=None,
start_date=days_ago(1),
catchup=False,
tags=['ytdlp', 'worker'],
doc_md=__doc__,
render_template_as_native_obj=True,
params={
'queue_name': Param(DEFAULT_QUEUE_NAME, type="string"),
'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string"),
'service_ip': Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string"),
'service_port': Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer"),
'account_pool': Param('default_account', type="string"),
'account_pool_size': Param(None, type=["integer", "null"]),
'machine_id': Param(None, type=["string", "null"]),
'clients': Param('web', type="string"),
'timeout': Param(DEFAULT_TIMEOUT, type="integer"),
'download_format': Param('ba[ext=m4a]/bestaudio/best', type="string"),
'output_path_template': Param("%(title)s [%(id)s].%(ext)s", type="string"),
'on_bannable_failure': Param('retry_with_new_account', type="string", enum=['stop_loop', 'retry_with_new_account', 'retry_without_ban', 'retry_and_ban_account_only', 'retry_on_connection_error']),
'retry_on_probe_failure': Param(False, type="boolean"),
'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean"),
# Internal params passed from dispatcher
'url_to_process': Param(None, type=["string", "null"]),
'worker_queue': Param(None, type=["string", "null"]),
}
) as dag:
initial_data = get_url_and_assign_account()
# --- Task Instantiation with TaskGroups ---
# Main success/failure handlers (outside groups for clear end points)
fatal_error_task = handle_fatal_error()
report_failure_task = report_failure_and_continue()
continue_loop_task = continue_processing_loop()
age_restriction_task = handle_age_restriction_error()
# --- Task Group 1: Initial Attempt ---
with TaskGroup("initial_attempt", tooltip="Initial token acquisition attempt") as initial_attempt_group:
first_token_attempt = get_token(initial_data)
initial_branch_task = handle_bannable_error_branch.override(trigger_rule='one_failed')(
task_id_to_check=first_token_attempt.operator.task_id
)
# Tasks for the "stop_loop" policy on initial attempt
ban_and_report_immediately_task = ban_and_report_immediately.override(task_id='ban_and_report_immediately')(
initial_data=initial_data,
reason="Banned by Airflow worker (policy is stop_loop)"
)
first_token_attempt >> initial_branch_task
initial_branch_task >> [fatal_error_task, ban_and_report_immediately_task, age_restriction_task]
# --- Task Group 2: Retry Logic ---
with TaskGroup("retry_logic", tooltip="Retry logic with account management") as retry_logic_group:
# Retry path tasks
ban_and_retry_group = ban_and_retry_logic.override(group_id='ban_account_and_prepare_for_retry')(
initial_data=initial_data
)
# This task is for retries after a ban check
after_ban_account_task = assign_new_account_after_ban_check.override(task_id='assign_new_account_after_ban_check')(
initial_data=initial_data
)
# This task is for direct retries (e.g., on connection error)
direct_retry_account_task = assign_new_account_for_direct_retry.override(task_id='assign_new_account_for_direct_retry')(
initial_data=initial_data
)
@task(trigger_rule='one_success')
def coalesce_retry_data(direct_retry_data=None, after_ban_data=None):
"""Coalesces account data from one of the two mutually exclusive retry paths."""
if direct_retry_data:
return direct_retry_data
if after_ban_data:
return after_ban_data
raise AirflowException("Could not find valid account data for retry.")
coalesced_retry_data = coalesce_retry_data(
direct_retry_data=direct_retry_account_task,
after_ban_data=after_ban_account_task
)
retry_token_task = get_token.override(task_id='retry_get_token')(
initial_data=coalesced_retry_data
)
# Retry failure branch and its tasks
retry_branch_task = handle_retry_failure_branch.override(trigger_rule='one_failed')(
task_id_to_check=retry_token_task.operator.task_id
)
ban_after_retry_report_task = ban_and_report_after_retry.override(task_id='ban_and_report_after_retry')(
retry_data=coalesced_retry_data,
reason="Banned by Airflow worker after failed retry"
)
# Internal dependencies within retry group
ban_and_retry_group >> after_ban_account_task
after_ban_account_task >> coalesced_retry_data
direct_retry_account_task >> coalesced_retry_data
coalesced_retry_data >> retry_token_task
retry_token_task >> retry_branch_task
retry_branch_task >> [fatal_error_task, report_failure_task, ban_after_retry_report_task, age_restriction_task]
ban_after_retry_report_task >> report_failure_task
# --- Task Group 3: Download and Processing ---
with TaskGroup("download_processing", tooltip="Download and media processing") as download_processing_group:
# Coalesce, download, and success tasks
token_data = coalesce_token_data(
get_token_result=first_token_attempt,
retry_get_token_result=retry_token_task
)
download_task = download_and_probe(token_data=token_data)
download_branch_task = handle_download_failure_branch.override(trigger_rule='one_failed')()
success_task = mark_url_as_success(
initial_data=initial_data,
downloaded_file_path=download_task,
token_data=token_data
)
# Internal dependencies within download group
first_token_attempt >> token_data
retry_token_task >> token_data
token_data >> download_task
download_task >> download_branch_task
download_branch_task >> report_failure_task
download_task >> success_task
success_task >> continue_loop_task
# --- DAG Dependencies between TaskGroups ---
# Initial attempt can lead to retry logic or direct failure
initial_branch_task >> [retry_logic_group, fatal_error_task, ban_and_report_immediately_task, age_restriction_task]
# Retry logic leads to download processing on success or failure reporting on failure
retry_branch_task >> [download_processing_group, report_failure_task]
# Ban and report immediately leads to failure reporting
ban_and_report_immediately_task >> report_failure_task
# Age restriction error leads to failure reporting and continues the loop
age_restriction_task >> continue_loop_task