from __future__ import annotations import asyncio import unittest from fastapi import HTTPException from starlette.requests import Request from app.auth import AnthropicAuthError, require_anthropic_key, require_bearer, require_metrics_access from app.concurrency import BackpressureRejected, InFlightGuard def _req(headers: dict[str, str] | None = None) -> Request: pairs = [] for k, v in (headers or {}).items(): pairs.append((k.lower().encode("latin-1"), v.encode("latin-1"))) scope = { "type": "http", "http_version": "1.1", "method": "GET", "scheme": "http", "path": "/x", "raw_path": b"/x", "query_string": b"", "headers": pairs, "client": ("test", 1), "server": ("test", 80), "root_path": "", } return Request(scope) class AuthAndConcurrencyTests(unittest.IsolatedAsyncioTestCase): def test_require_bearer_accepts_valid_token(self) -> None: request = _req({"authorization": "Bearer good"}) require_bearer(request, ["good"]) def test_require_bearer_rejects_invalid_token(self) -> None: request = _req({"authorization": "Bearer bad"}) with self.assertRaises(HTTPException) as ctx: require_bearer(request, ["good"]) self.assertEqual(ctx.exception.status_code, 401) self.assertEqual(ctx.exception.detail["error"]["code"], "invalid_api_key") def test_require_anthropic_key_accepts_x_api_key_or_bearer(self) -> None: request_x = _req({"x-api-key": "k1"}) require_anthropic_key(request_x, ["k1"]) request_b = _req({"authorization": "Bearer k2"}) require_anthropic_key(request_b, ["k2"]) def test_require_anthropic_key_raises_on_missing(self) -> None: request = _req() with self.assertRaises(AnthropicAuthError) as ctx: require_anthropic_key(request, ["k"]) self.assertEqual(ctx.exception.status_code, 401) self.assertEqual(ctx.exception.error_type, "authentication_error") def test_require_metrics_access_503_when_no_tokens_configured(self) -> None: request = _req({"authorization": "Bearer any"}) with self.assertRaises(HTTPException) as ctx: require_metrics_access(request, api_keys=[], metrics_token="", public=False) self.assertEqual(ctx.exception.status_code, 503) self.assertEqual(ctx.exception.detail["error"]["code"], "metrics_disabled") async def test_inflight_guard_unlimited_and_release_idempotent(self) -> None: guard = InFlightGuard(max_in_flight=0, queue_timeout_sec=0.01) ticket = await guard.try_acquire() self.assertEqual(guard.in_flight, 1) ticket.release() ticket.release() self.assertEqual(guard.in_flight, 0) self.assertEqual(guard.accepted_total, 1) async def test_inflight_guard_rejects_when_queue_timeout(self) -> None: guard = InFlightGuard(max_in_flight=1, queue_timeout_sec=0.01) first = await guard.try_acquire() with self.assertRaises(BackpressureRejected): await guard.try_acquire() self.assertEqual(guard.rejected_total, 1) first.release() self.assertEqual(guard.in_flight, 0) if __name__ == "__main__": unittest.main()