Initial import of cf-temp-email deploy CLI

This commit is contained in:
mmc
2026-03-26 08:06:02 +08:00
commit 4100e9cf72
29 changed files with 6703 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
"""Cloudflare Temp Email automated deployment package."""
__all__ = ["__version__"]
__version__ = "0.1.0"

View File

@@ -0,0 +1,6 @@
from cf_temp_email_deploy.cli import main
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,221 @@
"""Application admin API helpers."""
from __future__ import annotations
import time
from typing import Any
import httpx
from cf_temp_email_deploy.errors import ApplicationAPIError
LINUXDO_OAUTH2_ICON = (
'<svg viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg" width="1em" height="1em">'
'<g><path d="m7.44,0s.09,0,.13,0c.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0q.12,0,.25,0t.26.08c.15.03.29.06.44.08,1.97.38,3.78,1.47,4.95,3.11.04.06.09.12.13.18.67.96,1.15,2.11,1.3,3.28q0,.19.09.26c0,.15,0,.29,0,.44,0,.04,0,.09,0,.13,0,.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0,.08,0,.17,0,.25q0,.19-.08.26c-.03.15-.06.29-.08.44-.38,1.97-1.47,3.78-3.11,4.95-.06.04-.12.09-.18.13-.96.67-2.11,1.15-3.28,1.3q-.19,0-.26.09c-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25,0q-.19,0-.26-.08c-.15-.03-.29-.06-.44-.08-1.97-.38-3.78-1.47-4.95-3.11q-.07-.09-.13-.18c-.67-.96-1.15-2.11-1.3-3.28q0-.19-.09-.26c0-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25q0-.19.08-.26c.03-.15.06-.29.08-.44.38-1.97,1.47-3.78,3.11-4.95.06-.04.12-.09.18-.13C4.42.73,5.57.26,6.74.1,7,.07,7.15,0,7.44,0Z" fill="#EFEFEF"/>'
'<path d="m1.27,11.33h13.45c-.94,1.89-2.51,3.21-4.51,3.88-1.99.59-3.96.37-5.8-.57-1.25-.7-2.67-1.9-3.14-3.3Z" fill="#FEB005"/>'
'<path d="m12.54,1.99c.87.7,1.82,1.59,2.18,2.68H1.27c.87-1.74,2.33-3.13,4.2-3.78,2.44-.79,5-.47,7.07,1.1Z" fill="#1D1D1F"/></g></svg>'
)
LINUXDO_OAUTH2_NAME = "LINUX DO"
LINUXDO_AUTHORIZATION_URL = "https://connect.linux.do/oauth2/authorize"
LINUXDO_ACCESS_TOKEN_URL = "https://connect.linux.do/oauth2/token"
LINUXDO_USER_INFO_URL = "https://connect.linux.do/api/user"
LINUXDO_SCOPE = "user:email"
USER_SETTINGS_DEFAULTS = {
"enableMailVerify": False,
"verifyMailSender": "",
"enableMailAllowList": False,
"mailAllowList": [],
"maxAddressCount": 5,
"enableEmailCheckRegex": False,
"emailCheckRegex": "",
}
def linuxdo_oauth_callback_url(pages_domain: str) -> str:
return f"https://{pages_domain}/user/oauth2/callback"
def merge_user_settings(current: object, allow_user_register: bool) -> dict[str, Any]:
merged = dict(current) if isinstance(current, dict) else {}
for key, value in USER_SETTINGS_DEFAULTS.items():
merged.setdefault(key, value)
merged["enable"] = allow_user_register
return merged
def build_linuxdo_oauth2_setting(
*,
pages_domain: str,
client_id: str,
client_secret: str,
) -> dict[str, Any]:
return {
"name": LINUXDO_OAUTH2_NAME,
"icon": LINUXDO_OAUTH2_ICON,
"clientID": client_id,
"clientSecret": client_secret,
"authorizationURL": LINUXDO_AUTHORIZATION_URL,
"accessTokenURL": LINUXDO_ACCESS_TOKEN_URL,
"accessTokenFormat": "urlencoded",
"userInfoURL": LINUXDO_USER_INFO_URL,
"redirectURL": linuxdo_oauth_callback_url(pages_domain),
"logoutURL": "",
"userEmailKey": "id",
"enableEmailFormat": True,
"userEmailFormat": "^(.+)$",
"userEmailReplace": "linux_do_$1@oauth.linux.do",
"scope": LINUXDO_SCOPE,
"enableMailAllowList": False,
"mailAllowList": [],
}
def is_linuxdo_oauth2_setting(setting: object) -> bool:
if not isinstance(setting, dict):
return False
name = str(setting.get("name", "")).strip().lower()
authorization_url = str(setting.get("authorizationURL", "")).strip()
access_token_url = str(setting.get("accessTokenURL", "")).strip()
user_info_url = str(setting.get("userInfoURL", "")).strip()
return name in {"linux do", "linuxdo"} or (
authorization_url == LINUXDO_AUTHORIZATION_URL
and access_token_url == LINUXDO_ACCESS_TOKEN_URL
and user_info_url == LINUXDO_USER_INFO_URL
)
def merge_linuxdo_oauth2_settings(
current: object,
*,
pages_domain: str,
client_id: str,
client_secret: str,
) -> list[dict[str, Any]]:
settings = current if isinstance(current, list) else []
desired = build_linuxdo_oauth2_setting(
pages_domain=pages_domain,
client_id=client_id,
client_secret=client_secret,
)
updated: list[dict[str, Any]] = []
replaced = False
for item in settings:
if not isinstance(item, dict):
continue
if is_linuxdo_oauth2_setting(item):
merged = dict(item)
merged.update(desired)
for key in ("enableMailAllowList", "mailAllowList", "logoutURL"):
if key in item:
merged[key] = item[key]
updated.append(merged)
replaced = True
continue
updated.append(item)
if not replaced:
updated.append(desired)
return updated
class ApplicationAdminClient:
"""Minimal client for post-deployment admin API configuration."""
def __init__(
self,
base_url: str,
admin_password: str,
*,
timeout: float = 30.0,
transport: httpx.BaseTransport | None = None,
) -> None:
normalized_base_url = base_url.rstrip("/")
self.base_url = normalized_base_url
self.client = httpx.Client(
base_url=normalized_base_url,
timeout=timeout,
follow_redirects=True,
trust_env=False,
transport=transport,
headers={"x-admin-auth": admin_password},
)
def close(self) -> None:
self.client.close()
def __enter__(self) -> "ApplicationAdminClient":
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
self.close()
def wait_until_ready(
self,
*,
timeout_seconds: float = 180.0,
poll_interval_seconds: float = 5.0,
) -> dict[str, Any]:
deadline = time.monotonic() + timeout_seconds
last_error: Exception | None = None
while True:
try:
return self.get_user_settings()
except (ApplicationAPIError, httpx.HTTPError) as exc:
last_error = exc
if time.monotonic() >= deadline:
raise ApplicationAPIError(
f"管理员接口在限定时间内未就绪: {self.base_url}/admin/user_settings"
) from last_error
time.sleep(poll_interval_seconds)
def get_user_settings(self) -> dict[str, Any]:
payload = self._request_json("GET", "/admin/user_settings")
return payload if isinstance(payload, dict) else {}
def sync_user_settings(self, *, allow_user_register: bool) -> dict[str, Any]:
merged = merge_user_settings(self.get_user_settings(), allow_user_register)
payload = self._request_json("POST", "/admin/user_settings", json=merged)
return payload if isinstance(payload, dict) else merged
def get_user_oauth2_settings(self) -> list[dict[str, Any]]:
payload = self._request_json("GET", "/admin/user_oauth2_settings")
if not isinstance(payload, list):
return []
return [item for item in payload if isinstance(item, dict)]
def sync_linuxdo_oauth2(
self,
*,
pages_domain: str,
client_id: str,
client_secret: str,
) -> list[dict[str, Any]]:
merged = merge_linuxdo_oauth2_settings(
self.get_user_oauth2_settings(),
pages_domain=pages_domain,
client_id=client_id,
client_secret=client_secret,
)
payload = self._request_json("POST", "/admin/user_oauth2_settings", json=merged)
if not isinstance(payload, list):
return merged
return [item for item in payload if isinstance(item, dict)]
def _request_json(self, method: str, path: str, *, json: Any | None = None) -> Any:
try:
response = self.client.request(method, path, json=json)
except httpx.HTTPError as exc: # pragma: no cover
raise ApplicationAPIError(f"请求管理员接口失败: {self.base_url}{path}") from exc
if response.status_code >= 400:
body = response.text.strip()
raise ApplicationAPIError(
f"管理员接口返回错误: {method.upper()} {path} -> {response.status_code} {body}".strip(),
status_code=response.status_code,
)
try:
return response.json()
except ValueError as exc:
raise ApplicationAPIError(f"管理员接口返回了非法 JSON: {method.upper()} {path}") from exc

View File

