Files
duckmail_gui/account_store.py
2026-03-21 09:25:24 +08:00

424 lines
14 KiB
Python

import json
import os
import threading
try:
import psycopg
except ImportError:
psycopg = None
try:
import psycopg2
except ImportError:
psycopg2 = None
try:
import pymysql
except ImportError:
pymysql = None
_LOCK = threading.Lock()
DEFAULT_TABLE_NAME = "registered_accounts"
def _safe_read_json(path):
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def _safe_write_json(path, data):
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def save_config(path, config):
with _LOCK:
_safe_write_json(path, config)
def load_config(path):
data = _safe_read_json(path)
if not isinstance(data, dict):
return {}
return data
def _to_text(value):
if value is None:
return ""
return str(value)
def _normalize_optional_text(value):
text = _to_text(value).strip()
return text or None
def _normalize_config(config):
data = dict(config or {})
db_type = str(data.get("db_type") or ("postgresql" if data.get("pg_enabled", True) else "")).strip().lower()
return {
"db_enabled": bool(data.get("db_enabled", data.get("pg_enabled", True))),
"db_type": db_type or "postgresql",
"db_host": str(data.get("db_host", data.get("pg_host", ""))).strip(),
"db_port": int(data.get("db_port", data.get("pg_port", 5432 if db_type != "mysql" else 3306)) or (3306 if db_type == "mysql" else 5432)),
"db_name": str(data.get("db_name", data.get("pg_db", "mail_accounts_db"))).strip(),
"db_user": str(data.get("db_user", data.get("pg_user", ""))).strip(),
"db_password": str(data.get("db_password", data.get("pg_password", ""))).strip(),
"db_table": str(data.get("db_table", DEFAULT_TABLE_NAME)).strip() or DEFAULT_TABLE_NAME,
"db_auto_create": bool(data.get("db_auto_create", False)),
"db_connect_timeout": int(data.get("db_connect_timeout", data.get("pg_connect_timeout", 10)) or 10),
}
def _validate_identifier(value, label):
text = str(value or "").strip()
if not text:
raise ValueError(f"{label} 不能为空")
if not text.replace("_", "").isalnum():
raise ValueError(f"{label} 只能包含字母、数字和下划线")
return text
def _get_db_driver(db_type):
if db_type == "postgresql":
if psycopg is not None:
return "psycopg"
if psycopg2 is not None:
return "psycopg2"
raise RuntimeError("未安装 PostgreSQL 驱动,请安装 psycopg[binary] 或 psycopg2-binary")
if db_type == "mysql":
if pymysql is not None:
return "pymysql"
raise RuntimeError("未安装 MySQL 驱动,请安装 PyMySQL")
raise ValueError("仅支持 PostgreSQL 或 MySQL")
def _connection_kwargs(db_config, include_database=True):
host = db_config["db_host"]
port = int(db_config["db_port"])
dbname = db_config["db_name"]
user = db_config["db_user"]
password = db_config["db_password"]
connect_timeout = int(db_config["db_connect_timeout"])
if not host or not user:
raise ValueError("数据库配置不完整,请填写 Host 和 User")
if include_database and not dbname:
raise ValueError("数据库名不能为空")
if db_config["db_type"] == "postgresql":
kwargs = {
"host": host,
"port": port,
"user": user,
"password": password,
"connect_timeout": connect_timeout,
}
if include_database:
kwargs["dbname"] = dbname
return kwargs
kwargs = {
"host": host,
"port": port,
"user": user,
"password": password,
"connect_timeout": connect_timeout,
"charset": "utf8mb4",
"autocommit": False,
}
if include_database:
kwargs["database"] = dbname
return kwargs
def _connect(db_config, include_database=True):
db_type = db_config["db_type"]
driver = _get_db_driver(db_type)
kwargs = _connection_kwargs(db_config, include_database=include_database)
if db_type == "postgresql":
if driver == "psycopg":
conn = psycopg.connect(**kwargs)
conn.autocommit = False
return conn
conn = psycopg2.connect(**kwargs)
conn.autocommit = False
return conn
conn = pymysql.connect(**kwargs)
conn.autocommit(False)
return conn
def _connect_admin(db_config):
if db_config["db_type"] == "postgresql":
admin_config = dict(db_config)
admin_config["db_name"] = "postgres"
return _connect(admin_config, include_database=True)
return _connect(db_config, include_database=False)
def _can_connect_to_configured_database(db_config):
try:
with _connect(db_config) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
return True
except Exception:
return False
def _try_database_exists_via_configured_db(db_config):
if not _can_connect_to_configured_database(db_config):
return False
return True
def _has_admin_database_access(db_config):
try:
with _connect_admin(db_config) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
return True
except Exception:
return False
def _database_exists(db_config):
db_name = _validate_identifier(db_config["db_name"], "数据库名")
if _try_database_exists_via_configured_db(db_config):
return True
with _connect_admin(db_config) as conn:
with conn.cursor() as cur:
if db_config["db_type"] == "postgresql":
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
else:
cur.execute("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s", (db_name,))
return cur.fetchone() is not None
def _table_exists(db_config):
table_name = _validate_identifier(db_config["db_table"], "表名")
with _connect(db_config) as conn:
with conn.cursor() as cur:
if db_config["db_type"] == "postgresql":
cur.execute(
"SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s",
(table_name,),
)
else:
cur.execute(
"SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s",
(db_config["db_name"], table_name),
)
return cur.fetchone() is not None
def _create_database(db_config):
db_name = _validate_identifier(db_config["db_name"], "数据库名")
with _connect_admin(db_config) as conn:
if db_config["db_type"] == "postgresql":
conn.autocommit = True
with conn.cursor() as cur:
if db_config["db_type"] == "postgresql":
cur.execute(f'CREATE DATABASE "{db_name}"')
else:
cur.execute(f"CREATE DATABASE `{db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
if db_config["db_type"] != "postgresql":
conn.commit()
def _create_table(db_config):
table_name = _validate_identifier(db_config["db_table"], "表名")
if db_config["db_type"] == "postgresql":
sql = f"""
CREATE TABLE IF NOT EXISTS "{table_name}" (
email TEXT PRIMARY KEY,
mail_password TEXT,
mail_token TEXT,
chatgpt_password TEXT,
name TEXT,
birthdate TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
else:
sql = f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
email VARCHAR(255) PRIMARY KEY,
mail_password TEXT NULL,
mail_token TEXT NULL,
chatgpt_password TEXT NULL,
name VARCHAR(255) NULL,
birthdate VARCHAR(64) NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci
"""
with _connect(db_config) as conn:
with conn.cursor() as cur:
cur.execute(sql)
conn.commit()
def test_connection(config):
db_config = _normalize_config(config)
can_connect_directly = _can_connect_to_configured_database(db_config)
admin_access = _has_admin_database_access(db_config)
database_exists = can_connect_directly or (admin_access and _database_exists(db_config))
if database_exists:
if can_connect_directly:
with _connect(db_config) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
elif admin_access:
with _connect_admin(db_config) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
elif admin_access:
with _connect_admin(db_config) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
else:
raise RuntimeError("无法连接到配置的数据库,且当前账号也没有数据库管理权限")
return {
"success": True,
"db_type": db_config["db_type"],
"database_exists": database_exists,
"table_exists": can_connect_directly and _table_exists(db_config),
"admin_access": admin_access,
}
def ensure_database_and_table(config):
db_config = _normalize_config(config)
can_connect_directly = _can_connect_to_configured_database(db_config)
admin_access = _has_admin_database_access(db_config)
database_exists = can_connect_directly or (admin_access and _database_exists(db_config))
if not database_exists:
if not admin_access:
raise RuntimeError("当前账号无法创建数据库,请先手动创建数据库,或改用具备管理权限的账号")
_create_database(db_config)
database_exists = True
table_exists = _table_exists(db_config) if database_exists else False
if not table_exists:
_create_table(db_config)
table_exists = True
return {
"success": True,
"db_type": db_config["db_type"],
"database_exists": database_exists,
"table_exists": table_exists,
"admin_access": admin_access,
}
def _ensure_ready_if_needed(config):
db_config = _normalize_config(config)
if db_config["db_auto_create"]:
ensure_database_and_table(db_config)
return db_config
def load_accounts(config):
db_config = _ensure_ready_if_needed(config)
table_name = _validate_identifier(db_config["db_table"], "表名")
query = (
f"select email, mail_password, mail_token, chatgpt_password, name, birthdate, created_at "
f'from "{table_name}" order by created_at desc, email asc'
if db_config["db_type"] == "postgresql"
else f"select email, mail_password, mail_token, chatgpt_password, name, birthdate, created_at "
f"from `{table_name}` order by created_at desc, email asc"
)
accounts = []
with _connect(db_config) as conn:
with conn.cursor() as cur:
cur.execute(query)
rows = cur.fetchall()
for row in rows:
email, mail_password, mail_token, chatgpt_password, name, birthdate, created_at = row
if not email:
continue
accounts.append(
{
"email": _to_text(email).strip(),
"mail_password": _to_text(mail_password).strip(),
"mail_token": _to_text(mail_token).strip(),
"chatgpt_password": _to_text(chatgpt_password).strip(),
"name": _to_text(name).strip(),
"birthdate": _to_text(birthdate).strip(),
"created_at": _to_text(created_at).strip(),
"source": db_config["db_type"],
}
)
return accounts
def save_account(
config,
email,
mail_password,
mail_token,
chatgpt_password=None,
name=None,
birthdate=None,
):
db_config = _ensure_ready_if_needed(config)
table_name = _validate_identifier(db_config["db_table"], "表名")
params = (
email,
mail_password,
mail_token,
_normalize_optional_text(chatgpt_password),
_normalize_optional_text(name),
_normalize_optional_text(birthdate),
)
if db_config["db_type"] == "postgresql":
query = (
f'insert into "{table_name}" '
"(email, mail_password, mail_token, chatgpt_password, name, birthdate) "
"values (%s, %s, %s, %s, %s, %s) "
"on conflict (email) do update set "
"mail_password = excluded.mail_password, "
"mail_token = excluded.mail_token, "
"chatgpt_password = excluded.chatgpt_password, "
"name = excluded.name, "
"birthdate = excluded.birthdate"
)
else:
query = (
f"insert into `{table_name}` "
"(email, mail_password, mail_token, chatgpt_password, name, birthdate) "
"values (%s, %s, %s, %s, %s, %s) "
"on duplicate key update "
"mail_password = values(mail_password), "
"mail_token = values(mail_token), "
"chatgpt_password = values(chatgpt_password), "
"name = values(name), "
"birthdate = values(birthdate)"
)
with _LOCK:
with _connect(db_config) as conn:
with conn.cursor() as cur:
cur.execute(query, params)
conn.commit()