yt-dlp-dags/dags/ytdlp_mgmt_proxy_account.py
2025-08-06 18:02:44 +03:00

406 lines
19 KiB
Python

"""
DAG to manage the state of proxies and accounts used by the ytdlp-ops-server.
"""
from __future__ import annotations
import logging
from datetime import datetime
import socket
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.models.param import Param
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models.variable import Variable
from airflow.providers.redis.hooks.redis import RedisHook
# Configure logging
logger = logging.getLogger(__name__)
# Import and apply Thrift exceptions patch for Airflow compatibility
try:
from thrift_exceptions_patch import patch_thrift_exceptions
patch_thrift_exceptions()
logger.info("Applied Thrift exceptions patch for Airflow compatibility.")
except ImportError:
logger.warning("Could not import thrift_exceptions_patch. Compatibility may be affected.")
except Exception as e:
logger.error(f"Error applying Thrift exceptions patch: {e}")
# Thrift imports
try:
from thrift.transport import TSocket, TTransport
from thrift.protocol import TBinaryProtocol
from pangramia.yt.tokens_ops import YTTokenOpService
from pangramia.yt.exceptions.ttypes import PBServiceException, PBUserException
except ImportError as e:
logger.critical(f"Could not import Thrift modules: {e}. Ensure ytdlp-ops-auth package is installed.")
# Fail DAG parsing if thrift modules are not available
raise
DEFAULT_YT_AUTH_SERVICE_IP = Variable.get("YT_AUTH_SERVICE_IP", default_var="16.162.82.212")
DEFAULT_YT_AUTH_SERVICE_PORT = Variable.get("YT_AUTH_SERVICE_PORT", default_var=9080)
DEFAULT_REDIS_CONN_ID = "redis_default"
# Helper function to connect to Redis, similar to other DAGs
def _get_redis_client(redis_conn_id: str):
"""Gets a Redis client from an Airflow connection."""
try:
# Use the imported RedisHook
redis_hook = RedisHook(redis_conn_id=redis_conn_id)
# get_conn returns a redis.Redis client
return redis_hook.get_conn()
except Exception as e:
logger.error(f"Failed to connect to Redis using connection '{redis_conn_id}': {e}")
# Use the imported AirflowException
raise AirflowException(f"Redis connection failed: {e}")
def format_timestamp(ts_str: str) -> str:
"""Formats a string timestamp into a human-readable date string."""
if not ts_str:
return ""
try:
ts_float = float(ts_str)
if ts_float <= 0:
return ""
# Use datetime from the imported 'from datetime import datetime'
dt_obj = datetime.fromtimestamp(ts_float)
return dt_obj.strftime('%Y-%m-%d %H:%M:%S')
except (ValueError, TypeError):
return ts_str # Return original string if conversion fails
def get_thrift_client(host: str, port: int):
"""Helper function to create and connect a Thrift client."""
transport = TSocket.TSocket(host, port)
transport.setTimeout(30 * 1000) # 30s timeout
transport = TTransport.TFramedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = YTTokenOpService.Client(protocol)
transport.open()
logger.info(f"Connected to Thrift server at {host}:{port}")
return client, transport
def _list_proxy_statuses(client, server_identity):
"""Lists the status of proxies."""
logger.info(f"Listing proxy statuses for server: {server_identity or 'ALL'}")
statuses = client.getProxyStatus(server_identity)
if not statuses:
logger.info("No proxy statuses found.")
print("No proxy statuses found.")
return
from tabulate import tabulate
status_list = []
# This is forward-compatible: it checks for new attributes before using them.
has_extended_info = hasattr(statuses[0], 'recentAccounts') or hasattr(statuses[0], 'recentMachines')
headers = ["Server", "Proxy URL", "Status", "Success", "Failures", "Last Success", "Last Failure"]
if has_extended_info:
headers.extend(["Recent Accounts", "Recent Machines"])
for s in statuses:
status_item = {
"Server": s.serverIdentity,
"Proxy URL": s.proxyUrl,
"Status": s.status,
"Success": s.successCount,
"Failures": s.failureCount,
"Last Success": format_timestamp(s.lastSuccessTimestamp),
"Last Failure": format_timestamp(s.lastFailureTimestamp),
}
if has_extended_info:
recent_accounts = getattr(s, 'recentAccounts', [])
recent_machines = getattr(s, 'recentMachines', [])
status_item["Recent Accounts"] = "\n".join(recent_accounts) if recent_accounts else "N/A"
status_item["Recent Machines"] = "\n".join(recent_machines) if recent_machines else "N/A"
status_list.append(status_item)
print("\n--- Proxy Statuses ---")
# The f-string with a newline ensures the table starts on a new line in the logs.
print(f"\n{tabulate(status_list, headers='keys', tablefmt='grid')}")
print("----------------------\n")
if not has_extended_info:
logger.warning("Server does not seem to support 'recentAccounts' or 'recentMachines' fields yet.")
print("NOTE: To see Recent Accounts/Machines, the server's `getProxyStatus` method must be updated to return these fields.")
def _list_account_statuses(client, account_id):
"""Lists the status of accounts."""
logger.info(f"Listing account statuses for account: {account_id or 'ALL'}")
try:
# The thrift method takes accountId (specific) or accountPrefix.
# If account_id is provided, we use it. If not, we get all by leaving both params as None.
statuses = client.getAccountStatus(accountId=account_id, accountPrefix=None)
if not statuses:
logger.info("No account statuses found.")
print("\n--- Account Statuses ---\nNo account statuses found.\n------------------------\n")
return
from tabulate import tabulate
status_list = []
for s in statuses:
# Determine the last activity timestamp for sorting
last_success = float(s.lastSuccessTimestamp) if s.lastSuccessTimestamp else 0
last_failure = float(s.lastFailureTimestamp) if s.lastFailureTimestamp else 0
last_activity = max(last_success, last_failure)
status_item = {
"Account ID": s.accountId,
"Status": s.status,
"Success": s.successCount,
"Failures": s.failureCount,
"Last Success": format_timestamp(s.lastSuccessTimestamp),
"Last Failure": format_timestamp(s.lastFailureTimestamp),
"Last Proxy": s.lastUsedProxy or "N/A",
"Last Machine": s.lastUsedMachine or "N/A",
"_last_activity": last_activity, # Add a temporary key for sorting
}
status_list.append(status_item)
# Sort the list by the last activity timestamp in descending order
status_list.sort(key=lambda item: item.get('_last_activity', 0), reverse=True)
# Remove the temporary sort key before printing
for item in status_list:
del item['_last_activity']
print("\n--- Account Statuses ---")
# The f-string with a newline ensures the table starts on a new line in the logs.
print(f"\n{tabulate(status_list, headers='keys', tablefmt='grid')}")
print("------------------------\n")
except (PBServiceException, PBUserException) as e:
logger.error(f"Failed to get account statuses: {e.message}", exc_info=True)
print(f"\nERROR: Could not retrieve account statuses. Server returned: {e.message}\n")
except Exception as e:
logger.error(f"An unexpected error occurred while getting account statuses: {e}", exc_info=True)
print(f"\nERROR: An unexpected error occurred: {e}\n")
def manage_system_callable(**context):
"""Main callable to interact with the system management endpoints."""
params = context["params"]
entity = params["entity"]
action = params["action"]
host = params["host"]
port = params["port"]
server_identity = params.get("server_identity")
proxy_url = params.get("proxy_url")
account_id = params.get("account_id")
if action in ["ban", "unban", "reset_all"] and entity == "proxy" and not server_identity:
raise ValueError(f"A 'server_identity' is required for proxy action '{action}'.")
if action in ["ban", "unban"] and entity == "account" and not account_id:
raise ValueError(f"An 'account_id' is required for account action '{action}'.")
# Handle direct Redis action separately to avoid creating an unnecessary Thrift connection.
if entity == "account" and action == "remove_all":
confirm = params.get("confirm_remove_all_accounts", False)
if not confirm:
message = "FATAL: 'remove_all' action requires 'confirm_remove_all_accounts' to be set to True. No accounts were removed."
logger.error(message)
print(f"\nERROR: {message}\n")
raise ValueError(message)
redis_conn_id = params["redis_conn_id"]
account_prefix = params.get("account_id") # Repurpose account_id param as an optional prefix
redis_client = _get_redis_client(redis_conn_id)
pattern = f"account_status:{account_prefix}*" if account_prefix else "account_status:*"
logger.warning(f"Searching for account status keys in Redis with pattern: '{pattern}'")
# scan_iter returns bytes, so we don't need to decode for deletion
keys_to_delete = [key for key in redis_client.scan_iter(pattern)]
if not keys_to_delete:
logger.info(f"No account keys found matching pattern '{pattern}'. Nothing to do.")
print(f"\nNo accounts found matching pattern '{pattern}'.\n")
return
logger.warning(f"Found {len(keys_to_delete)} account keys to delete. This is a destructive operation!")
print(f"\nWARNING: Found {len(keys_to_delete)} accounts to remove from Redis.")
# Decode for printing
for key in keys_to_delete[:10]:
print(f" - {key.decode('utf-8')}")
if len(keys_to_delete) > 10:
print(f" ... and {len(keys_to_delete) - 10} more.")
deleted_count = redis_client.delete(*keys_to_delete)
logger.info(f"Successfully deleted {deleted_count} account keys from Redis.")
print(f"\nSuccessfully removed {deleted_count} accounts from Redis.\n")
return # End execution for this action
client, transport = None, None
try:
client, transport = get_thrift_client(host, port)
if entity == "proxy":
if action == "list":
_list_proxy_statuses(client, server_identity)
elif action == "ban":
if not proxy_url: raise ValueError("A 'proxy_url' is required.")
logger.info(f"Banning proxy '{proxy_url}' for server '{server_identity}'...")
client.banProxy(proxy_url, server_identity)
print(f"Successfully sent request to ban proxy '{proxy_url}'.")
elif action == "unban":
if not proxy_url: raise ValueError("A 'proxy_url' is required.")
logger.info(f"Unbanning proxy '{proxy_url}' for server '{server_identity}'...")
client.unbanProxy(proxy_url, server_identity)
print(f"Successfully sent request to unban proxy '{proxy_url}'.")
elif action == "reset_all":
logger.info(f"Resetting all proxy statuses for server '{server_identity}'...")
client.resetAllProxyStatuses(server_identity)
print(f"Successfully sent request to reset all proxy statuses for '{server_identity}'.")
else:
raise ValueError(f"Invalid action '{action}' for entity 'proxy'.")
elif entity == "account":
if action == "list":
_list_account_statuses(client, account_id)
elif action == "ban":
if not account_id: raise ValueError("An 'account_id' is required.")
reason = f"Manual ban from Airflow mgmt DAG by {socket.gethostname()}"
logger.info(f"Banning account '{account_id}'...")
client.banAccount(accountId=account_id, reason=reason)
print(f"Successfully sent request to ban account '{account_id}'.")
elif action == "unban":
if not account_id: raise ValueError("An 'account_id' is required.")
reason = f"Manual un-ban from Airflow mgmt DAG by {socket.gethostname()}"
logger.info(f"Unbanning account '{account_id}'...")
client.unbanAccount(accountId=account_id, reason=reason)
print(f"Successfully sent request to unban account '{account_id}'.")
elif action == "reset_all":
account_prefix = account_id # Repurpose account_id param as an optional prefix
logger.info(f"Resetting all account statuses to ACTIVE (prefix: '{account_prefix or 'ALL'}')...")
all_statuses = client.getAccountStatus(accountId=None, accountPrefix=account_prefix)
if not all_statuses:
print(f"No accounts found with prefix '{account_prefix or 'ALL'}' to reset.")
return
accounts_to_reset = [s.accountId for s in all_statuses]
logger.info(f"Found {len(accounts_to_reset)} accounts to reset.")
print(f"Found {len(accounts_to_reset)} accounts. Sending unban request for each...")
reset_count = 0
fail_count = 0
for acc_id in accounts_to_reset:
try:
reason = f"Manual reset from Airflow mgmt DAG by {socket.gethostname()}"
client.unbanAccount(accountId=acc_id, reason=reason)
logger.info(f" - Sent reset (unban) for '{acc_id}'.")
reset_count += 1
except Exception as e:
logger.error(f" - Failed to reset account '{acc_id}': {e}")
fail_count += 1
print(f"\nSuccessfully sent reset requests for {reset_count} accounts.")
if fail_count > 0:
print(f"Failed to send reset requests for {fail_count} accounts. See logs for details.")
# Optionally, list statuses again to confirm
print("\n--- Listing statuses after reset ---")
_list_account_statuses(client, account_prefix)
else:
raise ValueError(f"Invalid action '{action}' for entity 'account'.")
elif entity == "all":
if action == "list":
print("\nListing all entities...")
_list_proxy_statuses(client, server_identity)
_list_account_statuses(client, account_id)
else:
raise ValueError(f"Action '{action}' is not supported for entity 'all'. Only 'list' is supported.")
except (PBServiceException, PBUserException) as e:
logger.error(f"Thrift error performing action '{action}': {e.message}", exc_info=True)
raise
except NotImplementedError as e:
logger.error(f"Feature not implemented: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Error performing action '{action}': {e}", exc_info=True)
raise
finally:
if transport and transport.isOpen():
transport.close()
logger.info("Thrift connection closed.")
with DAG(
dag_id="ytdlp_mgmt_proxy_account",
start_date=days_ago(1),
schedule=None,
catchup=False,
tags=["ytdlp", "utility", "proxy", "account", "management"],
doc_md="""
### YT-DLP Proxy and Account Manager DAG
This DAG provides tools to manage the state of **proxies and accounts** used by the `ytdlp-ops-server`.
**Parameters:**
- `host`, `port`: Connection details for the `ytdlp-ops-server` Thrift service.
- `entity`: The type of resource to manage (`proxy`, `account`, or `all`).
- `action`: The operation to perform.
- `list`: View statuses. For `entity: all`, lists both proxies and accounts.
- `ban`: Ban a specific proxy or account.
- `unban`: Un-ban a specific proxy or account.
- `reset_all`: Reset all proxies for a server (or all accounts) to `ACTIVE`.
- `remove_all`: **Deletes all account status keys** from Redis for a given prefix. This is a destructive action.
- `server_identity`: Required for most proxy actions.
- `proxy_url`: Required for banning/unbanning a specific proxy.
- `account_id`: Required for managing a specific account. For `action: reset_all` or `remove_all` on `entity: account`, this can be used as an optional prefix to filter which accounts to act on.
- `confirm_remove_all_accounts`: **Required for `remove_all` action.** Must be set to `True` to confirm deletion.
""",
params={
"host": Param(DEFAULT_YT_AUTH_SERVICE_IP, type="string", description="The hostname of the ytdlp-ops-server service. Default is from Airflow variable YT_AUTH_SERVICE_IP or hardcoded."),
"port": Param(DEFAULT_YT_AUTH_SERVICE_PORT, type="integer", description="The port of the ytdlp-ops-server service (Envoy load balancer). Default is from Airflow variable YT_AUTH_SERVICE_PORT or hardcoded."),
"entity": Param(
"all",
type="string",
enum=["proxy", "account", "all"],
description="The type of entity to manage. Use 'all' with action 'list' to see both.",
),
"action": Param(
"list",
type="string",
enum=["list", "ban", "unban", "reset_all", "remove_all"],
description="The management action to perform. `reset_all` for proxies/accounts. `remove_all` for accounts only.",
),
"server_identity": Param(
"ytdlp-ops-airflow-service",
type=["null", "string"],
description="The identity of the server instance (for proxy management).",
),
"proxy_url": Param(
None,
type=["null", "string"],
description="The proxy URL to act upon (e.g., 'socks5://host:port').",
),
"account_id": Param(
None,
type=["null", "string"],
description="The account ID to act upon. For `reset_all` or `remove_all` on accounts, this can be an optional prefix.",
),
"confirm_remove_all_accounts": Param(
False,
type="boolean",
title="[remove_all] Confirm Deletion",
description="Must be set to True to execute the 'remove_all' action for accounts. This is a destructive operation.",
),
"redis_conn_id": Param(
DEFAULT_REDIS_CONN_ID,
type="string",
title="Redis Connection ID",
description="The Airflow connection ID for the Redis server (used for 'remove_all').",
),
},
) as dag:
system_management_task = PythonOperator(
task_id="system_management_task",
python_callable=manage_system_callable,
)