Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ func writeEvent(w io.Writer, evt Event) (int, error) {
//
// TODO(rfindley): consider a different API here that makes failure modes more
// apparent.
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
scanner := bufio.NewScanner(r)
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
scanner.Buffer(nil, maxTokenSize)
func scanEvents(ctx context.Context, r io.Reader) iter.Seq2[Event, error] {
scanner := bufio.NewReader(r)

// TODO: investigate proper behavior when events are out of order, or have
// non-standard names.
Expand Down Expand Up @@ -100,15 +98,40 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
dataBuf = nil
}
}
for scanner.Scan() {
line := scanner.Bytes()
emitEvent := func() bool {
flushData()
if evt.Empty() {
return true
}
if !yield(evt, nil) {
return false
}
evt = Event{}
return true
}
for {
line, err := scanner.ReadBytes('\n')
if err != nil {
if errors.Is(err, io.EOF) {
// Handle EOF below
} else if ctx.Err() != nil {
yield(Event{}, fmt.Errorf("context done: %w", ctx.Err()))
return
} else {
yield(Event{}, fmt.Errorf("error reading event: %w", err))
return
}
}
line = bytes.TrimRight(line, "\r\n")
isEOF := errors.Is(err, io.EOF)

if len(line) == 0 {
flushData()
// \n\n is the record delimiter
if !evt.Empty() && !yield(evt, nil) {
if !emitEvent() {
return
}
if isEOF {
return
}
evt = Event{}
continue
}
before, after, found := bytes.Cut(line, []byte{':'})
Expand Down Expand Up @@ -136,19 +159,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
dataBuf.Write(data)
}
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
}
if !yield(Event{}, err) {
if isEOF {
emitEvent()
return
}
}
flushData()
if !evt.Empty() {
yield(evt, nil)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestScanEvents(t *testing.T) {
r := strings.NewReader(tt.input)
var got []Event
var err error
for e, err2 := range scanEvents(r) {
for e, err2 := range scanEvents(t.Context(), r) {
if err2 != nil {
err = err2
break
Expand Down
31 changes: 22 additions & 9 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {

msgEndpoint, err := func() (*url.URL, error) {
var evt Event
for evt, err = range scanEvents(resp.Body) {
for evt, err = range scanEvents(ctx, resp.Body) {
break
}
if err != nil {
Expand All @@ -382,20 +382,24 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
s := &sseClientConn{
client: httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
incoming: make(chan sseMessage, 100),
body: resp.Body,
done: make(chan struct{}),
}

go func() {
defer s.Close() // close the transport when the GET exits

for evt, err := range scanEvents(resp.Body) {
for evt, err := range scanEvents(ctx, resp.Body) {
if err != nil {
select {
case s.incoming <- sseMessage{err: err}:
case <-s.done:
}
return
}
select {
case s.incoming <- evt.Data:
case s.incoming <- sseMessage{data: evt.Data}:
case <-s.done:
return
}
Expand All @@ -405,15 +409,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
return s, nil
}

// sseMessage represents a message or error from the SSE stream.
type sseMessage struct {
data []byte
err error
}

// An sseClientConn is a logical jsonrpc2 connection that implements the client
// half of the SSE protocol:
// - Writes are POSTS to the session endpoint.
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientConn struct {
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
client *http.Client // HTTP client to use for requests
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan sseMessage // queue of incoming messages or errors

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand All @@ -438,12 +448,15 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
case <-c.done:
return nil, io.EOF

case data := <-c.incoming:
case m := <-c.incoming:
if m.err != nil {
return nil, m.err
}
// TODO(rfindley): do we really need to check this? We receive from c.done above.
if c.isDone() {
return nil, io.EOF
}
msg, err := jsonrpc2.DecodeMessage(data)
msg, err := jsonrpc2.DecodeMessage(m.data)
if err != nil {
return nil, err
}
Expand Down
8 changes: 6 additions & 2 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1854,12 +1854,16 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()
for evt, err := range scanEvents(resp.Body) {
for evt, err := range scanEvents(ctx, resp.Body) {
if err != nil {
if ctx.Err() != nil {
return "", 0, true // don't reconnect: client cancelled
}
break

// Network errors during reading should trigger reconnection, not permanent failure.
// Return from processStream so handleSSE can attempt to reconnect.
c.logger.Debug(fmt.Sprintf("%s: stream read error (will attempt reconnect): %v", requestSummary, err))
return lastEventID, reconnectDelay, false
}

if evt.ID != "" {
Expand Down
4 changes: 2 additions & 2 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
var respBody []byte
if strings.HasPrefix(contentType, "text/event-stream") {
r := readerInto{resp.Body, new(bytes.Buffer)}
for evt, err := range scanEvents(r) {
for evt, err := range scanEvents(ctx, r) {
if err != nil {
return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err)
}
Expand Down Expand Up @@ -2143,7 +2143,7 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}}
var events []Event

// Scan all events
for evt, err := range scanEvents(reader) {
for evt, err := range scanEvents(t.Context(), reader) {
if err != nil {
if err != io.EOF {
t.Fatalf("scanEvents error: %v", err)
Expand Down