Initial import of cf-temp-email deploy CLI
This commit is contained in:
6
src/cf_temp_email_deploy/__init__.py
Normal file
6
src/cf_temp_email_deploy/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Cloudflare Temp Email automated deployment package."""
|
||||
|
||||
__all__ = ["__version__"]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
6
src/cf_temp_email_deploy/__main__.py
Normal file
6
src/cf_temp_email_deploy/__main__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cf_temp_email_deploy.cli import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
||||
221
src/cf_temp_email_deploy/app_admin.py
Normal file
221
src/cf_temp_email_deploy/app_admin.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Application admin API helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from cf_temp_email_deploy.errors import ApplicationAPIError
|
||||
|
||||
LINUXDO_OAUTH2_ICON = (
|
||||
'<svg viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg" width="1em" height="1em">'
|
||||
'<g><path d="m7.44,0s.09,0,.13,0c.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0q.12,0,.25,0t.26.08c.15.03.29.06.44.08,1.97.38,3.78,1.47,4.95,3.11.04.06.09.12.13.18.67.96,1.15,2.11,1.3,3.28q0,.19.09.26c0,.15,0,.29,0,.44,0,.04,0,.09,0,.13,0,.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0,.08,0,.17,0,.25q0,.19-.08.26c-.03.15-.06.29-.08.44-.38,1.97-1.47,3.78-3.11,4.95-.06.04-.12.09-.18.13-.96.67-2.11,1.15-3.28,1.3q-.19,0-.26.09c-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25,0q-.19,0-.26-.08c-.15-.03-.29-.06-.44-.08-1.97-.38-3.78-1.47-4.95-3.11q-.07-.09-.13-.18c-.67-.96-1.15-2.11-1.3-3.28q0-.19-.09-.26c0-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25q0-.19.08-.26c.03-.15.06-.29.08-.44.38-1.97,1.47-3.78,3.11-4.95.06-.04.12-.09.18-.13C4.42.73,5.57.26,6.74.1,7,.07,7.15,0,7.44,0Z" fill="#EFEFEF"/>'
|
||||
'<path d="m1.27,11.33h13.45c-.94,1.89-2.51,3.21-4.51,3.88-1.99.59-3.96.37-5.8-.57-1.25-.7-2.67-1.9-3.14-3.3Z" fill="#FEB005"/>'
|
||||
'<path d="m12.54,1.99c.87.7,1.82,1.59,2.18,2.68H1.27c.87-1.74,2.33-3.13,4.2-3.78,2.44-.79,5-.47,7.07,1.1Z" fill="#1D1D1F"/></g></svg>'
|
||||
)
|
||||
LINUXDO_OAUTH2_NAME = "LINUX DO"
|
||||
LINUXDO_AUTHORIZATION_URL = "https://connect.linux.do/oauth2/authorize"
|
||||
LINUXDO_ACCESS_TOKEN_URL = "https://connect.linux.do/oauth2/token"
|
||||
LINUXDO_USER_INFO_URL = "https://connect.linux.do/api/user"
|
||||
LINUXDO_SCOPE = "user:email"
|
||||
USER_SETTINGS_DEFAULTS = {
|
||||
"enableMailVerify": False,
|
||||
"verifyMailSender": "",
|
||||
"enableMailAllowList": False,
|
||||
"mailAllowList": [],
|
||||
"maxAddressCount": 5,
|
||||
"enableEmailCheckRegex": False,
|
||||
"emailCheckRegex": "",
|
||||
}
|
||||
|
||||
|
||||
def linuxdo_oauth_callback_url(pages_domain: str) -> str:
|
||||
return f"https://{pages_domain}/user/oauth2/callback"
|
||||
|
||||
|
||||
def merge_user_settings(current: object, allow_user_register: bool) -> dict[str, Any]:
|
||||
merged = dict(current) if isinstance(current, dict) else {}
|
||||
for key, value in USER_SETTINGS_DEFAULTS.items():
|
||||
merged.setdefault(key, value)
|
||||
merged["enable"] = allow_user_register
|
||||
return merged
|
||||
|
||||
|
||||
def build_linuxdo_oauth2_setting(
|
||||
*,
|
||||
pages_domain: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"name": LINUXDO_OAUTH2_NAME,
|
||||
"icon": LINUXDO_OAUTH2_ICON,
|
||||
"clientID": client_id,
|
||||
"clientSecret": client_secret,
|
||||
"authorizationURL": LINUXDO_AUTHORIZATION_URL,
|
||||
"accessTokenURL": LINUXDO_ACCESS_TOKEN_URL,
|
||||
"accessTokenFormat": "urlencoded",
|
||||
"userInfoURL": LINUXDO_USER_INFO_URL,
|
||||
"redirectURL": linuxdo_oauth_callback_url(pages_domain),
|
||||
"logoutURL": "",
|
||||
"userEmailKey": "id",
|
||||
"enableEmailFormat": True,
|
||||
"userEmailFormat": "^(.+)$",
|
||||
"userEmailReplace": "linux_do_$1@oauth.linux.do",
|
||||
"scope": LINUXDO_SCOPE,
|
||||
"enableMailAllowList": False,
|
||||
"mailAllowList": [],
|
||||
}
|
||||
|
||||
|
||||
def is_linuxdo_oauth2_setting(setting: object) -> bool:
|
||||
if not isinstance(setting, dict):
|
||||
return False
|
||||
name = str(setting.get("name", "")).strip().lower()
|
||||
authorization_url = str(setting.get("authorizationURL", "")).strip()
|
||||
access_token_url = str(setting.get("accessTokenURL", "")).strip()
|
||||
user_info_url = str(setting.get("userInfoURL", "")).strip()
|
||||
return name in {"linux do", "linuxdo"} or (
|
||||
authorization_url == LINUXDO_AUTHORIZATION_URL
|
||||
and access_token_url == LINUXDO_ACCESS_TOKEN_URL
|
||||
and user_info_url == LINUXDO_USER_INFO_URL
|
||||
)
|
||||
|
||||
|
||||
def merge_linuxdo_oauth2_settings(
|
||||
current: object,
|
||||
*,
|
||||
pages_domain: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
settings = current if isinstance(current, list) else []
|
||||
desired = build_linuxdo_oauth2_setting(
|
||||
pages_domain=pages_domain,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
updated: list[dict[str, Any]] = []
|
||||
replaced = False
|
||||
for item in settings:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if is_linuxdo_oauth2_setting(item):
|
||||
merged = dict(item)
|
||||
merged.update(desired)
|
||||
for key in ("enableMailAllowList", "mailAllowList", "logoutURL"):
|
||||
if key in item:
|
||||
merged[key] = item[key]
|
||||
updated.append(merged)
|
||||
replaced = True
|
||||
continue
|
||||
updated.append(item)
|
||||
|
||||
if not replaced:
|
||||
updated.append(desired)
|
||||
return updated
|
||||
|
||||
|
||||
class ApplicationAdminClient:
|
||||
"""Minimal client for post-deployment admin API configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
admin_password: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
transport: httpx.BaseTransport | None = None,
|
||||
) -> None:
|
||||
normalized_base_url = base_url.rstrip("/")
|
||||
self.base_url = normalized_base_url
|
||||
self.client = httpx.Client(
|
||||
base_url=normalized_base_url,
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
trust_env=False,
|
||||
transport=transport,
|
||||
headers={"x-admin-auth": admin_password},
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
def __enter__(self) -> "ApplicationAdminClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
self.close()
|
||||
|
||||
def wait_until_ready(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: float = 180.0,
|
||||
poll_interval_seconds: float = 5.0,
|
||||
) -> dict[str, Any]:
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
last_error: Exception | None = None
|
||||
while True:
|
||||
try:
|
||||
return self.get_user_settings()
|
||||
except (ApplicationAPIError, httpx.HTTPError) as exc:
|
||||
last_error = exc
|
||||
if time.monotonic() >= deadline:
|
||||
raise ApplicationAPIError(
|
||||
f"管理员接口在限定时间内未就绪: {self.base_url}/admin/user_settings"
|
||||
) from last_error
|
||||
time.sleep(poll_interval_seconds)
|
||||
|
||||
def get_user_settings(self) -> dict[str, Any]:
|
||||
payload = self._request_json("GET", "/admin/user_settings")
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
def sync_user_settings(self, *, allow_user_register: bool) -> dict[str, Any]:
|
||||
merged = merge_user_settings(self.get_user_settings(), allow_user_register)
|
||||
payload = self._request_json("POST", "/admin/user_settings", json=merged)
|
||||
return payload if isinstance(payload, dict) else merged
|
||||
|
||||
def get_user_oauth2_settings(self) -> list[dict[str, Any]]:
|
||||
payload = self._request_json("GET", "/admin/user_oauth2_settings")
|
||||
if not isinstance(payload, list):
|
||||
return []
|
||||
return [item for item in payload if isinstance(item, dict)]
|
||||
|
||||
def sync_linuxdo_oauth2(
|
||||
self,
|
||||
*,
|
||||
pages_domain: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
merged = merge_linuxdo_oauth2_settings(
|
||||
self.get_user_oauth2_settings(),
|
||||
pages_domain=pages_domain,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
payload = self._request_json("POST", "/admin/user_oauth2_settings", json=merged)
|
||||
if not isinstance(payload, list):
|
||||
return merged
|
||||
return [item for item in payload if isinstance(item, dict)]
|
||||
|
||||
def _request_json(self, method: str, path: str, *, json: Any | None = None) -> Any:
|
||||
try:
|
||||
response = self.client.request(method, path, json=json)
|
||||
except httpx.HTTPError as exc: # pragma: no cover
|
||||
raise ApplicationAPIError(f"请求管理员接口失败: {self.base_url}{path}") from exc
|
||||
|
||||
if response.status_code >= 400:
|
||||
body = response.text.strip()
|
||||
raise ApplicationAPIError(
|
||||
f"管理员接口返回错误: {method.upper()} {path} -> {response.status_code} {body}".strip(),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError as exc:
|
||||
raise ApplicationAPIError(f"管理员接口返回了非法 JSON: {method.upper()} {path}") from exc
|
||||
264
src/cf_temp_email_deploy/cli.py
Normal file
264
src/cf_temp_email_deploy/cli.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Command line entrypoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from cf_temp_email_deploy.cloudflare import CloudflareClient
|
||||
from cf_temp_email_deploy.config import (
|
||||
apply_overrides,
|
||||
config_from_document,
|
||||
default_config_document,
|
||||
load_config,
|
||||
load_state,
|
||||
load_toml_document,
|
||||
parse_toml_value,
|
||||
save_toml_document,
|
||||
save_state,
|
||||
)
|
||||
from cf_temp_email_deploy.discovery import discover_config
|
||||
from cf_temp_email_deploy.deployment import run_deployment
|
||||
from cf_temp_email_deploy.environment import check_required_tools
|
||||
from cf_temp_email_deploy.errors import ConfigError, DeployError
|
||||
from cf_temp_email_deploy.logging_utils import configure_logging, get_logger, log_stage
|
||||
from cf_temp_email_deploy.subprocess_runner import CommandRunner
|
||||
|
||||
COMMAND_NAMES = {"init-config", "check", "deploy", "resume", "discover-config"}
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="cf-temp-email")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
init_parser = subparsers.add_parser("init-config", help="生成示例配置文件。")
|
||||
init_parser.add_argument("--config", type=Path, default=Path("config.toml"))
|
||||
init_parser.add_argument("--force", action="store_true")
|
||||
init_parser.set_defaults(handler=handle_init_config)
|
||||
|
||||
for command_name, help_text, handler in (
|
||||
("check", "执行环境与配置检查。", handle_check),
|
||||
("deploy", "执行部署流程。", handle_deploy),
|
||||
("resume", "从状态文件恢复部署流程。", handle_resume),
|
||||
("discover-config", "从 Cloudflare 现有资源反推并写入配置。", handle_discover_config),
|
||||
):
|
||||
command_parser = subparsers.add_parser(command_name, help=help_text)
|
||||
add_shared_arguments(command_parser)
|
||||
command_parser.set_defaults(handler=handler)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_shared_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--config", type=Path, default=Path("config.toml"))
|
||||
parser.add_argument("--profile")
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
parser.add_argument("--api-token")
|
||||
parser.add_argument("--account-id")
|
||||
parser.add_argument("--account-name")
|
||||
parser.add_argument("--zone-name")
|
||||
parser.add_argument("--repo-ref")
|
||||
parser.add_argument("--pages-domain")
|
||||
parser.add_argument("--worker-name")
|
||||
parser.add_argument("--d1-name")
|
||||
parser.add_argument("--destination-address")
|
||||
parser.add_argument("--set", dest="set_values", action="append", default=[])
|
||||
|
||||
|
||||
def _qualify_override_key(dotted_key: str, profile: str | None = None) -> str:
|
||||
if not profile:
|
||||
return dotted_key
|
||||
return f"profiles.{profile}.{dotted_key}"
|
||||
|
||||
|
||||
def collect_cli_overrides(args: argparse.Namespace) -> dict[str, Any]:
|
||||
overrides: dict[str, Any] = {}
|
||||
profile = getattr(args, "profile", None)
|
||||
explicit_mapping = {
|
||||
"api_token": "cloudflare.api_token",
|
||||
"account_id": "cloudflare.account_id",
|
||||
"account_name": "cloudflare.account_name",
|
||||
"zone_name": "cloudflare.zone_name",
|
||||
"repo_ref": "source.repo_ref",
|
||||
"pages_domain": "pages.custom_domain",
|
||||
"worker_name": "worker.script_name",
|
||||
"d1_name": "d1.database_name",
|
||||
"destination_address": "mail.verified_destination_address",
|
||||
}
|
||||
for argument_name, dotted_key in explicit_mapping.items():
|
||||
value = getattr(args, argument_name, None)
|
||||
if value not in (None, ""):
|
||||
overrides[_qualify_override_key(dotted_key, profile)] = value
|
||||
|
||||
for raw_item in getattr(args, "set_values", []):
|
||||
if "=" not in raw_item:
|
||||
raise ConfigError(f"--set 参数缺少等号: {raw_item}")
|
||||
dotted_key, raw_value = raw_item.split("=", 1)
|
||||
overrides[_qualify_override_key(dotted_key, profile)] = parse_toml_value(raw_value)
|
||||
|
||||
return overrides
|
||||
|
||||
|
||||
def apply_runtime_overrides(config_path: Path, args: argparse.Namespace) -> None:
|
||||
overrides = collect_cli_overrides(args)
|
||||
if not overrides:
|
||||
return
|
||||
document = load_toml_document(config_path)
|
||||
apply_overrides(document, overrides)
|
||||
save_toml_document(config_path, document)
|
||||
|
||||
|
||||
def resolve_state_path(config_path: Path, profile: str | None = None) -> Path:
|
||||
deploy_dir = config_path.parent / ".deploy"
|
||||
if profile:
|
||||
return deploy_dir / "profiles" / profile / "state.toml"
|
||||
return deploy_dir / "state.toml"
|
||||
|
||||
|
||||
def apply_profile_runtime_defaults(config: Any, profile: str | None = None) -> None:
|
||||
if not profile:
|
||||
return
|
||||
if config.source.mode != "clone":
|
||||
return
|
||||
if config.source.workspace_dir != ".deploy/workspace":
|
||||
return
|
||||
config.source.workspace_dir = str(Path(".deploy") / "profiles" / profile / "workspace")
|
||||
|
||||
|
||||
def handle_init_config(args: argparse.Namespace) -> int:
|
||||
configure_logging()
|
||||
config_path = args.config
|
||||
if config_path.exists() and not args.force:
|
||||
raise ConfigError(f"配置文件已存在: {config_path}")
|
||||
|
||||
save_toml_document(config_path, default_config_document())
|
||||
get_logger("cli").info("已生成配置文件: %s", config_path)
|
||||
return 0
|
||||
|
||||
|
||||
def handle_check(args: argparse.Namespace) -> int:
|
||||
configure_logging(args.verbose)
|
||||
apply_runtime_overrides(args.config, args)
|
||||
config = load_config(args.config, profile=args.profile)
|
||||
apply_profile_runtime_defaults(config, args.profile)
|
||||
state_path = resolve_state_path(args.config, args.profile)
|
||||
state = load_state(state_path)
|
||||
|
||||
log_stage("开始检查本地环境。")
|
||||
runner = CommandRunner()
|
||||
versions = check_required_tools(runner)
|
||||
|
||||
if config.source.mode == "local":
|
||||
source_path = Path(config.source.local_path)
|
||||
if not source_path.exists():
|
||||
raise ConfigError(f"本地源码目录不存在: {source_path}")
|
||||
|
||||
if not config.cloudflare.resolved_api_token():
|
||||
raise ConfigError("cloudflare.api_token 或 cloudflare.api_token_env 需要提供其一。")
|
||||
|
||||
with CloudflareClient(config.cloudflare) as cloudflare:
|
||||
token_info = cloudflare.verify_token()
|
||||
account: dict[str, str] = {}
|
||||
if config.cloudflare.account_id or config.cloudflare.account_name:
|
||||
account = cloudflare.resolve_account(
|
||||
account_id=config.cloudflare.account_id,
|
||||
account_name=config.cloudflare.account_name,
|
||||
)
|
||||
zone = cloudflare.resolve_zone(
|
||||
zone_name=config.cloudflare.zone_name,
|
||||
account_id=account.get("id", ""),
|
||||
)
|
||||
if not account:
|
||||
zone_account = zone.get("account")
|
||||
if isinstance(zone_account, dict):
|
||||
account = {
|
||||
"id": str(zone_account.get("id", "")),
|
||||
"name": str(zone_account.get("name", "")),
|
||||
}
|
||||
|
||||
for tool_name, version in versions.items():
|
||||
get_logger("check").info("[tool] %s=%s", tool_name, version.version)
|
||||
|
||||
state.cloudflare.account_id = account.get("id", "")
|
||||
state.cloudflare.account_name = account.get("name", "")
|
||||
state.cloudflare.zone_id = zone.get("id", "")
|
||||
state.cloudflare.zone_name = zone.get("name", config.cloudflare.zone_name)
|
||||
state.mark_checkpoint("check_completed")
|
||||
save_state(state_path, state)
|
||||
|
||||
get_logger("check").info("[cloudflare] token_status=%s", token_info.get("status", "unknown"))
|
||||
get_logger("check").info("[cloudflare] account_id=%s", state.cloudflare.account_id or "<provided>")
|
||||
get_logger("check").info("[cloudflare] zone_id=%s", state.cloudflare.zone_id)
|
||||
get_logger("check").info("[config] worker_vars=%s", sorted(config.derived_worker_vars().keys()))
|
||||
log_stage("本地环境检查完成。")
|
||||
return 0
|
||||
|
||||
|
||||
def handle_discover_config(args: argparse.Namespace) -> int:
|
||||
configure_logging(args.verbose)
|
||||
discover_config(
|
||||
config_path=args.config,
|
||||
profile=args.profile,
|
||||
cli_overrides=collect_cli_overrides(args),
|
||||
cloudflare_client_factory=CloudflareClient,
|
||||
)
|
||||
get_logger("discover").info("已写入发现结果: %s", args.config)
|
||||
return 0
|
||||
|
||||
|
||||
def handle_deploy(args: argparse.Namespace) -> int:
|
||||
return _handle_deployment_command(args, is_resume=False)
|
||||
|
||||
|
||||
def handle_resume(args: argparse.Namespace) -> int:
|
||||
return _handle_deployment_command(args, is_resume=True)
|
||||
|
||||
|
||||
def _handle_deployment_command(args: argparse.Namespace, *, is_resume: bool) -> int:
|
||||
configure_logging(args.verbose)
|
||||
apply_runtime_overrides(args.config, args)
|
||||
config = load_config(args.config, profile=args.profile)
|
||||
apply_profile_runtime_defaults(config, args.profile)
|
||||
state_path = resolve_state_path(args.config, args.profile)
|
||||
state = load_state(state_path)
|
||||
if is_resume and not state.checkpoint:
|
||||
raise DeployError("状态文件中缺少检查点,无法继续恢复部署。")
|
||||
|
||||
runner = CommandRunner()
|
||||
run_deployment(
|
||||
config_path=args.config,
|
||||
config=config,
|
||||
state_path=state_path,
|
||||
state=state,
|
||||
runner=runner,
|
||||
cloudflare_client_factory=CloudflareClient,
|
||||
is_resume=is_resume,
|
||||
)
|
||||
save_state(state_path, state)
|
||||
return 0
|
||||
|
||||
|
||||
def resolve_cli_argv(argv: list[str] | None) -> list[str]:
|
||||
raw_args = list(sys.argv[1:] if argv is None else argv)
|
||||
if not raw_args:
|
||||
return ["deploy"]
|
||||
if raw_args[0] in COMMAND_NAMES:
|
||||
return raw_args
|
||||
if raw_args[0] in {"-h", "--help"}:
|
||||
return raw_args
|
||||
if raw_args[0].startswith("-"):
|
||||
return ["deploy", *raw_args]
|
||||
return raw_args
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(resolve_cli_argv(argv))
|
||||
try:
|
||||
return args.handler(args)
|
||||
except DeployError as exc:
|
||||
configure_logging(getattr(args, "verbose", False))
|
||||
get_logger("error").error("%s", exc)
|
||||
return 1
|
||||
727
src/cf_temp_email_deploy/cloudflare.py
Normal file
727
src/cf_temp_email_deploy/cloudflare.py
Normal file
@@ -0,0 +1,727 @@
|
||||
"""Cloudflare API client helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from cf_temp_email_deploy import __version__
|
||||
from cf_temp_email_deploy.errors import CloudflareAPIError, ConfigError
|
||||
from cf_temp_email_deploy.models import CloudflareConfig
|
||||
|
||||
AuthMode = Literal["token", "global_key"]
|
||||
RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
PAGES_ACTIVE_STATUSES = {"active", "verified"}
|
||||
EMAIL_ROUTING_READY_STATUSES = {"active", "enabled", "verified", "success"}
|
||||
|
||||
|
||||
class CloudflareClient:
|
||||
"""Minimal Cloudflare API client used by deployment commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CloudflareConfig,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
max_attempts: int = 3,
|
||||
transport: httpx.BaseTransport | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.max_attempts = max_attempts
|
||||
self.client = httpx.Client(
|
||||
base_url="https://api.cloudflare.com/client/v4",
|
||||
timeout=timeout,
|
||||
transport=transport,
|
||||
trust_env=False,
|
||||
headers={"User-Agent": f"cf-temp-email/{__version__}"},
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
def __enter__(self) -> "CloudflareClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
self.close()
|
||||
|
||||
def verify_token(self) -> dict[str, Any]:
|
||||
payload = self.request("GET", "/user/tokens/verify", auth_mode="token")
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, dict):
|
||||
raise CloudflareAPIError("Token 校验返回结果格式异常。")
|
||||
return result
|
||||
|
||||
def resolve_account(self, *, account_id: str = "", account_name: str = "") -> dict[str, Any]:
|
||||
if account_id:
|
||||
return {"id": account_id, "name": account_name}
|
||||
if not account_name:
|
||||
raise ConfigError("cloudflare.account_id 与 cloudflare.account_name 至少需要提供其一。")
|
||||
|
||||
accounts = self.list_paginated(
|
||||
"/accounts",
|
||||
params={"name": account_name},
|
||||
auth_mode="global_key",
|
||||
)
|
||||
matches = [item for item in accounts if item.get("name") == account_name]
|
||||
if not matches:
|
||||
raise CloudflareAPIError(f"未找到匹配的 Cloudflare 账户: {account_name}")
|
||||
if len(matches) > 1:
|
||||
raise CloudflareAPIError(f"存在多个同名 Cloudflare 账户: {account_name}")
|
||||
return matches[0]
|
||||
|
||||
def resolve_zone(self, *, zone_name: str, account_id: str = "") -> dict[str, Any]:
|
||||
params: dict[str, Any] = {"name": zone_name}
|
||||
if account_id:
|
||||
params["account.id"] = account_id
|
||||
zones = self.list_paginated("/zones", params=params, auth_mode="token")
|
||||
matches = [item for item in zones if item.get("name") == zone_name]
|
||||
if not matches:
|
||||
raise CloudflareAPIError(f"未找到匹配的 Zone: {zone_name}")
|
||||
if len(matches) > 1:
|
||||
raise CloudflareAPIError(f"匹配到多个同名 Zone: {zone_name}")
|
||||
return matches[0]
|
||||
|
||||
def get_pages_project(self, *, account_id: str, project_name: str) -> dict[str, Any] | None:
|
||||
return self._request_result_or_none(
|
||||
"GET",
|
||||
f"/accounts/{account_id}/pages/projects/{project_name}",
|
||||
auth_mode="token",
|
||||
)
|
||||
|
||||
def list_pages_projects(self, *, account_id: str) -> list[dict[str, Any]]:
|
||||
payload = self.request(
|
||||
"GET",
|
||||
f"/accounts/{account_id}/pages/projects",
|
||||
auth_mode="token",
|
||||
)
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, list):
|
||||
raise CloudflareAPIError("Pages 项目列表返回结果格式异常。")
|
||||
return [item for item in result if isinstance(item, dict)]
|
||||
|
||||
def ensure_pages_project(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
project_name: str,
|
||||
production_branch: str,
|
||||
) -> dict[str, Any]:
|
||||
existing = self.get_pages_project(account_id=account_id, project_name=project_name)
|
||||
if existing is not None:
|
||||
return existing
|
||||
try:
|
||||
payload = self.request(
|
||||
"POST",
|
||||
f"/accounts/{account_id}/pages/projects",
|
||||
json={"name": project_name, "production_branch": production_branch},
|
||||
auth_mode="token",
|
||||
)
|
||||
except CloudflareAPIError as exc:
|
||||
if "already exists" not in str(exc).lower():
|
||||
raise
|
||||
existing = self.get_pages_project(account_id=account_id, project_name=project_name)
|
||||
if existing is not None:
|
||||
return existing
|
||||
raise
|
||||
return self._result_dict(payload, "Pages 项目创建")
|
||||
|
||||
def get_pages_domain(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
project_name: str,
|
||||
domain_name: str,
|
||||
) -> dict[str, Any] | None:
|
||||
return self._request_result_or_none(
|
||||
"GET",
|
||||
f"/accounts/{account_id}/pages/projects/{project_name}/domains/{domain_name}",
|
||||
auth_mode="token",
|
||||
)
|
||||
|
||||
def ensure_pages_domain(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
project_name: str,
|
||||
domain_name: str,
|
||||
) -> dict[str, Any]:
|
||||
existing = self.get_pages_domain(
|
||||
account_id=account_id,
|
||||
project_name=project_name,
|
||||
domain_name=domain_name,
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
try:
|
||||
payload = self.request(
|
||||
"POST",
|
||||
f"/accounts/{account_id}/pages/projects/{project_name}/domains",
|
||||
json={"name": domain_name},
|
||||
auth_mode="token",
|
||||
)
|
||||
except CloudflareAPIError as exc:
|
||||
if "already exists" not in str(exc).lower():
|
||||
raise
|
||||
existing = self.get_pages_domain(
|
||||
account_id=account_id,
|
||||
project_name=project_name,
|
||||
domain_name=domain_name,
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
raise
|
||||
return self._result_dict(payload, "Pages 域名绑定")
|
||||
|
||||
def wait_for_pages_domain_active(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
project_name: str,
|
||||
domain_name: str,
|
||||
timeout_seconds: float = 300.0,
|
||||
poll_interval_seconds: float = 5.0,
|
||||
) -> dict[str, Any]:
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while True:
|
||||
domain = self.get_pages_domain(
|
||||
account_id=account_id,
|
||||
project_name=project_name,
|
||||
domain_name=domain_name,
|
||||
)
|
||||
if domain is None:
|
||||
raise CloudflareAPIError(f"Pages 自定义域名不存在: {domain_name}")
|
||||
status = str(domain.get("status", "")).lower()
|
||||
if status in PAGES_ACTIVE_STATUSES:
|
||||
return domain
|
||||
if time.monotonic() >= deadline:
|
||||
raise CloudflareAPIError(f"等待 Pages 自定义域名激活超时: {domain_name}")
|
||||
time.sleep(poll_interval_seconds)
|
||||
|
||||
def list_dns_records(
|
||||
self,
|
||||
zone_id: str,
|
||||
*,
|
||||
name: str = "",
|
||||
record_type: str = "",
|
||||
) -> list[dict[str, Any]]:
|
||||
params: dict[str, Any] = {}
|
||||
if name:
|
||||
params["name"] = name
|
||||
if record_type:
|
||||
params["type"] = record_type
|
||||
return self.list_paginated(
|
||||
f"/zones/{zone_id}/dns_records",
|
||||
params=params,
|
||||
auth_mode="token",
|
||||
)
|
||||
|
||||
def ensure_cname_record(
|
||||
self,
|
||||
*,
|
||||
zone_id: str,
|
||||
name: str,
|
||||
content: str,
|
||||
proxied: bool = True,
|
||||
ttl: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
records = [record for record in self.list_dns_records(zone_id, name=name) if record.get("name") == name]
|
||||
conflicts = [record for record in records if record.get("type") != "CNAME"]
|
||||
if conflicts:
|
||||
raise CloudflareAPIError(f"DNS 记录冲突: {name} 已存在非 CNAME 记录。")
|
||||
|
||||
cname_records = [record for record in records if record.get("type") == "CNAME"]
|
||||
if len(cname_records) > 1:
|
||||
raise CloudflareAPIError(f"DNS 记录冲突: {name} 存在多条 CNAME 记录。")
|
||||
if not cname_records:
|
||||
return self._create_dns_record(
|
||||
zone_id=zone_id,
|
||||
payload={"type": "CNAME", "name": name, "content": content, "proxied": proxied, "ttl": ttl},
|
||||
)
|
||||
|
||||
record = cname_records[0]
|
||||
if (
|
||||
record.get("content") == content
|
||||
and bool(record.get("proxied", proxied)) == proxied
|
||||
and int(record.get("ttl", ttl)) == ttl
|
||||
):
|
||||
return record
|
||||
|
||||
return self._update_dns_record(
|
||||
zone_id=zone_id,
|
||||
record_id=str(record.get("id", "")),
|
||||
payload={"type": "CNAME", "name": name, "content": content, "proxied": proxied, "ttl": ttl},
|
||||
)
|
||||
|
||||
def ensure_dns_record(
|
||||
self,
|
||||
*,
|
||||
zone_id: str,
|
||||
record_type: str,
|
||||
name: str,
|
||||
content: str,
|
||||
proxied: bool | None = None,
|
||||
ttl: int | None = 1,
|
||||
priority: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
records = [
|
||||
record
|
||||
for record in self.list_dns_records(zone_id, name=name, record_type=record_type)
|
||||
if record.get("name") == name and record.get("type") == record_type
|
||||
]
|
||||
exact = next(
|
||||
(
|
||||
record
|
||||
for record in records
|
||||
if self._dns_record_matches(
|
||||
record,
|
||||
content=content,
|
||||
proxied=proxied,
|
||||
ttl=ttl,
|
||||
priority=priority,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if exact is not None:
|
||||
return exact
|
||||
|
||||
payload: dict[str, Any] = {"type": record_type, "name": name, "content": content}
|
||||
if proxied is not None:
|
||||
payload["proxied"] = proxied
|
||||
if ttl is not None:
|
||||
payload["ttl"] = ttl
|
||||
if priority is not None:
|
||||
payload["priority"] = priority
|
||||
|
||||
if len(records) == 1 and record_type in {"TXT", "CNAME"}:
|
||||
return self._update_dns_record(
|
||||
zone_id=zone_id,
|
||||
record_id=str(records[0].get("id", "")),
|
||||
payload=payload,
|
||||
)
|
||||
return self._create_dns_record(zone_id=zone_id, payload=payload)
|
||||
|
||||
def list_d1_databases(self, *, account_id: str, database_name: str = "") -> list[dict[str, Any]]:
|
||||
databases = self.list_paginated(
|
||||
f"/accounts/{account_id}/d1/database",
|
||||
auth_mode="token",
|
||||
)
|
||||
if not database_name:
|
||||
return databases
|
||||
return [item for item in databases if item.get("name") == database_name]
|
||||
|
||||
def ensure_d1_database(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
database_name: str,
|
||||
jurisdiction: str = "",
|
||||
) -> dict[str, Any]:
|
||||
existing = self.list_d1_databases(account_id=account_id, database_name=database_name)
|
||||
if existing:
|
||||
if len(existing) > 1:
|
||||
raise CloudflareAPIError(f"存在多个同名 D1 数据库: {database_name}")
|
||||
return existing[0]
|
||||
|
||||
payload_data: dict[str, Any] = {"name": database_name}
|
||||
if jurisdiction:
|
||||
payload_data["primary_location_hint"] = jurisdiction
|
||||
payload = self.request(
|
||||
"POST",
|
||||
f"/accounts/{account_id}/d1/database",
|
||||
json=payload_data,
|
||||
auth_mode="token",
|
||||
)
|
||||
return self._result_dict(payload, "D1 数据库创建")
|
||||
|
||||
def query_d1(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
database_id: str,
|
||||
sql: str,
|
||||
params: list[Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
payload = self.request(
|
||||
"POST",
|
||||
f"/accounts/{account_id}/d1/database/{database_id}/query",
|
||||
json={"sql": sql, "params": params or []},
|
||||
auth_mode="token",
|
||||
)
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, list):
|
||||
raise CloudflareAPIError("D1 查询返回结果格式异常。")
|
||||
if not result:
|
||||
return []
|
||||
first = result[0]
|
||||
if not isinstance(first, dict):
|
||||
raise CloudflareAPIError("D1 查询结果项格式异常。")
|
||||
rows = first.get("results", [])
|
||||
if not isinstance(rows, list):
|
||||
raise CloudflareAPIError("D1 查询结果行格式异常。")
|
||||
return [row for row in rows if isinstance(row, dict)]
|
||||
|
||||
def get_workers_subdomain(self, *, account_id: str) -> dict[str, Any] | None:
|
||||
return self._request_result_or_none(
|
||||
"GET",
|
||||
f"/accounts/{account_id}/workers/subdomain",
|
||||
auth_mode="token",
|
||||
)
|
||||
|
||||
def get_worker_script(self, *, account_id: str, script_name: str) -> dict[str, Any] | None:
|
||||
scripts = self.list_paginated(
|
||||
f"/accounts/{account_id}/workers/scripts",
|
||||
auth_mode="token",
|
||||
)
|
||||
for script in scripts:
|
||||
if str(script.get("id", "")) == script_name:
|
||||
return script
|
||||
return None
|
||||
|
||||
def list_worker_scripts(self, *, account_id: str) -> list[dict[str, Any]]:
|
||||
return self.list_paginated(
|
||||
f"/accounts/{account_id}/workers/scripts",
|
||||
auth_mode="token",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def worker_script_supports_email(script: dict[str, Any]) -> bool:
|
||||
handlers = script.get("handlers")
|
||||
if not isinstance(handlers, list):
|
||||
return False
|
||||
return "email" in {str(handler).strip().lower() for handler in handlers}
|
||||
|
||||
def list_email_routing_addresses(self, *, account_id: str) -> list[dict[str, Any]]:
|
||||
return self._list_email_routing_paginated(f"/accounts/{account_id}/email/routing/addresses")
|
||||
|
||||
def get_email_routing_dns(self, *, zone_id: str) -> list[dict[str, Any]]:
|
||||
payload = self.request_email_routing("GET", f"/zones/{zone_id}/email/routing/dns")
|
||||
result = payload.get("result")
|
||||
if isinstance(result, list):
|
||||
return [item for item in result if isinstance(item, dict)]
|
||||
if isinstance(result, dict):
|
||||
records = result.get("records")
|
||||
if isinstance(records, list):
|
||||
return [item for item in records if isinstance(item, dict)]
|
||||
raise CloudflareAPIError("Email Routing DNS 返回结果格式异常。")
|
||||
|
||||
def get_catch_all(self, *, zone_id: str) -> dict[str, Any] | None:
|
||||
return self._request_email_routing_result_or_none(
|
||||
"GET",
|
||||
f"/zones/{zone_id}/email/routing/rules/catch_all",
|
||||
)
|
||||
|
||||
def ensure_catch_all_worker(self, *, zone_id: str, script_name: str) -> dict[str, Any]:
|
||||
current = self.get_catch_all(zone_id=zone_id)
|
||||
if current is not None and self.catch_all_points_to_worker(current, script_name):
|
||||
return current
|
||||
payload = self.request_email_routing(
|
||||
"PUT",
|
||||
f"/zones/{zone_id}/email/routing/rules/catch_all",
|
||||
json={
|
||||
"matchers": [{"type": "all"}],
|
||||
"actions": [{"type": "worker", "value": [script_name]}],
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
return self._result_dict(payload, "Catch-all 更新")
|
||||
|
||||
@staticmethod
|
||||
def catch_all_points_to_worker(rule: dict[str, Any], script_name: str) -> bool:
|
||||
actions = rule.get("actions")
|
||||
if not isinstance(actions, list):
|
||||
return False
|
||||
for action in actions:
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
if action.get("type") != "worker":
|
||||
continue
|
||||
if script_name in CloudflareClient.extract_worker_targets(action.get("value")):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def email_address_ready(address: dict[str, Any], target: str) -> bool:
|
||||
if str(address.get("email", "")).lower() != target.lower():
|
||||
return False
|
||||
status = str(address.get("status", "")).lower()
|
||||
return status in EMAIL_ROUTING_READY_STATUSES
|
||||
|
||||
@staticmethod
|
||||
def extract_worker_targets(value: Any) -> set[str]:
|
||||
stack = [value]
|
||||
targets: set[str] = set()
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
if current is None:
|
||||
continue
|
||||
if isinstance(current, str):
|
||||
candidate = current.strip()
|
||||
if candidate:
|
||||
targets.add(candidate)
|
||||
continue
|
||||
if isinstance(current, list):
|
||||
stack.extend(current)
|
||||
continue
|
||||
if isinstance(current, dict):
|
||||
for key in ("worker", "name", "script", "service", "value"):
|
||||
if key in current:
|
||||
stack.append(current[key])
|
||||
return targets
|
||||
|
||||
@staticmethod
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
if isinstance(error, ConfigError):
|
||||
message = str(error).lower()
|
||||
return "旧式鉴权" in str(error) or "api_email" in message or "global_api_key" in message
|
||||
if not isinstance(error, CloudflareAPIError):
|
||||
return False
|
||||
return CloudflareClient._should_fallback_to_global_key(error)
|
||||
|
||||
def request_email_routing(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
return self.request(method, path, params=params, json=json, auth_mode="token")
|
||||
except CloudflareAPIError as exc:
|
||||
if not self._should_fallback_to_global_key(exc):
|
||||
raise
|
||||
return self.request(method, path, params=params, json=json, auth_mode="global_key")
|
||||
|
||||
def list_paginated(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
params: dict[str, Any] | None = None,
|
||||
auth_mode: AuthMode,
|
||||
) -> list[dict[str, Any]]:
|
||||
collected: list[dict[str, Any]] = []
|
||||
page = 1
|
||||
per_page = 50
|
||||
while True:
|
||||
current_params = dict(params or {})
|
||||
current_params.setdefault("page", page)
|
||||
current_params.setdefault("per_page", per_page)
|
||||
payload = self.request("GET", path, params=current_params, auth_mode=auth_mode)
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, list):
|
||||
raise CloudflareAPIError(f"分页接口返回结果格式异常: {path}")
|
||||
collected.extend(result)
|
||||
result_info = payload.get("result_info") or {}
|
||||
total_pages = result_info.get("total_pages")
|
||||
if not total_pages or page >= total_pages:
|
||||
break
|
||||
page += 1
|
||||
return collected
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
auth_mode: AuthMode,
|
||||
) -> dict[str, Any]:
|
||||
last_error: CloudflareAPIError | None = None
|
||||
for attempt in range(1, self.max_attempts + 1):
|
||||
try:
|
||||
response = self.client.request(
|
||||
method,
|
||||
path,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=self._headers(auth_mode),
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
last_error = CloudflareAPIError(f"Cloudflare 请求失败: {exc}") # pragma: no cover
|
||||
if attempt < self.max_attempts:
|
||||
time.sleep(0.2 * attempt)
|
||||
continue
|
||||
raise last_error from exc
|
||||
|
||||
if response.status_code in RETRYABLE_STATUS_CODES and attempt < self.max_attempts:
|
||||
time.sleep(0.2 * attempt)
|
||||
continue
|
||||
|
||||
payload = self._parse_payload(response)
|
||||
if response.status_code >= 400 or payload.get("success") is False:
|
||||
error = CloudflareAPIError(
|
||||
self._build_error_message(payload, response),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
if response.status_code in RETRYABLE_STATUS_CODES and attempt < self.max_attempts:
|
||||
last_error = error
|
||||
time.sleep(0.2 * attempt)
|
||||
continue
|
||||
raise error
|
||||
return payload
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise CloudflareAPIError("Cloudflare 请求失败,且没有返回可用结果。")
|
||||
|
||||
def _request_result_or_none(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
auth_mode: AuthMode,
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
payload = self.request(method, path, params=params, json=json, auth_mode=auth_mode)
|
||||
except CloudflareAPIError as exc:
|
||||
if exc.status_code == 404:
|
||||
return None
|
||||
raise
|
||||
return self._result_dict(payload, path)
|
||||
|
||||
def _request_email_routing_result_or_none(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
try:
|
||||
payload = self.request_email_routing(method, path, params=params, json=json)
|
||||
except CloudflareAPIError as exc:
|
||||
if exc.status_code == 404:
|
||||
return None
|
||||
raise
|
||||
return self._result_dict(payload, path)
|
||||
|
||||
def _list_email_routing_paginated(self, path: str) -> list[dict[str, Any]]:
|
||||
collected: list[dict[str, Any]] = []
|
||||
page = 1
|
||||
per_page = 50
|
||||
while True:
|
||||
payload = self.request_email_routing(
|
||||
"GET",
|
||||
path,
|
||||
params={"page": page, "per_page": per_page},
|
||||
)
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, list):
|
||||
raise CloudflareAPIError(f"分页接口返回结果格式异常: {path}")
|
||||
collected.extend(item for item in result if isinstance(item, dict))
|
||||
result_info = payload.get("result_info") or {}
|
||||
total_pages = result_info.get("total_pages")
|
||||
if not total_pages or page >= total_pages:
|
||||
break
|
||||
page += 1
|
||||
return collected
|
||||
|
||||
def _create_dns_record(self, *, zone_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
response = self.request(
|
||||
"POST",
|
||||
f"/zones/{zone_id}/dns_records",
|
||||
json=payload,
|
||||
auth_mode="token",
|
||||
)
|
||||
return self._result_dict(response, "DNS 记录创建")
|
||||
|
||||
def _update_dns_record(self, *, zone_id: str, record_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
response = self.request(
|
||||
"PUT",
|
||||
f"/zones/{zone_id}/dns_records/{record_id}",
|
||||
json=payload,
|
||||
auth_mode="token",
|
||||
)
|
||||
return self._result_dict(response, "DNS 记录更新")
|
||||
|
||||
@staticmethod
|
||||
def _dns_record_matches(
|
||||
record: dict[str, Any],
|
||||
*,
|
||||
content: str,
|
||||
proxied: bool | None,
|
||||
ttl: int | None,
|
||||
priority: int | None,
|
||||
) -> bool:
|
||||
if record.get("content") != content:
|
||||
return False
|
||||
if proxied is not None and bool(record.get("proxied", proxied)) != proxied:
|
||||
return False
|
||||
if ttl is not None:
|
||||
try:
|
||||
if int(record.get("ttl", ttl)) != ttl:
|
||||
return False
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if priority is not None:
|
||||
try:
|
||||
if int(record.get("priority", priority)) != priority:
|
||||
return False
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _result_dict(payload: dict[str, Any], action: str) -> dict[str, Any]:
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, dict):
|
||||
raise CloudflareAPIError(f"{action} 返回结果格式异常。")
|
||||
return result
|
||||
|
||||
def _headers(self, auth_mode: AuthMode) -> dict[str, str]:
|
||||
if auth_mode == "token":
|
||||
token = self.config.resolved_api_token()
|
||||
if not token:
|
||||
raise ConfigError("缺少 Cloudflare API Token。")
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
email = self.config.resolved_api_email()
|
||||
api_key = self.config.resolved_global_api_key()
|
||||
if not email or not api_key:
|
||||
raise ConfigError("旧式鉴权需要同时提供 api_email 与 global_api_key。")
|
||||
return {
|
||||
"X-Auth-Email": email,
|
||||
"X-Auth-Key": api_key,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _parse_payload(response: httpx.Response) -> dict[str, Any]:
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError as exc:
|
||||
raise CloudflareAPIError("Cloudflare 返回内容不是合法 JSON。", status_code=response.status_code) from exc
|
||||
if not isinstance(payload, dict):
|
||||
raise CloudflareAPIError("Cloudflare 返回内容格式异常。", status_code=response.status_code)
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _build_error_message(payload: dict[str, Any], response: httpx.Response) -> str:
|
||||
errors = payload.get("errors") or []
|
||||
fragments = [
|
||||
error.get("message", str(error))
|
||||
for error in errors
|
||||
if isinstance(error, dict)
|
||||
]
|
||||
if not fragments and response.text:
|
||||
fragments.append(response.text.strip())
|
||||
message = "; ".join(fragment for fragment in fragments if fragment)
|
||||
if not message:
|
||||
message = f"HTTP {response.status_code}"
|
||||
return f"Cloudflare API 返回错误: {message}"
|
||||
|
||||
@staticmethod
|
||||
def _should_fallback_to_global_key(error: CloudflareAPIError) -> bool:
|
||||
if error.status_code not in {400, 401, 403}:
|
||||
return False
|
||||
message = str(error).lower()
|
||||
return any(
|
||||
fragment in message
|
||||
for fragment in ("x-auth-key", "x-auth-email", "global api key", "authentication")
|
||||
)
|
||||
278
src/cf_temp_email_deploy/config.py
Normal file
278
src/cf_temp_email_deploy/config.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Configuration and state file utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Mapping
|
||||
|
||||
import tomlkit
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.exceptions import ParseError
|
||||
from tomlkit.items import AbstractTable
|
||||
|
||||
from cf_temp_email_deploy.errors import ConfigError
|
||||
from cf_temp_email_deploy.models import DeploymentConfig, DeploymentState
|
||||
|
||||
DEFAULT_CONFIG_TOML = """# 必填参数一览
|
||||
# [必填] cloudflare.zone_name
|
||||
# [必填] cloudflare.api_token 或 cloudflare.api_token_env
|
||||
# [必填] mail.domains
|
||||
# [必填] pages.custom_domain
|
||||
# [必填] worker.vars.ADMIN_PASSWORDS
|
||||
# [条件必填] source.local_path:当 source.mode = "local" 时填写
|
||||
# [选填] source.repo_ref:留空时自动跟随上游默认分支最新提交;如需固定版本可填写 tag、branch 或 commit
|
||||
|
||||
# 主配置版本号,供 CLI 与状态加载逻辑识别。
|
||||
config_version = 1
|
||||
|
||||
[source]
|
||||
# [选填] "clone" 表示自动拉取远端仓库到 workspace_dir。
|
||||
# [条件必填] "local" 表示直接使用 local_path 指向的本地源码目录。
|
||||
mode = "clone"
|
||||
# [选填] 固定源码仓库地址。
|
||||
repo_url = "https://github.com/dreamhunter2333/cloudflare_temp_email.git"
|
||||
# [选填] 留空时自动跟随上游默认分支最新提交。
|
||||
repo_ref = ""
|
||||
# [选填] 克隆源码、安装依赖、构建产物时使用的本地工作目录。
|
||||
workspace_dir = ".deploy/workspace"
|
||||
# [条件必填] 仅在 mode = "local" 时生效。
|
||||
local_path = ""
|
||||
|
||||
[cloudflare]
|
||||
# [选填] 已知 account_id 时优先填写该字段。
|
||||
account_id = ""
|
||||
# [选填] 只有在 account_id 未提供时才需要 account_name。
|
||||
account_name = ""
|
||||
# [必填] 承载前端域名、Worker 域名与邮件域名的 Zone。
|
||||
zone_name = "example.com"
|
||||
# [必填] 本工具默认使用的主 API Token。
|
||||
api_token = ""
|
||||
# [必填-二选一] 共享环境中可改为从环境变量读取 Token。
|
||||
api_token_env = ""
|
||||
# [选填] Email Routing 部分接口需要旧式鉴权时,可填写邮箱与 Global API Key。
|
||||
api_email = ""
|
||||
api_email_env = ""
|
||||
global_api_key = ""
|
||||
global_api_key_env = ""
|
||||
|
||||
[mail]
|
||||
# [必填] 邮件接收域名,会同步写入 Email Routing 与 Worker 的 DOMAINS。
|
||||
# [必填] 该主机名必须与 pages.custom_domain 分离。
|
||||
domains = ["example.com"]
|
||||
# [选填] 如需严格校验 Email Routing 目标地址,可填写一个已验证的真实邮箱。
|
||||
# [选填] 留空时部署会跳过目标地址校验,仍继续配置 MX/SPF 与 Catch-all Worker。
|
||||
verified_destination_address = "inbox@example.net"
|
||||
|
||||
[d1]
|
||||
# [选填] 当前部署要创建或复用的远端 D1 数据库名称。
|
||||
database_name = "cf-temp-email"
|
||||
# [选填] 可选的 D1 地域提示;留空时使用默认位置。
|
||||
jurisdiction = ""
|
||||
# [选填] 仅在接管已有数据库且缺少 __deploy_history 时开启。
|
||||
adopt_existing_schema = false
|
||||
|
||||
[user_access]
|
||||
# [选填] 是否要求登录后才能创建邮箱;启用 Linux.do OAuth2 时会自动强制为 true。
|
||||
require_login_to_create = true
|
||||
# [选填] 是否允许用户自行注册。
|
||||
allow_user_register = false
|
||||
|
||||
[linuxdo]
|
||||
# [选填] 是否启用 LINUX DO OAuth2 登录。
|
||||
linuxdo_oauth = false
|
||||
# [条件必填] 当 linuxdo_oauth = true 时填写。
|
||||
client_id = ""
|
||||
# [条件必填] 当 linuxdo_oauth = true 时填写。
|
||||
client_secret = ""
|
||||
|
||||
[worker]
|
||||
# [选填] Worker 服务名称,会同时用于 wrangler、Pages 服务绑定与 Catch-all。
|
||||
script_name = "cloudflare-temp-email"
|
||||
# [选填] 保留 workers.dev 地址,用于调试与回退检查。
|
||||
use_workers_dev = true
|
||||
# [选填] 生产环境中的 Worker 自定义域名。
|
||||
custom_domain = ""
|
||||
# [选填] 写入 worker/wrangler.toml 的 compatibility_date。
|
||||
compatibility_date = "2024-09-23"
|
||||
|
||||
[worker.vars]
|
||||
# [选填] 写入 worker/wrangler.toml 的普通环境变量。
|
||||
PREFIX = "tmp"
|
||||
ENABLE_USER_CREATE_EMAIL = true
|
||||
ENABLE_USER_DELETE_EMAIL = true
|
||||
DEFAULT_LANG = "zh"
|
||||
# [必填] 管理员密码列表至少保留一个值,部署后会使用首项调用管理员接口。
|
||||
ADMIN_PASSWORDS = ["change-me"]
|
||||
|
||||
[worker.secrets]
|
||||
# [选填] 标量秘密值通过 `wrangler secret put` 写入 Cloudflare。
|
||||
# [选填] JWT_SECRET 留空时,会在首次部署 Worker 前自动生成安全随机值并写回 config.toml。
|
||||
JWT_SECRET = ""
|
||||
|
||||
[pages]
|
||||
# [选填] 目标账户中要创建或复用的 Pages 项目名称。
|
||||
project_name = "cf-temp-email-pages"
|
||||
# [必填] 前端访问域名,必须与 mail.domains 分离。
|
||||
custom_domain = "email.example.com"
|
||||
# [选填] "pages" 构建标准前端;"pages:nopwa" 关闭 PWA 产物。
|
||||
build_mode = "pages"
|
||||
production_branch = "production"
|
||||
|
||||
# [选填] 多账号/多环境场景可在 profiles 下定义命名配置。
|
||||
# CLI 使用 `--profile <name>` 选择对应配置,未覆盖的字段会继承根配置。
|
||||
#
|
||||
# [profiles.prod-cn.cloudflare]
|
||||
# account_id = "acc-prod-cn"
|
||||
# zone_name = "kotei.asia"
|
||||
# api_token_env = "CF_API_TOKEN_CN"
|
||||
#
|
||||
# [profiles.prod-cn.mail]
|
||||
# domains = ["mail.kotei.asia"]
|
||||
#
|
||||
# [profiles.prod-cn.pages]
|
||||
# custom_domain = "email.kotei.asia"
|
||||
"""
|
||||
|
||||
|
||||
def default_config_document() -> TOMLDocument:
|
||||
"""Return the default deployment configuration document."""
|
||||
|
||||
return tomlkit.parse(DEFAULT_CONFIG_TOML)
|
||||
|
||||
|
||||
def _ensure_parent_directory(path: Path) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def write_text_atomic(path: Path, text: str) -> None:
|
||||
"""Write a text file atomically."""
|
||||
|
||||
_ensure_parent_directory(path)
|
||||
with NamedTemporaryFile("w", encoding="utf-8", dir=path.parent, delete=False) as handle:
|
||||
handle.write(text)
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
temp_path = Path(handle.name)
|
||||
temp_path.replace(path)
|
||||
|
||||
|
||||
def load_toml_document(path: Path) -> TOMLDocument:
|
||||
"""Load a TOML document from disk."""
|
||||
|
||||
try:
|
||||
return tomlkit.parse(path.read_text(encoding="utf-8"))
|
||||
except FileNotFoundError as exc:
|
||||
raise ConfigError(f"配置文件不存在: {path}") from exc
|
||||
except ParseError as exc:
|
||||
raise ConfigError(f"TOML 解析失败: {path}") from exc
|
||||
|
||||
|
||||
def save_toml_document(path: Path, document: TOMLDocument) -> None:
|
||||
"""Persist a TOML document atomically."""
|
||||
|
||||
write_text_atomic(path, tomlkit.dumps(document))
|
||||
|
||||
|
||||
def _deep_merge_dicts(base: Mapping[str, Any], override: Mapping[str, Any]) -> dict[str, Any]:
|
||||
merged = deepcopy(dict(base))
|
||||
for key, value in override.items():
|
||||
current = merged.get(key)
|
||||
if isinstance(current, dict) and isinstance(value, Mapping):
|
||||
merged[key] = _deep_merge_dicts(current, value)
|
||||
continue
|
||||
merged[key] = deepcopy(value)
|
||||
return merged
|
||||
|
||||
|
||||
def _resolve_profile_payload(payload: Mapping[str, Any], profile: str | None = None) -> dict[str, Any]:
|
||||
base_payload = {key: deepcopy(value) for key, value in payload.items() if key != "profiles"}
|
||||
if not profile:
|
||||
return base_payload
|
||||
|
||||
profiles = payload.get("profiles", {})
|
||||
if not isinstance(profiles, Mapping):
|
||||
raise ConfigError("配置校验失败: profiles 必须是表。")
|
||||
|
||||
selected = profiles.get(profile)
|
||||
if selected is None:
|
||||
raise ConfigError(f"配置校验失败: 未找到 profile: {profile}")
|
||||
if not isinstance(selected, Mapping):
|
||||
raise ConfigError(f"配置校验失败: profiles.{profile} 必须是表。")
|
||||
return _deep_merge_dicts(base_payload, selected)
|
||||
|
||||
|
||||
def load_config(path: Path, *, profile: str | None = None) -> DeploymentConfig:
|
||||
"""Load and validate the deployment configuration."""
|
||||
|
||||
document = load_toml_document(path)
|
||||
return config_from_document(document, profile=profile)
|
||||
|
||||
|
||||
def config_from_document(document: TOMLDocument, *, profile: str | None = None) -> DeploymentConfig:
|
||||
"""Build and validate a deployment configuration from a TOML document."""
|
||||
|
||||
try:
|
||||
payload = _resolve_profile_payload(document.unwrap(), profile)
|
||||
return DeploymentConfig.model_validate(payload)
|
||||
except ValueError as exc:
|
||||
raise ConfigError(f"配置校验失败: {exc}") from exc
|
||||
|
||||
|
||||
def parse_toml_value(raw_value: str) -> Any:
|
||||
"""Parse a CLI value using TOML literal syntax."""
|
||||
|
||||
snippet = f"value = {raw_value}\n"
|
||||
try:
|
||||
item = tomlkit.parse(snippet)["value"]
|
||||
except ParseError:
|
||||
return raw_value
|
||||
return item.unwrap() if hasattr(item, "unwrap") else item
|
||||
|
||||
|
||||
def set_dotted_value(document: TOMLDocument, dotted_key: str, value: Any) -> None:
|
||||
"""Set a nested value inside a TOML document."""
|
||||
|
||||
parts = dotted_key.split(".")
|
||||
if not all(parts):
|
||||
raise ConfigError(f"非法配置键: {dotted_key}")
|
||||
|
||||
current: TOMLDocument | AbstractTable = document
|
||||
for part in parts[:-1]:
|
||||
if part not in current:
|
||||
current[part] = tomlkit.table()
|
||||
next_value = current[part]
|
||||
if not isinstance(next_value, AbstractTable):
|
||||
raise ConfigError(f"配置键 {dotted_key} 的父节点不是表: {part}")
|
||||
current = next_value
|
||||
|
||||
current[parts[-1]] = tomlkit.item(value)
|
||||
|
||||
|
||||
def apply_overrides(document: TOMLDocument, overrides: Mapping[str, Any]) -> TOMLDocument:
|
||||
"""Apply CLI override values to a configuration document."""
|
||||
|
||||
for dotted_key, value in overrides.items():
|
||||
set_dotted_value(document, dotted_key, value)
|
||||
return document
|
||||
|
||||
|
||||
def load_state(path: Path) -> DeploymentState:
|
||||
"""Load deployment state or return the default state when absent."""
|
||||
|
||||
if not path.exists():
|
||||
return DeploymentState()
|
||||
|
||||
try:
|
||||
document = tomlkit.parse(path.read_text(encoding="utf-8"))
|
||||
return DeploymentState.model_validate(document.unwrap())
|
||||
except (ParseError, ValueError) as exc:
|
||||
raise ConfigError(f"状态文件校验失败: {path}") from exc
|
||||
|
||||
|
||||
def save_state(path: Path, state: DeploymentState) -> None:
|
||||
"""Persist deployment state atomically."""
|
||||
|
||||
write_text_atomic(path, tomlkit.dumps(state.model_dump(mode="python")))
|
||||
1089
src/cf_temp_email_deploy/deployment.py
Normal file
1089
src/cf_temp_email_deploy/deployment.py
Normal file
File diff suppressed because it is too large
Load Diff
205
src/cf_temp_email_deploy/discovery.py
Normal file
205
src/cf_temp_email_deploy/discovery.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Discovery helpers for importing existing Cloudflare resources into config files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from cf_temp_email_deploy.cloudflare import CloudflareClient
|
||||
from cf_temp_email_deploy.config import (
|
||||
apply_overrides,
|
||||
config_from_document,
|
||||
default_config_document,
|
||||
load_toml_document,
|
||||
save_toml_document,
|
||||
)
|
||||
from cf_temp_email_deploy.errors import CloudflareAPIError, ConfigError
|
||||
from cf_temp_email_deploy.logging_utils import get_logger
|
||||
|
||||
LOGGER = get_logger("discover")
|
||||
|
||||
|
||||
def discover_config(
|
||||
*,
|
||||
config_path,
|
||||
profile: str | None,
|
||||
cli_overrides: dict[str, Any],
|
||||
cloudflare_client_factory: type[CloudflareClient] = CloudflareClient,
|
||||
) -> None:
|
||||
document = load_toml_document(config_path) if config_path.exists() else default_config_document()
|
||||
if cli_overrides:
|
||||
apply_overrides(document, cli_overrides)
|
||||
|
||||
config = config_from_document(document, profile=profile)
|
||||
if not config.cloudflare.resolved_api_token():
|
||||
raise ConfigError("discover-config 需要 cloudflare.api_token 或 cloudflare.api_token_env。")
|
||||
if not config.cloudflare.zone_name:
|
||||
raise ConfigError("discover-config 需要 cloudflare.zone_name。")
|
||||
|
||||
with cloudflare_client_factory(config.cloudflare) as cloudflare:
|
||||
discovered = _discover_overrides(cloudflare, config.cloudflare.zone_name, config.cloudflare.account_id)
|
||||
|
||||
target_overrides = {
|
||||
_qualify_profile_key(profile, dotted_key): value for dotted_key, value in discovered.items()
|
||||
}
|
||||
apply_overrides(document, target_overrides)
|
||||
save_toml_document(config_path, document)
|
||||
|
||||
|
||||
def _qualify_profile_key(profile: str | None, dotted_key: str) -> str:
|
||||
if not profile:
|
||||
return dotted_key
|
||||
return f"profiles.{profile}.{dotted_key}"
|
||||
|
||||
|
||||
def _discover_overrides(
|
||||
cloudflare: CloudflareClient,
|
||||
zone_name: str,
|
||||
configured_account_id: str = "",
|
||||
) -> dict[str, Any]:
|
||||
token_info = cloudflare.verify_token()
|
||||
LOGGER.info("[discover] token_status=%s", token_info.get("status", "unknown"))
|
||||
|
||||
zone = cloudflare.resolve_zone(zone_name=zone_name, account_id=configured_account_id)
|
||||
zone_account = zone.get("account") if isinstance(zone.get("account"), dict) else {}
|
||||
account_id = configured_account_id or str(zone_account.get("id", ""))
|
||||
account_name = str(zone_account.get("name", ""))
|
||||
if not account_id:
|
||||
raise CloudflareAPIError("无法从 Zone 信息中解析 account_id。")
|
||||
|
||||
overrides: dict[str, Any] = {
|
||||
"cloudflare.zone_name": str(zone.get("name", zone_name)),
|
||||
"cloudflare.account_id": account_id,
|
||||
}
|
||||
if account_name:
|
||||
overrides["cloudflare.account_name"] = account_name
|
||||
|
||||
zone_id = str(zone.get("id", ""))
|
||||
dns_records = cloudflare.list_dns_records(zone_id)
|
||||
mail_domains = _discover_mail_domains(dns_records)
|
||||
if mail_domains:
|
||||
overrides["mail.domains"] = mail_domains
|
||||
|
||||
verified_destination = _discover_verified_destination(cloudflare, account_id)
|
||||
if verified_destination:
|
||||
overrides["mail.verified_destination_address"] = verified_destination
|
||||
|
||||
d1_name = _discover_d1_name(cloudflare, account_id)
|
||||
if d1_name:
|
||||
overrides["d1.database_name"] = d1_name
|
||||
|
||||
pages = _discover_pages(cloudflare, account_id, dns_records)
|
||||
if pages.get("project_name"):
|
||||
overrides["pages.project_name"] = pages["project_name"]
|
||||
if pages.get("custom_domain"):
|
||||
overrides["pages.custom_domain"] = pages["custom_domain"]
|
||||
|
||||
worker = _discover_worker(cloudflare, zone_id, account_id)
|
||||
if worker.get("script_name"):
|
||||
overrides["worker.script_name"] = worker["script_name"]
|
||||
|
||||
return overrides
|
||||
|
||||
|
||||
def _discover_mail_domains(dns_records: list[dict[str, Any]]) -> list[str]:
|
||||
domains = {
|
||||
str(record.get("name", "")).strip()
|
||||
for record in dns_records
|
||||
if str(record.get("type", "")).upper() == "MX"
|
||||
and "mx.cloudflare.net" in str(record.get("content", "")).lower()
|
||||
and str(record.get("name", "")).strip()
|
||||
}
|
||||
return sorted(domains)
|
||||
|
||||
|
||||
def _discover_verified_destination(cloudflare: CloudflareClient, account_id: str) -> str:
|
||||
try:
|
||||
addresses = cloudflare.list_email_routing_addresses(account_id=account_id)
|
||||
except Exception as exc:
|
||||
if cloudflare.is_authentication_error(exc):
|
||||
LOGGER.warning("[discover] 当前鉴权无法读取 Email Routing 地址,跳过 verified_destination_address。")
|
||||
return ""
|
||||
raise
|
||||
for address in addresses:
|
||||
email = str(address.get("email", "")).strip()
|
||||
if email and cloudflare.email_address_ready(address, email):
|
||||
return email
|
||||
return ""
|
||||
|
||||
|
||||
def _discover_d1_name(cloudflare: CloudflareClient, account_id: str) -> str:
|
||||
databases = cloudflare.list_d1_databases(account_id=account_id)
|
||||
if len(databases) == 1:
|
||||
return str(databases[0].get("name", "")).strip()
|
||||
if len(databases) > 1:
|
||||
LOGGER.warning("[discover] 检测到多个 D1 数据库,跳过自动写入 database_name。")
|
||||
return ""
|
||||
|
||||
|
||||
def _discover_pages(
|
||||
cloudflare: CloudflareClient,
|
||||
account_id: str,
|
||||
dns_records: list[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
projects = cloudflare.list_pages_projects(account_id=account_id)
|
||||
if not projects:
|
||||
return {}
|
||||
if len(projects) > 1:
|
||||
LOGGER.warning("[discover] 检测到多个 Pages 项目,将尽量通过 DNS 反推。")
|
||||
|
||||
cname_by_content = {
|
||||
str(record.get("content", "")).strip().lower(): str(record.get("name", "")).strip()
|
||||
for record in dns_records
|
||||
if str(record.get("type", "")).upper() == "CNAME"
|
||||
}
|
||||
matches: list[dict[str, str]] = []
|
||||
for project in projects:
|
||||
subdomain = str(project.get("subdomain", "")).strip().lower()
|
||||
if not subdomain:
|
||||
continue
|
||||
custom_domain = cname_by_content.get(subdomain, "")
|
||||
matches.append(
|
||||
{
|
||||
"project_name": str(project.get("name", "")).strip(),
|
||||
"custom_domain": custom_domain,
|
||||
}
|
||||
)
|
||||
|
||||
exact = [item for item in matches if item["project_name"] and item["custom_domain"]]
|
||||
if len(exact) == 1:
|
||||
return exact[0]
|
||||
if len(projects) == 1:
|
||||
return {
|
||||
"project_name": str(projects[0].get("name", "")).strip(),
|
||||
"custom_domain": exact[0]["custom_domain"] if exact else "",
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def _discover_worker(cloudflare: CloudflareClient, zone_id: str, account_id: str) -> dict[str, str]:
|
||||
catch_all = None
|
||||
try:
|
||||
catch_all = cloudflare.get_catch_all(zone_id=zone_id)
|
||||
except Exception as exc:
|
||||
if not cloudflare.is_authentication_error(exc):
|
||||
raise
|
||||
LOGGER.warning("[discover] 当前鉴权无法读取 Catch-all,尝试从 Worker 列表反推脚本。")
|
||||
|
||||
if isinstance(catch_all, dict):
|
||||
actions = catch_all.get("actions")
|
||||
if isinstance(actions, list):
|
||||
for action in actions:
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
if action.get("type") != "worker":
|
||||
continue
|
||||
targets = sorted(cloudflare.extract_worker_targets(action.get("value")))
|
||||
if targets:
|
||||
return {"script_name": targets[0]}
|
||||
|
||||
scripts = cloudflare.list_worker_scripts(account_id=account_id)
|
||||
email_scripts = [item for item in scripts if cloudflare.worker_script_supports_email(item)]
|
||||
if len(email_scripts) == 1:
|
||||
return {"script_name": str(email_scripts[0].get("id", "")).strip()}
|
||||
if len(email_scripts) > 1:
|
||||
LOGGER.warning("[discover] 检测到多个带 email handler 的 Worker,跳过自动写入 script_name。")
|
||||
return {}
|
||||
53
src/cf_temp_email_deploy/environment.py
Normal file
53
src/cf_temp_email_deploy/environment.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Local environment checks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cf_temp_email_deploy.errors import EnvironmentCheckError
|
||||
from cf_temp_email_deploy.subprocess_runner import CommandRunner, CommandSpec
|
||||
|
||||
MINIMUM_NODE_VERSION = (20, 19, 0)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolVersion:
|
||||
name: str
|
||||
version: str
|
||||
|
||||
|
||||
def parse_semver(raw_version: str) -> tuple[int, int, int]:
|
||||
"""Extract a semantic version tuple from command output."""
|
||||
|
||||
match = re.search(r"(\d+)\.(\d+)\.(\d+)", raw_version)
|
||||
if not match:
|
||||
raise EnvironmentCheckError(f"无法解析版本号: {raw_version.strip()}")
|
||||
return tuple(int(part) for part in match.groups())
|
||||
|
||||
|
||||
def check_tool_version(runner: CommandRunner, command_name: str, version_flag: str = "--version") -> ToolVersion:
|
||||
"""Execute a tool version command and return the parsed output."""
|
||||
|
||||
result = runner.run_checked(CommandSpec(args=(command_name, version_flag)))
|
||||
version_text = (result.stdout or result.stderr).strip()
|
||||
if not version_text:
|
||||
raise EnvironmentCheckError(f"无法读取 {command_name} 的版本输出。")
|
||||
return ToolVersion(name=command_name, version=version_text)
|
||||
|
||||
|
||||
def check_required_tools(runner: CommandRunner) -> dict[str, ToolVersion]:
|
||||
"""Validate required local tools and versions."""
|
||||
|
||||
versions = {
|
||||
"git": check_tool_version(runner, "git"),
|
||||
"node": check_tool_version(runner, "node"),
|
||||
"npm": check_tool_version(runner, "npm"),
|
||||
}
|
||||
node_version = parse_semver(versions["node"].version)
|
||||
if node_version < MINIMUM_NODE_VERSION:
|
||||
minimum = ".".join(str(part) for part in MINIMUM_NODE_VERSION)
|
||||
current = ".".join(str(part) for part in node_version)
|
||||
raise EnvironmentCheckError(f"Node.js 版本过低,当前为 {current},最低要求为 {minimum}。")
|
||||
return versions
|
||||
|
||||
54
src/cf_temp_email_deploy/errors.py
Normal file
54
src/cf_temp_email_deploy/errors.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Project specific exceptions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class DeployError(Exception):
|
||||
"""Base exception for the deployment CLI."""
|
||||
|
||||
|
||||
class ConfigError(DeployError):
|
||||
"""Raised when configuration content is invalid."""
|
||||
|
||||
|
||||
class EnvironmentCheckError(DeployError):
|
||||
"""Raised when the local environment does not satisfy requirements."""
|
||||
|
||||
|
||||
class CloudflareAPIError(DeployError):
|
||||
"""Raised when Cloudflare API responses cannot be accepted."""
|
||||
|
||||
def __init__(self, message: str, *, status_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class CommandExecutionError(DeployError):
|
||||
"""Raised when a subprocess returns an unsuccessful result."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
command: tuple[str, ...],
|
||||
returncode: int | None,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.command = command
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class AcceptanceCheckError(DeployError):
|
||||
"""Raised when post-deployment acceptance checks fail."""
|
||||
|
||||
|
||||
class ApplicationAPIError(DeployError):
|
||||
"""Raised when application admin API responses cannot be accepted."""
|
||||
|
||||
def __init__(self, message: str, *, status_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
62
src/cf_temp_email_deploy/logging_utils.py
Normal file
62
src/cf_temp_email_deploy/logging_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Logging helpers used across the project."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
LOGGER_NAME = "cf_temp_email_deploy"
|
||||
LOG_FORMAT = "%(asctime)s | %(levelname)s | %(message)s"
|
||||
|
||||
|
||||
def configure_logging(verbose: bool = False) -> logging.Logger:
|
||||
"""Configure and return the package root logger."""
|
||||
|
||||
logger = logging.getLogger(LOGGER_NAME)
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logger.setLevel(level)
|
||||
logger.propagate = False
|
||||
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
logger.addHandler(handler)
|
||||
|
||||
for handler in logger.handlers:
|
||||
handler.setLevel(level)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str | None = None) -> logging.Logger:
|
||||
"""Return a child logger inside the package namespace."""
|
||||
|
||||
if not name:
|
||||
return logging.getLogger(LOGGER_NAME)
|
||||
return logging.getLogger(f"{LOGGER_NAME}.{name}")
|
||||
|
||||
|
||||
def log_stage(message: str) -> None:
|
||||
"""Emit a stage-level log message."""
|
||||
|
||||
get_logger("stage").info("[stage] %s", message)
|
||||
|
||||
|
||||
def log_command(command: tuple[str, ...], cwd: Path | None) -> None:
|
||||
"""Emit a log line before a subprocess starts."""
|
||||
|
||||
location = str(cwd) if cwd else "."
|
||||
get_logger("command").info("[command] cwd=%s cmd=%s", location, shlex.join(command))
|
||||
|
||||
|
||||
def log_command_result(command: tuple[str, ...], returncode: int, duration_seconds: float) -> None:
|
||||
"""Emit a log line after a subprocess finishes."""
|
||||
|
||||
get_logger("command").info(
|
||||
"[command-result] code=%s duration=%.3fs cmd=%s",
|
||||
returncode,
|
||||
duration_seconds,
|
||||
shlex.join(command),
|
||||
)
|
||||
|
||||
258
src/cf_temp_email_deploy/models.py
Normal file
258
src/cf_temp_email_deploy/models.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Typed configuration and state models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
def _now_isoformat() -> str:
|
||||
return datetime.now(UTC).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
class StrictModel(BaseModel):
|
||||
"""Shared Pydantic configuration."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class SourceConfig(StrictModel):
|
||||
mode: Literal["clone", "local"] = "clone"
|
||||
repo_url: str = "https://github.com/dreamhunter2333/cloudflare_temp_email.git"
|
||||
repo_ref: str = ""
|
||||
workspace_dir: str = ".deploy/workspace"
|
||||
local_path: str = ""
|
||||
|
||||
|
||||
class CloudflareConfig(StrictModel):
|
||||
account_id: str = ""
|
||||
account_name: str = ""
|
||||
zone_name: str = "example.com"
|
||||
api_token: str = ""
|
||||
api_token_env: str = ""
|
||||
api_email: str = ""
|
||||
api_email_env: str = ""
|
||||
global_api_key: str = ""
|
||||
global_api_key_env: str = ""
|
||||
|
||||
def resolved_api_token(self, environ: dict[str, str] | None = None) -> str:
|
||||
environment = environ or os.environ
|
||||
if self.api_token:
|
||||
return self.api_token
|
||||
if self.api_token_env:
|
||||
return environment.get(self.api_token_env, "")
|
||||
return ""
|
||||
|
||||
def resolved_api_email(self, environ: dict[str, str] | None = None) -> str:
|
||||
environment = environ or os.environ
|
||||
if self.api_email:
|
||||
return self.api_email
|
||||
if self.api_email_env:
|
||||
return environment.get(self.api_email_env, "")
|
||||
return ""
|
||||
|
||||
def resolved_global_api_key(self, environ: dict[str, str] | None = None) -> str:
|
||||
environment = environ or os.environ
|
||||
if self.global_api_key:
|
||||
return self.global_api_key
|
||||
if self.global_api_key_env:
|
||||
return environment.get(self.global_api_key_env, "")
|
||||
return ""
|
||||
|
||||
|
||||
class MailConfig(StrictModel):
|
||||
domains: list[str] = Field(default_factory=lambda: ["mail.example.com"])
|
||||
verified_destination_address: str = "inbox@example.net"
|
||||
|
||||
|
||||
class D1Config(StrictModel):
|
||||
database_name: str = "cf-temp-email"
|
||||
jurisdiction: str = ""
|
||||
adopt_existing_schema: bool = False
|
||||
|
||||
|
||||
class UserAccessConfig(StrictModel):
|
||||
require_login_to_create: bool = True
|
||||
allow_user_register: bool = False
|
||||
|
||||
|
||||
class LinuxdoConfig(StrictModel):
|
||||
linuxdo_oauth: bool = False
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
|
||||
|
||||
class WorkerConfig(StrictModel):
|
||||
script_name: str = "cloudflare-temp-email"
|
||||
use_workers_dev: bool = True
|
||||
custom_domain: str = ""
|
||||
compatibility_date: str = "2024-09-23"
|
||||
vars: dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"PREFIX": "tmp",
|
||||
"ENABLE_USER_CREATE_EMAIL": True,
|
||||
"ENABLE_USER_DELETE_EMAIL": True,
|
||||
"DEFAULT_LANG": "zh",
|
||||
"ADMIN_PASSWORDS": ["change-me"],
|
||||
}
|
||||
)
|
||||
secrets: dict[str, str] = Field(default_factory=lambda: {"JWT_SECRET": ""})
|
||||
|
||||
|
||||
class PagesConfig(StrictModel):
|
||||
project_name: str = "cf-temp-email-pages"
|
||||
custom_domain: str = "mail.example.com"
|
||||
build_mode: Literal["pages", "pages:nopwa"] = "pages"
|
||||
production_branch: str = "production"
|
||||
|
||||
|
||||
class DeploymentConfig(StrictModel):
|
||||
config_version: int = 1
|
||||
source: SourceConfig = Field(default_factory=SourceConfig)
|
||||
cloudflare: CloudflareConfig = Field(default_factory=CloudflareConfig)
|
||||
mail: MailConfig = Field(default_factory=MailConfig)
|
||||
d1: D1Config = Field(default_factory=D1Config)
|
||||
user_access: UserAccessConfig = Field(default_factory=UserAccessConfig)
|
||||
linuxdo: LinuxdoConfig = Field(default_factory=LinuxdoConfig)
|
||||
worker: WorkerConfig = Field(default_factory=WorkerConfig)
|
||||
pages: PagesConfig = Field(default_factory=PagesConfig)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self) -> "DeploymentConfig":
|
||||
if not self.mail.domains:
|
||||
raise ValueError("mail.domains 至少需要一个域名。")
|
||||
if self.source.mode == "local" and not self.source.local_path:
|
||||
raise ValueError("source.mode=local 时需要填写 source.local_path。")
|
||||
if not self.cloudflare.zone_name:
|
||||
raise ValueError("cloudflare.zone_name 不能为空。")
|
||||
if not self.pages.custom_domain:
|
||||
raise ValueError("pages.custom_domain 不能为空。")
|
||||
if not self.worker.custom_domain and not self.worker.use_workers_dev:
|
||||
raise ValueError("worker.custom_domain 与 worker.use_workers_dev 不能同时关闭。")
|
||||
if not self.admin_passwords():
|
||||
raise ValueError("worker.vars.ADMIN_PASSWORDS 至少需要一个非空密码。")
|
||||
if self.linuxdo.linuxdo_oauth and not self.linuxdo.client_id.strip():
|
||||
raise ValueError("linuxdo.linuxdo_oauth=true 时需要填写 linuxdo.client_id。")
|
||||
if self.linuxdo.linuxdo_oauth and not self.linuxdo.client_secret.strip():
|
||||
raise ValueError("linuxdo.linuxdo_oauth=true 时需要填写 linuxdo.client_secret。")
|
||||
return self
|
||||
|
||||
def build_frontend_url(self, pages_hostname: str | None = None) -> str:
|
||||
hostname = self.pages.custom_domain or pages_hostname or f"{self.pages.project_name}.pages.dev"
|
||||
return f"https://{hostname}"
|
||||
|
||||
def admin_passwords(self) -> list[str]:
|
||||
raw_value = self.worker.vars.get("ADMIN_PASSWORDS", [])
|
||||
if isinstance(raw_value, str):
|
||||
value = raw_value.strip()
|
||||
return [value] if value else []
|
||||
if not isinstance(raw_value, list):
|
||||
return []
|
||||
passwords: list[str] = []
|
||||
for item in raw_value:
|
||||
value = str(item).strip()
|
||||
if value:
|
||||
passwords.append(value)
|
||||
return passwords
|
||||
|
||||
def effective_require_login_to_create(self) -> bool:
|
||||
return self.user_access.require_login_to_create or self.linuxdo.linuxdo_oauth
|
||||
|
||||
def linuxdo_callback_url(self) -> str:
|
||||
return f"https://{self.pages.custom_domain}/user/oauth2/callback"
|
||||
|
||||
def derived_worker_vars(self, pages_hostname: str | None = None) -> dict[str, Any]:
|
||||
values = dict(self.worker.vars)
|
||||
values["DOMAINS"] = list(self.mail.domains)
|
||||
values.setdefault("DEFAULT_DOMAINS", list(self.mail.domains))
|
||||
values["FRONTEND_URL"] = self.build_frontend_url(pages_hostname)
|
||||
values["DISABLE_ANONYMOUS_USER_CREATE_EMAIL"] = self.effective_require_login_to_create()
|
||||
return values
|
||||
|
||||
|
||||
class CloudflareState(StrictModel):
|
||||
account_id: str = ""
|
||||
account_name: str = ""
|
||||
zone_id: str = ""
|
||||
zone_name: str = ""
|
||||
|
||||
|
||||
class SourceState(StrictModel):
|
||||
source_dir: str = ""
|
||||
commit_sha: str = ""
|
||||
|
||||
|
||||
class PagesState(StrictModel):
|
||||
project_id: str = ""
|
||||
project_name: str = ""
|
||||
subdomain: str = ""
|
||||
custom_domain: str = ""
|
||||
custom_domain_status: str = ""
|
||||
cname_record_id: str = ""
|
||||
|
||||
|
||||
class D1MigrationState(StrictModel):
|
||||
file_name: str
|
||||
sha256: str
|
||||
applied_at: str = Field(default_factory=_now_isoformat)
|
||||
|
||||
|
||||
class D1State(StrictModel):
|
||||
database_id: str = ""
|
||||
database_name: str = ""
|
||||
schema_version: str = ""
|
||||
migrations: list[D1MigrationState] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkerState(StrictModel):
|
||||
script_name: str = ""
|
||||
workers_dev_url: str = ""
|
||||
|
||||
|
||||
class ApplicationState(StrictModel):
|
||||
configured: bool = False
|
||||
admin_base_url: str = ""
|
||||
allow_user_register: bool = False
|
||||
require_login_to_create: bool = False
|
||||
linuxdo_oauth_enabled: bool = False
|
||||
linuxdo_redirect_url: str = ""
|
||||
|
||||
|
||||
class EmailRoutingState(StrictModel):
|
||||
destination_address: str = ""
|
||||
dns_record_ids: list[str] = Field(default_factory=list)
|
||||
catch_all_enabled: bool = False
|
||||
catch_all_rule_id: str = ""
|
||||
catch_all_worker: str = ""
|
||||
|
||||
|
||||
class DeploymentEvent(StrictModel):
|
||||
stage: str
|
||||
status: Literal["started", "completed", "failed"]
|
||||
message: str = ""
|
||||
timestamp: str = Field(default_factory=_now_isoformat)
|
||||
|
||||
|
||||
class DeploymentState(StrictModel):
|
||||
config_version: int = 1
|
||||
checkpoint: str = ""
|
||||
last_updated_at: str = Field(default_factory=_now_isoformat)
|
||||
cloudflare: CloudflareState = Field(default_factory=CloudflareState)
|
||||
source: SourceState = Field(default_factory=SourceState)
|
||||
pages: PagesState = Field(default_factory=PagesState)
|
||||
d1: D1State = Field(default_factory=D1State)
|
||||
worker: WorkerState = Field(default_factory=WorkerState)
|
||||
application: ApplicationState = Field(default_factory=ApplicationState)
|
||||
email_routing: EmailRoutingState = Field(default_factory=EmailRoutingState)
|
||||
events: list[DeploymentEvent] = Field(default_factory=list)
|
||||
|
||||
def mark_checkpoint(self, checkpoint: str) -> None:
|
||||
self.checkpoint = checkpoint
|
||||
self.last_updated_at = _now_isoformat()
|
||||
|
||||
def record_event(self, stage: str, status: Literal["started", "completed", "failed"], message: str = "") -> None:
|
||||
self.events.append(DeploymentEvent(stage=stage, status=status, message=message))
|
||||
self.last_updated_at = _now_isoformat()
|
||||
79
src/cf_temp_email_deploy/project_layout.py
Normal file
79
src/cf_temp_email_deploy/project_layout.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Helpers for locating the upstream project layout."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from cf_temp_email_deploy.errors import ConfigError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectLayout:
|
||||
"""Resolved directory layout for the upstream project."""
|
||||
|
||||
root_dir: Path
|
||||
worker_dir: Path
|
||||
frontend_dir: Path
|
||||
pages_dir: Path
|
||||
db_dir: Path
|
||||
worker_wrangler_template: Path
|
||||
worker_wrangler_path: Path
|
||||
worker_patch_path: Path
|
||||
pages_wrangler_path: Path
|
||||
schema_path: Path
|
||||
migration_paths: tuple[Path, ...]
|
||||
|
||||
@property
|
||||
def frontend_dist_dir(self) -> Path:
|
||||
return self.frontend_dir / "dist"
|
||||
|
||||
@property
|
||||
def telegraf_dir(self) -> Path:
|
||||
return self.worker_dir / "node_modules" / "telegraf"
|
||||
|
||||
|
||||
def detect_project_layout(root_dir: Path) -> ProjectLayout:
|
||||
"""Resolve the expected Cloudflare Temp Email project structure."""
|
||||
|
||||
base_dir = root_dir.expanduser().resolve()
|
||||
worker_dir = _require_directory(base_dir / "worker")
|
||||
frontend_dir = _require_directory(base_dir / "frontend")
|
||||
pages_dir = _require_directory(base_dir / "pages")
|
||||
db_dir = _require_directory(base_dir / "db")
|
||||
worker_wrangler_template = _require_file(worker_dir / "wrangler.toml.template")
|
||||
pages_wrangler_path = pages_dir / "wrangler.toml"
|
||||
schema_path = _require_file(db_dir / "schema.sql")
|
||||
worker_patch_path = _require_file(worker_dir / "patches" / "telegraf@4.16.3.patch")
|
||||
migration_paths = tuple(
|
||||
sorted(
|
||||
path
|
||||
for path in db_dir.glob("*.sql")
|
||||
if path.name != "schema.sql"
|
||||
)
|
||||
)
|
||||
return ProjectLayout(
|
||||
root_dir=base_dir,
|
||||
worker_dir=worker_dir,
|
||||
frontend_dir=frontend_dir,
|
||||
pages_dir=pages_dir,
|
||||
db_dir=db_dir,
|
||||
worker_wrangler_template=worker_wrangler_template,
|
||||
worker_wrangler_path=worker_dir / "wrangler.toml",
|
||||
worker_patch_path=worker_patch_path,
|
||||
pages_wrangler_path=pages_wrangler_path,
|
||||
schema_path=schema_path,
|
||||
migration_paths=migration_paths,
|
||||
)
|
||||
|
||||
|
||||
def _require_directory(path: Path) -> Path:
|
||||
if not path.is_dir():
|
||||
raise ConfigError(f"缺少目录: {path}")
|
||||
return path
|
||||
|
||||
|
||||
def _require_file(path: Path) -> Path:
|
||||
if not path.is_file():
|
||||
raise ConfigError(f"缺少文件: {path}")
|
||||
return path
|
||||
117
src/cf_temp_email_deploy/source.py
Normal file
117
src/cf_temp_email_deploy/source.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Source preparation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from cf_temp_email_deploy.errors import ConfigError
|
||||
from cf_temp_email_deploy.models import DeploymentConfig, DeploymentState, SourceConfig
|
||||
from cf_temp_email_deploy.subprocess_runner import CommandRunner, CommandSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedSource:
|
||||
"""Prepared source directory and its commit identity."""
|
||||
|
||||
source_dir: Path
|
||||
commit_sha: str
|
||||
|
||||
def apply_to_state(self, state: DeploymentState) -> None:
|
||||
state.source.source_dir = str(self.source_dir)
|
||||
state.source.commit_sha = self.commit_sha
|
||||
|
||||
|
||||
def prepare_source(config: DeploymentConfig, runner: CommandRunner) -> PreparedSource:
|
||||
"""Prepare deployment source code according to the configured mode."""
|
||||
|
||||
if config.source.mode == "clone":
|
||||
return prepare_cloned_source(config.source, runner)
|
||||
return prepare_local_source(config.source, runner)
|
||||
|
||||
|
||||
def prepare_cloned_source(source: SourceConfig, runner: CommandRunner) -> PreparedSource:
|
||||
"""Clone or refresh the configured repository and checkout the target ref."""
|
||||
|
||||
workspace_dir = Path(source.workspace_dir).expanduser().resolve()
|
||||
repo_dir = workspace_dir / "source"
|
||||
if repo_dir.exists() and not (repo_dir / ".git").exists():
|
||||
raise ConfigError(f"源码目录存在但不是 Git 仓库: {repo_dir}")
|
||||
|
||||
if not repo_dir.exists():
|
||||
runner.run_checked(CommandSpec(args=("git", "clone", source.repo_url, str(repo_dir))))
|
||||
|
||||
runner.run_checked(CommandSpec(args=("git", "-C", str(repo_dir), "fetch", "--tags", "origin")))
|
||||
if not _checkout_ref(repo_dir, source.repo_ref, runner):
|
||||
raise ConfigError(f"无法切换到目标源码版本: {source.repo_ref}")
|
||||
return PreparedSource(source_dir=repo_dir, commit_sha=resolve_commit_sha(repo_dir, runner))
|
||||
|
||||
|
||||
def prepare_local_source(source: SourceConfig, runner: CommandRunner) -> PreparedSource:
|
||||
"""Validate a local source directory and collect its commit SHA."""
|
||||
|
||||
if not source.local_path:
|
||||
raise ConfigError("source.mode=local 时需要填写 source.local_path。")
|
||||
|
||||
source_dir = Path(source.local_path).expanduser().resolve()
|
||||
if not source_dir.exists():
|
||||
raise ConfigError(f"本地源码目录不存在: {source_dir}")
|
||||
if not (source_dir / ".git").exists():
|
||||
raise ConfigError(f"本地源码目录不是 Git 仓库: {source_dir}")
|
||||
return PreparedSource(source_dir=source_dir, commit_sha=resolve_commit_sha(source_dir, runner))
|
||||
|
||||
|
||||
def resolve_commit_sha(repo_dir: Path, runner: CommandRunner) -> str:
|
||||
"""Resolve the current HEAD commit SHA for a Git repository."""
|
||||
|
||||
result = runner.run_checked(CommandSpec(args=("git", "-C", str(repo_dir), "rev-parse", "HEAD")))
|
||||
commit_sha = result.stdout.strip()
|
||||
if not commit_sha:
|
||||
raise ConfigError(f"无法读取 Git 提交哈希: {repo_dir}")
|
||||
return commit_sha
|
||||
|
||||
|
||||
def _checkout_ref(repo_dir: Path, repo_ref: str, runner: CommandRunner) -> bool:
|
||||
if not repo_ref:
|
||||
return _checkout_latest_remote_head(repo_dir, runner)
|
||||
|
||||
candidates = (
|
||||
("git", "-C", str(repo_dir), "checkout", "--detach", repo_ref),
|
||||
("git", "-C", str(repo_dir), "checkout", "--detach", f"origin/{repo_ref}"),
|
||||
)
|
||||
for args in candidates:
|
||||
result = runner.run(CommandSpec(args=args))
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
fetch_result = runner.run(
|
||||
CommandSpec(args=("git", "-C", str(repo_dir), "fetch", "origin", repo_ref))
|
||||
)
|
||||
if fetch_result.returncode != 0:
|
||||
return False
|
||||
checkout_result = runner.run(
|
||||
CommandSpec(args=("git", "-C", str(repo_dir), "checkout", "--detach", "FETCH_HEAD"))
|
||||
)
|
||||
return checkout_result.returncode == 0
|
||||
|
||||
|
||||
def _checkout_latest_remote_head(repo_dir: Path, runner: CommandRunner) -> bool:
|
||||
symbolic_ref = runner.run(
|
||||
CommandSpec(args=("git", "-C", str(repo_dir), "symbolic-ref", "refs/remotes/origin/HEAD"))
|
||||
)
|
||||
candidates: list[tuple[str, ...]] = []
|
||||
remote_head = symbolic_ref.stdout.strip()
|
||||
if symbolic_ref.returncode == 0 and remote_head:
|
||||
candidates.append(("git", "-C", str(repo_dir), "checkout", "--detach", remote_head))
|
||||
|
||||
candidates.extend(
|
||||
[
|
||||
("git", "-C", str(repo_dir), "checkout", "--detach", "origin/HEAD"),
|
||||
("git", "-C", str(repo_dir), "checkout", "--detach", "origin/main"),
|
||||
("git", "-C", str(repo_dir), "checkout", "--detach", "origin/master"),
|
||||
]
|
||||
)
|
||||
for args in candidates:
|
||||
result = runner.run(CommandSpec(args=args))
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
return False
|
||||
87
src/cf_temp_email_deploy/subprocess_runner.py
Normal file
87
src/cf_temp_email_deploy/subprocess_runner.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Subprocess execution helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from cf_temp_email_deploy.errors import CommandExecutionError
|
||||
from cf_temp_email_deploy.logging_utils import log_command, log_command_result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandSpec:
|
||||
args: Sequence[str]
|
||||
cwd: Path | None = None
|
||||
env: Mapping[str, str] | None = None
|
||||
timeout: float | None = None
|
||||
input_text: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandResult:
|
||||
args: tuple[str, ...]
|
||||
returncode: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
duration_seconds: float = field(compare=False)
|
||||
|
||||
|
||||
class CommandRunner:
|
||||
"""Execute subprocess commands with logging and structured errors."""
|
||||
|
||||
def run(self, spec: CommandSpec) -> CommandResult:
|
||||
command = tuple(spec.args)
|
||||
environment = os.environ.copy()
|
||||
if spec.env:
|
||||
environment.update(spec.env)
|
||||
|
||||
log_command(command, spec.cwd)
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
completed = subprocess.run(
|
||||
command,
|
||||
cwd=spec.cwd,
|
||||
env=environment,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
input=spec.input_text,
|
||||
timeout=spec.timeout,
|
||||
check=False,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise CommandExecutionError(
|
||||
f"命令执行超时: {command[0]}",
|
||||
command=command,
|
||||
returncode=None,
|
||||
stdout=exc.stdout or "",
|
||||
stderr=exc.stderr or "",
|
||||
) from exc
|
||||
|
||||
duration = time.perf_counter() - started_at
|
||||
log_command_result(command, completed.returncode, duration)
|
||||
return CommandResult(
|
||||
args=command,
|
||||
returncode=completed.returncode,
|
||||
stdout=completed.stdout,
|
||||
stderr=completed.stderr,
|
||||
duration_seconds=duration,
|
||||
)
|
||||
|
||||
def run_checked(self, spec: CommandSpec) -> CommandResult:
|
||||
"""Execute a command and require a zero exit code."""
|
||||
|
||||
result = self.run(spec)
|
||||
if result.returncode != 0:
|
||||
raise CommandExecutionError(
|
||||
f"命令执行失败: {result.args[0]}",
|
||||
command=result.args,
|
||||
returncode=result.returncode,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
)
|
||||
return result
|
||||
36
src/cf_temp_email_deploy/wrangler.py
Normal file
36
src/cf_temp_email_deploy/wrangler.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Wrangler command helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def build_wrangler_command(*args: str) -> tuple[str, ...]:
|
||||
"""Build a local wrangler invocation using npm exec."""
|
||||
|
||||
return ("npm", "exec", "--", "wrangler", *args)
|
||||
|
||||
|
||||
def build_wrangler_secret_put_command(secret_name: str) -> tuple[str, ...]:
|
||||
"""Build a command that stores a Worker secret from standard input."""
|
||||
|
||||
return build_wrangler_command("secret", "put", secret_name)
|
||||
|
||||
|
||||
def build_wrangler_d1_execute_command(database_name: str, sql_file: Path) -> tuple[str, ...]:
|
||||
"""Build a command that executes a SQL file against a remote D1 database."""
|
||||
|
||||
return build_wrangler_command(
|
||||
"d1",
|
||||
"execute",
|
||||
database_name,
|
||||
"--file",
|
||||
str(sql_file),
|
||||
"--remote",
|
||||
)
|
||||
|
||||
|
||||
def build_wrangler_pages_deploy_command(branch: str) -> tuple[str, ...]:
|
||||
"""Build a command that deploys a Pages project using the configured wrangler file."""
|
||||
|
||||
return build_wrangler_command("pages", "deploy", "--branch", branch)
|
||||
Reference in New Issue
Block a user