fix: harden tooling session reuse and event routing
Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,31 @@ import types
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
class _FakeSessionCache:
|
||||
def __init__(self) -> None:
|
||||
self.enabled = True
|
||||
self.keys: list[str] = []
|
||||
self.get_calls: list[str] = []
|
||||
self.put_calls: list[tuple[str, str, str]] = []
|
||||
self.invalidate_calls: list[str] = []
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
|
||||
marker = "with_tool" if tool_config is not None else "no_tool"
|
||||
key = f"{api_key}:{len(messages)}:{marker}"
|
||||
self.keys.append(key)
|
||||
return key
|
||||
|
||||
async def get(self, key: str):
|
||||
self.get_calls.append(key)
|
||||
return None
|
||||
|
||||
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
|
||||
self.put_calls.append((key, session_id, instance_name))
|
||||
|
||||
async def invalidate(self, key: str) -> None:
|
||||
self.invalidate_calls.append(key)
|
||||
|
||||
# app.main imports playwright via auto_login; tests don't exercise that path.
|
||||
# Inject a lightweight stub so unit tests run without installing playwright.
|
||||
_playwright = types.ModuleType("playwright")
|
||||
@@ -119,6 +144,38 @@ async def _collect_stream(response) -> str:
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
class _SpyClient(_FakeClient):
|
||||
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
||||
super().__init__(stream_events=stream_events, complete_result=complete_result)
|
||||
self.last_complete_kwargs: dict = {}
|
||||
self.last_stream_kwargs: dict = {}
|
||||
|
||||
async def chat_complete(self, *args, **kwargs) -> dict:
|
||||
self.last_complete_kwargs = dict(kwargs)
|
||||
return await super().chat_complete(*args, **kwargs)
|
||||
|
||||
async def chat_stream(self, *args, **kwargs):
|
||||
self.last_stream_kwargs = dict(kwargs)
|
||||
async for event in super().chat_stream(*args, **kwargs):
|
||||
yield event
|
||||
|
||||
|
||||
class _SettingsPatch:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __enter__(self):
|
||||
self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()]
|
||||
for p in self._patchers:
|
||||
p.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for p in reversed(self._patchers):
|
||||
p.stop()
|
||||
return False
|
||||
|
||||
|
||||
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
@@ -156,6 +213,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
message = payload["choices"][0]["message"]
|
||||
self.assertEqual(message["content"], "done")
|
||||
self.assertIsInstance(message["tool_calls"], list)
|
||||
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
|
||||
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
|
||||
self.assertEqual(
|
||||
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||
@@ -195,6 +253,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
self.assertIn('"tool_calls"', body)
|
||||
self.assertIn('"content": "hello"', body)
|
||||
self.assertIn('"finish_reason": "tool_calls"', body)
|
||||
self.assertIn('"usage"', body)
|
||||
self.assertIn("data: [DONE]", body)
|
||||
|
||||
@@ -239,9 +298,132 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
payload = json.loads(response.body)
|
||||
types = [item["type"] for item in payload["content"]]
|
||||
self.assertEqual(types, ["text", "tool_use", "tool_result"])
|
||||
self.assertEqual(payload["stop_reason"], "end_turn")
|
||||
self.assertEqual(payload["content"][1]["name"], "lookup")
|
||||
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
||||
|
||||
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"id": "call_a",
|
||||
"name": "read_file",
|
||||
"input": {"path": "README.md"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"id": "call_b",
|
||||
"name": "search_docs",
|
||||
"input": {"query": "gateway"},
|
||||
},
|
||||
},
|
||||
],
|
||||
complete_result={},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
):
|
||||
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
body = await _collect_stream(response)
|
||||
|
||||
self.assertIn('"id": "call_a"', body)
|
||||
self.assertIn('"id": "call_b"', body)
|
||||
self.assertIn('"index": 0', body)
|
||||
self.assertIn('"index": 1', body)
|
||||
|
||||
async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={
|
||||
"text": "",
|
||||
"toolEvents": [
|
||||
{
|
||||
"name": "lookup",
|
||||
"input": {"k": "v"},
|
||||
}
|
||||
],
|
||||
"sessionId": "sess-2",
|
||||
},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
payload = json.loads(response.body)
|
||||
self.assertEqual(payload["stop_reason"], "tool_use")
|
||||
self.assertEqual(len(payload["content"]), 1)
|
||||
self.assertEqual(payload["content"][0]["type"], "tool_use")
|
||||
|
||||
async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"name": "read",
|
||||
"input": {"file": "a.txt"},
|
||||
},
|
||||
}
|
||||
],
|
||||
complete_result={},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
body = await _collect_stream(response)
|
||||
|
||||
self.assertIn('"type": "tool_use"', body)
|
||||
self.assertIn('"stop_reason": "tool_use"', body)
|
||||
|
||||
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
@@ -284,11 +466,262 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIn("event: message_start", body)
|
||||
self.assertIn('"type": "tool_use"', body)
|
||||
self.assertIn('"type": "tool_result"', body)
|
||||
self.assertIn('"stop_reason": "end_turn"', body)
|
||||
self.assertIn('"type": "text_delta"', body)
|
||||
self.assertIn("event: message_stop", body)
|
||||
|
||||
|
||||
class LingmaClientToolEventExtractionTests(unittest.TestCase):
|
||||
|
||||
async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None:
|
||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(tool_forward_enabled=True),
|
||||
):
|
||||
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||
cfg = spy_client.last_complete_kwargs["tool_config"]
|
||||
self.assertEqual(cfg["provider"], "openai")
|
||||
self.assertEqual(len(cfg["tools"]), 1)
|
||||
self.assertIsInstance(cfg["tool_choice"], dict)
|
||||
|
||||
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
|
||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(tool_forward_enabled=False),
|
||||
):
|
||||
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
|
||||
|
||||
|
||||
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||
fake_cache = _FakeSessionCache()
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-3"},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[
|
||||
{"role": "user", "content": "turn-1"},
|
||||
{"role": "user", "content": "turn-2"},
|
||||
],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "session_cache", fake_cache),
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(tool_forward_enabled=True),
|
||||
):
|
||||
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertEqual(fake_cache.keys, [])
|
||||
self.assertEqual(fake_cache.get_calls, [])
|
||||
self.assertEqual(fake_cache.put_calls, [])
|
||||
|
||||
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||
fake_cache = _FakeSessionCache()
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-4"},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=128,
|
||||
messages=[
|
||||
{"role": "user", "content": "turn-1"},
|
||||
{"role": "user", "content": "turn-2"},
|
||||
],
|
||||
stream=False,
|
||||
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
|
||||
tool_choice={"type": "auto"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "session_cache", fake_cache),
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(fake_cache.keys, [])
|
||||
self.assertEqual(fake_cache.get_calls, [])
|
||||
self.assertEqual(fake_cache.put_calls, [])
|
||||
|
||||
|
||||
class SessionCacheToolFingerprintTests(unittest.TestCase):
|
||||
def test_build_key_changes_with_tool_config(self) -> None:
|
||||
from app.session_cache import SessionCache
|
||||
|
||||
cache = SessionCache(max_entries=8, ttl_sec=60)
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
cfg_a = {
|
||||
"provider": "openai",
|
||||
"tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
"tool_choice": {"type": "function", "function": {"name": "lookup"}},
|
||||
}
|
||||
cfg_a_reordered = {
|
||||
"tool_choice": {"function": {"name": "lookup"}, "type": "function"},
|
||||
"tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}],
|
||||
"provider": "openai",
|
||||
}
|
||||
cfg_b = {
|
||||
"provider": "openai",
|
||||
"tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}],
|
||||
"tool_choice": {"type": "function", "function": {"name": "lookup_v2"}},
|
||||
}
|
||||
|
||||
key_no_tool = cache.build_key("api-key", messages)
|
||||
key_a = cache.build_key("api-key", messages, tool_config=cfg_a)
|
||||
key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered)
|
||||
key_b = cache.build_key("api-key", messages, tool_config=cfg_b)
|
||||
|
||||
self.assertNotEqual(key_no_tool, key_a)
|
||||
self.assertEqual(key_a, key_a_reordered)
|
||||
self.assertNotEqual(key_a, key_b)
|
||||
|
||||
|
||||
def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invoke",
|
||||
"params": {
|
||||
"name": "lookup",
|
||||
"parameters": {"q": "x"},
|
||||
},
|
||||
}
|
||||
)
|
||||
stream = rpc._chat_streams["req-1"]
|
||||
self.assertEqual(stream["tool_order"], [])
|
||||
self.assertEqual(stream["tool_states"], {})
|
||||
self.assertTrue(stream["chunks"].empty())
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_handle_server_message_routes_by_tool_map_without_request_id(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invoke",
|
||||
"params": {
|
||||
"requestId": "req-1",
|
||||
"toolCallId": "call-1",
|
||||
"name": "lookup",
|
||||
"parameters": {"q": "a"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invokeResult",
|
||||
"params": {
|
||||
"toolCallId": "call-1",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = rpc.get_stream_result("req-1")
|
||||
self.assertEqual(len(result["toolEvents"]), 1)
|
||||
self.assertEqual(result["toolEvents"][0]["id"], "call-1")
|
||||
self.assertEqual(result["toolEvents"][0]["input"], {"q": "a"})
|
||||
self.assertEqual(result["toolEvents"][0]["result"], {"ok": True})
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_handle_server_message_dedupes_identical_repeated_tool_events(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
msg = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invoke",
|
||||
"params": {
|
||||
"requestId": "req-1",
|
||||
"toolCallId": "call-dup",
|
||||
"name": "lookup",
|
||||
"parameters": {"q": "dup"},
|
||||
},
|
||||
}
|
||||
await rpc._handle_server_message(msg)
|
||||
await rpc._handle_server_message(msg)
|
||||
|
||||
stream = rpc._chat_streams["req-1"]
|
||||
self.assertEqual(stream["tool_order"], ["call-dup"])
|
||||
self.assertEqual(stream["chunks"].qsize(), 1)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_extracts_tool_event_from_results_and_parameters(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
|
||||
Reference in New Issue
Block a user