yt-dlp-dags/dags/ytdlp_ops_worker_per_url.py
2025-08-15 18:00:26 +03:00

552 lines
23 KiB
Python

# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2024 rl <rl@rlmbp>
#
# Distributed under terms of the MIT license.
"""
DAG for processing a single YouTube URL passed via DAG run configuration.
This is the "Worker" part of a Sensor/Worker pattern.
This DAG has been refactored to use the TaskFlow API to implement worker affinity,
ensuring all tasks for a single URL run on the same machine.
"""
from __future__ import annotations
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import Variable
from airflow.models.dag import DAG
from airflow.models.param import Param
from airflow.models.xcom_arg import XComArg
from airflow.operators.dummy import DummyOperator
from airflow.utils.dates import days_ago
from airflow.api.common.trigger_dag import trigger_dag
from datetime import datetime, timedelta
import json
import logging
import os
import random
import re
import socket
import time
import traceback
import uuid
# Import utility functions and Thrift modules
from utils.redis_utils import _get_redis_client
from pangramia.yt.common.ttypes import TokenUpdateMode
from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException
from pangramia.yt.tokens_ops import YTTokenOpService
from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport
from thrift.transport.TTransport import TTransportException
# Configure logging
logger = logging.getLogger(__name__)
# Default settings from Airflow Variables or hardcoded fallbacks
DEFAULT_QUEUE_NAME = 'video_queue'
DEFAULT_REDIS_CONN_ID = 'redis_default'
DEFAULT_TIMEOUT = 600
DEFAULT_YT_AUTH_SERVICE_IP = Variable.get("YT_AUTH_SERVICE_IP", default_var="172.17.0.1")
DEFAULT_YT_AUTH_SERVICE_PORT = Variable.get("YT_AUTH_SERVICE_PORT", default_var=9080)
# The queue is set to a fallback here. The actual worker-specific queue is
# assigned just-in-time by the task_instance_mutation_hook in airflow_local_settings.py,
# which reads the 'worker_queue' from the DAG run configuration.
DEFAULT_ARGS = {
'owner': 'airflow',
'retries': 0,
'queue': 'queue-dl', # Fallback queue. Will be overridden by the policy hook.
}
# --- Helper Functions ---
def _get_thrift_client(host, port, timeout):
"""Helper to create and connect a Thrift client."""
transport = TSocket.TSocket(host, port)
transport.setTimeout(timeout * 1000)
transport = TTransport.TFramedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = YTTokenOpService.Client(protocol)
transport.open()
logger.info(f"Connected to Thrift server at {host}:{port}")
return client, transport
def _extract_video_id(url):
"""Extracts YouTube video ID from URL."""
if not url or not isinstance(url, str):
return None
patterns = [r'v=([a-zA-Z0-9_-]{11})', r'youtu\.be/([a-zA-Z0-9_-]{11})']
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None
def _get_account_pool(params: dict) -> list:
"""
Gets the list of accounts to use for processing, filtering out banned/resting accounts.
Supports explicit list, prefix-based generation, and single account modes.
"""
account_pool_str = params.get('account_pool', 'default_account')
accounts = []
is_prefix_mode = False
if ',' in account_pool_str:
accounts = [acc.strip() for acc in account_pool_str.split(',') if acc.strip()]
else:
prefix = account_pool_str
pool_size_param = params.get('account_pool_size')
if pool_size_param is not None:
is_prefix_mode = True
pool_size = int(pool_size_param)
accounts = [f"{prefix}_{i:02d}" for i in range(1, pool_size + 1)]
else:
accounts = [prefix]
if not accounts:
raise AirflowException("Initial account pool is empty.")
redis_conn_id = params.get('redis_conn_id', DEFAULT_REDIS_CONN_ID)
try:
redis_client = _get_redis_client(redis_conn_id)
active_accounts = []
for account in accounts:
status_bytes = redis_client.hget(f"account_status:{account}", "status")
status = status_bytes.decode('utf-8') if status_bytes else "ACTIVE"
if status not in ['BANNED'] and 'RESTING' not in status:
active_accounts.append(account)
if not active_accounts and accounts:
auto_create = params.get('auto_create_new_accounts_on_exhaustion', False)
if auto_create and is_prefix_mode:
new_account_id = f"{account_pool_str}-auto-{str(uuid.uuid4())[:8]}"
logger.warning(f"Account pool exhausted. Auto-creating new account: '{new_account_id}'")
active_accounts.append(new_account_id)
else:
raise AirflowException("All accounts in the configured pool are currently exhausted.")
accounts = active_accounts
except Exception as e:
logger.error(f"Could not filter accounts from Redis. Using unfiltered pool. Error: {e}", exc_info=True)
if not accounts:
raise AirflowException("Account pool is empty after filtering.")
logger.info(f"Final active account pool with {len(accounts)} accounts.")
return accounts
# =============================================================================
# TASK DEFINITIONS (TaskFlow API)
# =============================================================================
@task
def get_url_and_assign_account(**context):
"""
Gets the URL to process from the DAG run configuration and assigns an active account.
This is the first task in the pinned-worker DAG.
"""
params = context['params']
# The URL is passed by the dispatcher DAG.
url_to_process = params.get('url_to_process')
if not url_to_process:
raise AirflowException("'url_to_process' was not found in the DAG run configuration.")
logger.info(f"Received URL '{url_to_process}' to process.")
# Account assignment logic is the same as before.
account_id = random.choice(_get_account_pool(params))
logger.info(f"Selected account '{account_id}' for this run.")
return {
'url_to_process': url_to_process,
'account_id': account_id,
'accounts_tried': [account_id],
}
@task
def get_token(initial_data: dict, **context):
"""Makes a single attempt to get a token from the Thrift service."""
ti = context['task_instance']
params = context['params']
account_id = initial_data['account_id']
url = initial_data['url_to_process']
info_json_dir = Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles')
host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT))
machine_id = params.get('machine_id') or socket.gethostname()
logger.info(f"--- Attempting to get token for URL '{url}' with account '{account_id}' ---")
client, transport = None, None
try:
client, transport = _get_thrift_client(host, port, timeout)
token_data = client.getOrRefreshToken(accountId=account_id, updateType=TokenUpdateMode.AUTO, url=url, clients=params.get('clients'), machineId=machine_id)
info_json = getattr(token_data, 'infoJson', None)
if not (info_json and json.loads(info_json)):
raise AirflowException("Service returned success but info.json was empty or invalid.")
video_id = _extract_video_id(url)
os.makedirs(info_json_dir, exist_ok=True)
# Use a readable timestamp for a unique filename on each attempt.
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
info_json_path = os.path.join(info_json_dir, f"info_{video_id or 'unknown'}_{account_id}_{timestamp}.json")
with open(info_json_path, 'w', encoding='utf-8') as f:
f.write(info_json)
proxy_attr = next((attr for attr in ['socks5Proxy', 'socksProxy', 'socks'] if hasattr(token_data, attr)), None)
return {
'info_json_path': info_json_path,
'socks_proxy': getattr(token_data, proxy_attr) if proxy_attr else None,
'ytdlp_command': getattr(token_data, 'ytdlpCommand', None),
'successful_account_id': account_id,
'original_url': url, # Include original URL for fallback
}
except (PBServiceException, PBUserException, TTransportException) as e:
error_context = getattr(e, 'context', None)
if isinstance(error_context, str):
try: error_context = json.loads(error_context.replace("'", "\""))
except: pass
error_details = {
'error_message': getattr(e, 'message', str(e)),
'error_code': getattr(e, 'errorCode', 'TRANSPORT_ERROR'),
'proxy_url': error_context.get('proxy_url') if isinstance(error_context, dict) else None
}
logger.error(f"Thrift call failed for account '{account_id}'. Exception: {error_details['error_message']}")
ti.xcom_push(key='error_details', value=error_details)
raise AirflowException(f"Thrift call failed: {error_details['error_message']}")
finally:
if transport and transport.isOpen():
transport.close()
@task.branch
def handle_bannable_error_branch(task_id_to_check: str, **context):
"""Inspects a failed task and routes to retry logic if the error is bannable."""
ti = context['task_instance']
params = context['params']
error_details = ti.xcom_pull(task_ids=task_id_to_check, key='error_details')
if not error_details:
return None # Let DAG fail for unexpected errors
error_code = error_details.get('error_code', '').strip()
policy = params.get('on_bannable_failure', 'retry_with_new_account')
is_bannable = error_code in ["SOCKS5_CONNECTION_FAILED", "SOCKET_TIMEOUT", "BOT_DETECTED", "BOT_DETECTION_SIGN_IN_REQUIRED"]
logger.info(f"Handling failure from '{task_id_to_check}'. Error code: '{error_code}', Policy: '{policy}'")
if is_bannable and policy in ['retry_with_new_account', 'retry_and_ban_account_only']:
return 'ban_account_and_prepare_for_retry'
if is_bannable and policy in ['retry_on_connection_error', 'retry_without_ban']:
return 'assign_new_account_for_retry'
if is_bannable: # stop_loop
return 'ban_account_and_fail'
return None # Not a bannable error, let DAG fail
@task
def ban_account(initial_data: dict, reason: str, **context):
"""Bans a single account via the Thrift service."""
params = context['params']
account_id = initial_data['account_id']
client, transport = None, None
try:
host, port, timeout = params['service_ip'], int(params['service_port']), int(params.get('timeout', DEFAULT_TIMEOUT))
client, transport = _get_thrift_client(host, port, timeout)
logger.warning(f"Banning account '{account_id}'. Reason: {reason}")
client.banAccount(accountId=account_id, reason=reason)
except Exception as e:
logger.error(f"Failed to issue ban for account '{account_id}': {e}", exc_info=True)
finally:
if transport and transport.isOpen():
transport.close()
@task
def assign_new_account_for_retry(initial_data: dict, **context):
"""Selects a new, unused account for the retry attempt."""
params = context['params']
accounts_tried = initial_data['accounts_tried']
account_pool = _get_account_pool(params)
available_for_retry = [acc for acc in account_pool if acc not in accounts_tried]
if not available_for_retry:
raise AirflowException("No other accounts available in the pool for a retry.")
new_account_id = random.choice(available_for_retry)
accounts_tried.append(new_account_id)
logger.info(f"Selected new account for retry: '{new_account_id}'")
# Return updated initial_data with new account
return {
'url_to_process': initial_data['url_to_process'],
'account_id': new_account_id,
'accounts_tried': accounts_tried,
}
@task
def ban_and_fail(initial_data: dict, reason: str, **context):
"""Bans an account and then intentionally fails the task to stop the DAG."""
ban_account(initial_data, reason, **context)
raise AirflowException(f"Failing task as per policy. Reason: {reason}")
@task
def download_and_probe(token_data: dict, **context):
"""
Uses the retrieved token data to download and probe the media file.
This version uses subprocess directly with an argument list for better security and clarity.
"""
import subprocess
params = context['params']
info_json_path = token_data.get('info_json_path')
proxy = token_data.get('socks_proxy')
original_url = token_data.get('original_url')
download_dir = Variable.get('DOWNLOADS_TEMP', '/opt/airflow/downloadfiles/video')
download_format = params.get('download_format', 'ba[ext=m4a]/bestaudio/best')
output_template = params.get('output_path_template', "%(title)s [%(id)s].%(ext)s")
full_output_path = os.path.join(download_dir, output_template)
retry_on_probe_failure = params.get('retry_on_probe_failure', False)
if not (info_json_path and os.path.exists(info_json_path)):
raise AirflowException(f"Error: info.json path is missing or file does not exist ({info_json_path}).")
def run_yt_dlp():
"""Constructs and runs the yt-dlp command, returning the final filename."""
cmd = [
'yt-dlp',
'--load-info-json', info_json_path,
'-f', download_format,
'-o', full_output_path,
'--print', 'filename',
'--continue',
'--no-progress',
'--no-simulate',
'--no-write-info-json',
'--ignore-errors',
'--no-playlist',
]
if proxy:
cmd.extend(['--proxy', proxy])
# Crucially, add the original URL to allow yt-dlp to refresh expired download links,
# which is the most common cause of HTTP 403 errors.
if original_url:
cmd.append(original_url)
logger.info(f"Executing yt-dlp command: {' '.join(cmd)}")
process = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)
if process.returncode != 0:
logger.error(f"yt-dlp failed with exit code {process.returncode}")
logger.error(f"STDOUT: {process.stdout}")
logger.error(f"STDERR: {process.stderr}")
raise AirflowException("yt-dlp command failed.")
# Get the last line of stdout, which should be the filename
final_filename = process.stdout.strip().split('\n')[-1]
if not (final_filename and os.path.exists(final_filename)):
logger.error(f"Download command finished but the output file does not exist: '{final_filename}'")
logger.error(f"Full STDOUT:\n{process.stdout}")
logger.error(f"Full STDERR:\n{process.stderr}")
raise AirflowException(f"Download failed or did not produce a file: {final_filename}")
logger.info(f"SUCCESS: Download complete. Final file at: {final_filename}")
return final_filename
def run_ffmpeg_probe(filename):
"""Probes the given file with ffmpeg to check for corruption."""
logger.info(f"Probing downloaded file: {filename}")
try:
subprocess.run(['ffmpeg', '-v', 'error', '-i', filename, '-f', 'null', '-'], check=True, capture_output=True, text=True)
logger.info("SUCCESS: Probe confirmed valid media file.")
except subprocess.CalledProcessError as e:
logger.error(f"ffmpeg probe check failed for '{filename}'. The file might be corrupt.")
logger.error(f"ffmpeg STDERR: {e.stderr}")
raise AirflowException("ffmpeg probe failed.")
# --- Main Execution Logic ---
final_filename = run_yt_dlp()
try:
run_ffmpeg_probe(final_filename)
return final_filename
except AirflowException as e:
if "probe failed" in str(e) and retry_on_probe_failure:
logger.warning("Probe failed. Attempting one re-download...")
try:
# Rename the failed file to allow for a fresh download attempt
part_file = f"{final_filename}.part"
os.rename(final_filename, part_file)
logger.info(f"Renamed corrupted file to {part_file}")
except OSError as rename_err:
logger.error(f"Could not rename corrupted file: {rename_err}")
final_filename_retry = run_yt_dlp()
run_ffmpeg_probe(final_filename_retry)
return final_filename_retry
else:
# Re-raise the original exception if no retry is attempted
raise
@task
def mark_url_as_success(initial_data: dict, downloaded_file_path: str, token_data: dict, **context):
"""Records the successful result in Redis."""
params = context['params']
url = initial_data['url_to_process']
result_data = {
'status': 'success', 'end_time': time.time(), 'url': url,
'downloaded_file_path': downloaded_file_path, **token_data,
'dag_run_id': context['dag_run'].run_id,
}
client = _get_redis_client(params['redis_conn_id'])
client.hset(f"{params['queue_name']}_result", url, json.dumps(result_data))
logger.info(f"Stored success result for URL '{url}'.")
@task(trigger_rule='one_failed')
def handle_generic_failure(**context):
"""Handles any failure in the DAG by recording a detailed error report to Redis."""
# This task is simplified for brevity. The original's detailed logic can be ported here.
logger.error("A failure occurred in the DAG. See previous task logs for details.")
# In a real scenario, this would pull XComs and build a rich report like the original.
raise AirflowException("Failing task to mark DAG run as failed after error.")
@task(trigger_rule='one_success')
def continue_processing_loop(**context):
"""
After a successful run, triggers a new dispatcher to continue the processing loop,
effectively asking for the next URL to be processed.
"""
params = context['params']
dag_run = context['dag_run']
# Create a new unique run_id for the dispatcher, tied to this worker's run.
new_dispatcher_run_id = f"retriggered_by_{dag_run.run_id}"
# Pass all original parameters from the orchestrator through to the new dispatcher run.
conf_to_pass = {k: v for k, v in params.items() if v is not None}
# The new dispatcher will pull its own URL and determine its own queue, so we don't pass these.
conf_to_pass.pop('url_to_process', None)
conf_to_pass.pop('worker_queue', None)
logger.info(f"Worker finished successfully. Triggering a new dispatcher ('{new_dispatcher_run_id}') to continue the loop.")
trigger_dag(
dag_id='ytdlp_ops_dispatcher',
run_id=new_dispatcher_run_id,
conf=conf_to_pass,
replace_microseconds=False
)
@task(trigger_rule='one_success')
def coalesce_token_data(get_token_result=None, retry_get_token_result=None):
"""
Selects the successful token data from either the first attempt or the retry.
The task that did not run or failed will have a result of None.
"""
if retry_get_token_result:
logger.info("Using token data from retry attempt.")
return retry_get_token_result
if get_token_result:
logger.info("Using token data from initial attempt.")
return get_token_result
# This should not be reached if trigger_rule='one_success' is working correctly.
raise AirflowException("Could not find a successful token result from any attempt.")
# =============================================================================
# DAG Definition
# =============================================================================
with DAG(
dag_id='ytdlp_ops_worker_per_url',
default_args=DEFAULT_ARGS,
schedule=None,
start_date=days_ago(1),
catchup=False,
tags=['ytdlp', 'worker'],
doc_md=__doc__,
render_template_as_native_obj=True,
params={
'queue_name': Param(DEFAULT_QUEUE_NAME, type="string"),
'redis_conn_id': Param(DEFAULT_REDIS_CONN_ID, type="string"),
'service_ip': Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string"),
'service_port': Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer"),
'account_pool': Param('default_account', type="string"),
'account_pool_size': Param(None, type=["integer", "null"]),
'machine_id': Param(None, type=["string", "null"]),
'clients': Param('mweb', type="string"),
'timeout': Param(DEFAULT_TIMEOUT, type="integer"),
'download_format': Param('ba[ext=m4a]/bestaudio/best', type="string"),
'output_path_template': Param("%(title)s [%(id)s].%(ext)s", type="string"),
'on_bannable_failure': Param('retry_with_new_account', type="string", enum=['stop_loop', 'retry_with_new_account', 'retry_without_ban', 'retry_and_ban_account_only', 'retry_on_connection_error']),
'retry_on_probe_failure': Param(False, type="boolean"),
'auto_create_new_accounts_on_exhaustion': Param(True, type="boolean"),
# Internal params passed from dispatcher
'url_to_process': Param(None, type=["string", "null"]),
'worker_queue': Param(None, type=["string", "null"]),
}
) as dag:
initial_data = get_url_and_assign_account()
# First attempt at getting token
first_token_attempt = get_token(initial_data)
# Branch task to handle errors
branch_task = handle_bannable_error_branch.override(trigger_rule='one_failed')(
task_id_to_check=first_token_attempt.operator.task_id
)
# Retry path tasks
ban_task = ban_account.override(task_id='ban_account_and_prepare_for_retry')(
initial_data=initial_data,
reason="Banned by Airflow worker on first attempt"
)
new_account_task = assign_new_account_for_retry.override()(
initial_data=initial_data
)
retry_token_task = get_token.override(task_id='retry_get_token')(
initial_data=new_account_task
)
# Stop path
ban_and_fail_task = ban_and_fail.override()(
initial_data=initial_data,
reason="Banned by Airflow worker (policy is stop_loop)"
)
# Set up dependencies for retry logic
first_token_attempt >> branch_task
branch_task >> ban_task >> new_account_task >> retry_token_task
branch_task >> new_account_task # For policies that don't ban
branch_task >> ban_and_fail_task
# Coalesce results from the two possible token tasks
token_data = coalesce_token_data(
get_token_result=first_token_attempt,
retry_get_token_result=retry_token_task
)
download_task = download_and_probe(token_data=token_data)
success_task = mark_url_as_success(
initial_data=initial_data,
downloaded_file_path=download_task,
token_data=token_data
)
failure_task = handle_generic_failure()
# Main pipeline
token_data >> download_task >> success_task
# On success, trigger a new dispatcher to continue the loop.
success_task >> continue_processing_loop()
# Failure handling
[first_token_attempt, retry_token_task, download_task] >> failure_task