Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import io.a2a.client.transport.spi.interceptors.ClientCallContext;
import io.a2a.client.transport.spi.ClientTransport;
import io.a2a.common.A2AHeaders;
import io.a2a.grpc.A2AServiceGrpc;
import io.a2a.grpc.CancelTaskRequest;
import io.a2a.grpc.CreateTaskPushNotificationConfigRequest;
Expand All @@ -37,8 +38,9 @@
import io.a2a.spec.TaskPushNotificationConfig;
import io.a2a.spec.TaskQueryParams;
import io.grpc.Channel;

import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;

public class GrpcTransport implements ClientTransport {
Expand All @@ -61,7 +63,8 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex
SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context);

try {
SendMessageResponse response = blockingStub.sendMessage(sendMessageRequest);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest);
if (response.hasMsg()) {
return FromProto.message(response.getMsg());
} else if (response.hasTask()) {
Expand All @@ -83,7 +86,8 @@ public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEv
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);

try {
asyncStub.sendStreamingMessage(grpcRequest, streamObserver);
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
stubWithMetadata.sendStreamingMessage(grpcRequest, streamObserver);
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to send streaming message request: ");
}
Expand All @@ -101,7 +105,8 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A
GetTaskRequest getTaskRequest = requestBuilder.build();

try {
return FromProto.task(blockingStub.getTask(getTaskRequest));
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
return FromProto.task(stubWithMetadata.getTask(getTaskRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task: ");
}
Expand All @@ -116,7 +121,8 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A
.build();

try {
return FromProto.task(blockingStub.cancelTask(cancelTaskRequest));
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
return FromProto.task(stubWithMetadata.cancelTask(cancelTaskRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to cancel task: ");
}
Expand All @@ -135,7 +141,8 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
.build();

try {
return FromProto.taskPushNotificationConfig(blockingStub.createTaskPushNotificationConfig(grpcRequest));
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
return FromProto.taskPushNotificationConfig(stubWithMetadata.createTaskPushNotificationConfig(grpcRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to create task push notification config: ");
}
Expand All @@ -152,7 +159,8 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
.build();

try {
return FromProto.taskPushNotificationConfig(blockingStub.getTaskPushNotificationConfig(grpcRequest));
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
return FromProto.taskPushNotificationConfig(stubWithMetadata.getTaskPushNotificationConfig(grpcRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task push notification config: ");
}
Expand All @@ -169,7 +177,8 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(
.build();

try {
return blockingStub.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream()
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
return stubWithMetadata.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream()
.map(FromProto::taskPushNotificationConfig)
.collect(Collectors.toList());
} catch (StatusRuntimeException e) {
Expand All @@ -187,7 +196,8 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC
.build();

try {
blockingStub.deleteTaskPushNotificationConfig(grpcRequest);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest);
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to delete task push notification config: ");
}
Expand All @@ -206,7 +216,8 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);

try {
asyncStub.taskSubscription(grpcRequest, streamObserver);
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
stubWithMetadata.taskSubscription(grpcRequest, streamObserver);
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to resubscribe task push notification config: ");
}
Expand Down Expand Up @@ -234,6 +245,50 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag
return builder.build();
}

/**
* Creates gRPC metadata from ClientCallContext headers.
* Extracts headers like X-A2A-Extensions and sets them as gRPC metadata.
*/
private Metadata createGrpcMetadata(ClientCallContext context) {
Metadata metadata = new Metadata();

if (context != null && context.getHeaders() != null) {
// Set X-A2A-Extensions header if present
String extensionsHeader = context.getHeaders().get(A2AHeaders.X_A2A_EXTENSIONS);
if (extensionsHeader != null) {
Metadata.Key<String> extensionsKey = Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER);
metadata.put(extensionsKey, extensionsHeader);
}

// Add other headers as needed in the future
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just thinking some more about this and realized that I can probably modify #292 to add the authorization/api-key headers for gRPC in a similar way.

Currently, we are relying on the user configuring a gRPC ClientInterceptor to specify the token/API key. Instead, I think I can update GrpcTransport to make use of the ClientCallInterceptors so that the AuthInterceptor could be used for gRPC like we use it for the other transports to add the appropriate header. And then, I can update this method to add the appropriate header to the gRPC metadata if present.

// For now, we only handle X-A2A-Extensions
}

return metadata;
}

/**
* Creates a blocking stub with metadata attached from the ClientCallContext.
*
* @param context the client call context
* @return blocking stub with metadata interceptor
*/
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context) {
Metadata metadata = createGrpcMetadata(context);
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}

/**
* Creates an async stub with metadata attached from the ClientCallContext.
*
* @param context the client call context
* @return async stub with metadata interceptor
*/
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context) {
Metadata metadata = createGrpcMetadata(context);
return asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}

private String getTaskPushNotificationConfigName(GetTaskPushNotificationConfigParams params) {
return getTaskPushNotificationConfigName(params.id(), params.pushNotificationConfigId());
}
Expand Down
17 changes: 17 additions & 0 deletions common/src/main/java/io/a2a/common/A2AHeaders.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.a2a.common;

/**
* Common A2A protocol headers and constants.
*/
public final class A2AHeaders {

/**
* HTTP header name for A2A extensions.
* Used to communicate which extensions are requested by the client.
*/
public static final String X_A2A_EXTENSIONS = "X-A2A-Extensions";

private A2AHeaders() {
// Utility class
}
}
4 changes: 4 additions & 0 deletions reference/grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
<groupId>${project.groupId}</groupId>
<artifactId>a2a-java-sdk-reference-common</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>a2a-java-sdk-common</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>a2a-java-sdk-transport-grpc</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.a2a.server.grpc.quarkus;

import jakarta.enterprise.context.ApplicationScoped;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.a2a.common.A2AHeaders;
import io.a2a.transport.grpc.context.GrpcContextKeys;

/**
* gRPC server interceptor that captures request metadata and context information,
* providing equivalent functionality to Python's grpc.aio.ServicerContext.
*
* This interceptor:
* - Extracts A2A extension headers from incoming requests
* - Captures ServerCall and Metadata for rich context access
* - Stores context information in gRPC Context for service method access
* - Provides proper equivalence to Python's ServicerContext
*/
@ApplicationScoped
public class A2AExtensionsInterceptor implements ServerInterceptor {


@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> serverCall,
Metadata metadata,
ServerCallHandler<ReqT, RespT> serverCallHandler) {

// Extract A2A extensions header
Metadata.Key<String> extensionsKey =
Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER);
String extensions = metadata.get(extensionsKey);

// Create enhanced context with rich information (equivalent to Python's ServicerContext)
Context context = Context.current()
// Store complete metadata for full header access
.withValue(GrpcContextKeys.METADATA_KEY, metadata)
// Store method name (equivalent to Python's context.method())
.withValue(GrpcContextKeys.METHOD_NAME_KEY, serverCall.getMethodDescriptor().getFullMethodName())
// Store peer information for client connection details
.withValue(GrpcContextKeys.PEER_INFO_KEY, getPeerInfo(serverCall));

// Store A2A extensions if present
if (extensions != null) {
context = context.withValue(GrpcContextKeys.EXTENSIONS_HEADER_KEY, extensions);
}

// Proceed with the call in the enhanced context
return Contexts.interceptCall(context, serverCall, metadata, serverCallHandler);
}

/**
* Safely extracts peer information from the ServerCall.
*
* @param serverCall the gRPC ServerCall
* @return peer information string, or "unknown" if not available
*/
private String getPeerInfo(ServerCall<?, ?> serverCall) {
try {
Object remoteAddr = serverCall.getAttributes().get(io.grpc.Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
return remoteAddr != null ? remoteAddr.toString() : "unknown";
} catch (Exception e) {
return "unknown";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import io.a2a.transport.grpc.handler.CallContextFactory;
import io.a2a.transport.grpc.handler.GrpcHandler;
import io.quarkus.grpc.GrpcService;
import io.quarkus.grpc.RegisterInterceptor;

@GrpcService
@RegisterInterceptor(A2AExtensionsInterceptor.class)
public class QuarkusGrpcHandler extends GrpcHandler {

private final AgentCard agentCard;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
Expand All @@ -19,9 +20,11 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.io.JsonEOFException;
import com.fasterxml.jackson.databind.JsonNode;
import io.a2a.common.A2AHeaders;
import io.a2a.server.ServerCallContext;
import io.a2a.server.auth.UnauthenticatedUser;
import io.a2a.server.auth.User;
import io.a2a.server.extensions.A2AExtensions;
import io.a2a.server.util.async.Internal;
import io.a2a.spec.AgentCard;
import io.a2a.spec.CancelTaskRequest;
Expand Down Expand Up @@ -241,7 +244,11 @@ public String getUsername() {
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
state.put("headers", headers);

return new ServerCallContext(user, state);
// Extract requested extensions from X-A2A-Extensions header
List<String> extensionHeaderValues = rc.request().headers().getAll(A2AHeaders.X_A2A_EXTENSIONS);
Set<String> requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues);

return new ServerCallContext(user, state, requestedExtensions);
} else {
CallContextFactory builder = callContextFactory.get();
return builder.build(rc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import io.a2a.common.A2AHeaders;
import io.a2a.server.ServerCallContext;
import io.a2a.server.auth.UnauthenticatedUser;
import io.a2a.server.auth.User;
Expand All @@ -34,9 +35,13 @@
import io.vertx.core.http.HttpServerResponse;
import io.vertx.ext.web.RoutingContext;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import io.a2a.server.extensions.A2AExtensions;

@Singleton
public class A2AServerRoutes {

Expand Down Expand Up @@ -308,7 +313,11 @@ public String getUsername() {
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
state.put("headers", headers);

return new ServerCallContext(user, state);
// Extract requested extensions from X-A2A-Extensions header
List<String> extensionHeaderValues = rc.request().headers().getAll(A2AHeaders.X_A2A_EXTENSIONS);
Set<String> requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues);

return new ServerCallContext(user, state, requestedExtensions);
} else {
CallContextFactory builder = callContextFactory.get();
return builder.build(rc);
Expand Down
32 changes: 31 additions & 1 deletion server-common/src/main/java/io/a2a/server/ServerCallContext.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.a2a.server;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import io.a2a.server.auth.User;
Expand All @@ -10,10 +12,14 @@ public class ServerCallContext {
private final Map<Object, Object> modelConfig = new ConcurrentHashMap<>();
private final Map<String, Object> state;
private final User user;
private final Set<String> requestedExtensions;
private final Set<String> activatedExtensions;

public ServerCallContext(User user, Map<String, Object> state) {
public ServerCallContext(User user, Map<String, Object> state, Set<String> requestedExtensions) {
this.user = user;
this.state = state;
this.requestedExtensions = new HashSet<>(requestedExtensions);
this.activatedExtensions = new HashSet<>(); // Always starts empty, populated later by application code
}

public Map<String, Object> getState() {
Expand All @@ -23,4 +29,28 @@ public Map<String, Object> getState() {
public User getUser() {
return user;
}

public Set<String> getRequestedExtensions() {
return new HashSet<>(requestedExtensions);
}

public Set<String> getActivatedExtensions() {
return new HashSet<>(activatedExtensions);
}

public void activateExtension(String extensionUri) {
activatedExtensions.add(extensionUri);
}

public void deactivateExtension(String extensionUri) {
activatedExtensions.remove(extensionUri);
}

public boolean isExtensionActivated(String extensionUri) {
return activatedExtensions.contains(extensionUri);
}

public boolean isExtensionRequested(String extensionUri) {
return requestedExtensions.contains(extensionUri);
}
}
Loading