@@ -0,0 +1,264 @@
"""Command line entrypoint."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any
from cf_temp_email_deploy.cloudflare import CloudflareClient
from cf_temp_email_deploy.config import (
apply_overrides,
config_from_document,
default_config_document,
load_config,
load_state,
load_toml_document,
parse_toml_value,
save_toml_document,
save_state,
)
from cf_temp_email_deploy.discovery import discover_config
from cf_temp_email_deploy.deployment import run_deployment
from cf_temp_email_deploy.environment import check_required_tools
from cf_temp_email_deploy.errors import ConfigError, DeployError
from cf_temp_email_deploy.logging_utils import configure_logging, get_logger, log_stage
from cf_temp_email_deploy.subprocess_runner import CommandRunner
COMMAND_NAMES = {"init-config", "check", "deploy", "resume", "discover-config"}
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="cf-temp-email")
subparsers = parser.add_subparsers(dest="command", required=True)
init_parser = subparsers.add_parser("init-config", help="生成示例配置文件。")
init_parser.add_argument("--config", type=Path, default=Path("config.toml"))
init_parser.add_argument("--force", action="store_true")
init_parser.set_defaults(handler=handle_init_config)
for command_name, help_text, handler in (
("check", "执行环境与配置检查。", handle_check),
("deploy", "执行部署流程。", handle_deploy),
("resume", "从状态文件恢复部署流程。", handle_resume),
("discover-config", "从 Cloudflare 现有资源反推并写入配置。", handle_discover_config),
):
command_parser = subparsers.add_parser(command_name, help=help_text)
add_shared_arguments(command_parser)
command_parser.set_defaults(handler=handler)
return parser
def add_shared_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--config", type=Path, default=Path("config.toml"))
parser.add_argument("--profile")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--api-token")
parser.add_argument("--account-id")
parser.add_argument("--account-name")
parser.add_argument("--zone-name")
parser.add_argument("--repo-ref")
parser.add_argument("--pages-domain")
parser.add_argument("--worker-name")
parser.add_argument("--d1-name")
parser.add_argument("--destination-address")
parser.add_argument("--set", dest="set_values", action="append", default=[])
def _qualify_override_key(dotted_key: str, profile: str | None = None) -> str:
if not profile:
return dotted_key
return f"profiles.{profile}.{dotted_key}"
def collect_cli_overrides(args: argparse.Namespace) -> dict[str, Any]:
overrides: dict[str, Any] = {}
profile = getattr(args, "profile", None)
explicit_mapping = {
"api_token": "cloudflare.api_token",
"account_id": "cloudflare.account_id",
"account_name": "cloudflare.account_name",
"zone_name": "cloudflare.zone_name",
"repo_ref": "source.repo_ref",
"pages_domain": "pages.custom_domain",
"worker_name": "worker.script_name",
"d1_name": "d1.database_name",
"destination_address": "mail.verified_destination_address",
}
for argument_name, dotted_key in explicit_mapping.items():
value = getattr(args, argument_name, None)
if value not in (None, ""):
overrides[_qualify_override_key(dotted_key, profile)] = value
for raw_item in getattr(args, "set_values", []):
if "=" not in raw_item:
raise ConfigError(f"--set 参数缺少等号: {raw_item}")
dotted_key, raw_value = raw_item.split("=", 1)
overrides[_qualify_override_key(dotted_key, profile)] = parse_toml_value(raw_value)
return overrides
def apply_runtime_overrides(config_path: Path, args: argparse.Namespace) -> None:
overrides = collect_cli_overrides(args)
if not overrides:
return
document = load_toml_document(config_path)
apply_overrides(document, overrides)
save_toml_document(config_path, document)
def resolve_state_path(config_path: Path, profile: str | None = None) -> Path:
deploy_dir = config_path.parent / ".deploy"
if profile:
return deploy_dir / "profiles" / profile / "state.toml"
return deploy_dir / "state.toml"
def apply_profile_runtime_defaults(config: Any, profile: str | None = None) -> None:
if not profile:
return
if config.source.mode != "clone":
return
if config.source.workspace_dir != ".deploy/workspace":
return
config.source.workspace_dir = str(Path(".deploy") / "profiles" / profile / "workspace")
def handle_init_config(args: argparse.Namespace) -> int:
configure_logging()
config_path = args.config
if config_path.exists() and not args.force:
raise ConfigError(f"配置文件已存在: {config_path}")
save_toml_document(config_path, default_config_document())
get_logger("cli").info("已生成配置文件: %s", config_path)
return 0
def handle_check(args: argparse.Namespace) -> int:
configure_logging(args.verbose)
apply_runtime_overrides(args.config, args)
config = load_config(args.config, profile=args.profile)
apply_profile_runtime_defaults(config, args.profile)
state_path = resolve_state_path(args.config, args.profile)
state = load_state(state_path)
log_stage("开始检查本地环境。")
runner = CommandRunner()
versions = check_required_tools(runner)
if config.source.mode == "local":
source_path = Path(config.source.local_path)
if not source_path.exists():
raise ConfigError(f"本地源码目录不存在: {source_path}")
if not config.cloudflare.resolved_api_token():
raise ConfigError("cloudflare.api_token 或 cloudflare.api_token_env 需要提供其一。")
with CloudflareClient(config.cloudflare) as cloudflare:
token_info = cloudflare.verify_token()
account: dict[str, str] = {}
if config.cloudflare.account_id or config.cloudflare.account_name:
account = cloudflare.resolve_account(
account_id=config.cloudflare.account_id,
account_name=config.cloudflare.account_name,
)
zone = cloudflare.resolve_zone(
zone_name=config.cloudflare.zone_name,
account_id=account.get("id", ""),
)
if not account:
zone_account = zone.get("account")
if isinstance(zone_account, dict):
account = {
"id": str(zone_account.get("id", "")),
"name": str(zone_account.get("name", "")),
}
for tool_name, version in versions.items():
get_logger("check").info("[tool] %s=%s", tool_name, version.version)
state.cloudflare.account_id = account.get("id", "")
state.cloudflare.account_name = account.get("name", "")
state.cloudflare.zone_id = zone.get("id", "")
state.cloudflare.zone_name = zone.get("name", config.cloudflare.zone_name)
state.mark_checkpoint("check_completed")
save_state(state_path, state)
get_logger("check").info("[cloudflare] token_status=%s", token_info.get("status", "unknown"))
get_logger("check").info("[cloudflare] account_id=%s", state.cloudflare.account_id or "<provided>")
get_logger("check").info("[cloudflare] zone_id=%s", state.cloudflare.zone_id)
get_logger("check").info("[config] worker_vars=%s", sorted(config.derived_worker_vars().keys()))
log_stage("本地环境检查完成。")
return 0
def handle_discover_config(args: argparse.Namespace) -> int:
configure_logging(args.verbose)
discover_config(
config_path=args.config,
profile=args.profile,
cli_overrides=collect_cli_overrides(args),
cloudflare_client_factory=CloudflareClient,
)
get_logger("discover").info("已写入发现结果: %s", args.config)
return 0
def handle_deploy(args: argparse.Namespace) -> int:
return _handle_deployment_command(args, is_resume=False)
def handle_resume(args: argparse.Namespace) -> int:
return _handle_deployment_command(args, is_resume=True)
def _handle_deployment_command(args: argparse.Namespace, *, is_resume: bool) -> int:
configure_logging(args.verbose)
apply_runtime_overrides(args.config, args)
config = load_config(args.config, profile=args.profile)
apply_profile_runtime_defaults(config, args.profile)
state_path = resolve_state_path(args.config, args.profile)
state = load_state(state_path)
if is_resume and not state.checkpoint:
raise DeployError("状态文件中缺少检查点,无法继续恢复部署。")
runner = CommandRunner()
run_deployment(
config_path=args.config,
config=config,
state_path=state_path,
state=state,
runner=runner,
cloudflare_client_factory=CloudflareClient,
is_resume=is_resume,
)
save_state(state_path, state)
return 0
def resolve_cli_argv(argv: list[str] | None) -> list[str]:
raw_args = list(sys.argv[1:] if argv is None else argv)
if not raw_args:
return ["deploy"]
if raw_args[0] in COMMAND_NAMES:
return raw_args
if raw_args[0] in {"-h", "--help"}:
return raw_args
if raw_args[0].startswith("-"):
return ["deploy", *raw_args]
return raw_args
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(resolve_cli_argv(argv))
try:
return args.handler(args)
except DeployError as exc:
configure_logging(getattr(args, "verbose", False))
get_logger("error").error("%s", exc)
return 1

View File

@@ -0,0 +1,727 @@
"""Cloudflare API client helpers."""
from __future__ import annotations
import time
from typing import Any, Literal
import httpx
from cf_temp_email_deploy import __version__
from cf_temp_email_deploy.errors import CloudflareAPIError, ConfigError
from cf_temp_email_deploy.models import CloudflareConfig
AuthMode = Literal["token", "global_key"]
RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}
PAGES_ACTIVE_STATUSES = {"active", "verified"}
EMAIL_ROUTING_READY_STATUSES = {"active", "enabled", "verified", "success"}
class CloudflareClient:
"""Minimal Cloudflare API client used by deployment commands."""
def __init__(
self,
config: CloudflareConfig,
*,
timeout: float = 30.0,
max_attempts: int = 3,
transport: httpx.BaseTransport | None = None,
) -> None:
self.config = config
self.max_attempts = max_attempts
self.client = httpx.Client(
base_url="https://api.cloudflare.com/client/v4",
timeout=timeout,
transport=transport,
trust_env=False,
headers={"User-Agent": f"cf-temp-email/{__version__}"},
)
def close(self) -> None:
self.client.close()
def __enter__(self) -> "CloudflareClient":
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
self.close()
def verify_token(self) -> dict[str, Any]:
payload = self.request("GET", "/user/tokens/verify", auth_mode="token")
result = payload.get("result")
if not isinstance(result, dict):
raise CloudflareAPIError("Token 校验返回结果格式异常。")
return result
def resolve_account(self, *, account_id: str = "", account_name: str = "") -> dict[str, Any]:
if account_id:
return {"id": account_id, "name": account_name}
if not account_name:
raise ConfigError("cloudflare.account_id 与 cloudflare.account_name 至少需要提供其一。")
accounts = self.list_paginated(
"/accounts",
params={"name": account_name},
auth_mode="global_key",
)
matches = [item for item in accounts if item.get("name") == account_name]
if not matches:
raise CloudflareAPIError(f"未找到匹配的 Cloudflare 账户: {account_name}")
if len(matches) > 1:
raise CloudflareAPIError(f"存在多个同名 Cloudflare 账户: {account_name}")
return matches[0]
def resolve_zone(self, *, zone_name: str, account_id: str = "") -> dict[str, Any]:
params: dict[str, Any] = {"name": zone_name}
if account_id:
params["account.id"] = account_id
zones = self.list_paginated("/zones", params=params, auth_mode="token")
matches = [item for item in zones if item.get("name") == zone_name]
if not matches:
raise CloudflareAPIError(f"未找到匹配的 Zone: {zone_name}")
if len(matches) > 1:
raise CloudflareAPIError(f"匹配到多个同名 Zone: {zone_name}")
return matches[0]
def get_pages_project(self, *, account_id: str, project_name: str) -> dict[str, Any] | None:
return self._request_result_or_none(
"GET",
f"/accounts/{account_id}/pages/projects/{project_name}",
auth_mode="token",
)
def list_pages_projects(self, *, account_id: str) -> list[dict[str, Any]]:
payload = self.request(
"GET",
f"/accounts/{account_id}/pages/projects",
auth_mode="token",
)
result = payload.get("result")
if not isinstance(result, list):
raise CloudflareAPIError("Pages 项目列表返回结果格式异常。")
return [item for item in result if isinstance(item, dict)]
def ensure_pages_project(
self,
*,
account_id: str,
project_name: str,
production_branch: str,
) -> dict[str, Any]:
existing = self.get_pages_project(account_id=account_id, project_name=project_name)
if existing is not None:
return existing
try:
payload = self.request(
"POST",
f"/accounts/{account_id}/pages/projects",
json={"name": project_name, "production_branch": production_branch},
auth_mode="token",
)
except CloudflareAPIError as exc:
if "already exists" not in str(exc).lower():
raise
existing = self.get_pages_project(account_id=account_id, project_name=project_name)
if existing is not None:
return existing
raise
return self._result_dict(payload, "Pages 项目创建")
def get_pages_domain(
self,
*,
account_id: str,
project_name: str,
domain_name: str,
) -> dict[str, Any] | None:
return self._request_result_or_none(
"GET",
f"/accounts/{account_id}/pages/projects/{project_name}/domains/{domain_name}",
auth_mode="token",
)
def ensure_pages_domain(
self,
*,
account_id: str,
project_name: str,
domain_name: str,
) -> dict[str, Any]:
existing = self.get_pages_domain(
account_id=account_id,
project_name=project_name,
domain_name=domain_name,
)
if existing is not None:
return existing
try:
payload = self.request(
"POST",
f"/accounts/{account_id}/pages/projects/{project_name}/domains",
json={"name": domain_name},
auth_mode="token",
)
except CloudflareAPIError as exc:
if "already exists" not in str(exc).lower():
raise
existing = self.get_pages_domain(
account_id=account_id,
project_name=project_name,
domain_name=domain_name,
)
if existing is not None:
return existing
raise
return self._result_dict(payload, "Pages 域名绑定")
def wait_for_pages_domain_active(
self,
*,
account_id: str,
project_name: str,
domain_name: str,
timeout_seconds: float = 300.0,
poll_interval_seconds: float = 5.0,
) -> dict[str, Any]:
deadline = time.monotonic() + timeout_seconds
while True:
domain = self.get_pages_domain(
account_id=account_id,
project_name=project_name,
domain_name=domain_name,
)
if domain is None:
raise CloudflareAPIError(f"Pages 自定义域名不存在: {domain_name}")
status = str(domain.get("status", "")).lower()
if status in PAGES_ACTIVE_STATUSES:
return domain
if time.monotonic() >= deadline:
raise CloudflareAPIError(f"等待 Pages 自定义域名激活超时: {domain_name}")
time.sleep(poll_interval_seconds)
def list_dns_records(
self,
zone_id: str,
*,
name: str = "",
record_type: str = "",
) -> list[dict[str, Any]]:
params: dict[str, Any] = {}
if name:
params["name"] = name
if record_type:
params["type"] = record_type
return self.list_paginated(
f"/zones/{zone_id}/dns_records",
params=params,
auth_mode="token",
)
def ensure_cname_record(
self,
*,
zone_id: str,
name: str,
content: str,
proxied: bool = True,
ttl: int = 1,
) -> dict[str, Any]:
records = [record for record in self.list_dns_records(zone_id, name=name) if record.get("name") == name]
conflicts = [record for record in records if record.get("type") != "CNAME"]
if conflicts:
raise CloudflareAPIError(f"DNS 记录冲突: {name} 已存在非 CNAME 记录。")
cname_records = [record for record in records if record.get("type") == "CNAME"]
if len(cname_records) > 1:
raise CloudflareAPIError(f"DNS 记录冲突: {name} 存在多条 CNAME 记录。")
if not cname_records:
return self._create_dns_record(
zone_id=zone_id,
payload={"type": "CNAME", "name": name, "content": content, "proxied": proxied, "ttl": ttl},
)
record = cname_records[0]
if (
record.get("content") == content
and bool(record.get("proxied", proxied)) == proxied
and int(record.get("ttl", ttl)) == ttl
):
return record
return self._update_dns_record(
zone_id=zone_id,
record_id=str(record.get("id", "")),
payload={"type": "CNAME", "name": name, "content": content, "proxied": proxied, "ttl": ttl},
)
def ensure_dns_record(
self,
*,
zone_id: str,
record_type: str,
name: str,
content: str,
proxied: bool | None = None,
ttl: int | None = 1,
priority: int | None = None,
) -> dict[str, Any]:
records = [
record
for record in self.list_dns_records(zone_id, name=name, record_type=record_type)
if record.get("name") == name and record.get("type") == record_type
]
exact = next(
(
record
for record in records
if self._dns_record_matches(
record,
content=content,
proxied=proxied,
ttl=ttl,
priority=priority,
)
),
None,
)
if exact is not None:
return exact
payload: dict[str, Any] = {"type": record_type, "name": name, "content": content}
if proxied is not None:
payload["proxied"] = proxied
if ttl is not None:
payload["ttl"] = ttl
if priority is not None:
payload["priority"] = priority
if len(records) == 1 and record_type in {"TXT", "CNAME"}:
return self._update_dns_record(
zone_id=zone_id,
record_id=str(records[0].get("id", "")),
payload=payload,
)
return self._create_dns_record(zone_id=zone_id, payload=payload)
def list_d1_databases(self, *, account_id: str, database_name: str = "") -> list[dict[str, Any]]:
databases = self.list_paginated(
f"/accounts/{account_id}/d1/database",
auth_mode="token",
)
if not database_name:
return databases
return [item for item in databases if item.get("name") == database_name]
def ensure_d1_database(
self,
*,
account_id: str,
database_name: str,
jurisdiction: str = "",
) -> dict[str, Any]:
existing = self.list_d1_databases(account_id=account_id, database_name=database_name)
if existing:
if len(existing) > 1:
raise CloudflareAPIError(f"存在多个同名 D1 数据库: {database_name}")
return existing[0]
payload_data: dict[str, Any] = {"name": database_name}
if jurisdiction:
payload_data["primary_location_hint"] = jurisdiction
payload = self.request(
"POST",
f"/accounts/{account_id}/d1/database",
json=payload_data,
auth_mode="token",
)
return self._result_dict(payload, "D1 数据库创建")
def query_d1(
self,
*,
account_id: str,
database_id: str,
sql: str,
params: list[Any] | None = None,
) -> list[dict[str, Any]]:
payload = self.request(
"POST",
f"/accounts/{account_id}/d1/database/{database_id}/query",
json={"sql": sql, "params": params or []},
auth_mode="token",
)
result = payload.get("result")
if not isinstance(result, list):
raise CloudflareAPIError("D1 查询返回结果格式异常。")
if not result:
return []
first = result[0]
if not isinstance(first, dict):
raise CloudflareAPIError("D1 查询结果项格式异常。")
rows = first.get("results", [])
if not isinstance(rows, list):
raise CloudflareAPIError("D1 查询结果行格式异常。")
return [row for row in rows if isinstance(row, dict)]
def get_workers_subdomain(self, *, account_id: str) -> dict[str, Any] | None:
return self._request_result_or_none(
"GET",
f"/accounts/{account_id}/workers/subdomain",
auth_mode="token",
)
def get_worker_script(self, *, account_id: str, script_name: str) -> dict[str, Any] | None:
scripts = self.list_paginated(
f"/accounts/{account_id}/workers/scripts",
auth_mode="token",
)
for script in scripts:
if str(script.get("id", "")) == script_name:
return script
return None
def list_worker_scripts(self, *, account_id: str) -> list[dict[str, Any]]:
return self.list_paginated(
f"/accounts/{account_id}/workers/scripts",
auth_mode="token",
)
@staticmethod
def worker_script_supports_email(script: dict[str, Any]) -> bool:
handlers = script.get("handlers")
if not isinstance(handlers, list):
return False
return "email" in {str(handler).strip().lower() for handler in handlers}
def list_email_routing_addresses(self, *, account_id: str) -> list[dict[str, Any]]:
return self._list_email_routing_paginated(f"/accounts/{account_id}/email/routing/addresses")
def get_email_routing_dns(self, *, zone_id: str) -> list[dict[str, Any]]:
payload = self.request_email_routing("GET", f"/zones/{zone_id}/email/routing/dns")
result = payload.get("result")
if isinstance(result, list):
return [item for item in result if isinstance(item, dict)]
if isinstance(result, dict):
records = result.get("records")
if isinstance(records, list):
return [item for item in records if isinstance(item, dict)]
raise CloudflareAPIError("Email Routing DNS 返回结果格式异常。")
def get_catch_all(self, *, zone_id: str) -> dict[str, Any] | None:
return self._request_email_routing_result_or_none(
"GET",
f"/zones/{zone_id}/email/routing/rules/catch_all",
)
def ensure_catch_all_worker(self, *, zone_id: str, script_name: str) -> dict[str, Any]:
current = self.get_catch_all(zone_id=zone_id)
if current is not None and self.catch_all_points_to_worker(current, script_name):
return current
payload = self.request_email_routing(
"PUT",
f"/zones/{zone_id}/email/routing/rules/catch_all",
json={
"matchers": [{"type": "all"}],
"actions": [{"type": "worker", "value": [script_name]}],
"enabled": True,
},
)
return self._result_dict(payload, "Catch-all 更新")
@staticmethod
def catch_all_points_to_worker(rule: dict[str, Any], script_name: str) -> bool:
actions = rule.get("actions")
if not isinstance(actions, list):
return False
for action in actions:
if not isinstance(action, dict):
continue
if action.get("type") != "worker":
continue
if script_name in CloudflareClient.extract_worker_targets(action.get("value")):
return True
return False
@staticmethod
def email_address_ready(address: dict[str, Any], target: str) -> bool:
if str(address.get("email", "")).lower() != target.lower():
return False
status = str(address.get("status", "")).lower()
return status in EMAIL_ROUTING_READY_STATUSES
@staticmethod
def extract_worker_targets(value: Any) -> set[str]:
stack = [value]
targets: set[str] = set()
while stack:
current = stack.pop()
if current is None:
continue
if isinstance(current, str):
candidate = current.strip()
if candidate:
targets.add(candidate)
continue
if isinstance(current, list):
stack.extend(current)
continue
if isinstance(current, dict):
for key in ("worker", "name", "script", "service", "value"):
if key in current:
stack.append(current[key])
return targets
@staticmethod
def is_authentication_error(error: Exception) -> bool:
if isinstance(error, ConfigError):
message = str(error).lower()
return "旧式鉴权" in str(error) or "api_email" in message or "global_api_key" in message
if not isinstance(error, CloudflareAPIError):
return False
return CloudflareClient._should_fallback_to_global_key(error)
def request_email_routing(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
) -> dict[str, Any]:
try:
return self.request(method, path, params=params, json=json, auth_mode="token")
except CloudflareAPIError as exc:
if not self._should_fallback_to_global_key(exc):
raise
return self.request(method, path, params=params, json=json, auth_mode="global_key")
def list_paginated(
self,
path: str,
*,
params: dict[str, Any] | None = None,
auth_mode: AuthMode,
) -> list[dict[str, Any]]:
collected: list[dict[str, Any]] = []
page = 1
per_page = 50
while True:
current_params = dict(params or {})
current_params.setdefault("page", page)
current_params.setdefault("per_page", per_page)
payload = self.request("GET", path, params=current_params, auth_mode=auth_mode)
result = payload.get("result")
if not isinstance(result, list):
raise CloudflareAPIError(f"分页接口返回结果格式异常: {path}")
collected.extend(result)
result_info = payload.get("result_info") or {}
total_pages = result_info.get("total_pages")
if not total_pages or page >= total_pages:
break
page += 1
return collected
def request(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
auth_mode: AuthMode,
) -> dict[str, Any]:
last_error: CloudflareAPIError | None = None
for attempt in range(1, self.max_attempts + 1):
try:
response = self.client.request(
method,
path,
params=params,
json=json,
headers=self._headers(auth_mode),
)
except httpx.RequestError as exc:
last_error = CloudflareAPIError(f"Cloudflare 请求失败: {exc}") # pragma: no cover
if attempt < self.max_attempts:
time.sleep(0.2 * attempt)
continue
raise last_error from exc
if response.status_code in RETRYABLE_STATUS_CODES and attempt < self.max_attempts:
time.sleep(0.2 * attempt)
continue
payload = self._parse_payload(response)
if response.status_code >= 400 or payload.get("success") is False:
error = CloudflareAPIError(
self._build_error_message(payload, response),
status_code=response.status_code,
)
if response.status_code in RETRYABLE_STATUS_CODES and attempt < self.max_attempts:
last_error = error
time.sleep(0.2 * attempt)
continue
raise error
return payload
if last_error is not None:
raise last_error
raise CloudflareAPIError("Cloudflare 请求失败,且没有返回可用结果。")
def _request_result_or_none(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
auth_mode: AuthMode,
) -> dict[str, Any] | None:
try:
payload = self.request(method, path, params=params, json=json, auth_mode=auth_mode)
except CloudflareAPIError as exc:
if exc.status_code == 404:
return None
raise
return self._result_dict(payload, path)
def _request_email_routing_result_or_none(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
try:
payload = self.request_email_routing(method, path, params=params, json=json)
except CloudflareAPIError as exc:
if exc.status_code == 404:
return None
raise
return self._result_dict(payload, path)
def _list_email_routing_paginated(self, path: str) -> list[dict[str, Any]]:
collected: list[dict[str, Any]] = []
page = 1
per_page = 50
while True:
payload = self.request_email_routing(
"GET",
path,
params={"page": page, "per_page": per_page},
)
result = payload.get("result")
if not isinstance(result, list):
raise CloudflareAPIError(f"分页接口返回结果格式异常: {path}")
collected.extend(item for item in result if isinstance(item, dict))
result_info = payload.get("result_info") or {}
total_pages = result_info.get("total_pages")
if not total_pages or page >= total_pages:
break
page += 1
return collected
def _create_dns_record(self, *, zone_id: str, payload: dict[str, Any]) -> dict[str, Any]:
response = self.request(
"POST",
f"/zones/{zone_id}/dns_records",
json=payload,
auth_mode="token",
)
return self._result_dict(response, "DNS 记录创建")
def _update_dns_record(self, *, zone_id: str, record_id: str, payload: dict[str, Any]) -> dict[str, Any]:
response = self.request(
"PUT",
f"/zones/{zone_id}/dns_records/{record_id}",
json=payload,
auth_mode="token",
)
return self._result_dict(response, "DNS 记录更新")
@staticmethod
def _dns_record_matches(
record: dict[str, Any],
*,
content: str,
proxied: bool | None,
ttl: int | None,
priority: int | None,
) -> bool:
if record.get("content") != content:
return False
if proxied is not None and bool(record.get("proxied", proxied)) != proxied:
return False
if ttl is not None:
try:
if int(record.get("ttl", ttl)) != ttl:
return False
except (TypeError, ValueError):
return False
if priority is not None:
try:
if int(record.get("priority", priority)) != priority:
return False
except (TypeError, ValueError):
return False
return True
@staticmethod
def _result_dict(payload: dict[str, Any], action: str) -> dict[str, Any]:
result = payload.get("result")
if not isinstance(result, dict):
raise CloudflareAPIError(f"{action} 返回结果格式异常。")
return result
def _headers(self, auth_mode: AuthMode) -> dict[str, str]:
if auth_mode == "token":
token = self.config.resolved_api_token()
if not token:
raise ConfigError("缺少 Cloudflare API Token。")
return {"Authorization": f"Bearer {token}"}
email = self.config.resolved_api_email()
api_key = self.config.resolved_global_api_key()
if not email or not api_key:
raise ConfigError("旧式鉴权需要同时提供 api_email 与 global_api_key。")
return {
"X-Auth-Email": email,
"X-Auth-Key": api_key,
}
@staticmethod
def _parse_payload(response: httpx.Response) -> dict[str, Any]:
try:
payload = response.json()
except ValueError as exc:
raise CloudflareAPIError("Cloudflare 返回内容不是合法 JSON。", status_code=response.status_code) from exc
if not isinstance(payload, dict):
raise CloudflareAPIError("Cloudflare 返回内容格式异常。", status_code=response.status_code)
return payload
@staticmethod
def _build_error_message(payload: dict[str, Any], response: httpx.Response) -> str:
errors = payload.get("errors") or []
fragments = [
error.get("message", str(error))
for error in errors
if isinstance(error, dict)
]
if not fragments and response.text:
fragments.append(response.text.strip())
message = "; ".join(fragment for fragment in fragments if fragment)
if not message:
message = f"HTTP {response.status_code}"
return f"Cloudflare API 返回错误: {message}"
@staticmethod
def _should_fallback_to_global_key(error: CloudflareAPIError) -> bool:
if error.status_code not in {400, 401, 403}:
return False
message = str(error).lower()
return any(
fragment in message
for fragment in ("x-auth-key", "x-auth-email", "global api key", "authentication")
)

View File

@@ -0,0 +1,278 @@
"""Configuration and state file utilities."""
from __future__ import annotations
from copy import deepcopy
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Mapping
import tomlkit
from tomlkit import TOMLDocument
from tomlkit.exceptions import ParseError
from tomlkit.items import AbstractTable
from cf_temp_email_deploy.errors import ConfigError
from cf_temp_email_deploy.models import DeploymentConfig, DeploymentState
DEFAULT_CONFIG_TOML = """# 必填参数一览
# [必填] cloudflare.zone_name
# [必填] cloudflare.api_token 或 cloudflare.api_token_env
# [必填] mail.domains
# [必填] pages.custom_domain
# [必填] worker.vars.ADMIN_PASSWORDS
# [条件必填] source.local_path当 source.mode = "local" 时填写
# [选填] source.repo_ref留空时自动跟随上游默认分支最新提交如需固定版本可填写 tag、branch 或 commit
# 主配置版本号,供 CLI 与状态加载逻辑识别。
config_version = 1
[source]
# [选填] "clone" 表示自动拉取远端仓库到 workspace_dir。
# [条件必填] "local" 表示直接使用 local_path 指向的本地源码目录。
mode = "clone"
# [选填] 固定源码仓库地址。
repo_url = "https://github.com/dreamhunter2333/cloudflare_temp_email.git"
# [选填] 留空时自动跟随上游默认分支最新提交。
repo_ref = ""
# [选填] 克隆源码、安装依赖、构建产物时使用的本地工作目录。
workspace_dir = ".deploy/workspace"
# [条件必填] 仅在 mode = "local" 时生效。
local_path = ""
[cloudflare]
# [选填] 已知 account_id 时优先填写该字段。
account_id = ""
# [选填] 只有在 account_id 未提供时才需要 account_name。
account_name = ""
# [必填] 承载前端域名、Worker 域名与邮件域名的 Zone。
zone_name = "example.com"
# [必填] 本工具默认使用的主 API Token。
api_token = ""
# [必填-二选一] 共享环境中可改为从环境变量读取 Token。
api_token_env = ""
# [选填] Email Routing 部分接口需要旧式鉴权时,可填写邮箱与 Global API Key。
api_email = ""
api_email_env = ""
global_api_key = ""
global_api_key_env = ""
[mail]
# [必填] 邮件接收域名,会同步写入 Email Routing 与 Worker 的 DOMAINS。
# [必填] 该主机名必须与 pages.custom_domain 分离。
domains = ["example.com"]
# [选填] 如需严格校验 Email Routing 目标地址,可填写一个已验证的真实邮箱。
# [选填] 留空时部署会跳过目标地址校验,仍继续配置 MX/SPF 与 Catch-all Worker。
verified_destination_address = "inbox@example.net"
[d1]
# [选填] 当前部署要创建或复用的远端 D1 数据库名称。
database_name = "cf-temp-email"
# [选填] 可选的 D1 地域提示;留空时使用默认位置。
jurisdiction = ""
# [选填] 仅在接管已有数据库且缺少 __deploy_history 时开启。
adopt_existing_schema = false
[user_access]
# [选填] 是否要求登录后才能创建邮箱;启用 Linux.do OAuth2 时会自动强制为 true。
require_login_to_create = true
# [选填] 是否允许用户自行注册。
allow_user_register = false
[linuxdo]
# [选填] 是否启用 LINUX DO OAuth2 登录。
linuxdo_oauth = false
# [条件必填] 当 linuxdo_oauth = true 时填写。
client_id = ""
# [条件必填] 当 linuxdo_oauth = true 时填写。
client_secret = ""
[worker]
# [选填] Worker 服务名称,会同时用于 wrangler、Pages 服务绑定与 Catch-all。
script_name = "cloudflare-temp-email"
# [选填] 保留 workers.dev 地址,用于调试与回退检查。
use_workers_dev = true
# [选填] 生产环境中的 Worker 自定义域名。
custom_domain = ""
# [选填] 写入 worker/wrangler.toml 的 compatibility_date。
compatibility_date = "2024-09-23"
[worker.vars]
# [选填] 写入 worker/wrangler.toml 的普通环境变量。
PREFIX = "tmp"
ENABLE_USER_CREATE_EMAIL = true
ENABLE_USER_DELETE_EMAIL = true
DEFAULT_LANG = "zh"
# [必填] 管理员密码列表至少保留一个值,部署后会使用首项调用管理员接口。
ADMIN_PASSWORDS = ["change-me"]
[worker.secrets]
# [选填] 标量秘密值通过 `wrangler secret put` 写入 Cloudflare。
# [选填] JWT_SECRET 留空时,会在首次部署 Worker 前自动生成安全随机值并写回 config.toml。
JWT_SECRET = ""
[pages]
# [选填] 目标账户中要创建或复用的 Pages 项目名称。
project_name = "cf-temp-email-pages"
# [必填] 前端访问域名,必须与 mail.domains 分离。
custom_domain = "email.example.com"
# [选填] "pages" 构建标准前端;"pages:nopwa" 关闭 PWA 产物。
build_mode = "pages"
production_branch = "production"
# [选填] 多账号/多环境场景可在 profiles 下定义命名配置。
# CLI 使用 `--profile <name>` 选择对应配置,未覆盖的字段会继承根配置。
#
# [profiles.prod-cn.cloudflare]
# account_id = "acc-prod-cn"
# zone_name = "kotei.asia"
# api_token_env = "CF_API_TOKEN_CN"
#
# [profiles.prod-cn.mail]
# domains = ["mail.kotei.asia"]
#
# [profiles.prod-cn.pages]
# custom_domain = "email.kotei.asia"
"""
def default_config_document() -> TOMLDocument:
"""Return the default deployment configuration document."""
return tomlkit.parse(DEFAULT_CONFIG_TOML)
def _ensure_parent_directory(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def write_text_atomic(path: Path, text: str) -> None:
"""Write a text file atomically."""
_ensure_parent_directory(path)
with NamedTemporaryFile("w", encoding="utf-8", dir=path.parent, delete=False) as handle:
handle.write(text)
handle.flush()
os.fsync(handle.fileno())
temp_path = Path(handle.name)
temp_path.replace(path)
def load_toml_document(path: Path) -> TOMLDocument:
"""Load a TOML document from disk."""
try:
return tomlkit.parse(path.read_text(encoding="utf-8"))
except FileNotFoundError as exc:
raise ConfigError(f"配置文件不存在: {path}") from exc
except ParseError as exc:
raise ConfigError(f"TOML 解析失败: {path}") from exc
def save_toml_document(path: Path, document: TOMLDocument) -> None:
"""Persist a TOML document atomically."""
write_text_atomic(path, tomlkit.dumps(document))
def _deep_merge_dicts(base: Mapping[str, Any], override: Mapping[str, Any]) -> dict[str, Any]:
merged = deepcopy(dict(base))
for key, value in override.items():
current = merged.get(key)
if isinstance(current, dict) and isinstance(value, Mapping):
merged[key] = _deep_merge_dicts(current, value)
continue
merged[key] = deepcopy(value)
return merged
def _resolve_profile_payload(payload: Mapping[str, Any], profile: str | None = None) -> dict[str, Any]:
base_payload = {key: deepcopy(value) for key, value in payload.items() if key != "profiles"}
if not profile:
return base_payload
profiles = payload.get("profiles", {})
if not isinstance(profiles, Mapping):
raise ConfigError("配置校验失败: profiles 必须是表。")
selected = profiles.get(profile)
if selected is None:
raise ConfigError(f"配置校验失败: 未找到 profile: {profile}")
if not isinstance(selected, Mapping):
raise ConfigError(f"配置校验失败: profiles.{profile} 必须是表。")
return _deep_merge_dicts(base_payload, selected)
def load_config(path: Path, *, profile: str | None = None) -> DeploymentConfig:
"""Load and validate the deployment configuration."""
document = load_toml_document(path)
return config_from_document(document, profile=profile)
def config_from_document(document: TOMLDocument, *, profile: str | None = None) -> DeploymentConfig:
"""Build and validate a deployment configuration from a TOML document."""
try:
payload = _resolve_profile_payload(document.unwrap(), profile)
return DeploymentConfig.model_validate(payload)
except ValueError as exc:
raise ConfigError(f"配置校验失败: {exc}") from exc
def parse_toml_value(raw_value: str) -> Any:
"""Parse a CLI value using TOML literal syntax."""
snippet = f"value = {raw_value}\n"
try:
item = tomlkit.parse(snippet)["value"]
except ParseError:
return raw_value
return item.unwrap() if hasattr(item, "unwrap") else item
def set_dotted_value(document: TOMLDocument, dotted_key: str, value: Any) -> None:
"""Set a nested value inside a TOML document."""
parts = dotted_key.split(".")
if not all(parts):
raise ConfigError(f"非法配置键: {dotted_key}")
current: TOMLDocument | AbstractTable = document
for part in parts[:-1]:
if part not in current:
current[part] = tomlkit.table()
next_value = current[part]
if not isinstance(next_value, AbstractTable):
raise ConfigError(f"配置键 {dotted_key} 的父节点不是表: {part}")
current = next_value
current[parts[-1]] = tomlkit.item(value)
def apply_overrides(document: TOMLDocument, overrides: Mapping[str, Any]) -> TOMLDocument:
"""Apply CLI override values to a configuration document."""
for dotted_key, value in overrides.items():
set_dotted_value(document, dotted_key, value)
return document
def load_state(path: Path) -> DeploymentState:
"""Load deployment state or return the default state when absent."""
if not path.exists():
return DeploymentState()
try:
document = tomlkit.parse(path.read_text(encoding="utf-8"))
return DeploymentState.model_validate(document.unwrap())
except (ParseError, ValueError) as exc:
raise ConfigError(f"状态文件校验失败: {path}") from exc
def save_state(path: Path, state: DeploymentState) -> None:
"""Persist deployment state atomically."""
write_text_atomic(path, tomlkit.dumps(state.model_dump(mode="python")))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
"""Discovery helpers for importing existing Cloudflare resources into config files."""
from __future__ import annotations
from typing import Any
from cf_temp_email_deploy.cloudflare import CloudflareClient
from cf_temp_email_deploy.config import (
apply_overrides,
config_from_document,
default_config_document,
load_toml_document,
save_toml_document,
)
from cf_temp_email_deploy.errors import CloudflareAPIError, ConfigError
from cf_temp_email_deploy.logging_utils import get_logger
LOGGER = get_logger("discover")
def discover_config(
*,
config_path,
profile: str | None,
cli_overrides: dict[str, Any],
cloudflare_client_factory: type[CloudflareClient] = CloudflareClient,
) -> None:
document = load_toml_document(config_path) if config_path.exists() else default_config_document()
if cli_overrides:
apply_overrides(document, cli_overrides)
config = config_from_document(document, profile=profile)
if not config.cloudflare.resolved_api_token():
raise ConfigError("discover-config 需要 cloudflare.api_token 或 cloudflare.api_token_env。")
if not config.cloudflare.zone_name:
raise ConfigError("discover-config 需要 cloudflare.zone_name。")
with cloudflare_client_factory(config.cloudflare) as cloudflare:
discovered = _discover_overrides(cloudflare, config.cloudflare.zone_name, config.cloudflare.account_id)
target_overrides = {
_qualify_profile_key(profile, dotted_key): value for dotted_key, value in discovered.items()
}
apply_overrides(document, target_overrides)
save_toml_document(config_path, document)
def _qualify_profile_key(profile: str | None, dotted_key: str) -> str:
if not profile:
return dotted_key
return f"profiles.{profile}.{dotted_key}"
def _discover_overrides(
cloudflare: CloudflareClient,
zone_name: str,
configured_account_id: str = "",
) -> dict[str, Any]:
token_info = cloudflare.verify_token()
LOGGER.info("[discover] token_status=%s", token_info.get("status", "unknown"))
zone = cloudflare.resolve_zone(zone_name=zone_name, account_id=configured_account_id)
zone_account = zone.get("account") if isinstance(zone.get("account"), dict) else {}
account_id = configured_account_id or str(zone_account.get("id", ""))
account_name = str(zone_account.get("name", ""))
if not account_id:
raise CloudflareAPIError("无法从 Zone 信息中解析 account_id。")
overrides: dict[str, Any] = {
"cloudflare.zone_name": str(zone.get("name", zone_name)),
"cloudflare.account_id": account_id,
}
if account_name:
overrides["cloudflare.account_name"] = account_name
zone_id = str(zone.get("id", ""))
dns_records = cloudflare.list_dns_records(zone_id)
mail_domains = _discover_mail_domains(dns_records)
if mail_domains:
overrides["mail.domains"] = mail_domains
verified_destination = _discover_verified_destination(cloudflare, account_id)
if verified_destination:
overrides["mail.verified_destination_address"] = verified_destination
d1_name = _discover_d1_name(cloudflare, account_id)
if d1_name:
overrides["d1.database_name"] = d1_name
pages = _discover_pages(cloudflare, account_id, dns_records)
if pages.get("project_name"):
overrides["pages.project_name"] = pages["project_name"]
if pages.get("custom_domain"):
overrides["pages.custom_domain"] = pages["custom_domain"]
worker = _discover_worker(cloudflare, zone_id, account_id)
if worker.get("script_name"):
overrides["worker.script_name"] = worker["script_name"]
return overrides
def _discover_mail_domains(dns_records: list[dict[str, Any]]) -> list[str]:
domains = {
str(record.get("name", "")).strip()
for record in dns_records
if str(record.get("type", "")).upper() == "MX"
and "mx.cloudflare.net" in str(record.get("content", "")).lower()
and str(record.get("name", "")).strip()
}
return sorted(domains)
def _discover_verified_destination(cloudflare: CloudflareClient, account_id: str) -> str:
try:
addresses = cloudflare.list_email_routing_addresses(account_id=account_id)
except Exception as exc:
if cloudflare.is_authentication_error(exc):
LOGGER.warning("[discover] 当前鉴权无法读取 Email Routing 地址,跳过 verified_destination_address。")
return ""
raise
for address in addresses:
email = str(address.get("email", "")).strip()
if email and cloudflare.email_address_ready(address, email):
return email
return ""
def _discover_d1_name(cloudflare: CloudflareClient, account_id: str) -> str:
databases = cloudflare.list_d1_databases(account_id=account_id)
if len(databases) == 1:
return str(databases[0].get("name", "")).strip()
if len(databases) > 1:
LOGGER.warning("[discover] 检测到多个 D1 数据库,跳过自动写入 database_name。")
return ""
def _discover_pages(
cloudflare: CloudflareClient,
account_id: str,
dns_records: list[dict[str, Any]],
) -> dict[str, str]:
projects = cloudflare.list_pages_projects(account_id=account_id)
if not projects:
return {}
if len(projects) > 1:
LOGGER.warning("[discover] 检测到多个 Pages 项目,将尽量通过 DNS 反推。")
cname_by_content = {
str(record.get("content", "")).strip().lower(): str(record.get("name", "")).strip()
for record in dns_records
if str(record.get("type", "")).upper() == "CNAME"
}
matches: list[dict[str, str]] = []
for project in projects:
subdomain = str(project.get("subdomain", "")).strip().lower()
if not subdomain:
continue
custom_domain = cname_by_content.get(subdomain, "")
matches.append(
{
"project_name": str(project.get("name", "")).strip(),
"custom_domain": custom_domain,
}
)
exact = [item for item in matches if item["project_name"] and item["custom_domain"]]
if len(exact) == 1:
return exact[0]
if len(projects) == 1:
return {
"project_name": str(projects[0].get("name", "")).strip(),
"custom_domain": exact[0]["custom_domain"] if exact else "",
}
return {}
def _discover_worker(cloudflare: CloudflareClient, zone_id: str, account_id: str) -> dict[str, str]:
catch_all = None
try:
catch_all = cloudflare.get_catch_all(zone_id=zone_id)
except Exception as exc:
if not cloudflare.is_authentication_error(exc):
raise
LOGGER.warning("[discover] 当前鉴权无法读取 Catch-all尝试从 Worker 列表反推脚本。")
if isinstance(catch_all, dict):
actions = catch_all.get("actions")
if isinstance(actions, list):
for action in actions:
if not isinstance(action, dict):
continue
if action.get("type") != "worker":
continue
targets = sorted(cloudflare.extract_worker_targets(action.get("value")))
if targets:
return {"script_name": targets[0]}
scripts = cloudflare.list_worker_scripts(account_id=account_id)
email_scripts = [item for item in scripts if cloudflare.worker_script_supports_email(item)]
if len(email_scripts) == 1:
return {"script_name": str(email_scripts[0].get("id", "")).strip()}
if len(email_scripts) > 1:
LOGGER.warning("[discover] 检测到多个带 email handler 的 Worker跳过自动写入 script_name。")
return {}

View File

@@ -0,0 +1,53 @@
"""Local environment checks."""
from __future__ import annotations
import re
from dataclasses import dataclass
from cf_temp_email_deploy.errors import EnvironmentCheckError
from cf_temp_email_deploy.subprocess_runner import CommandRunner, CommandSpec
MINIMUM_NODE_VERSION = (20, 19, 0)
@dataclass(frozen=True)
class ToolVersion:
name: str
version: str
def parse_semver(raw_version: str) -> tuple[int, int, int]:
"""Extract a semantic version tuple from command output."""
match = re.search(r"(\d+)\.(\d+)\.(\d+)", raw_version)
if not match:
raise EnvironmentCheckError(f"无法解析版本号: {raw_version.strip()}")
return tuple(int(part) for part in match.groups())
def check_tool_version(runner: CommandRunner, command_name: str, version_flag: str = "--version") -> ToolVersion:
"""Execute a tool version command and return the parsed output."""
result = runner.run_checked(CommandSpec(args=(command_name, version_flag)))
version_text = (result.stdout or result.stderr).strip()
if not version_text:
raise EnvironmentCheckError(f"无法读取 {command_name} 的版本输出。")
return ToolVersion(name=command_name, version=version_text)
def check_required_tools(runner: CommandRunner) -> dict[str, ToolVersion]:
"""Validate required local tools and versions."""
versions = {
"git": check_tool_version(runner, "git"),
"node": check_tool_version(runner, "node"),
"npm": check_tool_version(runner, "npm"),
}
node_version = parse_semver(versions["node"].version)
if node_version < MINIMUM_NODE_VERSION:
minimum = ".".join(str(part) for part in MINIMUM_NODE_VERSION)
current = ".".join(str(part) for part in node_version)
raise EnvironmentCheckError(f"Node.js 版本过低,当前为 {current},最低要求为 {minimum}")
return versions

View File

@@ -0,0 +1,54 @@
"""Project specific exceptions."""
from __future__ import annotations
class DeployError(Exception):
"""Base exception for the deployment CLI."""
class ConfigError(DeployError):
"""Raised when configuration content is invalid."""
class EnvironmentCheckError(DeployError):
"""Raised when the local environment does not satisfy requirements."""
class CloudflareAPIError(DeployError):
"""Raised when Cloudflare API responses cannot be accepted."""
def __init__(self, message: str, *, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
class CommandExecutionError(DeployError):
"""Raised when a subprocess returns an unsuccessful result."""
def __init__(
self,
message: str,
*,
command: tuple[str, ...],
returncode: int | None,
stdout: str,
stderr: str,
) -> None:
super().__init__(message)
self.command = command
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
class AcceptanceCheckError(DeployError):
"""Raised when post-deployment acceptance checks fail."""
class ApplicationAPIError(DeployError):
"""Raised when application admin API responses cannot be accepted."""
def __init__(self, message: str, *, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code

View File

@@ -0,0 +1,62 @@
"""Logging helpers used across the project."""
from __future__ import annotations
import logging
import shlex
from pathlib import Path
LOGGER_NAME = "cf_temp_email_deploy"
LOG_FORMAT = "%(asctime)s | %(levelname)s | %(message)s"
def configure_logging(verbose: bool = False) -> logging.Logger:
"""Configure and return the package root logger."""
logger = logging.getLogger(LOGGER_NAME)
level = logging.DEBUG if verbose else logging.INFO
logger.setLevel(level)
logger.propagate = False
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(handler)
for handler in logger.handlers:
handler.setLevel(level)
return logger
def get_logger(name: str | None = None) -> logging.Logger:
"""Return a child logger inside the package namespace."""
if not name:
return logging.getLogger(LOGGER_NAME)
return logging.getLogger(f"{LOGGER_NAME}.{name}")
def log_stage(message: str) -> None:
"""Emit a stage-level log message."""
get_logger("stage").info("[stage] %s", message)
def log_command(command: tuple[str, ...], cwd: Path | None) -> None:
"""Emit a log line before a subprocess starts."""
location = str(cwd) if cwd else "."
get_logger("command").info("[command] cwd=%s cmd=%s", location, shlex.join(command))
def log_command_result(command: tuple[str, ...], returncode: int, duration_seconds: float) -> None:
"""Emit a log line after a subprocess finishes."""
get_logger("command").info(
"[command-result] code=%s duration=%.3fs cmd=%s",
returncode,
duration_seconds,
shlex.join(command),
)

View File

@@ -0,0 +1,258 @@
"""Typed configuration and state models."""
from __future__ import annotations
import os
from datetime import UTC, datetime
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
def _now_isoformat() -> str:
return datetime.now(UTC).isoformat(timespec="seconds")
class StrictModel(BaseModel):
"""Shared Pydantic configuration."""
model_config = ConfigDict(extra="forbid", validate_assignment=True)
class SourceConfig(StrictModel):
mode: Literal["clone", "local"] = "clone"
repo_url: str = "https://github.com/dreamhunter2333/cloudflare_temp_email.git"
repo_ref: str = ""
workspace_dir: str = ".deploy/workspace"
local_path: str = ""
class CloudflareConfig(StrictModel):
account_id: str = ""
account_name: str = ""
zone_name: str = "example.com"
api_token: str = ""
api_token_env: str = ""
api_email: str = ""
api_email_env: str = ""
global_api_key: str = ""
global_api_key_env: str = ""
def resolved_api_token(self, environ: dict[str, str] | None = None) -> str:
environment = environ or os.environ
if self.api_token:
return self.api_token
if self.api_token_env:
return environment.get(self.api_token_env, "")
return ""
def resolved_api_email(self, environ: dict[str, str] | None = None) -> str:
environment = environ or os.environ
if self.api_email:
return self.api_email
if self.api_email_env:
return environment.get(self.api_email_env, "")
return ""
def resolved_global_api_key(self, environ: dict[str, str] | None = None) -> str:
environment = environ or os.environ
if self.global_api_key:
return self.global_api_key
if self.global_api_key_env:
return environment.get(self.global_api_key_env, "")
return ""
class MailConfig(StrictModel):
domains: list[str] = Field(default_factory=lambda: ["mail.example.com"])
verified_destination_address: str = "inbox@example.net"
class D1Config(StrictModel):
database_name: str = "cf-temp-email"
jurisdiction: str = ""
adopt_existing_schema: bool = False
class UserAccessConfig(StrictModel):
require_login_to_create: bool = True
allow_user_register: bool = False
class LinuxdoConfig(StrictModel):
linuxdo_oauth: bool = False
client_id: str = ""
client_secret: str = ""
class WorkerConfig(StrictModel):
script_name: str = "cloudflare-temp-email"
use_workers_dev: bool = True
custom_domain: str = ""
compatibility_date: str = "2024-09-23"
vars: dict[str, Any] = Field(
default_factory=lambda: {
"PREFIX": "tmp",
"ENABLE_USER_CREATE_EMAIL": True,
"ENABLE_USER_DELETE_EMAIL": True,
"DEFAULT_LANG": "zh",
"ADMIN_PASSWORDS": ["change-me"],
}
)
secrets: dict[str, str] = Field(default_factory=lambda: {"JWT_SECRET": ""})
class PagesConfig(StrictModel):
project_name: str = "cf-temp-email-pages"
custom_domain: str = "mail.example.com"
build_mode: Literal["pages", "pages:nopwa"] = "pages"
production_branch: str = "production"
class DeploymentConfig(StrictModel):
config_version: int = 1
source: SourceConfig = Field(default_factory=SourceConfig)
cloudflare: CloudflareConfig = Field(default_factory=CloudflareConfig)
mail: MailConfig = Field(default_factory=MailConfig)
d1: D1Config = Field(default_factory=D1Config)
user_access: UserAccessConfig = Field(default_factory=UserAccessConfig)
linuxdo: LinuxdoConfig = Field(default_factory=LinuxdoConfig)
worker: WorkerConfig = Field(default_factory=WorkerConfig)
pages: PagesConfig = Field(default_factory=PagesConfig)
@model_validator(mode="after")
def validate_model(self) -> "DeploymentConfig":
if not self.mail.domains:
raise ValueError("mail.domains 至少需要一个域名。")
if self.source.mode == "local" and not self.source.local_path:
raise ValueError("source.mode=local 时需要填写 source.local_path。")
if not self.cloudflare.zone_name:
raise ValueError("cloudflare.zone_name 不能为空。")
if not self.pages.custom_domain:
raise ValueError("pages.custom_domain 不能为空。")
if not self.worker.custom_domain and not self.worker.use_workers_dev:
raise ValueError("worker.custom_domain 与 worker.use_workers_dev 不能同时关闭。")
if not self.admin_passwords():
raise ValueError("worker.vars.ADMIN_PASSWORDS 至少需要一个非空密码。")
if self.linuxdo.linuxdo_oauth and not self.linuxdo.client_id.strip():
raise ValueError("linuxdo.linuxdo_oauth=true 时需要填写 linuxdo.client_id。")
if self.linuxdo.linuxdo_oauth and not self.linuxdo.client_secret.strip():
raise ValueError("linuxdo.linuxdo_oauth=true 时需要填写 linuxdo.client_secret。")
return self
def build_frontend_url(self, pages_hostname: str | None = None) -> str:
hostname = self.pages.custom_domain or pages_hostname or f"{self.pages.project_name}.pages.dev"
return f"https://{hostname}"
def admin_passwords(self) -> list[str]:
raw_value = self.worker.vars.get("ADMIN_PASSWORDS", [])
if isinstance(raw_value, str):
value = raw_value.strip()
return [value] if value else []
if not isinstance(raw_value, list):
return []
passwords: list[str] = []
for item in raw_value:
value = str(item).strip()
if value:
passwords.append(value)
return passwords
def effective_require_login_to_create(self) -> bool:
return self.user_access.require_login_to_create or self.linuxdo.linuxdo_oauth
def linuxdo_callback_url(self) -> str:
return f"https://{self.pages.custom_domain}/user/oauth2/callback"
def derived_worker_vars(self, pages_hostname: str | None = None) -> dict[str, Any]:
values = dict(self.worker.vars)
values["DOMAINS"] = list(self.mail.domains)
values.setdefault("DEFAULT_DOMAINS", list(self.mail.domains))
values["FRONTEND_URL"] = self.build_frontend_url(pages_hostname)
values["DISABLE_ANONYMOUS_USER_CREATE_EMAIL"] = self.effective_require_login_to_create()
return values
class CloudflareState(StrictModel):
account_id: str = ""
account_name: str = ""
zone_id: str = ""
zone_name: str = ""
class SourceState(StrictModel):
source_dir: str = ""
commit_sha: str = ""
class PagesState(StrictModel):
project_id: str = ""
project_name: str = ""
subdomain: str = ""
custom_domain: str = ""
custom_domain_status: str = ""
cname_record_id: str = ""
class D1MigrationState(StrictModel):
file_name: str
sha256: str
applied_at: str = Field(default_factory=_now_isoformat)
class D1State(StrictModel):
database_id: str = ""
database_name: str = ""
schema_version: str = ""
migrations: list[D1MigrationState] = Field(default_factory=list)
class WorkerState(StrictModel):
script_name: str = ""
workers_dev_url: str = ""
class ApplicationState(StrictModel):
configured: bool = False
admin_base_url: str = ""
allow_user_register: bool = False
require_login_to_create: bool = False
linuxdo_oauth_enabled: bool = False
linuxdo_redirect_url: str = ""
class EmailRoutingState(StrictModel):
destination_address: str = ""
dns_record_ids: list[str] = Field(default_factory=list)
catch_all_enabled: bool = False
catch_all_rule_id: str = ""
catch_all_worker: str = ""
class DeploymentEvent(StrictModel):
stage: str
status: Literal["started", "completed", "failed"]
message: str = ""
timestamp: str = Field(default_factory=_now_isoformat)
class DeploymentState(StrictModel):
config_version: int = 1
checkpoint: str = ""
last_updated_at: str = Field(default_factory=_now_isoformat)
cloudflare: CloudflareState = Field(default_factory=CloudflareState)
source: SourceState = Field(default_factory=SourceState)
pages: PagesState = Field(default_factory=PagesState)
d1: D1State = Field(default_factory=D1State)
worker: WorkerState = Field(default_factory=WorkerState)
application: ApplicationState = Field(default_factory=ApplicationState)
email_routing: EmailRoutingState = Field(default_factory=EmailRoutingState)
events: list[DeploymentEvent] = Field(default_factory=list)
def mark_checkpoint(self, checkpoint: str) -> None:
self.checkpoint = checkpoint
self.last_updated_at = _now_isoformat()
def record_event(self, stage: str, status: Literal["started", "completed", "failed"], message: str = "") -> None:
self.events.append(DeploymentEvent(stage=stage, status=status, message=message))
self.last_updated_at = _now_isoformat()

View File

@@ -0,0 +1,79 @@
"""Helpers for locating the upstream project layout."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from cf_temp_email_deploy.errors import ConfigError
@dataclass(frozen=True)
class ProjectLayout:
"""Resolved directory layout for the upstream project."""
root_dir: Path
worker_dir: Path
frontend_dir: Path
pages_dir: Path
db_dir: Path
worker_wrangler_template: Path
worker_wrangler_path: Path
worker_patch_path: Path
pages_wrangler_path: Path
schema_path: Path
migration_paths: tuple[Path, ...]
@property
def frontend_dist_dir(self) -> Path:
return self.frontend_dir / "dist"
@property
def telegraf_dir(self) -> Path:
return self.worker_dir / "node_modules" / "telegraf"
def detect_project_layout(root_dir: Path) -> ProjectLayout:
"""Resolve the expected Cloudflare Temp Email project structure."""
base_dir = root_dir.expanduser().resolve()
worker_dir = _require_directory(base_dir / "worker")
frontend_dir = _require_directory(base_dir / "frontend")
pages_dir = _require_directory(base_dir / "pages")
db_dir = _require_directory(base_dir / "db")
worker_wrangler_template = _require_file(worker_dir / "wrangler.toml.template")
pages_wrangler_path = pages_dir / "wrangler.toml"
schema_path = _require_file(db_dir / "schema.sql")
worker_patch_path = _require_file(worker_dir / "patches" / "telegraf@4.16.3.patch")
migration_paths = tuple(
sorted(
path
for path in db_dir.glob("*.sql")
if path.name != "schema.sql"
)
)
return ProjectLayout(
root_dir=base_dir,
worker_dir=worker_dir,
frontend_dir=frontend_dir,
pages_dir=pages_dir,
db_dir=db_dir,
worker_wrangler_template=worker_wrangler_template,
worker_wrangler_path=worker_dir / "wrangler.toml",
worker_patch_path=worker_patch_path,
pages_wrangler_path=pages_wrangler_path,
schema_path=schema_path,
migration_paths=migration_paths,
)
def _require_directory(path: Path) -> Path:
if not path.is_dir():
raise ConfigError(f"缺少目录: {path}")
return path
def _require_file(path: Path) -> Path:
if not path.is_file():
raise ConfigError(f"缺少文件: {path}")
return path

