Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

Expand Down Expand Up @@ -256,39 +255,31 @@ private ServerResponse handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);

// Send initial endpoint event
try {
return ServerResponse.sse(sseBuilder -> {
sseBuilder.onComplete(() -> {
logger.debug("SSE connection completed for session: {}", sessionId);
sessions.remove(sessionId);
});
sseBuilder.onTimeout(() -> {
logger.debug("SSE connection timed out for session: {}", sessionId);
sessions.remove(sessionId);
});

WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder);
McpServerSession session = sessionFactory.create(sessionTransport);
this.sessions.put(sessionId, session);
return ServerResponse.sse(sseBuilder -> {
WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sseBuilder);
McpServerSession session = sessionFactory.create(sessionTransport);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to have this.

String sessionId = session.getId();
logger.debug("Creating new SSE connection for session: {}", sessionId);
sseBuilder.onComplete(() -> {
logger.debug("SSE connection completed for session: {}", sessionId);
sessions.remove(sessionId);
});
sseBuilder.onTimeout(() -> {
logger.debug("SSE connection timed out for session: {}", sessionId);
sessions.remove(sessionId);
});
this.sessions.put(sessionId, session);

try {
sseBuilder.id(sessionId).event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId));
}
catch (Exception e) {
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
sseBuilder.error(e);
}
}, Duration.ZERO);
}
catch (Exception e) {
logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage());
sessions.remove(sessionId);
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build();
}
try {
sseBuilder.event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId));
}
catch (Exception e) {
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
this.sessions.remove(sessionId);
sseBuilder.error(e);
}
}, Duration.ZERO);
}

/**
Expand Down Expand Up @@ -363,8 +354,6 @@ private ServerResponse handleMessage(ServerRequest request) {
*/
private class WebMvcMcpSessionTransport implements McpServerTransport {

private final String sessionId;

private final SseBuilder sseBuilder;

/**
Expand All @@ -374,14 +363,11 @@ private class WebMvcMcpSessionTransport implements McpServerTransport {
private final ReentrantLock sseBuilderLock = new ReentrantLock();

/**
* Creates a new session transport with the specified ID and SSE builder.
* @param sessionId The unique identifier for this session
* Creates a new session transport with the specified SSE builder.
* @param sseBuilder The SSE builder for sending server events to the client
*/
WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) {
this.sessionId = sessionId;
WebMvcMcpSessionTransport(SseBuilder sseBuilder) {
this.sseBuilder = sseBuilder;
logger.debug("Session transport {} initialized with SSE builder", sessionId);
}

/**
Expand All @@ -395,11 +381,10 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
sseBuilderLock.lock();
try {
String jsonText = jsonMapper.writeValueAsString(message);
sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText);
logger.debug("Message sent to session {}", sessionId);
sseBuilder.event(MESSAGE_EVENT_TYPE).data(jsonText);
}
catch (Exception e) {
logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
logger.error("Failed to send message: {}", e.getMessage());
sseBuilder.error(e);
}
finally {
Expand Down Expand Up @@ -427,14 +412,12 @@ public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
@Override
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(() -> {
logger.debug("Closing session transport: {}", sessionId);
sseBuilderLock.lock();
try {
sseBuilder.complete();
logger.debug("Successfully completed SSE builder for session {}", sessionId);
}
catch (Exception e) {
logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage());
logger.warn("Failed to complete SSE builder: {}", e.getMessage());
}
finally {
sseBuilderLock.unlock();
Expand All @@ -450,10 +433,9 @@ public void close() {
sseBuilderLock.lock();
try {
sseBuilder.complete();
logger.debug("Successfully completed SSE builder for session {}", sessionId);
}
catch (Exception e) {
logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage());
logger.warn("Failed to complete SSE builder: {}", e.getMessage());
}
finally {
sseBuilderLock.unlock();
Expand Down