From 4f7b898a378341db2be3a7f45e75c27942a1c35d Mon Sep 17 00:00:00 2001 From: Diogo Andre Santos Date: Thu, 9 Apr 2026 06:52:12 +0100 Subject: [PATCH] feat: add public setter methods for ClientSession callbacks Add set_sampling_callback(), set_elicitation_callback(), and set_list_roots_callback() methods to ClientSession, allowing callbacks to be updated at runtime after initialization without mutating private attributes directly. Also removes the # pragma: no cover from _default_elicitation_callback and adds coverage via the new test for set_elicitation_callback(None). Reported-by: dgenio Github-Issue: #2379 --- src/mcp/client/session.py | 44 ++++++++++++++++++- tests/client/test_elicitation_callback.py | 51 +++++++++++++++++++++++ tests/client/test_list_roots_callback.py | 39 +++++++++++++++++ tests/client/test_sampling_callback.py | 44 +++++++++++++++++++ 4 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 tests/client/test_elicitation_callback.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a7..aa2c03c14 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -74,7 +74,7 @@ async def _default_elicitation_callback( context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( # pragma: no cover + return types.ErrorData( code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -216,6 +216,48 @@ def experimental(self) -> ExperimentalClientFeatures: self._experimental_features = ExperimentalClientFeatures(self) return self._experimental_features + def set_sampling_callback(self, callback: SamplingFnT | None) -> None: + """Update the sampling callback. + + Note: Client capabilities are advertised to the server during :meth:`initialize` + and will not be re-negotiated when this setter is called. If a sampling + callback is set after initialization, the server may not be aware of the + capability. + + Args: + callback: The new sampling callback, or ``None`` to restore the default + (which rejects all sampling requests with an error). + """ + self._sampling_callback = callback or _default_sampling_callback + + def set_elicitation_callback(self, callback: ElicitationFnT | None) -> None: + """Update the elicitation callback. + + Note: Client capabilities are advertised to the server during :meth:`initialize` + and will not be re-negotiated when this setter is called. If an elicitation + callback is set after initialization, the server may not be aware of the + capability. + + Args: + callback: The new elicitation callback, or ``None`` to restore the default + (which rejects all elicitation requests with an error). + """ + self._elicitation_callback = callback or _default_elicitation_callback + + def set_list_roots_callback(self, callback: ListRootsFnT | None) -> None: + """Update the list roots callback. + + Note: Client capabilities are advertised to the server during :meth:`initialize` + and will not be re-negotiated when this setter is called. If a list-roots + callback is set after initialization, the server may not be aware of the + capability. + + Args: + callback: The new list roots callback, or ``None`` to restore the default + (which rejects all list-roots requests with an error). + """ + self._list_roots_callback = callback or _default_list_roots_callback + async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a ping request.""" return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult) diff --git a/tests/client/test_elicitation_callback.py b/tests/client/test_elicitation_callback.py new file mode 100644 index 000000000..7d3dbeb98 --- /dev/null +++ b/tests/client/test_elicitation_callback.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import pytest +from pydantic import BaseModel, Field + +from mcp import Client +from mcp.client.session import ClientSession +from mcp.server.mcpserver import Context, MCPServer +from mcp.shared._context import RequestContext +from mcp.types import ElicitRequestParams, ElicitResult, TextContent + + +class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer") + + +@pytest.mark.anyio +async def test_set_elicitation_callback(): + server = MCPServer("test") + + updated_answer = "Updated answer" + + async def updated_callback( + context: RequestContext[ClientSession], + params: ElicitRequestParams, + ) -> ElicitResult: + return ElicitResult(action="accept", content={"answer": updated_answer}) + + @server.tool("ask") + async def ask(prompt: str, ctx: Context) -> str: + result = await ctx.elicit(message=prompt, schema=AnswerSchema) + if result.action == "accept" and result.data: + return result.data.answer + return "no answer" # pragma: no cover + + async with Client(server) as client: + # Before setting callback — default rejects with error + result = await client.call_tool("ask", {"prompt": "question?"}) + assert result.is_error is True + + # Set new callback — should succeed + client.session.set_elicitation_callback(updated_callback) + result = await client.call_tool("ask", {"prompt": "question?"}) + assert result.is_error is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == updated_answer + + # Reset to None — back to default error + client.session.set_elicitation_callback(None) + result = await client.call_tool("ask", {"prompt": "question?"}) + assert result.is_error is True diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index be4b9a97b..c7f86e536 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -45,3 +45,42 @@ async def test_list_roots(context: Context, message: str): assert result.is_error is True assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported" + + +@pytest.mark.anyio +async def test_set_list_roots_callback(): + server = MCPServer("test") + + updated_result = ListRootsResult( + roots=[ + Root(uri=FileUrl("file://users/fake/updated"), name="Updated Root"), + ] + ) + + async def updated_callback( + context: RequestContext[ClientSession], + ) -> ListRootsResult: + return updated_result + + @server.tool("get_roots") + async def get_roots(context: Context, param: str) -> bool: + roots = await context.session.list_roots() + assert roots == updated_result + return True + + async with Client(server) as client: + # Before setting callback — default rejects with error + result = await client.call_tool("get_roots", {"param": "x"}) + assert result.is_error is True + + # Set new callback — should succeed + client.session.set_list_roots_callback(updated_callback) + result = await client.call_tool("get_roots", {"param": "x"}) + assert result.is_error is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Reset to None — back to default error + client.session.set_list_roots_callback(None) + result = await client.call_tool("get_roots", {"param": "x"}) + assert result.is_error is True diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 6efcac0a5..74c094353 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -57,6 +57,50 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" +@pytest.mark.anyio +async def test_set_sampling_callback(): + server = MCPServer("test") + + updated_return = CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Updated response"), + model="updated-model", + stop_reason="endTurn", + ) + + async def updated_callback( + context: RequestContext[ClientSession], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + return updated_return + + @server.tool("do_sample") + async def do_sample(message: str, ctx: Context) -> bool: + value = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], + max_tokens=100, + ) + assert value == updated_return + return True + + async with Client(server) as client: + # Before setting callback — default rejects with error + result = await client.call_tool("do_sample", {"message": "test"}) + assert result.is_error is True + + # Set new callback — should succeed + client.session.set_sampling_callback(updated_callback) + result = await client.call_tool("do_sample", {"message": "test"}) + assert result.is_error is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Reset to None — back to default error + client.session.set_sampling_callback(None) + result = await client.call_tool("do_sample", {"message": "test"}) + assert result.is_error is True + + @pytest.mark.anyio async def test_create_message_backwards_compat_single_content(): """Test backwards compatibility: create_message without tools returns single content."""