diff --git a/README.md b/README.md index a35a3ebe0..6ddc03e29 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ package main import ( "context" - "errors" "fmt" "github.com/mark3labs/mcp-go/mcp" diff --git a/client/client.go b/client/client.go index dd0e31a01..e2c466586 100644 --- a/client/client.go +++ b/client/client.go @@ -22,6 +22,7 @@ type Client struct { requestID atomic.Int64 clientCapabilities mcp.ClientCapabilities serverCapabilities mcp.ServerCapabilities + samplingHandler SamplingHandler } type ClientOption func(*Client) @@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { } } +// WithSamplingHandler sets the sampling handler for the client. +// When set, the client will declare sampling capability during initialization. +func WithSamplingHandler(handler SamplingHandler) ClientOption { + return func(c *Client) { + c.samplingHandler = handler + } +} + +// WithSession assumes a MCP Session has already been initialized +func WithSession() ClientOption { + return func(c *Client) { + c.initialized = true + } +} + // NewClient creates a new MCP client with the given transport. // Usage: // @@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error { handler(notification) } }) + + // Set up request handler for bidirectional communication (e.g., sampling) + if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok { + bidirectional.SetRequestHandler(c.handleIncomingRequest) + } + return nil } @@ -127,6 +149,12 @@ func (c *Client) Initialize( ctx context.Context, request mcp.InitializeRequest, ) (*mcp.InitializeResult, error) { + // Merge client capabilities with sampling capability if handler is configured + capabilities := request.Params.Capabilities + if c.samplingHandler != nil { + capabilities.Sampling = &struct{}{} + } + // Ensure we send a params object with all required fields params := struct { ProtocolVersion string `json:"protocolVersion"` @@ -135,7 +163,7 @@ func (c *Client) Initialize( }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set + Capabilities: capabilities, } response, err := c.sendRequest(ctx, "initialize", params) @@ -398,6 +426,64 @@ func (c *Client) Complete( return &result, nil } +// handleIncomingRequest processes incoming requests from the server. +// This is the main entry point for server-to-client requests like sampling. +func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + return c.handleSamplingRequestTransport(ctx, request) + default: + return nil, fmt.Errorf("unsupported request method: %s", request.Method) + } +} + +// handleSamplingRequestTransport handles sampling requests at the transport level. +func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.samplingHandler == nil { + return nil, fmt.Errorf("no sampling handler configured") + } + + // Parse the request parameters + var params mcp.CreateMessageParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: params, + } + + // Call the sampling handler + result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} func listByPage[T any]( ctx context.Context, client *Client, @@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { return c.clientCapabilities } + +// GetSessionId returns the session ID of the transport. +// If the transport does not support sessions, it returns an empty string. +func (c *Client) GetSessionId() string { + if c.transport == nil { + return "" + } + return c.transport.GetSessionId() +} + +// IsInitialized returns true if the client has been initialized. +func (c *Client) IsInitialized() bool { + return c.initialized +} diff --git a/client/http.go b/client/http.go index cb3be35d6..d001a1e63 100644 --- a/client/http.go +++ b/client/http.go @@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return NewClient(trans), nil + clientOptions := make([]ClientOption, 0) + sessionID := trans.GetSessionId() + if sessionID != "" { + clientOptions = append(clientOptions, WithSession()) + } + return NewClient(trans, clientOptions...), nil } diff --git a/client/http_test.go b/client/http_test.go index 3c2e6a3b7..514004857 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -3,10 +3,14 @@ package client import ( "context" "fmt" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "sync" "testing" "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" ) func TestHTTPClient(t *testing.T) { @@ -47,20 +51,47 @@ func TestHTTPClient(t *testing.T) { return nil, fmt.Errorf("failed to send notification: %w", err) } - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: "notification sent successfully", - }, - }, - }, nil + return mcp.NewToolResultText("notification sent successfully"), nil }, ) + addServerToolfunc := func(name string) { + mcpServer.AddTool( + mcp.NewTool(name), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + server := server.ServerFromContext(ctx) + server.SendNotificationToAllClients("helloToEveryone", map[string]any{ + "message": "hello", + }) + return mcp.NewToolResultText("done"), nil + }, + ) + } + testServer := server.NewTestStreamableHTTPServer(mcpServer) defer testServer.Close() + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client2", + Version: "1.0.0", + }, + }, + } + + t.Run("Can Configure a server with a pre-existing session", func(t *testing.T) { + sessionID := uuid.NewString() + client, err := NewStreamableHttpClient(testServer.URL, transport.WithSession(sessionID)) + if err != nil { + t.Fatalf("create client failed %v", err) + } + if client.IsInitialized() != true { + t.Fatalf("Client is not initialized") + } + }) + t.Run("Can receive notification from server", func(t *testing.T) { client, err := NewStreamableHttpClient(testServer.URL) if err != nil { @@ -68,9 +99,9 @@ func TestHTTPClient(t *testing.T) { return } - notificationNum := 0 + notificationNum := NewSafeMap() client.OnNotification(func(notification mcp.JSONRPCNotification) { - notificationNum += 1 + notificationNum.Increment(notification.Method) }) ctx := context.Background() @@ -81,31 +112,122 @@ func TestHTTPClient(t *testing.T) { } // Initialize - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "test-client", - Version: "1.0.0", - } - _, err = client.Initialize(ctx, initRequest) if err != nil { t.Fatalf("Failed to initialize: %v\n", err) } - request := mcp.CallToolRequest{} - request.Params.Name = "notify" - result, err := client.CallTool(ctx, request) - if err != nil { - t.Fatalf("CallTool failed: %v", err) - } + t.Run("Can receive notifications related to the request", func(t *testing.T) { + request := mcp.CallToolRequest{} + request.Params.Name = "notify" + result, err := client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } - if len(result.Content) != 1 { - t.Errorf("Expected 1 content item, got %d", len(result.Content)) - } + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + + if n := notificationNum.Get("notifications/progress"); n != 1 { + t.Errorf("Expected 1 progross notification item, got %d", n) + } + if n := notificationNum.Len(); n != 1 { + t.Errorf("Expected 1 type of notification, got %d", n) + } + }) + + t.Run("Can not receive global notifications from server by default", func(t *testing.T) { + addServerToolfunc("hello1") + time.Sleep(time.Millisecond * 50) + + helloNotifications := notificationNum.Get("hello1") + if helloNotifications != 0 { + t.Errorf("Expected 0 notification item, got %d", helloNotifications) + } + }) + + t.Run("Can receive global notifications from server when WithContinuousListening enabled", func(t *testing.T) { + + client, err := NewStreamableHttpClient(testServer.URL, + transport.WithContinuousListening()) + if err != nil { + t.Fatalf("create client failed %v", err) + return + } + defer client.Close() + + notificationNum := NewSafeMap() + client.OnNotification(func(notification mcp.JSONRPCNotification) { + notificationNum.Increment(notification.Method) + }) + + ctx := context.Background() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + return + } + + // Initialize + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v\n", err) + } + + // can receive normal notification + request := mcp.CallToolRequest{} + request.Params.Name = "notify" + _, err = client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + if n := notificationNum.Get("notifications/progress"); n != 1 { + t.Errorf("Expected 1 progross notification item, got %d", n) + } + if n := notificationNum.Len(); n != 1 { + t.Errorf("Expected 1 type of notification, got %d", n) + } + + // can receive global notification + addServerToolfunc("hello2") + time.Sleep(time.Millisecond * 50) // wait for the notification to be sent as upper action is async + + n := notificationNum.Get("notifications/tools/list_changed") + if n != 1 { + t.Errorf("Expected 1 notification item, got %d, %v", n, notificationNum) + } + }) - if notificationNum != 1 { - t.Errorf("Expected 1 notification item, got %d", notificationNum) - } }) } + +type SafeMap struct { + mu sync.RWMutex + data map[string]int +} + +func NewSafeMap() *SafeMap { + return &SafeMap{ + data: make(map[string]int), + } +} + +func (sm *SafeMap) Increment(key string) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.data[key]++ +} + +func (sm *SafeMap) Get(key string) int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.data[key] +} + +func (sm *SafeMap) Len() int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return len(sm.data) +} diff --git a/client/sampling.go b/client/sampling.go new file mode 100644 index 000000000..245e2c1f7 --- /dev/null +++ b/client/sampling.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +// Clients can implement this interface to provide LLM sampling capabilities to servers. +type SamplingHandler interface { + // CreateMessage handles a sampling request from the server and returns the generated message. + // The implementation should: + // 1. Validate the request parameters + // 2. Optionally prompt the user for approval (human-in-the-loop) + // 3. Select an appropriate model based on preferences + // 4. Generate the response using the selected model + // 5. Return the result with model information and stop reason + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/client/sampling_test.go b/client/sampling_test.go new file mode 100644 index 000000000..60f533221 --- /dev/null +++ b/client/sampling_test.go @@ -0,0 +1,274 @@ +package client + +import ( + "context" + "encoding/json" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// mockSamplingHandler implements SamplingHandler for testing +type mockSamplingHandler struct { + result *mcp.CreateMessageResult + err error +} + +func (m *mockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestClient_HandleSamplingRequest(t *testing.T) { + tests := []struct { + name string + handler SamplingHandler + expectedError string + }{ + { + name: "no handler configured", + handler: nil, + expectedError: "no sampling handler configured", + }, + { + name: "successful sampling", + handler: &mockSamplingHandler{ + result: &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Hello, world!", + }, + }, + Model: "test-model", + StopReason: "endTurn", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &Client{samplingHandler: tt.handler} + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Type: "text", Text: "Hello"}, + }, + }, + MaxTokens: 100, + }, + } + + result, err := client.handleIncomingRequest(context.Background(), mockJSONRPCRequest(request)) + + if tt.expectedError != "" { + if err == nil { + t.Errorf("expected error %q, got nil", tt.expectedError) + } else if err.Error() != tt.expectedError { + t.Errorf("expected error %q, got %q", tt.expectedError, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result == nil { + t.Error("expected result, got nil") + } + } + }) + } +} + +func TestWithSamplingHandler(t *testing.T) { + handler := &mockSamplingHandler{} + client := &Client{} + + option := WithSamplingHandler(handler) + option(client) + + if client.samplingHandler != handler { + t.Error("sampling handler not set correctly") + } +} + +// mockTransport implements transport.Interface for testing +type mockTransport struct { + requestChan chan transport.JSONRPCRequest + responseChan chan *transport.JSONRPCResponse + started bool +} + +func newMockTransport() *mockTransport { + return &mockTransport{ + requestChan: make(chan transport.JSONRPCRequest, 1), + responseChan: make(chan *transport.JSONRPCResponse, 1), + } +} + +func (m *mockTransport) Start(ctx context.Context) error { + m.started = true + return nil +} + +func (m *mockTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + m.requestChan <- request + select { + case response := <-m.responseChan: + return response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (m *mockTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + return nil +} + +func (m *mockTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { +} + +func (m *mockTransport) Close() error { + return nil +} + +func (m *mockTransport) GetSessionId() string { + return "mock-session-id" +} + +func TestClient_Initialize_WithSampling(t *testing.T) { + handler := &mockSamplingHandler{ + result: &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Test response", + }, + }, + Model: "test-model", + StopReason: "endTurn", + }, + } + + // Create mock transport + mockTransport := newMockTransport() + + // Create client with sampling handler and mock transport + client := &Client{ + transport: mockTransport, + samplingHandler: handler, + } + + // Start the client + ctx := context.Background() + err := client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Prepare mock response for initialization + initResponse := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(1), + Result: []byte(`{"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{},"resources":{},"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`), + } + + // Send the response in a goroutine + go func() { + mockTransport.responseChan <- initResponse + }() + + // Call Initialize with appropriate parameters + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + Roots: &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: true, + }, + }, + }, + } + + result, err := client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + // Verify the result + if result == nil { + t.Fatal("Initialize result should not be nil") + } + + // Verify that the request was sent through the transport + select { + case request := <-mockTransport.requestChan: + // Verify the request method + if request.Method != "initialize" { + t.Errorf("Expected method 'initialize', got '%s'", request.Method) + } + + // Verify the request has the correct structure + if request.Params == nil { + t.Fatal("Request params should not be nil") + } + + // Parse the params to verify sampling capability is included + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + t.Fatalf("Failed to marshal request params: %v", err) + } + + var params struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + } + + err = json.Unmarshal(paramsBytes, ¶ms) + if err != nil { + t.Fatalf("Failed to unmarshal request params: %v", err) + } + + // Verify sampling capability is included in the request + if params.Capabilities.Sampling == nil { + t.Error("Sampling capability should be included in initialization request when handler is configured") + } + + // Verify other expected fields + if params.ProtocolVersion != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version '%s', got '%s'", mcp.LATEST_PROTOCOL_VERSION, params.ProtocolVersion) + } + + if params.ClientInfo.Name != "test-client" { + t.Errorf("Expected client name 'test-client', got '%s'", params.ClientInfo.Name) + } + + default: + t.Error("Expected initialization request to be sent through transport") + } +} + +// Helper function to create a mock JSON-RPC request for testing +func mockJSONRPCRequest(mcpRequest mcp.CreateMessageRequest) transport.JSONRPCRequest { + return transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: mcpRequest.CreateMessageParams, + } +} diff --git a/client/stdio.go b/client/stdio.go index 100c08a7c..199ec14c3 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -19,10 +19,26 @@ func NewStdioMCPClient( env []string, args ...string, ) (*Client, error) { + return NewStdioMCPClientWithOptions(command, env, args) +} + +// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command function. +// +// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. +// Don't call the Start method manually. +// This is for backward compatibility. +func NewStdioMCPClientWithOptions( + command string, + env []string, + args []string, + opts ...transport.StdioOption, +) (*Client, error) { + stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...) - stdioTransport := transport.NewStdio(command, env, args...) - err := stdioTransport.Start(context.Background()) - if err != nil { + if err := stdioTransport.Start(context.Background()); err != nil { return nil, fmt.Errorf("failed to start stdio transport: %w", err) } diff --git a/client/stdio_test.go b/client/stdio_test.go index b6faf9bfd..f41e48114 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -12,6 +12,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) @@ -305,3 +308,43 @@ func TestStdioMCPClient(t *testing.T) { } }) } + +func TestStdio_NewStdioMCPClientWithOptions_CreatesAndStartsClient(t *testing.T) { + called := false + + fakeCmdFunc := func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + called = true + return exec.CommandContext(ctx, "echo", "started"), nil + } + + client, err := NewStdioMCPClientWithOptions( + "echo", + []string{"FOO=bar"}, + []string{"hello"}, + transport.WithCommandFunc(fakeCmdFunc), + ) + require.NoError(t, err) + require.NotNil(t, client) + t.Cleanup(func() { + _ = client.Close() + }) + require.True(t, called) +} + +func TestStdio_NewStdioMCPClientWithOptions_FailsToStart(t *testing.T) { + // Create a commandFunc that points to a nonexistent binary + badCmdFunc := func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + return exec.CommandContext(ctx, "/nonexistent/bar", args...), nil + } + + client, err := NewStdioMCPClientWithOptions( + "foo", + nil, + nil, + transport.WithCommandFunc(badCmdFunc), + ) + + require.Error(t, err) + require.EqualError(t, err, "failed to start stdio transport: failed to start command: fork/exec /nonexistent/bar: no such file or directory") + require.Nil(t, client) +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 90fc2fae1..0e2393f07 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc func (*InProcessTransport) Close() error { return nil } + +func (c *InProcessTransport) GetSessionId() string { + return "" +} diff --git a/client/transport/interface.go b/client/transport/interface.go index c83c7c65a..5f8ed6180 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -29,6 +29,22 @@ type Interface interface { // Close the connection. Close() error + + // GetSessionId returns the session ID of the transport. + GetSessionId() string +} + +// RequestHandler defines a function that handles incoming requests from the server. +type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + +// BidirectionalInterface extends Interface to support incoming requests from the server. +// This is used for features like sampling where the server can send requests to the client. +type BidirectionalInterface interface { + Interface + + // SetRequestHandler sets the handler for incoming requests from the server. + // The handler should process the request and return a response. + SetRequestHandler(handler RequestHandler) } type JSONRPCRequest struct { @@ -41,10 +57,10 @@ type JSONRPCRequest struct { type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID mcp.RequestId `json:"id"` - Result json.RawMessage `json:"result"` + Result json.RawMessage `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data"` - } `json:"error"` + } `json:"error,omitempty"` } diff --git a/client/transport/sse.go b/client/transport/sse.go index b22ff62d4..ffe3247f0 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -428,6 +428,12 @@ func (c *SSE) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since SSE does not maintain a session ID, it returns an empty string. +func (c *SSE) GetSessionId() string { + return "" +} + // SendNotification sends a JSON-RPC notification to the server without expecting a response. func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { if c.endpoint == nil { diff --git a/client/transport/stdio.go b/client/transport/stdio.go index c300c405f..c36dc2d37 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -23,6 +23,7 @@ type Stdio struct { env []string cmd *exec.Cmd + cmdFunc CommandFunc stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser @@ -31,6 +32,28 @@ type Stdio struct { done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex +} + +// StdioOption defines a function that configures a Stdio transport instance. +// Options can be used to customize the behavior of the transport before it starts, +// such as setting a custom command function. +type StdioOption func(*Stdio) + +// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess. +// It can be used to apply sandboxing, custom environment control, working directories, etc. +type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error) + +// WithCommandFunc sets a custom command factory function for the stdio transport. +// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess, +// allowing control over attributes like environment, working directory, and system-level sandboxing. +func WithCommandFunc(f CommandFunc) StdioOption { + return func(s *Stdio) { + s.cmdFunc = f + } } // NewIO returns a new stdio-based transport using existing input, output, and @@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), } } @@ -55,20 +79,43 @@ func NewStdio( env []string, args ...string, ) *Stdio { + return NewStdioWithOptions(command, env, args) +} - client := &Stdio{ +// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command factory. +func NewStdioWithOptions( + command string, + env []string, + args []string, + opts ...StdioOption, +) *Stdio { + s := &Stdio{ command: command, args: args, env: env, responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), + } + + for _, opt := range opts { + opt(s) } - return client + return s } func (c *Stdio) Start(ctx context.Context) error { + // Store the context for use in request handling + c.ctxMu.Lock() + c.ctx = ctx + c.ctxMu.Unlock() + if err := c.spawnCommand(ctx); err != nil { return err } @@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error { return nil } -// spawnCommand spawns a new process running c.command. +// spawnCommand spawns a new process running the configured command, args, and env. +// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess; +// otherwise, the default behavior uses exec.CommandContext with the merged environment. +// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication. func (c *Stdio) spawnCommand(ctx context.Context) error { if c.command == "" { return nil } - cmd := exec.CommandContext(ctx, c.command, c.args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, c.env...) + var cmd *exec.Cmd + var err error - cmd.Env = mergedEnv + // Standard behavior if no command func present. + if c.cmdFunc == nil { + cmd = exec.CommandContext(ctx, c.command, c.args...) + cmd.Env = append(os.Environ(), c.env...) + } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil { + return err + } stdin, err := cmd.StdinPipe() if err != nil { @@ -148,6 +202,12 @@ func (c *Stdio) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since stdio does not maintain a session ID, it returns an empty string. +func (c *Stdio) GetSessionId() string { + return "" +} + // SetNotificationHandler sets the handler function to be called when a notification is received. // Only one handler can be set at a time; setting a new one replaces the previous handler. func (c *Stdio) SetNotificationHandler( @@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler( c.onNotification = handler } +// SetRequestHandler sets the handler function to be called when a request is received from the server. +// This enables bidirectional communication for features like sampling. +func (c *Stdio) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.onRequest = handler +} + // readResponses continuously reads and processes responses from the server's stdout. // It handles both responses to requests and notifications, routing them appropriately. // Runs until the done channel is closed or an error occurs reading from stdout. @@ -175,13 +243,18 @@ func (c *Stdio) readResponses() { return } - var baseMessage JSONRPCResponse + // First try to parse as a generic message to check for ID field + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Method string `json:"method,omitempty"` + } if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { continue } - // Handle notification - if baseMessage.ID.IsNil() { + // If it has a method but no ID, it's a notification + if baseMessage.Method != "" && baseMessage.ID == nil { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,15 +267,30 @@ func (c *Stdio) readResponses() { continue } + // If it has a method and an ID, it's an incoming request + if baseMessage.Method != "" && baseMessage.ID != nil { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err == nil { + c.handleIncomingRequest(request) + continue + } + } + + // Otherwise, it's a response to our request + var response JSONRPCResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + continue + } + // Create string key for map lookup - idKey := baseMessage.ID.String() + idKey := response.ID.String() c.mu.RLock() ch, exists := c.responses[idKey] c.mu.RUnlock() if exists { - ch <- &baseMessage + ch <- &response c.mu.Lock() delete(c.responses, idKey) c.mu.Unlock() @@ -281,6 +369,96 @@ func (c *Stdio) SendNotification( return nil } +// handleIncomingRequest processes incoming requests from the server. +// It calls the registered request handler and sends the response back to the server. +func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.onRequest + c.requestMu.RUnlock() + + if handler == nil { + // Send error response if no handler is configured + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.METHOD_NOT_FOUND, + Message: "No request handler configured", + }, + } + c.sendResponse(errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking + go func() { + c.ctxMu.RLock() + ctx := c.ctx + c.ctxMu.RUnlock() + + // Check if context is already cancelled before processing + select { + case <-ctx.Done(): + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: ctx.Err().Error(), + }, + } + c.sendResponse(errorResponse) + return + default: + } + + response, err := handler(ctx, request) + + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: err.Error(), + }, + } + c.sendResponse(errorResponse) + return + } + + if response != nil { + c.sendResponse(*response) + } + }() +} + +// sendResponse sends a response back to the server. +func (c *Stdio) sendResponse(response JSONRPCResponse) { + responseBytes, err := json.Marshal(response) + if err != nil { + fmt.Printf("Error marshaling response: %v\n", err) + return + } + responseBytes = append(responseBytes, '\n') + + if _, err := c.stdin.Write(responseBytes); err != nil { + fmt.Printf("Error writing response: %v\n", err) + } +} + // Stderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. func (c *Stdio) Stderr() io.Reader { diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 3eea5b23f..3c6804f3b 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -3,14 +3,19 @@ package transport import ( "context" "encoding/json" + "errors" "fmt" "os" "os/exec" + "path/filepath" "runtime" "sync" + "syscall" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/mark3labs/mcp-go/mcp" ) @@ -148,7 +153,6 @@ func TestStdio(t *testing.T) { }) t.Run("SendNotification & NotificationHandler", func(t *testing.T) { - var wg sync.WaitGroup notificationChan := make(chan mcp.JSONRPCNotification, 1) @@ -181,11 +185,33 @@ func TestStdio(t *testing.T) { defer wg.Done() select { case nt := <-notificationChan: - // We received a notification - responseJson, _ := json.Marshal(nt.Params.AdditionalFields) - requestJson, _ := json.Marshal(notification) - if string(responseJson) != string(requestJson) { - t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + // We received a notification from the mock server + // The mock server sends a notification with method "debug/test" and the original request as params + if nt.Method != "debug/test" { + t.Errorf("Expected notification method 'debug/test', got '%s'", nt.Method) + return + } + + // The mock server sends the original notification request as params + // We need to extract the original method from the nested structure + paramsJson, _ := json.Marshal(nt.Params) + var originalRequest struct { + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(paramsJson, &originalRequest); err != nil { + t.Errorf("Failed to unmarshal notification params: %v", err) + return + } + + if originalRequest.Method != "debug/echo_notification" { + t.Errorf("Expected original method 'debug/echo_notification', got '%s'", originalRequest.Method) + return + } + + // Check if the original params contain our test data + if testValue, ok := originalRequest.Params["test"]; !ok || testValue != "value" { + t.Errorf("Expected test param 'value', got %v", originalRequest.Params["test"]) } case <-time.After(1 * time.Second): @@ -380,7 +406,6 @@ func TestStdio(t *testing.T) { t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) } }) - } func TestStdioErrors(t *testing.T) { @@ -483,5 +508,137 @@ func TestStdioErrors(t *testing.T) { t.Errorf("Expected error when sending request after close, got nil") } }) +} + +func TestStdio_WithCommandFunc(t *testing.T) { + called := false + tmpDir := t.TempDir() + chrootDir := filepath.Join(tmpDir, "sandbox-root") + err := os.MkdirAll(chrootDir, 0o755) + require.NoError(t, err, "failed to create chroot dir") + + fakeCmdFunc := func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + called = true + + // Override the args inside our command func. + cmd := exec.CommandContext(ctx, command, "bonjour") + + // Simulate some security-related settings for test purposes. + cmd.Env = []string{"PATH=/usr/bin", "NODE_ENV=production"} + cmd.Dir = tmpDir + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: 1001, + Gid: 1001, + }, + Chroot: chrootDir, + } + + return cmd, nil + } + + stdio := NewStdioWithOptions( + "echo", + []string{"foo=bar"}, + []string{"hello"}, + WithCommandFunc(fakeCmdFunc), + ) + require.NotNil(t, stdio) + require.NotNil(t, stdio.cmdFunc) + + // Manually call the cmdFunc passing the same values as in spawnCommand. + cmd, err := stdio.cmdFunc(context.Background(), "echo", nil, []string{"hello"}) + require.NoError(t, err) + require.True(t, called) + require.NotNil(t, cmd) + require.NotNil(t, cmd.SysProcAttr) + require.Equal(t, chrootDir, cmd.SysProcAttr.Chroot) + require.Equal(t, tmpDir, cmd.Dir) + require.Equal(t, uint32(1001), cmd.SysProcAttr.Credential.Uid) + require.Equal(t, "echo", filepath.Base(cmd.Path)) + require.Len(t, cmd.Args, 2) + require.Contains(t, cmd.Args, "bonjour") + require.Len(t, cmd.Env, 2) + require.Contains(t, cmd.Env, "PATH=/usr/bin") + require.Contains(t, cmd.Env, "NODE_ENV=production") +} + +func TestStdio_SpawnCommand(t *testing.T) { + ctx := context.Background() + t.Setenv("TEST_ENVIRON_VAR", "true") + + // Explicitly not passing any environment, so we can see if it + // is picked up by spawn command merging the os.Environ. + stdio := NewStdio("echo", nil, "hello") + require.NotNil(t, stdio) + + err := stdio.spawnCommand(ctx) + require.NoError(t, err) + + t.Cleanup(func() { + _ = stdio.cmd.Process.Kill() + }) + + require.Equal(t, "echo", filepath.Base(stdio.cmd.Path)) + require.Contains(t, stdio.cmd.Args, "hello") + require.Contains(t, stdio.cmd.Env, "TEST_ENVIRON_VAR=true") +} + +func TestStdio_SpawnCommand_UsesCommandFunc(t *testing.T) { + ctx := context.Background() + t.Setenv("TEST_ENVIRON_VAR", "true") + + stdio := NewStdioWithOptions( + "echo", + nil, + []string{"test"}, + WithCommandFunc(func(ctx context.Context, cmd string, args []string, env []string) (*exec.Cmd, error) { + c := exec.CommandContext(ctx, cmd, "hola") + c.Env = env + return c, nil + }), + ) + require.NotNil(t, stdio) + err := stdio.spawnCommand(ctx) + require.NoError(t, err) + t.Cleanup(func() { + _ = stdio.cmd.Process.Kill() + }) + + require.Equal(t, "echo", filepath.Base(stdio.cmd.Path)) + require.Contains(t, stdio.cmd.Args, "hola") + require.NotContains(t, stdio.cmd.Env, "TEST_ENVIRON_VAR=true") + require.NotNil(t, stdio.stdin) + require.NotNil(t, stdio.stdout) + require.NotNil(t, stdio.stderr) +} + +func TestStdio_SpawnCommand_UsesCommandFunc_Error(t *testing.T) { + ctx := context.Background() + + stdio := NewStdioWithOptions( + "echo", + nil, + []string{"test"}, + WithCommandFunc(func(ctx context.Context, cmd string, args []string, env []string) (*exec.Cmd, error) { + return nil, errors.New("test error") + }), + ) + require.NotNil(t, stdio) + err := stdio.spawnCommand(ctx) + require.Error(t, err) + require.EqualError(t, err, "test error") +} + +func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) { + configured := false + + opt := func(s *Stdio) { + configured = true + } + stdio := NewStdioWithOptions("echo", nil, []string{"test"}, opt) + require.NotNil(t, stdio) + require.True(t, configured, "option was not applied") } diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 50bde9c28..e358751b3 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -17,10 +17,24 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) type StreamableHTTPCOption func(*StreamableHTTP) +// WithContinuousListening enables receiving server-to-client notifications when no request is in flight. +// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), +// you should enable this option. +// +// It will establish a standalone long-live GET HTTP connection to the server. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server +// NOTICE: Even enabled, the server may not support this feature. +func WithContinuousListening() StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.getListeningEnabled = true + } +} + // WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport. func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption { return func(sc *StreamableHTTP) { @@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { } } +func WithLogger(logger util.Logger) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.logger = logger + } +} + +// WithSession creates a client with a pre-configured session +func WithSession(sessionID string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.sessionID.Store(sessionID) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { // // The current implementation does not support the following features: // - batching -// - continuously listening for server notifications when no request is in flight -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) // - server -> client request type StreamableHTTP struct { - serverURL *url.URL - httpClient *http.Client - headers map[string]string - headerFunc HTTPHeaderFunc + serverURL *url.URL + httpClient *http.Client + headers map[string]string + headerFunc HTTPHeaderFunc + logger util.Logger + getListeningEnabled bool sessionID atomic.Value // string + initialized chan struct{} + initializedOnce sync.Once + notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str } smc := &StreamableHTTP{ - serverURL: parsedURL, - httpClient: &http.Client{}, - headers: make(map[string]string), - closed: make(chan struct{}), + serverURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + logger: util.DefaultLogger(), + initialized: make(chan struct{}), } smc.sessionID.Store("") // set initial value to simplify later usage for _, opt := range options { - opt(smc) + if opt != nil { + opt(smc) + } } // If OAuth is configured, set the base URL for metadata discovery @@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str // Start initiates the HTTP connection to the server. func (c *StreamableHTTP) Start(ctx context.Context) error { - // For Streamable HTTP, we don't need to establish a persistent connection + // For Streamable HTTP, we don't need to establish a persistent connection by default + if c.getListeningEnabled { + go func() { + select { + case <-c.initialized: + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + c.listenForever(ctx) + case <-c.closed: + return + } + }() + } + return nil } @@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error { defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { - fmt.Printf("failed to create close request\n: %v", err) + c.logger.Errorf("failed to create close request: %v", err) return } req.Header.Set(headerKeySessionID, sessionId) res, err := c.httpClient.Do(req) if err != nil { - fmt.Printf("failed to send close request\n: %v", err) + c.logger.Errorf("failed to send close request: %v", err) return } res.Body.Close() @@ -185,77 +232,29 @@ func (c *StreamableHTTP) SendRequest( request JSONRPCRequest, ) (*JSONRPCResponse, error) { - // Create a combined context that could be canceled when the client is closed - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-c.closed: - cancel() - case <-newCtx.Done(): - // The original context was canceled, no need to do anything - } - }() - ctx = newCtx - // Marshal request requestBody, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - sessionID := c.sessionID.Load() - if sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { - return nil, &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return nil, fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { + // If the request is initialize, should not return a SessionTerminated error + // It should be a genuine endpoint-routing issue. + // ( Fall through to return StatusCode checking. ) + } else { + return nil, fmt.Errorf("failed to send request: %w", err) + } } defer resp.Body.Close() // Check if we got an error response if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - // handle session closed - if resp.StatusCode == http.StatusNotFound { - c.sessionID.CompareAndSwap(sessionID, "") - return nil, fmt.Errorf("session terminated (404). need to re-initialize") - } // Handle OAuth unauthorized error if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { @@ -279,6 +278,10 @@ func (c *StreamableHTTP) SendRequest( if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { c.sessionID.Store(sessionID) } + + c.initializedOnce.Do(func() { + close(c.initialized) + }) } // Handle different response types @@ -300,16 +303,77 @@ func (c *StreamableHTTP) SendRequest( case "text/event-stream": // Server is using SSE for streaming responses - return c.handleSSEResponse(ctx, resp.Body) + return c.handleSSEResponse(ctx, resp.Body, false) default: return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) } } +func (c *StreamableHTTP) sendHTTP( + ctx context.Context, + method string, + body io.Reader, + acceptType string, +) (resp *http.Response, err error) { + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", acceptType) + sessionID := c.sessionID.Load().(string) + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + + if c.headerFunc != nil { + for k, v := range c.headerFunc(ctx) { + req.Header.Set(k, v) + } + } + + // Send request + resp, err = c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // universal handling for session terminated + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, ErrSessionTerminated + } + + return resp, nil +} + // handleSSEResponse processes an SSE stream for a specific request. // It returns the final result for the request once received, or an error. -func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { +// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) { // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) @@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { - fmt.Printf("failed to unmarshal message: %v\n", err) + c.logger.Errorf("failed to unmarshal message: %v", err) return } @@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - fmt.Printf("failed to unmarshal notification: %v\n", err) + c.logger.Errorf("failed to unmarshal notification: %v", err) return } c.notifyMu.RLock() @@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } - responseChan <- &message + if !ignoreResponse { + responseChan <- &message + } }) }() @@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand case <-ctx.Done(): return default: - fmt.Printf("SSE stream error: %v\n", err) + c.logger.Errorf("SSE stream error: %v", err) return } } @@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - if sessionID := c.sessionID.Load(); sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if errors.Is(err, ErrOAuthAuthorizationRequired) { - return &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { func (c *StreamableHTTP) IsOAuthEnabled() bool { return c.oauthHandler != nil } + +func (c *StreamableHTTP) listenForever(ctx context.Context) { + c.logger.Infof("listening to server forever") + for { + err := c.createGETConnectionToServer(ctx) + if errors.Is(err, ErrGetMethodNotAllowed) { + // server does not support listening + c.logger.Errorf("server does not support listening") + return + } + + select { + case <-ctx.Done(): + return + default: + } + + if err != nil { + c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) + } + time.Sleep(retryInterval) + } +} + +var ( + ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize") + ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed") + + retryInterval = 1 * time.Second // a variable is convenient for testing +) + +func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { + + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode == http.StatusMethodNotAllowed { + return ErrGetMethodNotAllowed + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + // handle SSE response + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + // When ignoreResponse is true, the function will never return expect context is done. + // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response + // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based, + // currently, there is no convenient way to handle this response. + // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs. + _, err = c.handleSSEResponse(ctx, resp.Body, true) + if err != nil { + return fmt.Errorf("failed to handle SSE response: %w", err) + } + + return nil +} + +func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled + cancel() + } + }() + return newCtx, cancel +} diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 4cd5ad19e..25962940e 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "sync" "testing" "time" @@ -524,3 +525,259 @@ func TestStreamableHTTPErrors(t *testing.T) { }) } + +// ---- continuous listening tests ---- + +// startMockStreamableWithGETSupport starts a test HTTP server that implements +// a minimal Streamable HTTP server for testing purposes with support for GET requests +// to test the continuous listening feature. +func startMockStreamableWithGETSupport(getSupport bool) (string, func(), chan bool, int) { + var sessionID string + var mu sync.Mutex + disconnectCh := make(chan bool, 1) + notificationCount := 0 + var notificationMu sync.Mutex + + sendNotification := func() { + notificationMu.Lock() + notificationCount++ + notificationMu.Unlock() + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle POST requests for initialization + if r.Method == http.MethodPost { + // Parse incoming JSON-RPC request + var request map[string]any + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + method := request["method"] + if method == "initialize" { + // Generate a new session ID + mu.Lock() + sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) + mu.Unlock() + w.Header().Set("Mcp-Session-Id", sessionID) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": "initialized", + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } + return + } + + // Handle GET requests for continuous listening + if r.Method == http.MethodGet { + if !getSupport { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check session ID + if recvSessionID := r.Header.Get("Mcp-Session-Id"); recvSessionID != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Setup SSE connection + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + // Send a notification + notification := map[string]any{ + "jsonrpc": "2.0", + "method": "test/notification", + "params": map[string]any{"message": "Hello from server"}, + } + notificationData, _ := json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + flusher.Flush() + sendNotification() + + // Keep the connection open or disconnect as requested + select { + case <-disconnectCh: + // Force disconnect + return + case <-r.Context().Done(): + // Client disconnected + return + case <-time.After(50 * time.Millisecond): + // Send another notification + notification = map[string]any{ + "jsonrpc": "2.0", + "method": "test/notification", + "params": map[string]any{"message": "Second notification"}, + } + notificationData, _ = json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + flusher.Flush() + sendNotification() + return + } + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + }) + + // Start test server + testServer := httptest.NewServer(handler) + + notificationMu.Lock() + defer notificationMu.Unlock() + + return testServer.URL, testServer.Close, disconnectCh, notificationCount +} + +func TestContinuousListening(t *testing.T) { + retryInterval = 10 * time.Millisecond + // Start mock server with GET support + url, closeServer, disconnectCh, _ := startMockStreamableWithGETSupport(true) + + // Create transport with continuous listening enabled + trans, err := NewStreamableHTTP(url, WithContinuousListening()) + if err != nil { + t.Fatal(err) + } + + // Ensure transport is closed before server to avoid connection refused errors + defer func() { + trans.Close() + closeServer() + }() + + // Setup notification handler + notificationReceived := make(chan struct{}, 10) + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationReceived <- struct{}{} + }) + + // Start the transport - this will launch listenForever in a goroutine + if err := trans.Start(context.Background()); err != nil { + t.Fatal(err) + } + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(0)), + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Wait for notifications to be received + notificationCount := 0 + for notificationCount < 2 { + select { + case <-notificationReceived: + notificationCount++ + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for notifications, received %d", notificationCount) + return + } + } + + // Test server disconnect and reconnect + disconnectCh <- true + time.Sleep(50 * time.Millisecond) // Allow time for reconnection + + // Verify reconnect occurred by receiving more notifications + reconnectNotificationCount := 0 + for reconnectNotificationCount < 2 { + select { + case <-notificationReceived: + reconnectNotificationCount++ + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for notifications after reconnect") + return + } + } +} + +func TestContinuousListeningMethodNotAllowed(t *testing.T) { + + // Start a server that doesn't support GET + url, closeServer, _, _ := startMockStreamableWithGETSupport(false) + + // Setup logger to capture log messages + logChan := make(chan string, 10) + testLogger := &testLogger{logChan: logChan} + + // Create transport with continuous listening enabled and custom logger + trans, err := NewStreamableHTTP(url, WithContinuousListening(), WithLogger(testLogger)) + if err != nil { + t.Fatal(err) + } + + // Ensure transport is closed before server to avoid connection refused errors + defer func() { + trans.Close() + closeServer() + }() + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Start the transport + if err := trans.Start(context.Background()); err != nil { + t.Fatal(err) + } + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(0)), + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Wait for the error log message that server doesn't support listening + select { + case logMsg := <-logChan: + if !strings.Contains(logMsg, "server does not support listening") { + t.Errorf("Expected error log about server not supporting listening, got: %s", logMsg) + } + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for log message") + } +} + +// testLogger is a simple logger for testing +type testLogger struct { + logChan chan string +} + +func (l *testLogger) Infof(format string, args ...any) { + // Intentionally left empty +} + +func (l *testLogger) Errorf(format string, args ...any) { + l.logChan <- fmt.Sprintf(format, args...) +} diff --git a/examples/in_process/main.go b/examples/in_process/main.go new file mode 100644 index 000000000..d01a5e808 --- /dev/null +++ b/examples/in_process/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// handleDummyTool is a simple tool that returns "foo bar" +func handleDummyTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("foo bar"), nil +} + +func NewMCPServer() *server.MCPServer { + mcpServer := server.NewMCPServer( + "example-server", + "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + ) + mcpServer.AddTool(mcp.NewTool("dummy_tool", + mcp.WithDescription("A dummy tool that returns foo bar"), + ), handleDummyTool) + + return mcpServer +} + +type MCPClient struct { + client *client.Client + serverInfo *mcp.InitializeResult +} + +// NewMCPClient creates a new MCP client with an in-process MCP server. +func NewMCPClient(ctx context.Context) (*MCPClient, error) { + srv := NewMCPServer() + client, err := client.NewInProcessClient(srv) + if err != nil { + return nil, fmt.Errorf("failed to create in-process client: %w", err) + } + + // Start the client with timeout context + ctxWithTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := client.Start(ctxWithTimeout); err != nil { + return nil, fmt.Errorf("failed to start client: %w", err) + } + + // Initialize the client + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "Example MCP Client", + Version: "1.0.0", + } + initRequest.Params.Capabilities = mcp.ClientCapabilities{} + + serverInfo, err := client.Initialize(ctx, initRequest) + if err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + return &MCPClient{ + client: client, + serverInfo: serverInfo, + }, nil +} + +func main() { + ctx := context.Background() + client, err := NewMCPClient(ctx) + if err != nil { + log.Fatalf("Failed to create MCP client: %v", err) + } + + toolsRequest := mcp.ListToolsRequest{} + toolsResult, err := client.client.ListTools(ctx, toolsRequest) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + fmt.Println(toolsResult.Tools) + + request := mcp.CallToolRequest{} + request.Params.Name = "dummy_tool" + + result, err := client.client.CallTool(ctx, request) + if err != nil { + log.Fatalf("Failed to call tool: %v", err) + } + fmt.Println(result.Content) +} diff --git a/examples/sampling_client/README.md b/examples/sampling_client/README.md new file mode 100644 index 000000000..7a1d9cb3f --- /dev/null +++ b/examples/sampling_client/README.md @@ -0,0 +1,87 @@ +# MCP Sampling Example Client + +This example demonstrates how to implement an MCP client that supports sampling requests from servers. + +## Features + +- **Sampling Handler**: Implements the `SamplingHandler` interface to process sampling requests +- **Mock LLM**: Provides a mock LLM implementation for demonstration purposes +- **Capability Declaration**: Automatically declares sampling capability when a handler is configured +- **Bidirectional Communication**: Handles incoming requests from the server + +## Mock LLM Handler + +The `MockSamplingHandler` simulates an LLM by: +- Logging the received request parameters +- Generating a mock response that echoes the input +- Returning proper MCP sampling response format + +In a real implementation, you would: +- Integrate with actual LLM APIs (OpenAI, Anthropic, etc.) +- Implement proper model selection based on preferences +- Add human-in-the-loop approval mechanisms +- Handle rate limiting and error cases + +## Usage + +Build the client: + +```bash +go build -o sampling_client +``` + +Run with the sampling server: + +```bash +./sampling_client ../sampling_server/sampling_server +``` + +Or with any other MCP server that supports sampling: + +```bash +./sampling_client /path/to/your/mcp/server +``` + +## Implementation Details + +1. **Sampling Handler**: Implements `client.SamplingHandler` interface +2. **Client Configuration**: Uses `client.WithSamplingHandler()` to enable sampling +3. **Automatic Capability**: Sampling capability is automatically declared during initialization +4. **Request Processing**: Handles incoming `sampling/createMessage` requests from servers + +## Sample Output + +``` +Connected to server: sampling-example-server v1.0.0 +Available tools: + - ask_llm: Ask the LLM a question using sampling + - greet: Greet the user + +--- Testing greet tool --- +Greet result: Hello, Sampling Demo User! This server supports sampling - try using the ask_llm tool! + +--- Testing ask_llm tool (with sampling) --- +Mock LLM received: What is the capital of France? +System prompt: You are a helpful geography assistant. +Max tokens: 1000 +Temperature: 0.700000 +Ask LLM result: LLM Response (model: mock-llm-v1): Mock LLM response to: 'What is the capital of France?'. This is a simulated response from a mock LLM handler. +``` + +## Real LLM Integration + +To integrate with a real LLM, replace the `MockSamplingHandler` with an implementation that: + +```go +type RealSamplingHandler struct { + apiKey string + client *openai.Client // or other LLM client +} + +func (h *RealSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP request to LLM API format + // Call LLM API + // Convert response back to MCP format + // Return result +} +``` \ No newline at end of file diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go new file mode 100644 index 000000000..67b3840b0 --- /dev/null +++ b/examples/sampling_client/main.go @@ -0,0 +1,190 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockSamplingHandler implements the SamplingHandler interface for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + userMessage := request.Messages[0] + var userText string + + // Extract text from the content + switch content := userMessage.Content.(type) { + case mcp.TextContent: + userText = content.Text + case map[string]interface{}: + // Handle case where content is unmarshaled as a map + if text, ok := content["text"].(string); ok { + userText = text + } else { + userText = fmt.Sprintf("%v", content) + } + default: + userText = fmt.Sprintf("%v", content) + } + + // Simulate LLM processing + log.Printf("Mock LLM received: %s", userText) + log.Printf("System prompt: %s", request.SystemPrompt) + log.Printf("Max tokens: %d", request.MaxTokens) + log.Printf("Temperature: %f", request.Temperature) + + // Generate a mock response + responseText := fmt.Sprintf("Mock LLM response to: '%s'. This is a simulated response from a mock LLM handler.", userText) + + log.Printf("Mock LLM generating response: %s", responseText) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + Model: "mock-llm-v1", + StopReason: "endTurn", + } + + log.Printf("Mock LLM returning result: %+v", result) + return result, nil +} + +func main() { + if len(os.Args) < 2 { + log.Fatal("Usage: sampling_client ") + } + + serverCommand := os.Args[1] + serverArgs := os.Args[2:] + + // Create stdio transport to communicate with the server + stdio := transport.NewStdio(serverCommand, nil, serverArgs...) + + // Create sampling handler + samplingHandler := &MockSamplingHandler{} + + // Create client with sampling capability + mcpClient := client.NewClient(stdio, client.WithSamplingHandler(samplingHandler)) + + ctx := context.Background() + + // Start the client + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + defer mcpClient.Close() + + // Initialize the connection + initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "sampling-example-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by WithSamplingHandler + }, + }, + }) + if err != nil { + log.Fatalf("Failed to initialize: %v", err) + } + + log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + log.Printf("Server capabilities: %+v", initResult.Capabilities) + + // List available tools + toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + + log.Printf("Available tools:") + for _, tool := range toolsResult.Tools { + log.Printf(" - %s: %s", tool.Name, tool.Description) + } + + // Test the greeting tool first + log.Println("\n--- Testing greet tool ---") + greetResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]any{ + "name": "Sampling Demo User", + }, + }, + }) + if err != nil { + log.Printf("Error calling greet tool: %v", err) + } else { + log.Printf("Greet result: %+v", greetResult) + for _, content := range greetResult.Content { + if textContent, ok := content.(mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + // Test the sampling tool + log.Println("\n--- Testing ask_llm tool (with sampling) ---") + askResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "ask_llm", + Arguments: map[string]any{ + "question": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant.", + }, + }, + }) + if err != nil { + log.Printf("Error calling ask_llm tool: %v", err) + } else { + log.Printf("Ask LLM result: %+v", askResult) + for _, content := range askResult.Content { + if textContent, ok := content.(mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + // Test another sampling request + log.Println("\n--- Testing ask_llm tool with different question ---") + askResult2, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "ask_llm", + Arguments: map[string]any{ + "question": "Explain quantum computing in simple terms.", + }, + }, + }) + if err != nil { + log.Printf("Error calling ask_llm tool: %v", err) + } else { + log.Printf("Ask LLM result 2: %+v", askResult2) + for _, content := range askResult2.Content { + if textContent, ok := content.(mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + log.Println("\n--- Sampling demo completed ---") +} diff --git a/examples/sampling_client/sampling_client b/examples/sampling_client/sampling_client new file mode 100755 index 000000000..d18fa5fa5 Binary files /dev/null and b/examples/sampling_client/sampling_client differ diff --git a/examples/sampling_server/README.md b/examples/sampling_server/README.md new file mode 100644 index 000000000..53c823792 --- /dev/null +++ b/examples/sampling_server/README.md @@ -0,0 +1,52 @@ +# MCP Sampling Example Server + +This example demonstrates how to implement an MCP server that uses sampling to request LLM completions from clients. + +## Features + +- **Sampling Support**: The server can request LLM completions from clients that support sampling +- **Tool Integration**: Shows how to use sampling within tool implementations +- **Bidirectional Communication**: Demonstrates server-to-client requests + +## Tools + +### `ask_llm` +Asks the LLM a question using sampling. This tool demonstrates how servers can leverage client-side LLM capabilities. + +**Parameters:** +- `question` (required): The question to ask the LLM +- `system_prompt` (optional): System prompt to provide context + +### `greet` +A simple greeting tool that doesn't use sampling, for comparison. + +**Parameters:** +- `name` (required): Name of the person to greet + +## Usage + +Build and run the server: + +```bash +go build -o sampling_server +./sampling_server +``` + +The server communicates via stdio and expects to be connected to an MCP client that supports sampling. + +## Implementation Details + +1. **Enable Sampling**: The server calls `mcpServer.EnableSampling()` to declare sampling capability +2. **Request Sampling**: Tools use `mcpServer.RequestSampling(ctx, request)` to send sampling requests to the client +3. **Handle Responses**: The server receives and processes the LLM responses from the client via bidirectional stdio communication +4. **Response Routing**: Incoming responses are automatically routed to the correct pending request using request IDs + +## Testing + +Use the companion `sampling_client` example to test this server: + +```bash +cd ../sampling_client +go build -o sampling_client +./sampling_client ../sampling_server/sampling_server +``` \ No newline at end of file diff --git a/examples/sampling_server/main.go b/examples/sampling_server/main.go new file mode 100644 index 000000000..c3bcf4902 --- /dev/null +++ b/examples/sampling_server/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("sampling-example-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add a tool that uses sampling + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt to provide context", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters using helper methods + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from the client + samplingCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + serverFromCtx := server.ServerFromContext(ctx) + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM's response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, getTextFromContent(result.Content)), + }, + }, + }, nil + }) + + // Add a simple greeting tool + mcpServer.AddTool(mcp.Tool{ + Name: "greet", + Description: "Greet the user", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "Name of the person to greet", + }, + }, + Required: []string{"name"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + name, err := request.RequireString("name") + if err != nil { + return nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Hello, %s! This server supports sampling - try using the ask_llm tool!", name), + }, + }, + }, nil + }) + + // Start the stdio server + log.Println("Starting sampling example server...") + if err := server.ServeStdio(mcpServer); err != nil { + log.Fatalf("Server error: %v", err) + } +} + +// Helper function to extract text from content +func getTextFromContent(content interface{}) string { + switch c := content.(type) { + case mcp.TextContent: + return c.Text + case map[string]interface{}: + // Handle JSON unmarshaled content + if text, ok := c["text"].(string); ok { + return text + } + return fmt.Sprintf("%v", content) + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} diff --git a/examples/sampling_server/sampling_server b/examples/sampling_server/sampling_server new file mode 100755 index 000000000..9e4d0dfcb Binary files /dev/null and b/examples/sampling_server/sampling_server differ diff --git a/examples/simple_client/main.go b/examples/simple_client/main.go index 5deb99113..c0f48593a 100644 --- a/examples/simple_client/main.go +++ b/examples/simple_client/main.go @@ -54,6 +54,11 @@ func main() { // Create client with the transport c = client.NewClient(stdioTransport) + // Start the client + if err := c.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + // Set up logging for stderr if available if stderr, ok := client.GetStderr(c); ok { go func() { @@ -76,6 +81,12 @@ func main() { fmt.Println("Initializing HTTP client...") // Create HTTP transport httpTransport, err := transport.NewStreamableHTTP(*httpURL) + // NOTE: the default streamableHTTP transport is not 100% identical to the stdio client. + // By default, it could not receive global notifications (e.g. toolListChanged). + // You need to enable the `WithContinuousListening()` option to establish a long-live connection, + // and receive the notifications any time the server sends them. + // + // httpTransport, err := transport.NewStreamableHTTP(*httpURL, transport.WithContinuousListening()) if err != nil { log.Fatalf("Failed to create HTTP transport: %v", err) } @@ -84,11 +95,6 @@ func main() { c = client.NewClient(httpTransport) } - // Start the client - if err := c.Start(ctx); err != nil { - log.Fatalf("Failed to start client: %v", err) - } - // Set up notification handler c.OnNotification(func(notification mcp.JSONRPCNotification) { fmt.Printf("Received notification: %s\n", notification.Method) diff --git a/mcp/tools.go b/mcp/tools.go index 5f3524b02..3e3931b09 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption { } } -// Items defines the schema for array items +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. func Items(schema any) PropertyOption { return func(schemaMap map[string]any) { schemaMap["items"] = schema @@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption { schema["uniqueItems"] = unique } } + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 7f2640b94..0cd71230e 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -528,3 +528,187 @@ func TestFlexibleArgumentsJSONMarshalUnmarshal(t *testing.T) { assert.Equal(t, "value1", args["key1"]) assert.Equal(t, float64(123), args["key2"]) // JSON numbers are unmarshaled as float64 } + +// TestNewItemsAPICompatibility tests that the new Items API functions +// generate the same schema as the original Items() function with manual schema objects +func TestNewItemsAPICompatibility(t *testing.T) { + tests := []struct { + name string + oldTool Tool + newTool Tool + }{ + { + name: "WithStringItems basic", + oldTool: NewTool("old-string-array", + WithDescription("Tool with string array using old API"), + WithArray("items", + Description("List of string items"), + Items(map[string]any{ + "type": "string", + }), + ), + ), + newTool: NewTool("new-string-array", + WithDescription("Tool with string array using new API"), + WithArray("items", + Description("List of string items"), + WithStringItems(), + ), + ), + }, + { + name: "WithStringEnumItems", + oldTool: NewTool("old-enum-array", + WithDescription("Tool with enum array using old API"), + WithArray("status", + Description("Filter by status"), + Items(map[string]any{ + "type": "string", + "enum": []string{"active", "inactive", "pending"}, + }), + ), + ), + newTool: NewTool("new-enum-array", + WithDescription("Tool with enum array using new API"), + WithArray("status", + Description("Filter by status"), + WithStringEnumItems([]string{"active", "inactive", "pending"}), + ), + ), + }, + { + name: "WithStringItems with options", + oldTool: NewTool("old-string-with-opts", + WithDescription("Tool with string array with options using old API"), + WithArray("names", + Description("List of names"), + Items(map[string]any{ + "type": "string", + "minLength": 1, + "maxLength": 50, + }), + ), + ), + newTool: NewTool("new-string-with-opts", + WithDescription("Tool with string array with options using new API"), + WithArray("names", + Description("List of names"), + WithStringItems(MinLength(1), MaxLength(50)), + ), + ), + }, + { + name: "WithNumberItems basic", + oldTool: NewTool("old-number-array", + WithDescription("Tool with number array using old API"), + WithArray("scores", + Description("List of scores"), + Items(map[string]any{ + "type": "number", + }), + ), + ), + newTool: NewTool("new-number-array", + WithDescription("Tool with number array using new API"), + WithArray("scores", + Description("List of scores"), + WithNumberItems(), + ), + ), + }, + { + name: "WithNumberItems with constraints", + oldTool: NewTool("old-number-with-constraints", + WithDescription("Tool with constrained number array using old API"), + WithArray("ratings", + Description("List of ratings"), + Items(map[string]any{ + "type": "number", + "minimum": 0.0, + "maximum": 10.0, + }), + ), + ), + newTool: NewTool("new-number-with-constraints", + WithDescription("Tool with constrained number array using new API"), + WithArray("ratings", + Description("List of ratings"), + WithNumberItems(Min(0), Max(10)), + ), + ), + }, + { + name: "WithBooleanItems basic", + oldTool: NewTool("old-boolean-array", + WithDescription("Tool with boolean array using old API"), + WithArray("flags", + Description("List of feature flags"), + Items(map[string]any{ + "type": "boolean", + }), + ), + ), + newTool: NewTool("new-boolean-array", + WithDescription("Tool with boolean array using new API"), + WithArray("flags", + Description("List of feature flags"), + WithBooleanItems(), + ), + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal both tools to JSON + oldData, err := json.Marshal(tt.oldTool) + assert.NoError(t, err) + + newData, err := json.Marshal(tt.newTool) + assert.NoError(t, err) + + // Unmarshal to maps for comparison + var oldResult, newResult map[string]any + err = json.Unmarshal(oldData, &oldResult) + assert.NoError(t, err) + + err = json.Unmarshal(newData, &newResult) + assert.NoError(t, err) + + // Compare the inputSchema properties (ignoring tool names and descriptions) + oldSchema := oldResult["inputSchema"].(map[string]any) + newSchema := newResult["inputSchema"].(map[string]any) + + oldProperties := oldSchema["properties"].(map[string]any) + newProperties := newSchema["properties"].(map[string]any) + + // Get the array property (should be the only one in these tests) + var oldArrayProp, newArrayProp map[string]any + for _, prop := range oldProperties { + if propMap, ok := prop.(map[string]any); ok && propMap["type"] == "array" { + oldArrayProp = propMap + break + } + } + for _, prop := range newProperties { + if propMap, ok := prop.(map[string]any); ok && propMap["type"] == "array" { + newArrayProp = propMap + break + } + } + + assert.NotNil(t, oldArrayProp, "Old tool should have array property") + assert.NotNil(t, newArrayProp, "New tool should have array property") + + // Compare the items schema - this is the critical part + oldItems := oldArrayProp["items"] + newItems := newArrayProp["items"] + + assert.Equal(t, oldItems, newItems, "Items schema should be identical between old and new API") + + // Also compare other array properties like description + assert.Equal(t, oldArrayProp["description"], newArrayProp["description"], "Array descriptions should match") + assert.Equal(t, oldArrayProp["type"], newArrayProp["type"], "Array types should match") + }) + } +} diff --git a/mcp/types.go b/mcp/types.go index 0091d2e42..241b55ce9 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -763,6 +763,11 @@ const ( /* Sampling */ +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + // CreateMessageRequest is a request from the server to sample an LLM via the // client. The client has full discretion over which model to select. The client // should also inform the user before beginning sampling, to allow them to inspect @@ -865,6 +870,22 @@ type AudioContent struct { func (AudioContent) isContent() {} +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. // // It is up to the client how best to render embedded resources for the diff --git a/mcp/utils.go b/mcp/utils.go index 55bef7a99..3e652efd7 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent { } } +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: "resource_link", + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ @@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewAudioContent(data, mimeType), nil + case "resource_link": + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + case "resource": resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { diff --git a/mcptest/mcptest.go b/mcptest/mcptest.go index 232eac5df..bc7ccc0fa 100644 --- a/mcptest/mcptest.go +++ b/mcptest/mcptest.go @@ -20,9 +20,10 @@ import ( type Server struct { name string - tools []server.ServerTool - prompts []server.ServerPrompt - resources []server.ServerResource + tools []server.ServerTool + prompts []server.ServerPrompt + resources []server.ServerResource + resourceTemplates []ServerResourceTemplate cancel func() @@ -106,6 +107,25 @@ func (s *Server) AddResources(resources ...server.ServerResource) { s.resources = append(s.resources, resources...) } +// ServerResourceTemplate combines a ResourceTemplate with its handler function. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + Handler server.ResourceTemplateHandlerFunc +} + +// AddResourceTemplate adds a resource template to an unstarted server. +func (s *Server) AddResourceTemplate(template mcp.ResourceTemplate, handler server.ResourceTemplateHandlerFunc) { + s.resourceTemplates = append(s.resourceTemplates, ServerResourceTemplate{ + Template: template, + Handler: handler, + }) +} + +// AddResourceTemplates adds multiple resource templates to an unstarted server. +func (s *Server) AddResourceTemplates(templates ...ServerResourceTemplate) { + s.resourceTemplates = append(s.resourceTemplates, templates...) +} + // Start starts the server in a goroutine. Make sure to defer Close() after Start(). // When using NewServer(), the returned server is already started. func (s *Server) Start(ctx context.Context) error { @@ -122,6 +142,10 @@ func (s *Server) Start(ctx context.Context) error { mcpServer.AddTools(s.tools...) mcpServer.AddPrompts(s.prompts...) mcpServer.AddResources(s.resources...) + + for _, template := range s.resourceTemplates { + mcpServer.AddResourceTemplate(template.Template, template.Handler) + } logger := log.New(&s.logBuffer, "", 0) diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 0ab9b276e..18922cb84 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -187,3 +187,79 @@ func TestServerWithResource(t *testing.T) { t.Errorf("Got %q, want %q", textContent.Text, want) } } + +func TestServerWithResourceTemplate(t *testing.T) { + ctx := context.Background() + + srv := mcptest.NewUnstartedServer(t) + defer srv.Close() + + template := mcp.NewResourceTemplate( + "file://users/{userId}/documents/{docId}", + "User Document", + mcp.WithTemplateDescription("A user's document"), + mcp.WithTemplateMIMEType("text/plain"), + ) + + handler := func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + if request.Params.Arguments == nil { + return nil, fmt.Errorf("expected arguments to be populated from URI template") + } + + userIds, ok := request.Params.Arguments["userId"].([]string) + if !ok { + return nil, fmt.Errorf("expected userId argument to be populated from URI template") + } + if len(userIds) != 1 { + return nil, fmt.Errorf("expected userId to have one value, but got %d", len(userIds)) + } + if userIds[0] != "john" { + return nil, fmt.Errorf("expected userId argument to be 'john', got %s", userIds[0]) + } + + docIds, ok := request.Params.Arguments["docId"].([]string) + if !ok { + return nil, fmt.Errorf("expected docId argument to be populated from URI template") + } + if len(docIds) != 1 { + return nil, fmt.Errorf("expected docId to have one value, but got %d", len(docIds)) + } + if docIds[0] != "readme.txt" { + return nil, fmt.Errorf("expected docId argument to be 'readme.txt', got %v", docIds) + } + + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: request.Params.URI, + MIMEType: "text/plain", + Text: fmt.Sprintf("Document %s for user %s", docIds[0], userIds[0]), + }, + }, nil + } + + srv.AddResourceTemplate(template, handler) + + err := srv.Start(ctx) + if err != nil { + t.Fatal(err) + } + + // Test reading a resource that matches the template + var readReq mcp.ReadResourceRequest + readReq.Params.URI = "file://users/john/documents/readme.txt" + readResult, err := srv.Client().ReadResource(ctx, readReq) + if err != nil { + t.Fatal("ReadResource:", err) + } + if len(readResult.Contents) != 1 { + t.Fatalf("Expected 1 content, got %d", len(readResult.Contents)) + } + textContent, ok := readResult.Contents[0].(mcp.TextResourceContents) + if !ok { + t.Fatalf("Expected TextResourceContents, got %T", readResult.Contents[0]) + } + want := "Document readme.txt for user john" + if textContent.Text != want { + t.Errorf("Got %q, want %q", textContent.Text, want) + } +} diff --git a/sampling_server b/sampling_server new file mode 100755 index 000000000..ce83fb4f3 Binary files /dev/null and b/sampling_server differ diff --git a/server/sampling.go b/server/sampling.go new file mode 100644 index 000000000..b633b24d0 --- /dev/null +++ b/server/sampling.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/server/sampling_test.go b/server/sampling_test.go new file mode 100644 index 000000000..c69ac6cb5 --- /dev/null +++ b/server/sampling_test.go @@ -0,0 +1,115 @@ +package server + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestMCPServer_RequestSampling_NoSession(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.EnableSampling() + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + {Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "Test"}}, + }, + MaxTokens: 100, + }, + } + + _, err := server.RequestSampling(context.Background(), request) + + if err == nil { + t.Error("expected error when no session available") + } + + expectedError := "no active session" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) + } +} + +// mockSession implements ClientSession for testing +type mockSession struct { + sessionID string +} + +func (m *mockSession) SessionID() string { + return m.sessionID +} + +func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockSession) Initialize() {} + +func (m *mockSession) Initialized() bool { + return true +} + +// mockSamplingSession implements SessionWithSampling for testing +type mockSamplingSession struct { + mockSession + result *mcp.CreateMessageResult + err error +} + +func (m *mockSamplingSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestMCPServer_RequestSampling_Success(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.EnableSampling() + + // Create a mock sampling session + mockSession := &mockSamplingSession{ + mockSession: mockSession{sessionID: "test-session"}, + result: &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Test response", + }, + }, + Model: "test-model", + StopReason: "endTurn", + }, + } + + // Create context with session + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + {Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "Test"}}, + }, + MaxTokens: 100, + }, + } + + result, err := server.RequestSampling(ctx, request) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if result == nil { + t.Error("expected result, got nil") + return + } + + if result.Model != "test-model" { + t.Errorf("expected model %q, got %q", "test-model", result.Model) + } +} diff --git a/server/stdio.go b/server/stdio.go index 746a7d96f..33ac9bb88 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error } func (s *stdioSession) SessionID() string { @@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + var ( _ ClientSession = (*stdioSession)(nil) _ SessionWithLogging = (*stdioSession)(nil) _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -224,6 +308,9 @@ func (s *StdioServer) Listen( defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + // Add in any custom context. if s.contextFunc != nil { ctx = s.contextFunc(ctx) @@ -256,7 +343,29 @@ func (s *StdioServer) processMessage( return s.writeResponse(response, writer) } - // Handle the message using the wrapped server + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Process tool calls concurrently to avoid blocking on sampling requests + go func() { + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + }() + return nil + } + + // Handle other messages synchronously response := s.server.HandleMessage(ctx, rawMessage) // Only write response if there is one (not for notifications) @@ -269,6 +378,65 @@ func (s *StdioServer) processMessage( return nil } +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/server/streamable_http.go b/server/streamable_http.go index e9a011fb1..1312c9753 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption { // to StatelessSessionIdManager. func WithStateLess(stateLess bool) StreamableHTTPOption { return func(s *StreamableHTTPServer) { - s.sessionIdManager = &StatelessSessionIdManager{} + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } } } @@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index aad48fc3a..0e1c7a65b 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -529,8 +529,8 @@ func TestStreamableHTTP_GET(t *testing.T) { } defer resp.Body.Close() - if resp.StatusCode != http.StatusAccepted { - t.Errorf("Expected status 202, got %d", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("content-type") != "text/event-stream" { diff --git a/www/docs/pages/clients/advanced-sampling.mdx b/www/docs/pages/clients/advanced-sampling.mdx new file mode 100644 index 000000000..51172f16b --- /dev/null +++ b/www/docs/pages/clients/advanced-sampling.mdx @@ -0,0 +1,467 @@ +# Sampling + +Learn how to implement MCP clients that can handle sampling requests from servers, enabling bidirectional communication where clients provide LLM capabilities to servers. + +## Overview + +Sampling allows MCP clients to respond to LLM completion requests from servers. When a server needs to generate content, answer questions, or perform reasoning tasks, it can send a sampling request to the client, which then processes it using an LLM and returns the result. + +## Implementing a Sampling Handler + +Create a sampling handler by implementing the `SamplingHandler` interface: + +```go +package main + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +type MySamplingHandler struct { + // Add fields for your LLM client (OpenAI, Anthropic, etc.) +} + +func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract request parameters + messages := request.Messages + systemPrompt := request.SystemPrompt + maxTokens := request.MaxTokens + temperature := request.Temperature + + // Process with your LLM + response, err := h.callLLM(ctx, messages, systemPrompt, maxTokens, temperature) + if err != nil { + return nil, fmt.Errorf("LLM call failed: %w", err) + } + + // Return MCP-formatted result + return &mcp.CreateMessageResult{ + Model: "your-model-name", + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: response, + }, + StopReason: "endTurn", + }, nil +} + +func (h *MySamplingHandler) callLLM(ctx context.Context, messages []mcp.SamplingMessage, systemPrompt string, maxTokens int, temperature float64) (string, error) { + // Implement your LLM integration here + // This is where you'd call OpenAI, Anthropic, or other LLM APIs + return "Your LLM response here", nil +} +``` + +## Configuring the Client + +Enable sampling by providing a handler when creating the client: + +```go +func main() { + // Create sampling handler + samplingHandler := &MySamplingHandler{} + + // Create client with sampling support + mcpClient, err := client.NewStdioClient( + "/path/to/mcp/server", + client.WithSamplingHandler(samplingHandler), + ) + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + defer mcpClient.Close() + + // Connect to server + ctx := context.Background() + if err := mcpClient.Connect(ctx); err != nil { + log.Fatalf("Failed to connect: %v", err) + } + + // The client will now automatically handle sampling requests + // from the server using your handler +} +``` + +## Mock Implementation Example + +Here's a complete mock implementation for testing: + +```go +package main + +import ( +import ( + "context" + "fmt" + "log" + "strings" + "os" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Log the request for debugging + log.Printf("Mock LLM received sampling request:") + log.Printf(" System prompt: %s", request.SystemPrompt) + log.Printf(" Max tokens: %d", request.MaxTokens) + log.Printf(" Temperature: %f", request.Temperature) + + // Extract the user's message + var userMessage string + for _, msg := range request.Messages { + if msg.Role == mcp.RoleUser { + if textContent, ok := msg.Content.(mcp.TextContent); ok { + userMessage = textContent.Text + log.Printf(" User message: %s", userMessage) + break + } + } + } + + // Generate a mock response + mockResponse := fmt.Sprintf( + "Mock LLM response to: '%s'. This is a simulated response from a mock LLM handler.", + userMessage, + ) + + return &mcp.CreateMessageResult{ + Model: "mock-llm-v1", + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: mockResponse, + }, + StopReason: "endTurn", + }, nil +} + +func main() { + if len(os.Args) < 2 { + log.Fatal("Usage: sampling_client ") + } + + serverPath := os.Args[1] + + // Create client with mock sampling handler + mcpClient, err := client.NewStdioClient( + serverPath, + client.WithSamplingHandler(&MockSamplingHandler{}), + ) + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + defer mcpClient.Close() + + // Connect and test + ctx := context.Background() + if err := mcpClient.Connect(ctx); err != nil { + log.Fatalf("Failed to connect: %v", err) + } + + // Test server tools that use sampling + result, err := mcpClient.CallTool(ctx, "ask_llm", map[string]any{ + "question": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant.", + }) + if err != nil { + log.Fatalf("Tool call failed: %v", err) + } + + fmt.Printf("Tool result: %+v\\n", result) +} +``` + +## Real LLM Integration + +### OpenAI Integration + +```go +import ( + "github.com/sashabaranov/go-openai" +) + +type OpenAISamplingHandler struct { + client *openai.Client +} + +func NewOpenAISamplingHandler(apiKey string) *OpenAISamplingHandler { + return &OpenAISamplingHandler{ + client: openai.NewClient(apiKey), + } +} + +func (h *OpenAISamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP messages to OpenAI format + var messages []openai.ChatCompletionMessage + + // Add system message if provided + if request.SystemPrompt != "" { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, + Content: request.SystemPrompt, + }) + } + + // Convert MCP messages + for _, msg := range request.Messages { + var role string + switch msg.Role { + case mcp.RoleUser: + role = openai.ChatMessageRoleUser + case mcp.RoleAssistant: + role = openai.ChatMessageRoleAssistant + } + + if textContent, ok := msg.Content.(mcp.TextContent); ok { + messages = append(messages, openai.ChatCompletionMessage{ + Role: role, + Content: textContent.Text, + }) + } + } + + // Create OpenAI request + req := openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: messages, + MaxTokens: request.MaxTokens, + Temperature: float32(request.Temperature), + } + + // Call OpenAI API + resp, err := h.client.CreateChatCompletion(ctx, req) + if err != nil { + return nil, fmt.Errorf("OpenAI API call failed: %w", err) + } + + if len(resp.Choices) == 0 { + return nil, fmt.Errorf("no response from OpenAI") + } + + choice := resp.Choices[0] + + // Convert stop reason + var stopReason string + switch choice.FinishReason { + case "stop": + stopReason = "endTurn" + case "length": + stopReason = "maxTokens" + default: + stopReason = "other" + } + + return &mcp.CreateMessageResult{ + Model: resp.Model, + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: choice.Message.Content, + }, + StopReason: stopReason, + }, nil +} +``` + +### Anthropic Integration + +```go +import ( + "bytes" + "encoding/json" + "net/http" +) + +type AnthropicSamplingHandler struct { + apiKey string + client *http.Client +} + +func NewAnthropicSamplingHandler(apiKey string) *AnthropicSamplingHandler { + return &AnthropicSamplingHandler{ + apiKey: apiKey, + client: &http.Client{}, + } +} + +func (h *AnthropicSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert to Anthropic format + anthropicReq := map[string]any{ + "model": "claude-3-sonnet-20240229", + "max_tokens": request.MaxTokens, + "messages": h.convertMessages(request.Messages), + } + + if request.SystemPrompt != "" { + anthropicReq["system"] = request.SystemPrompt + } + + if request.Temperature > 0 { + anthropicReq["temperature"] = request.Temperature + } + + // Make API call + reqBody, _ := json.Marshal(anthropicReq) + httpReq, _ := http.NewRequestWithContext(ctx, "POST", + "https://api.anthropic.com/v1/messages", bytes.NewBuffer(reqBody)) + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("x-api-key", h.apiKey) + httpReq.Header.Set("anthropic-version", "2023-06-01") + + resp, err := h.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("Anthropic API call failed: %w", err) + } + defer resp.Body.Close() + + var anthropicResp struct { + Content []struct { + Text string `json:"text"` + Type string `json:"type"` + } `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + } + + if err := json.NewDecoder(resp.Body).Decode(&anthropicResp); err != nil { + return nil, fmt.Errorf("failed to decode Anthropic response: %w", err) + } + + // Extract text content + var text string + for _, content := range anthropicResp.Content { + if content.Type == "text" { + text += content.Text + } + } + + return &mcp.CreateMessageResult{ + Model: anthropicResp.Model, + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: text, + }, + StopReason: anthropicResp.StopReason, + }, nil +} + +func (h *AnthropicSamplingHandler) convertMessages(messages []mcp.SamplingMessage) []map[string]any { + var result []map[string]any + for _, msg := range messages { + if textContent, ok := msg.Content.(mcp.TextContent); ok { + result = append(result, map[string]any{ + "role": string(msg.Role), + "content": textContent.Text, + }) + } + } + return result +} +``` + +## Automatic Capability Declaration + +When you provide a sampling handler, the client automatically declares the sampling capability during initialization: + +```go +// This automatically adds sampling capability +mcpClient, err := client.NewStdioClient( + serverPath, + client.WithSamplingHandler(handler), // Enables sampling capability +) +``` + +The client will include this in the initialization request: + +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "sampling": {} + }, + "clientInfo": { + "name": "your-client", + "version": "1.0.0" + } + } +} +``` + +## Error Handling + +Handle errors gracefully in your sampling handler: + +```go +func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Validate request + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + // Check for context cancellation + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("request cancelled: %w", err) + } + + // Call LLM with error handling + response, err := h.callLLM(ctx, request) + if err != nil { + // Log error for debugging + log.Printf("LLM call failed: %v", err) + + // Return appropriate error + if strings.Contains(err.Error(), "rate limit") { + return nil, fmt.Errorf("rate limit exceeded, please try again later") + } + return nil, fmt.Errorf("LLM service unavailable: %w", err) + } + + return response, nil +} +``` + +## Best Practices + +1. **Implement Proper Error Handling**: Always handle LLM API errors gracefully +2. **Respect Rate Limits**: Implement rate limiting and backoff strategies +3. **Validate Inputs**: Check message content and parameters before processing +4. **Use Context**: Respect context cancellation and timeouts +5. **Log Appropriately**: Log requests for debugging but avoid logging sensitive data +6. **Model Selection**: Allow configuration of which LLM model to use +7. **Content Filtering**: Implement content filtering if required by your use case + +## Testing Your Implementation + +Test your sampling handler with the sampling server example: + +```bash +# Build the sampling server +cd examples/sampling_server +go build -o sampling_server + +# Build your client +go build -o my_client + +# Test the integration +./my_client ./sampling_server +``` + +## Next Steps + +- Learn about [server-side sampling implementation](/servers/advanced-sampling) +- Explore [client operations](/clients/operations) +- Check out the [sampling examples](https://github.com/mark3labs/mcp-go/tree/main/examples/sampling_client) \ No newline at end of file diff --git a/www/docs/pages/clients/operations.mdx b/www/docs/pages/clients/operations.mdx index 3cea1e3ec..8734486d1 100644 --- a/www/docs/pages/clients/operations.mdx +++ b/www/docs/pages/clients/operations.mdx @@ -909,6 +909,52 @@ func demonstrateSubscriptionManager(c client.Client) { } ``` +## Advanced: Sampling Support + +Sampling is an advanced feature that allows clients to respond to LLM completion requests from servers. This enables servers to leverage client-side LLM capabilities for content generation and reasoning. + +> **Note**: Sampling is an advanced feature that most clients don't need. Only implement sampling if you're building a client that provides LLM capabilities to servers. + +### When to Implement Sampling + +Consider implementing sampling when your client: +- Has access to LLM APIs (OpenAI, Anthropic, etc.) +- Wants to provide LLM capabilities to servers +- Needs to support servers that generate dynamic content + +### Basic Implementation + +```go +import "github.com/mark3labs/mcp-go/client" + +// Implement the SamplingHandler interface +type MySamplingHandler struct { + // Add your LLM client here +} + +func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Process the request with your LLM + // Return the result in MCP format + return &mcp.CreateMessageResult{ + Model: "your-model", + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Your LLM response here", + }, + StopReason: "endTurn", + }, nil +} + +// Create client with sampling support +mcpClient, err := client.NewStdioClient( + "/path/to/server", + client.WithSamplingHandler(&MySamplingHandler{}), +) +``` + +For complete sampling documentation, see **[Client Sampling Guide](/clients/advanced-sampling)**. + ## Next Steps - **[Client Transports](/clients/transports)** - Learn transport-specific client features diff --git a/www/docs/pages/clients/transports.mdx b/www/docs/pages/clients/transports.mdx index efef67cf8..af25fb65a 100644 --- a/www/docs/pages/clients/transports.mdx +++ b/www/docs/pages/clients/transports.mdx @@ -389,6 +389,30 @@ func (pool *StreamableHTTPClientPool) CallTool(ctx context.Context, req mcp.Call } ``` +### StreamableHTTP With Preconfigured Session +You can also create a StreamableHTTP client with a preconfigured session, which allows you to reuse the same session across multiple requests + +```go +func createStreamableHTTPClientWithSession() { + // Create StreamableHTTP client with options + sessionID := // fetch existing session ID + c := client.NewStreamableHttpClient("https://api.example.com/mcp", + transport.WithSession(sessionID), + ) + defer c.Close() + + ctx := context.Background() + // Use client... + _, err := c.ListTools(ctx) + // If the session is terminated, you must reinitialize the client + if errors.Is(err, transport.ErrSessionTerminated) { + c.Initialize(ctx) // Reinitialize if session is terminated + // The session ID should change after reinitialization + sessionID = c.GetSessionId() // Update session ID + } +} +``` + ## SSE Client SSE (Server-Sent Events) clients provide real-time communication with servers. diff --git a/www/docs/pages/servers/advanced-sampling.mdx b/www/docs/pages/servers/advanced-sampling.mdx new file mode 100644 index 000000000..4f10ab933 --- /dev/null +++ b/www/docs/pages/servers/advanced-sampling.mdx @@ -0,0 +1,359 @@ +# Sampling + +Learn how to implement MCP servers that can request LLM completions from clients using the sampling capability. + +## Overview + +Sampling allows MCP servers to request LLM completions from clients, enabling bidirectional communication where servers can leverage client-side LLM capabilities. This is particularly useful for tools that need to generate content, answer questions, or perform reasoning tasks. + +## Enabling Sampling + +To enable sampling in your server, call `EnableSampling()` during server setup: + +```go +package main + +import ( + "context" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create server + mcpServer := server.NewMCPServer("my-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add tools that use sampling... + + // Start server + server.ServeStdio(mcpServer) +} +``` + +## Requesting Sampling + +Use `RequestSampling()` within tool handlers to request LLM completions: + +```go +mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt", + }, + }, + Required: []string{"question"}, + }, +}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response: %s", getTextFromContent(result.Content)), + }, + }, + }, nil +}) +``` + +## Sampling Request Parameters + +The `CreateMessageRequest` supports various parameters to control LLM behavior: + +```go +samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + // Required: Messages to send to the LLM + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, // or mcp.RoleAssistant + Content: mcp.TextContent{ // or mcp.ImageContent + Type: "text", + Text: "Your message here", + }, + }, + }, + + // Optional: System prompt for context + SystemPrompt: "You are a helpful assistant.", + + // Optional: Maximum tokens to generate + MaxTokens: 1000, + + // Optional: Temperature for randomness (0.0 to 1.0) + Temperature: 0.7, + + // Optional: Top-p sampling parameter + TopP: 0.9, + + // Optional: Stop sequences + StopSequences: []string{"\\n\\n"}, + }, +} +``` + +## Message Types + +Sampling supports different message roles and content types: + +### Message Roles + +```go +// User message +{ + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "What is the capital of France?", + }, +} + +// Assistant message (for conversation context) +{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "The capital of France is Paris.", + }, +} +``` + +### Content Types + +#### Text Content + +```go +mcp.TextContent{ + Type: "text", + Text: "Your text content here", +} +``` + +#### Image Content + +```go +mcp.ImageContent{ + Type: "image", + Data: "base64-encoded-image-data", + MimeType: "image/jpeg", +} +``` + +## Error Handling + +Always handle sampling errors gracefully: + +```go +result, err := mcpServer.RequestSampling(ctx, samplingRequest) +if err != nil { + // Log the error + log.Printf("Sampling request failed: %v", err) + + // Return appropriate error response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Sorry, I couldn't process your request at this time.", + }, + }, + IsError: true, + }, nil +} +``` + +## Context and Timeouts + +Use context for timeout control: + +```go +// Set a timeout for the sampling request +ctx, cancel := context.WithTimeout(ctx, 30*time.Second) +defer cancel() + +result, err := mcpServer.RequestSampling(ctx, samplingRequest) +``` + +## Best Practices + +1. **Enable Sampling Early**: Call `EnableSampling()` during server initialization +2. **Handle Timeouts**: Set appropriate timeouts for sampling requests +3. **Graceful Errors**: Always provide meaningful error messages to users +4. **Content Extraction**: Use helper functions to extract text from responses +5. **System Prompts**: Use clear system prompts to guide LLM behavior +6. **Parameter Validation**: Validate tool parameters before making sampling requests + +## Complete Example + +Here's a complete example server with sampling: + +```go +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create server + mcpServer := server.NewMCPServer("sampling-example-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add sampling tool + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + result, err := mcpServer.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", + result.Model, getTextFromContent(result.Content)), + }, + }, + }, nil + }) + + // Start server + log.Println("Starting sampling example server...") + if err := server.ServeStdio(mcpServer); err != nil { + log.Fatalf("Server error: %v", err) + } +} + +// Helper function to extract text from content +func getTextFromContent(content interface{}) string { + switch c := content.(type) { + case mcp.TextContent: + return c.Text + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} +``` + +## Next Steps + +- Learn about [client-side sampling implementation](/clients/advanced-sampling) +- Explore [advanced server features](/servers/advanced) +- Check out the [sampling examples](https://github.com/mark3labs/mcp-go/tree/main/examples/sampling_server) \ No newline at end of file diff --git a/www/docs/pages/servers/advanced.mdx b/www/docs/pages/servers/advanced.mdx index d2e3b4a8f..5751c8220 100644 --- a/www/docs/pages/servers/advanced.mdx +++ b/www/docs/pages/servers/advanced.mdx @@ -821,6 +821,91 @@ func startWithGracefulShutdown(s *server.MCPServer) { } ``` +## Sampling (Advanced) + +Sampling is an advanced feature that allows servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities for content generation, reasoning, and question answering. + +> **Note**: Sampling is an advanced feature that most servers don't need. Only implement sampling if your server specifically needs to generate content using the client's LLM. + +### When to Use Sampling + +Consider sampling when your server needs to: +- Generate content based on user input +- Answer questions using LLM reasoning +- Perform text analysis or summarization +- Create dynamic responses that require LLM capabilities + +### Basic Implementation + +```go +// Enable sampling capability +mcpServer.EnableSampling() + +// Add a tool that uses sampling +mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + }, + Required: []string{"question"}, + }, +}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: "You are a helpful assistant.", + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error: %v", err), + }, + }, + IsError: true, + }, nil + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response: %s", result.Content), + }, + }, + }, nil +}) +``` + +For complete sampling documentation, see **[Server Sampling Guide](/servers/advanced-sampling)**. + ## Next Steps - **[Client Development](/clients)** - Learn to build MCP clients diff --git a/www/docs/pages/servers/index.mdx b/www/docs/pages/servers/index.mdx index e0249172a..d0983d13a 100644 --- a/www/docs/pages/servers/index.mdx +++ b/www/docs/pages/servers/index.mdx @@ -12,7 +12,7 @@ MCP servers expose tools, resources, and prompts to LLM clients. MCP-Go makes it - **[Resources](/servers/resources)** - Exposing data to LLMs - **[Tools](/servers/tools)** - Providing functionality to LLMs - **[Prompts](/servers/prompts)** - Creating reusable interaction templates -- **[Advanced Features](/servers/advanced)** - Typed tools, middleware, hooks, and more +- **[Advanced Features](/servers/advanced)** - Typed tools, middleware, hooks, sampling, and more ## Quick Example diff --git a/www/docs/pages/transports/http.mdx b/www/docs/pages/transports/http.mdx index 1d0430d6a..7dc38a857 100644 --- a/www/docs/pages/transports/http.mdx +++ b/www/docs/pages/transports/http.mdx @@ -43,7 +43,7 @@ import ( func main() { s := server.NewMCPServer("StreamableHTTP API Server", "1.0.0", server.WithToolCapabilities(true), - server.WithResourceCapabilities(true), + server.WithResourceCapabilities(true, true), ) // Add RESTful tools @@ -60,7 +60,7 @@ func main() { mcp.WithDescription("Create a new user"), mcp.WithString("name", mcp.Required()), mcp.WithString("email", mcp.Required()), - mcp.WithInteger("age", mcp.Minimum(0)), + mcp.WithNumber("age", mcp.Min(0)), ), handleCreateUser, ) @@ -69,8 +69,8 @@ func main() { mcp.NewTool("search_users", mcp.WithDescription("Search users with filters"), mcp.WithString("query", mcp.Description("Search query")), - mcp.WithInteger("limit", mcp.Default(10), mcp.Maximum(100)), - mcp.WithInteger("offset", mcp.Default(0), mcp.Minimum(0)), + mcp.WithNumber("limit", mcp.DefaultNumber(10), mcp.Max(100)), + mcp.WithNumber("offset", mcp.DefaultNumber(0), mcp.Min(0)), ), handleSearchUsers, ) @@ -95,7 +95,10 @@ func main() { } func handleGetUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID := req.Params.Arguments["user_id"].(string) + userID := req.GetString("user_id", "") + if userID == "" { + return nil, fmt.Errorf("user_id is required") + } // Simulate database lookup user, err := getUserFromDB(userID) @@ -103,13 +106,18 @@ func handleGetUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolR return nil, fmt.Errorf("user not found: %s", userID) } - return mcp.NewToolResultJSON(user), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"id":"%s","name":"%s","email":"%s","age":%d}`, + user.ID, user.Name, user.Email, user.Age)), nil } func handleCreateUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := req.Params.Arguments["name"].(string) - email := req.Params.Arguments["email"].(string) - age := int(req.Params.Arguments["age"].(float64)) + name := req.GetString("name", "") + email := req.GetString("email", "") + age := req.GetInt("age", 0) + + if name == "" || email == "" { + return nil, fmt.Errorf("name and email are required") + } // Validate input if !isValidEmail(email) { @@ -129,11 +137,8 @@ func handleCreateUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo return nil, fmt.Errorf("failed to create user: %w", err) } - return mcp.NewToolResultJSON(map[string]interface{}{ - "id": user.ID, - "message": "User created successfully", - "user": user, - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"id":"%s","message":"User created successfully","user":{"id":"%s","name":"%s","email":"%s","age":%d}}`, + user.ID, user.ID, user.Name, user.Email, user.Age)), nil } // Helper functions and types for the examples @@ -156,7 +161,6 @@ func getUserFromDB(userID string) (*User, error) { } func isValidEmail(email string) bool { - // Simple email validation return strings.Contains(email, "@") && strings.Contains(email, ".") } @@ -171,9 +175,9 @@ func saveUserToDB(user *User) error { } func handleSearchUsers(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := getStringParam(req.Params.Arguments, "query", "") - limit := int(getFloatParam(req.Params.Arguments, "limit", 10)) - offset := int(getFloatParam(req.Params.Arguments, "offset", 0)) + query := req.GetString("query", "") + limit := req.GetInt("limit", 10) + offset := req.GetInt("offset", 0) // Search users with pagination users, total, err := searchUsersInDB(query, limit, offset) @@ -181,16 +185,11 @@ func handleSearchUsers(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT return nil, fmt.Errorf("search failed: %w", err) } - return mcp.NewToolResultJSON(map[string]interface{}{ - "users": users, - "total": total, - "limit": limit, - "offset": offset, - "query": query, - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"users":[{"id":"1","name":"John Doe","email":"john@example.com","age":30},{"id":"2","name":"Jane Smith","email":"jane@example.com","age":25}],"total":%d,"limit":%d,"offset":%d,"query":"%s"}`, + total, limit, offset, query)), nil } -func handleUserResource(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { +func handleUserResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { userID := extractUserIDFromURI(req.Params.URI) user, err := getUserFromDB(userID) @@ -198,27 +197,16 @@ func handleUserResource(ctx context.Context, req mcp.ReadResourceRequest) (*mcp. return nil, fmt.Errorf("user not found: %s", userID) } - return mcp.NewResourceResultJSON(user), nil -} - -// Additional helper functions for parameter handling -func getStringParam(args map[string]interface{}, key, defaultValue string) string { - if val, ok := args[key]; ok && val != nil { - if str, ok := val.(string); ok { - return str - } - } - return defaultValue + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "application/json", + Text: fmt.Sprintf(`{"id":"%s","name":"%s","email":"%s","age":%d}`, user.ID, user.Name, user.Email, user.Age), + }, + }, nil } -func getFloatParam(args map[string]interface{}, key string, defaultValue float64) float64 { - if val, ok := args[key]; ok && val != nil { - if f, ok := val.(float64); ok { - return f - } - } - return defaultValue -} +// Additional helper functions func searchUsersInDB(query string, limit, offset int) ([]*User, int, error) { // Placeholder implementation @@ -231,9 +219,8 @@ func searchUsersInDB(query string, limit, offset int) ([]*User, int, error) { func extractUserIDFromURI(uri string) string { // Extract user ID from URI like "users://123" - parts := strings.Split(uri, "://") - if len(parts) > 1 { - return parts[1] + if len(uri) > 8 && uri[:8] == "users://" { + return uri[8:] } return uri } @@ -244,39 +231,24 @@ func extractUserIDFromURI(uri string) string { ```go func main() { s := server.NewMCPServer("Advanced StreamableHTTP Server", "1.0.0", - server.WithAllCapabilities(), - server.WithRecovery(), - server.WithHooks(&server.Hooks{ - OnToolCall: logToolCall, - OnResourceRead: logResourceRead, - }), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), ) - // Configure StreamableHTTP-specific options - streamableHTTPOptions := server.StreamableHTTPOptions{ - BasePath: "/api/v1/mcp", - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, - MaxBodySize: 10 * 1024 * 1024, // 10MB - EnableCORS: true, - AllowedOrigins: []string{"https://myapp.com", "http://localhost:3000"}, - AllowedMethods: []string{"GET", "POST", "OPTIONS"}, - AllowedHeaders: []string{"Content-Type", "Authorization"}, - EnableGzip: true, - TrustedProxies: []string{"10.0.0.0/8", "172.16.0.0/12"}, - } - - // Add middleware - addStreamableHTTPMiddleware(s) - - // Add comprehensive tools + // Add comprehensive tools and resources addCRUDTools(s) addBatchTools(s) addAnalyticsTools(s) log.Println("Starting advanced StreamableHTTP server on :8080") - httpServer := server.NewStreamableHTTPServer(s, streamableHTTPOptions...) + httpServer := server.NewStreamableHTTPServer(s, + server.WithEndpointPath("/api/v1/mcp"), + server.WithHeartbeatInterval(30*time.Second), + server.WithStateLess(false), + ) + if err := httpServer.Start(":8080"); err != nil { log.Fatal(err) } diff --git a/www/docs/pages/transports/inprocess.mdx b/www/docs/pages/transports/inprocess.mdx index dce982357..bfdb1536b 100644 --- a/www/docs/pages/transports/inprocess.mdx +++ b/www/docs/pages/transports/inprocess.mdx @@ -56,19 +56,34 @@ func main() { ) // Create in-process client - client := client.NewInProcessClient(s) - defer client.Close() + mcpClient, err := client.NewInProcessClient(s) + if err != nil { + log.Fatal(err) + } + defer mcpClient.Close() ctx := context.Background() // Initialize - if err := client.Initialize(ctx); err != nil { + _, err = mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeRequestParams{ + ProtocolVersion: "2024-11-05", + Capabilities: mcp.ClientCapabilities{ + Tools: &mcp.ToolsCapability{}, + }, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + }, + }) + if err != nil { log.Fatal(err) } // Use the calculator - result, err := client.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ + result, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ Name: "calculate", Arguments: map[string]interface{}{ "operation": "add", @@ -81,13 +96,18 @@ func main() { log.Fatal(err) } - fmt.Printf("Result: %s\n", result.Content[0].Text) + // Extract text from the first content item + if len(result.Content) > 0 { + if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { + fmt.Printf("Result: %s\n", textContent.Text) + } + } } func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - operation := req.Params.Arguments["operation"].(string) - x := req.Params.Arguments["x"].(float64) - y := req.Params.Arguments["y"].(float64) + operation := req.GetString("operation", "") + x := req.GetFloat("x", 0) + y := req.GetFloat("y", 0) var result float64 switch operation { @@ -134,7 +154,11 @@ func NewApplication(config *Config) *Application { app.addApplicationTools() // Create in-process client for internal use - app.mcpClient = client.NewInProcessClient(app.mcpServer) + var err error + app.mcpClient, err = client.NewInProcessClient(app.mcpServer) + if err != nil { + panic(err) + } return app } @@ -151,12 +175,8 @@ func (app *Application) addApplicationTools() { mcp.WithDescription("Get current application status"), ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - status := map[string]interface{}{ - "app_name": app.config.AppName, - "debug": app.config.Debug, - "status": "running", - } - return mcp.NewToolResultJSON(status), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"app_name":"%s","debug":%t,"status":"running"}`, + app.config.AppName, app.config.Debug)), nil }, ) @@ -168,8 +188,8 @@ func (app *Application) addApplicationTools() { mcp.WithString("value", mcp.Required()), ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - key := req.Params.Arguments["key"].(string) - value := req.Params.Arguments["value"].(string) + key := req.GetString("key", "") + value := req.GetString("value", "") // Update configuration based on key switch key { @@ -189,7 +209,7 @@ func (app *Application) addApplicationTools() { func (app *Application) ProcessWithMCP(ctx context.Context, operation string) (interface{}, error) { // Use MCP tools internally for processing result, err := app.mcpClient.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ + Params: mcp.CallToolParams{ Name: "calculate", Arguments: map[string]interface{}{ "operation": operation, @@ -202,7 +222,14 @@ func (app *Application) ProcessWithMCP(ctx context.Context, operation string) (i return nil, err } - return result.Content[0].Text, nil + // Extract text from the first content item + if len(result.Content) > 0 { + if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { + return textContent.Text, nil + } + } + + return "no result", nil } // Usage example @@ -216,7 +243,19 @@ func main() { ctx := context.Background() // Initialize the embedded MCP client - if err := app.mcpClient.Initialize(ctx); err != nil { + _, err := app.mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeRequestParams{ + ProtocolVersion: "2024-11-05", + Capabilities: mcp.ClientCapabilities{ + Tools: &mcp.ToolsCapability{}, + }, + ClientInfo: mcp.Implementation{ + Name: "embedded-client", + Version: "1.0.0", + }, + }, + }) + if err != nil { log.Fatal(err) } diff --git a/www/docs/pages/transports/sse.mdx b/www/docs/pages/transports/sse.mdx index 930bb0eba..f09104ec9 100644 --- a/www/docs/pages/transports/sse.mdx +++ b/www/docs/pages/transports/sse.mdx @@ -39,7 +39,7 @@ import ( func main() { s := server.NewMCPServer("SSE Server", "1.0.0", server.WithToolCapabilities(true), - server.WithResourceCapabilities(true), + server.WithResourceCapabilities(true, true), ) // Add real-time tools @@ -47,7 +47,7 @@ func main() { mcp.NewTool("stream_data", mcp.WithDescription("Stream data with real-time updates"), mcp.WithString("source", mcp.Required()), - mcp.WithInteger("count", mcp.Default(10)), + mcp.WithNumber("count", mcp.DefaultNumber(10)), ), handleStreamData, ) @@ -55,7 +55,7 @@ func main() { s.AddTool( mcp.NewTool("monitor_system", mcp.WithDescription("Monitor system metrics in real-time"), - mcp.WithInteger("duration", mcp.Default(60)), + mcp.WithNumber("duration", mcp.DefaultNumber(60)), ), handleSystemMonitor, ) @@ -73,19 +73,18 @@ func main() { // Start SSE server log.Println("Starting SSE server on :8080") - if err := server.ServeSSE(s, ":8080"); err != nil { + sseServer := server.NewSSEServer(s) + if err := sseServer.Start(":8080"); err != nil { log.Fatal(err) } } func handleStreamData(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - source := req.Params.Arguments["source"].(string) - count := int(req.Params.Arguments["count"].(float64)) + source := req.GetString("source", "") + count := req.GetInt("count", 10) - // Get notifier for real-time updates (hypothetical functions) - // Note: These functions would be provided by the SSE transport implementation - notifier := getNotifierFromContext(ctx) // Hypothetical function - sessionID := getSessionIDFromContext(ctx) // Hypothetical function + // Get server from context for notifications + mcpServer := server.ServerFromContext(ctx) // Stream data with progress updates var results []map[string]interface{} @@ -102,23 +101,22 @@ func handleStreamData(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo results = append(results, data) // Send progress notification - if notifier != nil { - // Note: ProgressNotification would be defined by the MCP protocol - notifier.SendProgress(sessionID, map[string]interface{}{ + if mcpServer != nil { + err := mcpServer.SendNotificationToClient(ctx, "notifications/progress", map[string]interface{}{ "progress": i + 1, "total": count, "message": fmt.Sprintf("Processed %d/%d items from %s", i+1, count, source), }) + if err != nil { + log.Printf("Failed to send notification: %v", err) + } } time.Sleep(100 * time.Millisecond) } - return mcp.NewToolResultJSON(map[string]interface{}{ - "source": source, - "results": results, - "count": len(results), - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"source":"%s","results":%v,"count":%d}`, + source, results, len(results))), nil } // Helper functions for the examples @@ -130,21 +128,10 @@ func generateData(source string, index int) map[string]interface{} { } } -func getNotifierFromContext(ctx context.Context) interface{} { - // Placeholder implementation - would be provided by SSE transport - return nil -} - -func getSessionIDFromContext(ctx context.Context) string { - // Placeholder implementation - would be provided by SSE transport - return "session_123" -} - func handleSystemMonitor(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - duration := int(req.Params.Arguments["duration"].(float64)) + duration := req.GetInt("duration", 60) - notifier := getNotifierFromContext(ctx) - sessionID := getSessionIDFromContext(ctx) + mcpServer := server.ServerFromContext(ctx) // Monitor system for specified duration ticker := time.NewTicker(5 * time.Second) @@ -158,20 +145,19 @@ func handleSystemMonitor(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal case <-ctx.Done(): return nil, ctx.Err() case <-timeout: - return mcp.NewToolResultJSON(map[string]interface{}{ - "duration": duration, - "metrics": metrics, - "samples": len(metrics), - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"duration":%d,"metrics":%v,"samples":%d}`, + duration, metrics, len(metrics))), nil case <-ticker.C: // Collect current metrics currentMetrics := collectSystemMetrics() metrics = append(metrics, currentMetrics) // Send real-time update - if notifier != nil { - // Note: SendCustom would be a method on the notifier interface - // notifier.SendCustom(sessionID, "system_metrics", currentMetrics) + if mcpServer != nil { + err := mcpServer.SendNotificationToClient(ctx, "system_metrics", currentMetrics) + if err != nil { + log.Printf("Failed to send system metrics notification: %v", err) + } } } } @@ -186,9 +172,15 @@ func collectSystemMetrics() map[string]interface{} { } } -func handleCurrentMetrics(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { +func handleCurrentMetrics(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { metrics := collectSystemMetrics() - return mcp.NewResourceResultJSON(metrics), nil + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "application/json", + Text: fmt.Sprintf(`{"cpu":%.1f,"memory":%.1f,"disk":%.1f}`, metrics["cpu"], metrics["memory"], metrics["disk"]), + }, + }, nil } ``` @@ -197,50 +189,29 @@ func handleCurrentMetrics(ctx context.Context, req mcp.ReadResourceRequest) (*mc ```go func main() { s := server.NewMCPServer("Advanced SSE Server", "1.0.0", - server.WithAllCapabilities(), - server.WithRecovery(), - server.WithHooks(&server.Hooks{ - OnSessionStart: func(sessionID string) { - log.Printf("SSE client connected: %s", sessionID) - broadcastUserCount() - }, - OnSessionEnd: func(sessionID string) { - log.Printf("SSE client disconnected: %s", sessionID) - broadcastUserCount() - }, - }), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), ) - // Configure SSE-specific options - sseOptions := server.SSEOptions{ - BasePath: "/mcp", - AllowedOrigins: []string{"http://localhost:3000", "https://myapp.com"}, - HeartbeatInterval: 30 * time.Second, - MaxConnections: 100, - ConnectionTimeout: 5 * time.Minute, - EnableCompression: true, - } - // Add collaborative tools addCollaborativeTools(s) addRealTimeResources(s) log.Println("Starting advanced SSE server on :8080") - if err := server.ServeSSEWithOptions(s, ":8080", sseOptions); err != nil { + sseServer := server.NewSSEServer(s, + server.WithStaticBasePath("/mcp"), + server.WithKeepAliveInterval(30*time.Second), + server.WithBaseURL("http://localhost:8080"), + ) + + if err := sseServer.Start(":8080"); err != nil { log.Fatal(err) } } // Helper functions for the advanced example -func broadcastUserCount() { - // Placeholder implementation - log.Println("Broadcasting user count update") -} - -func addCollaborativeToolsPlaceholder(s *server.MCPServer) { - // Placeholder implementation - would add collaborative tools -} - func addRealTimeResources(s *server.MCPServer) { // Placeholder implementation - would add real-time resources } @@ -262,7 +233,7 @@ func addCollaborativeTools(s *server.MCPServer) { mcp.NewTool("send_message", mcp.WithDescription("Send a message to all connected clients"), mcp.WithString("message", mcp.Required()), - mcp.WithString("channel", mcp.Default("general")), + mcp.WithString("channel", mcp.DefaultString("general")), ), handleSendMessage, ) @@ -281,240 +252,71 @@ func addCollaborativeTools(s *server.MCPServer) { ## Configuration -### Base URLs and Paths - -```go -// Custom SSE endpoint configuration -sseOptions := server.SSEOptions{ - BasePath: "/api/mcp", // SSE endpoint will be /api/mcp/sse - - // Additional HTTP endpoints - HealthPath: "/api/health", - MetricsPath: "/api/metrics", - StatusPath: "/api/status", -} - -// Start server with custom paths -server.ServeSSEWithOptions(s, ":8080", sseOptions) -``` - -**Resulting endpoints:** -- SSE stream: `http://localhost:8080/api/mcp/sse` -- Health check: `http://localhost:8080/api/health` -- Metrics: `http://localhost:8080/api/metrics` -- Status: `http://localhost:8080/api/status` - -### CORS Configuration - -```go -sseOptions := server.SSEOptions{ - // Allow specific origins - AllowedOrigins: []string{ - "http://localhost:3000", - "https://myapp.com", - "https://*.myapp.com", - }, - - // Allow all origins (development only) - AllowAllOrigins: true, - - // Custom CORS headers - AllowedHeaders: []string{ - "Authorization", - "Content-Type", - "X-API-Key", - }, - - // Allow credentials - AllowCredentials: true, -} -``` +### SSE Server Options -### Connection Management +The SSE server can be configured with various options: ```go -sseOptions := server.SSEOptions{ - // Connection limits - MaxConnections: 100, - MaxConnectionsPerIP: 10, +sseServer := server.NewSSEServer(s, + // Set the base path for SSE endpoints + server.WithStaticBasePath("/api/mcp"), - // Timeouts - ConnectionTimeout: 5 * time.Minute, - WriteTimeout: 30 * time.Second, - ReadTimeout: 30 * time.Second, + // Configure keep-alive interval + server.WithKeepAliveInterval(30*time.Second), - // Heartbeat to keep connections alive - HeartbeatInterval: 30 * time.Second, + // Set base URL for client connections + server.WithBaseURL("http://localhost:8080"), - // Buffer sizes - WriteBufferSize: 4096, - ReadBufferSize: 4096, + // Configure SSE and message endpoints + server.WithSSEEndpoint("/sse"), + server.WithMessageEndpoint("/message"), - // Compression - EnableCompression: true, - CompressionLevel: 6, -} + // Add context function for request processing + server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { + // Add custom context values from headers + return ctx + }), +) ``` -## Session Handling - -### Multi-Client State Management - -```go -type SessionManager struct { - sessions map[string]*ClientSession - mutex sync.RWMutex - notifier *SSENotifier -} - -type ClientSession struct { - ID string - UserID string - ConnectedAt time.Time - LastSeen time.Time - Subscriptions map[string]bool - Metadata map[string]interface{} -} - -func NewSessionManager() *SessionManager { - return &SessionManager{ - sessions: make(map[string]*ClientSession), - notifier: NewSSENotifier(), - } -} +**Resulting endpoints:** +- SSE stream: `http://localhost:8080/api/mcp/sse` +- Message endpoint: `http://localhost:8080/api/mcp/message` -func (sm *SessionManager) OnSessionStart(sessionID string) { - sm.mutex.Lock() - defer sm.mutex.Unlock() - - session := &ClientSession{ - ID: sessionID, - ConnectedAt: time.Now(), - LastSeen: time.Now(), - Subscriptions: make(map[string]bool), - Metadata: make(map[string]interface{}), - } - - sm.sessions[sessionID] = session - - // Notify other clients - sm.notifier.BroadcastExcept(sessionID, "user_joined", map[string]interface{}{ - "session_id": sessionID, - "timestamp": time.Now().Unix(), - }) -} +## Real-Time Notifications -func (sm *SessionManager) OnSessionEnd(sessionID string) { - sm.mutex.Lock() - defer sm.mutex.Unlock() - - delete(sm.sessions, sessionID) - - // Notify other clients - sm.notifier.Broadcast("user_left", map[string]interface{}{ - "session_id": sessionID, - "timestamp": time.Now().Unix(), - }) -} +SSE transport enables real-time server-to-client communication through notifications. Use the server context to send notifications: -func (sm *SessionManager) GetActiveSessions() []ClientSession { - sm.mutex.RLock() - defer sm.mutex.RUnlock() +```go +func handleRealtimeTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get the MCP server from context + mcpServer := server.ServerFromContext(ctx) - var sessions []ClientSession - for _, session := range sm.sessions { - sessions = append(sessions, *session) + // Send a notification to the client + if mcpServer != nil { + err := mcpServer.SendNotificationToClient(ctx, "custom_event", map[string]interface{}{ + "message": "Real-time update", + "timestamp": time.Now().Unix(), + }) + if err != nil { + log.Printf("Failed to send notification: %v", err) + } } - return sessions + return mcp.NewToolResultText(`{"status":"notification_sent"}`), nil } ``` -### Real-Time Notifications - -```go -type SSENotifier struct { - clients map[string]chan mcp.Notification - mutex sync.RWMutex -} +### Session Management -func NewSSENotifier() *SSENotifier { - return &SSENotifier{ - clients: make(map[string]chan mcp.Notification), - } -} - -func (n *SSENotifier) RegisterClient(sessionID string) <-chan mcp.Notification { - n.mutex.Lock() - defer n.mutex.Unlock() - - ch := make(chan mcp.Notification, 100) - n.clients[sessionID] = ch - return ch -} - -func (n *SSENotifier) UnregisterClient(sessionID string) { - n.mutex.Lock() - defer n.mutex.Unlock() - - if ch, exists := n.clients[sessionID]; exists { - close(ch) - delete(n.clients, sessionID) - } -} - -func (n *SSENotifier) SendToClient(sessionID string, notification mcp.Notification) { - n.mutex.RLock() - defer n.mutex.RUnlock() - - if ch, exists := n.clients[sessionID]; exists { - select { - case ch <- notification: - default: - // Channel full, drop notification - } - } -} +The SSE server automatically handles session management. You can send events to specific sessions using the server's notification methods: -func (n *SSENotifier) Broadcast(eventType string, data interface{}) { - notification := mcp.Notification{ - Type: eventType, - Data: data, - } - - n.mutex.RLock() - defer n.mutex.RUnlock() - - for _, ch := range n.clients { - select { - case ch <- notification: - default: - // Channel full, skip this client - } - } -} +```go +// Send notification to current client session +mcpServer.SendNotificationToClient(ctx, "progress_update", progressData) -func (n *SSENotifier) BroadcastExcept(excludeSessionID, eventType string, data interface{}) { - notification := mcp.Notification{ - Type: eventType, - Data: data, - } - - n.mutex.RLock() - defer n.mutex.RUnlock() - - for sessionID, ch := range n.clients { - if sessionID == excludeSessionID { - continue - } - - select { - case ch <- notification: - default: - // Channel full, skip this client - } - } -} +// Send notification to all connected clients (if supported) +// Note: Check the server implementation for broadcast capabilities ``` ## Next Steps diff --git a/www/docs/pages/transports/stdio.mdx b/www/docs/pages/transports/stdio.mdx index 6dea7adcf..3f609690f 100644 --- a/www/docs/pages/transports/stdio.mdx +++ b/www/docs/pages/transports/stdio.mdx @@ -40,7 +40,7 @@ import ( func main() { s := server.NewMCPServer("File Tools", "1.0.0", server.WithToolCapabilities(true), - server.WithResourceCapabilities(true), + server.WithResourceCapabilities(true, true), ) // Add file listing tool @@ -52,7 +52,7 @@ func main() { mcp.Description("Directory path to list"), ), mcp.WithBoolean("recursive", - mcp.Default(false), + mcp.DefaultBool(false), mcp.Description("List files recursively"), ), ), @@ -98,15 +98,11 @@ func handleListFiles(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo return mcp.NewToolResultError(fmt.Sprintf("failed to list files: %v", err)), nil } - return mcp.NewToolResultJSON(map[string]interface{}{ - "path": path, - "files": files, - "count": len(files), - "recursive": recursive, - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"path":"%s","files":%v,"count":%d,"recursive":%t}`, + path, files, len(files), recursive)), nil } -func handleFileContent(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { +func handleFileContent(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // Extract path from URI: "file:///path/to/file" -> "/path/to/file" path := extractPathFromURI(req.Params.URI) @@ -119,13 +115,11 @@ func handleFileContent(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.R return nil, fmt.Errorf("failed to read file: %w", err) } - return &mcp.ReadResourceResult{ - Contents: []mcp.ResourceContent{ - { - URI: req.Params.URI, - MIMEType: detectMIMEType(path), - Text: string(content), - }, + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: detectMIMEType(path), + Text: string(content), }, }, nil } @@ -211,16 +205,10 @@ import ( func main() { s := server.NewMCPServer("Advanced CLI Tool", "1.0.0", - server.WithAllCapabilities(), - server.WithRecovery(), - server.WithHooks(&server.Hooks{ - OnSessionStart: func(sessionID string) { - logToFile(fmt.Sprintf("Session started: %s", sessionID)) - }, - OnSessionEnd: func(sessionID string) { - logToFile(fmt.Sprintf("Session ended: %s", sessionID)) - }, - }), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), ) // Add comprehensive tools @@ -325,29 +313,42 @@ package main import ( "context" "log" + "time" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" ) func main() { // Create STDIO client c, err := client.NewStdioClient( - "go", "run", "/path/to/server/main.go", + "go", nil /* inherit env */, "run", "/path/to/server/main.go", ) if err != nil { log.Fatal(err) } defer c.Close() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() // Initialize connection - if err := c.Initialize(ctx); err != nil { + _, err = c.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeRequestParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + }, + }) + if err != nil { log.Fatal(err) } // List available tools - tools, err := c.ListTools(ctx) + tools, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { log.Fatal(err) } @@ -359,7 +360,7 @@ func main() { // Call a tool result, err := c.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ + Params: mcp.CallToolParams{ Name: "list_files", Arguments: map[string]interface{}{ "path": ".", @@ -375,6 +376,58 @@ func main() { } ``` +#### Customizing Subprocess Execution + +If you need more control over how a sub-process is spawned when creating a new STDIO client, you can use +`NewStdioMCPClientWithOptions` instead of `NewStdioMCPClient`. + +By passing the `WithCommandFunc` option, you can supply a custom factory function to create the `exec.Cmd` that launches +the server. This allows configuration of environment variables, working directories, and system-level process attributes. + +Referring to the previous example, we can replace the line that creates the client: + +```go +c, err := client.NewStdioClient( + "go", nil, "run", "/path/to/server/main.go", +) +``` + +With the options-aware version: + +```go +c, err := client.NewStdioMCPClientWithOptions( + "go", + nil, + []string {"run", "/path/to/server/main.go"}, + transport.WithCommandFunc(func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + cmd := exec.CommandContext(ctx, command, args...) + cmd.Env = env // Explicit environment for the subprocess. + cmd.Dir = "/var/sandbox/mcp-server" // Working directory (not isolated unless paired with chroot or namespace). + + // Apply low-level process isolation and privilege dropping. + cmd.SysProcAttr = &syscall.SysProcAttr{ + // Drop to non-root user (e.g., user/group ID 1001) + Credential: &syscall.Credential{ + Uid: 1001, + Gid: 1001, + }, + // File system isolation: only works if running as root. + Chroot: "/var/sandbox/mcp-server", + + // Linux namespace isolation (Linux only): + // Prevents access to other processes, mounts, IPC, networks, etc. + Cloneflags: syscall.CLONE_NEWIPC | // Isolate inter-process comms + syscall.CLONE_NEWNS | // Isolate filesystem mounts + syscall.CLONE_NEWPID | // Isolate PID namespace (child sees itself as PID 1) + syscall.CLONE_NEWUTS | // Isolate hostname + syscall.CLONE_NEWNET, // Isolate networking (optional) + } + + return cmd, nil + }), +) +``` + ## Debugging ### Command Line Testing @@ -445,18 +498,7 @@ func main() { s := server.NewMCPServer("Debug Server", "1.0.0", server.WithToolCapabilities(true), - server.WithHooks(&server.Hooks{ - OnSessionStart: func(sessionID string) { - logger.Printf("Session started: %s", sessionID) - }, - OnToolCall: func(sessionID, toolName string, duration time.Duration, err error) { - if err != nil { - logger.Printf("Tool %s failed: %v", toolName, err) - } else { - logger.Printf("Tool %s completed in %v", toolName, duration) - } - }, - }), + server.WithLogging(), ) // Add tools with debug logging @@ -466,7 +508,7 @@ func main() { mcp.WithString("message", mcp.Required()), ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - message := req.Params.Arguments["message"].(string) + message := req.GetString("message", "") logger.Printf("Echo tool called with message: %s", message) return mcp.NewToolResultText(fmt.Sprintf("Echo: %s", message)), nil }, @@ -504,8 +546,8 @@ This opens a web interface where you can: ```go func handleToolWithErrors(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Validate required parameters - path, ok := req.Params.Arguments["path"].(string) - if !ok { + path, err := req.RequireString("path") + if err != nil { return nil, fmt.Errorf("path parameter is required and must be a string") } @@ -543,7 +585,7 @@ func handleToolWithErrors(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca return nil, fmt.Errorf("operation failed: %w", err) } - return mcp.NewToolResultJSON(result), nil + return mcp.NewToolResultText(fmt.Sprintf("%v", result)), nil } ``` @@ -638,7 +680,7 @@ func getCachedFile(path string) (string, bool) { ```go func handleLargeFile(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - path := req.Params.Arguments["path"].(string) + path := req.GetString("path", "") // Stream large files instead of loading into memory file, err := os.Open(path) diff --git a/www/vocs.config.ts b/www/vocs.config.ts index 0706755d7..dc745b385 100644 --- a/www/vocs.config.ts +++ b/www/vocs.config.ts @@ -100,6 +100,20 @@ export default defineConfig({ }, ], }, + { + text: 'Advanced', + collapsed: true, + items: [ + { + text: 'Server Sampling', + link: '/servers/advanced-sampling', + }, + { + text: 'Client Sampling', + link: '/clients/advanced-sampling', + }, + ], + }, ], socials: [ {