View File

@@ -0,0 +1,117 @@
"""Source preparation helpers."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from cf_temp_email_deploy.errors import ConfigError
from cf_temp_email_deploy.models import DeploymentConfig, DeploymentState, SourceConfig
from cf_temp_email_deploy.subprocess_runner import CommandRunner, CommandSpec
@dataclass(frozen=True)
class PreparedSource:
"""Prepared source directory and its commit identity."""
source_dir: Path
commit_sha: str
def apply_to_state(self, state: DeploymentState) -> None:
state.source.source_dir = str(self.source_dir)
state.source.commit_sha = self.commit_sha
def prepare_source(config: DeploymentConfig, runner: CommandRunner) -> PreparedSource:
"""Prepare deployment source code according to the configured mode."""
if config.source.mode == "clone":
return prepare_cloned_source(config.source, runner)
return prepare_local_source(config.source, runner)
def prepare_cloned_source(source: SourceConfig, runner: CommandRunner) -> PreparedSource:
"""Clone or refresh the configured repository and checkout the target ref."""
workspace_dir = Path(source.workspace_dir).expanduser().resolve()
repo_dir = workspace_dir / "source"
if repo_dir.exists() and not (repo_dir / ".git").exists():
raise ConfigError(f"源码目录存在但不是 Git 仓库: {repo_dir}")
if not repo_dir.exists():
runner.run_checked(CommandSpec(args=("git", "clone", source.repo_url, str(repo_dir))))
runner.run_checked(CommandSpec(args=("git", "-C", str(repo_dir), "fetch", "--tags", "origin")))
if not _checkout_ref(repo_dir, source.repo_ref, runner):
raise ConfigError(f"无法切换到目标源码版本: {source.repo_ref}")
return PreparedSource(source_dir=repo_dir, commit_sha=resolve_commit_sha(repo_dir, runner))
def prepare_local_source(source: SourceConfig, runner: CommandRunner) -> PreparedSource:
"""Validate a local source directory and collect its commit SHA."""
if not source.local_path:
raise ConfigError("source.mode=local 时需要填写 source.local_path。")
source_dir = Path(source.local_path).expanduser().resolve()
if not source_dir.exists():
raise ConfigError(f"本地源码目录不存在: {source_dir}")
if not (source_dir / ".git").exists():
raise ConfigError(f"本地源码目录不是 Git 仓库: {source_dir}")
return PreparedSource(source_dir=source_dir, commit_sha=resolve_commit_sha(source_dir, runner))
def resolve_commit_sha(repo_dir: Path, runner: CommandRunner) -> str:
"""Resolve the current HEAD commit SHA for a Git repository."""
result = runner.run_checked(CommandSpec(args=("git", "-C", str(repo_dir), "rev-parse", "HEAD")))
commit_sha = result.stdout.strip()
if not commit_sha:
raise ConfigError(f"无法读取 Git 提交哈希: {repo_dir}")
return commit_sha
def _checkout_ref(repo_dir: Path, repo_ref: str, runner: CommandRunner) -> bool:
if not repo_ref:
return _checkout_latest_remote_head(repo_dir, runner)
candidates = (
("git", "-C", str(repo_dir), "checkout", "--detach", repo_ref),
("git", "-C", str(repo_dir), "checkout", "--detach", f"origin/{repo_ref}"),
)
for args in candidates:
result = runner.run(CommandSpec(args=args))
if result.returncode == 0:
return True
fetch_result = runner.run(
CommandSpec(args=("git", "-C", str(repo_dir), "fetch", "origin", repo_ref))
)
if fetch_result.returncode != 0:
return False
checkout_result = runner.run(
CommandSpec(args=("git", "-C", str(repo_dir), "checkout", "--detach", "FETCH_HEAD"))
)
return checkout_result.returncode == 0
def _checkout_latest_remote_head(repo_dir: Path, runner: CommandRunner) -> bool:
symbolic_ref = runner.run(
CommandSpec(args=("git", "-C", str(repo_dir), "symbolic-ref", "refs/remotes/origin/HEAD"))
)
candidates: list[tuple[str, ...]] = []
remote_head = symbolic_ref.stdout.strip()
if symbolic_ref.returncode == 0 and remote_head:
candidates.append(("git", "-C", str(repo_dir), "checkout", "--detach", remote_head))
candidates.extend(
[
("git", "-C", str(repo_dir), "checkout", "--detach", "origin/HEAD"),
("git", "-C", str(repo_dir), "checkout", "--detach", "origin/main"),
("git", "-C", str(repo_dir), "checkout", "--detach", "origin/master"),
]
)
for args in candidates:
result = runner.run(CommandSpec(args=args))
if result.returncode == 0:
return True
return False

View File

@@ -0,0 +1,87 @@
"""Subprocess execution helpers."""
from __future__ import annotations
import os
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Mapping, Sequence
from cf_temp_email_deploy.errors import CommandExecutionError
from cf_temp_email_deploy.logging_utils import log_command, log_command_result
@dataclass(frozen=True)
class CommandSpec:
args: Sequence[str]
cwd: Path | None = None
env: Mapping[str, str] | None = None
timeout: float | None = None
input_text: str | None = None
@dataclass(frozen=True)
class CommandResult:
args: tuple[str, ...]
returncode: int
stdout: str
stderr: str
duration_seconds: float = field(compare=False)
class CommandRunner:
"""Execute subprocess commands with logging and structured errors."""
def run(self, spec: CommandSpec) -> CommandResult:
command = tuple(spec.args)
environment = os.environ.copy()
if spec.env:
environment.update(spec.env)
log_command(command, spec.cwd)
started_at = time.perf_counter()
try:
completed = subprocess.run(
command,
cwd=spec.cwd,
env=environment,
capture_output=True,
text=True,
input=spec.input_text,
timeout=spec.timeout,
check=False,
)
except subprocess.TimeoutExpired as exc:
raise CommandExecutionError(
f"命令执行超时: {command[0]}",
command=command,
returncode=None,
stdout=exc.stdout or "",
stderr=exc.stderr or "",
) from exc
duration = time.perf_counter() - started_at
log_command_result(command, completed.returncode, duration)
return CommandResult(
args=command,
returncode=completed.returncode,
stdout=completed.stdout,
stderr=completed.stderr,
duration_seconds=duration,
)
def run_checked(self, spec: CommandSpec) -> CommandResult:
"""Execute a command and require a zero exit code."""
result = self.run(spec)
if result.returncode != 0:
raise CommandExecutionError(
f"命令执行失败: {result.args[0]}",
command=result.args,
returncode=result.returncode,
stdout=result.stdout,
stderr=result.stderr,
)
return result

View File

@@ -0,0 +1,36 @@
"""Wrangler command helpers."""
from __future__ import annotations
from pathlib import Path
def build_wrangler_command(*args: str) -> tuple[str, ...]:
"""Build a local wrangler invocation using npm exec."""
return ("npm", "exec", "--", "wrangler", *args)
def build_wrangler_secret_put_command(secret_name: str) -> tuple[str, ...]:
"""Build a command that stores a Worker secret from standard input."""
return build_wrangler_command("secret", "put", secret_name)
def build_wrangler_d1_execute_command(database_name: str, sql_file: Path) -> tuple[str, ...]:
"""Build a command that executes a SQL file against a remote D1 database."""
return build_wrangler_command(
"d1",
"execute",
database_name,
"--file",
str(sql_file),
"--remote",
)
def build_wrangler_pages_deploy_command(branch: str) -> tuple[str, ...]:
"""Build a command that deploys a Pages project using the configured wrangler file."""
return build_wrangler_command("pages", "deploy", "--branch", branch)