diff --git a/pkg/transport/http.go b/pkg/transport/http.go index 5bbe1a2..f55f9a5 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -78,17 +78,64 @@ func NewHTTP(address string) (*HTTP, error) { return nil, fmt.Errorf("timeout waiting for SSE response") } - return &HTTP{ + client := &HTTP{ // Use the SSE message address as the base address for the HTTP transport - address: address + messageAddress, + address: address + "/sse" + messageAddress, nextID: 1, debug: debug, eventCh: eventCh, - }, nil + } + + // Send initialize request + _, err = client.Execute("initialize", map[string]any{ + "clientInfo": map[string]any{ + "name": "mcp-client", + "version": "0.1.0", + }, + "capabilities": map[string]any{}, + "protocolVersion": "2024-11-05", + }) + if err != nil { + return nil, fmt.Errorf("error sending initialize request: %w", err) + } + + // Send intialized notification + if err := client.send("notifications/initialized", nil); err != nil { + return nil, fmt.Errorf("error sending initialized notification: %w", err) + } + + return client, nil } // Execute implements the Transport via JSON-RPC over HTTP. func (t *HTTP) Execute(method string, params any) (map[string]any, error) { + if err := t.send(method, params); err != nil { + return nil, err + } + + // After sending the request, we listen the SSE channel for the response + var response Response + select { + case msg := <-t.eventCh: + if unmarshalErr := json.Unmarshal([]byte(msg), &response); unmarshalErr != nil { + return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, msg) + } + case <-time.After(10 * time.Second): + return nil, fmt.Errorf("timeout waiting for SSE response") + } + + if response.Error != nil { + return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message) + } + + if t.debug { + fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n") + } + + return response.Result, nil +} + +func (t *HTTP) send(method string, params any) error { if t.debug { fmt.Fprintf(os.Stderr, "DEBUG: Connecting to server: %s\n", t.address) } @@ -103,7 +150,7 @@ func (t *HTTP) Execute(method string, params any) (map[string]any, error) { requestJSON, err := json.Marshal(request) if err != nil { - return nil, fmt.Errorf("error marshaling request: %w", err) + return fmt.Errorf("error marshaling request: %w", err) } requestJSON = append(requestJSON, '\n') @@ -114,7 +161,7 @@ func (t *HTTP) Execute(method string, params any) (map[string]any, error) { resp, err := http.Post(t.address, "application/json", bytes.NewBuffer(requestJSON)) if err != nil { - return nil, fmt.Errorf("error sending request: %w", err) + return fmt.Errorf("error sending request: %w", err) } if t.debug { @@ -129,35 +176,12 @@ func (t *HTTP) Execute(method string, params any) (map[string]any, error) { body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response: %w", err) + return fmt.Errorf("error reading response: %w", err) } if t.debug { fmt.Fprintf(os.Stderr, "DEBUG: Read from server: %s\n", string(body)) } - if len(body) == 0 { - return nil, fmt.Errorf("no response from server") - } - - // After sending the request, we listen the SSE channel for the response - var response Response - select { - case msg := <-t.eventCh: - if unmarshalErr := json.Unmarshal([]byte(msg), &response); unmarshalErr != nil { - return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, msg) - } - case <-time.After(10 * time.Second): - return nil, fmt.Errorf("timeout waiting for SSE response") - } - - if response.Error != nil { - return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n") - } - - return response.Result, nil + return nil }