Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.a2a.client.transport.spi.interceptors.auth.AuthInterceptor;
import io.a2a.common.A2AHeaders;
import io.a2a.grpc.A2AServiceGrpc;
import io.a2a.grpc.A2AServiceGrpc.A2AServiceBlockingV2Stub;
import io.a2a.grpc.A2AServiceGrpc.A2AServiceStub;
import io.a2a.grpc.CancelTaskRequest;
import io.a2a.grpc.CreateTaskPushNotificationConfigRequest;
import io.a2a.grpc.DeleteTaskPushNotificationConfigRequest;
Expand All @@ -28,7 +30,8 @@
import io.a2a.grpc.SendMessageResponse;
import io.a2a.grpc.StreamResponse;
import io.a2a.grpc.TaskSubscriptionRequest;

import io.a2a.grpc.utils.ProtoUtils.FromProto;
import io.a2a.grpc.utils.ProtoUtils.ToProto;
import io.a2a.spec.A2AClientException;
import io.a2a.spec.AgentCard;
import io.a2a.spec.DeleteTaskPushNotificationConfigParams;
Expand All @@ -49,6 +52,7 @@
import io.grpc.StatusRuntimeException;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import org.jspecify.annotations.Nullable;

public class GrpcTransport implements ClientTransport {

Expand All @@ -60,23 +64,24 @@ public class GrpcTransport implements ClientTransport {
Metadata.ASCII_STRING_MARSHALLER);
private final A2AServiceBlockingV2Stub blockingStub;
private final A2AServiceStub asyncStub;
private final List<ClientCallInterceptor> interceptors;
private final @Nullable List<ClientCallInterceptor> interceptors;
private AgentCard agentCard;

public GrpcTransport(Channel channel, AgentCard agentCard) {
this(channel, agentCard, null);
}

public GrpcTransport(Channel channel, AgentCard agentCard, List<ClientCallInterceptor> interceptors) {
public GrpcTransport(Channel channel, AgentCard agentCard, @Nullable List<ClientCallInterceptor> interceptors) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved robustness and to align with the non-null contract of the agentCard field (as per @NullMarked on the package), it's good practice to validate the agentCard parameter for nullity in this public constructor. This ensures that the GrpcTransport is always instantiated in a valid state.

Consider adding checkNotNullParam("agentCard", agentCard); at the start of the constructor body.

checkNotNullParam("channel", channel);
checkNotNullParam("agentCard", agentCard);
this.asyncStub = A2AServiceGrpc.newStub(channel);
this.blockingStub = A2AServiceGrpc.newBlockingV2Stub(channel);
this.agentCard = agentCard;
this.interceptors = interceptors;
}

@Override
public EventKind sendMessage(MessageSendParams request, ClientCallContext context) throws A2AClientException {
public EventKind sendMessage(MessageSendParams request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context);
Expand All @@ -100,7 +105,7 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex

@Override
public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEventKind> eventConsumer,
Consumer<Throwable> errorConsumer, ClientCallContext context) throws A2AClientException {
Consumer<Throwable> errorConsumer, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
checkNotNullParam("eventConsumer", eventConsumer);
SendMessageRequest grpcRequest = createGrpcSendMessageRequest(request, context);
Expand All @@ -117,7 +122,7 @@ public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEv
}

@Override
public Task getTask(TaskQueryParams request, ClientCallContext context) throws A2AClientException {
public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

GetTaskRequest.Builder requestBuilder = GetTaskRequest.newBuilder();
Expand All @@ -138,7 +143,7 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A
}

@Override
public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A2AClientException {
public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

CancelTaskRequest cancelTaskRequest = CancelTaskRequest.newBuilder()
Expand All @@ -157,7 +162,7 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A

@Override
public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request,
ClientCallContext context) throws A2AClientException {
@Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

String configId = request.pushNotificationConfig().id();
Expand All @@ -180,7 +185,7 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
@Override
public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
GetTaskPushNotificationConfigParams request,
ClientCallContext context) throws A2AClientException {
@Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

GetTaskPushNotificationConfigRequest grpcRequest = GetTaskPushNotificationConfigRequest.newBuilder()
Expand All @@ -200,7 +205,7 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
@Override
public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(
ListTaskPushNotificationConfigParams request,
ClientCallContext context) throws A2AClientException {
@Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

ListTaskPushNotificationConfigRequest grpcRequest = ListTaskPushNotificationConfigRequest.newBuilder()
Expand All @@ -221,7 +226,7 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(

@Override
public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request,
ClientCallContext context) throws A2AClientException {
@Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

DeleteTaskPushNotificationConfigRequest grpcRequest = DeleteTaskPushNotificationConfigRequest.newBuilder()
Expand All @@ -240,7 +245,7 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC

@Override
public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> eventConsumer,
Consumer<Throwable> errorConsumer, ClientCallContext context) throws A2AClientException {
Consumer<Throwable> errorConsumer, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
checkNotNullParam("eventConsumer", eventConsumer);

Expand All @@ -261,7 +266,7 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
}

@Override
public AgentCard getAgentCard(ClientCallContext context) throws A2AClientException {
public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException {
// TODO: Determine how to handle retrieving the authenticated extended agent card
return agentCard;
}
Expand All @@ -270,7 +275,7 @@ public AgentCard getAgentCard(ClientCallContext context) throws A2AClientExcepti
public void close() {
}

private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messageSendParams, ClientCallContext context) {
private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messageSendParams, @Nullable ClientCallContext context) {
SendMessageRequest.Builder builder = SendMessageRequest.newBuilder();
builder.setRequest(ToProto.message(messageSendParams.message()));
if (messageSendParams.configuration() != null) {
Expand All @@ -285,8 +290,11 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag
/**
* Creates gRPC metadata from ClientCallContext headers.
* Extracts headers like X-A2A-Extensions and sets them as gRPC metadata.
* @param context the client call context containing headers, may be null
* @param payloadAndHeaders the payload and headers wrapper, may be null
* @return the gRPC metadata
*/
private Metadata createGrpcMetadata(ClientCallContext context, PayloadAndHeaders payloadAndHeaders) {
private Metadata createGrpcMetadata(@Nullable ClientCallContext context, @Nullable PayloadAndHeaders payloadAndHeaders) {
Metadata metadata = new Metadata();

if (context != null && context.getHeaders() != null) {
Expand Down Expand Up @@ -328,7 +336,7 @@ private Metadata createGrpcMetadata(ClientCallContext context, PayloadAndHeaders
* @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
* @return blocking stub with metadata interceptor
*/
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context,
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(@Nullable ClientCallContext context,
PayloadAndHeaders payloadAndHeaders) {
Metadata metadata = createGrpcMetadata(context, payloadAndHeaders);
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
Expand All @@ -341,7 +349,7 @@ private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContex
* @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
* @return async stub with metadata interceptor
*/
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context,
private A2AServiceStub createAsyncStubWithMetadata(@Nullable ClientCallContext context,
PayloadAndHeaders payloadAndHeaders) {
Metadata metadata = createGrpcMetadata(context, payloadAndHeaders);
return asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
Expand All @@ -351,7 +359,7 @@ private String getTaskPushNotificationConfigName(GetTaskPushNotificationConfigPa
return getTaskPushNotificationConfigName(params.id(), params.pushNotificationConfigId());
}

private String getTaskPushNotificationConfigName(String taskId, String pushNotificationConfigId) {
private String getTaskPushNotificationConfigName(String taskId, @Nullable String pushNotificationConfigId) {
StringBuilder name = new StringBuilder();
name.append("tasks/");
name.append(taskId);
Expand All @@ -366,7 +374,7 @@ private String getTaskPushNotificationConfigName(String taskId, String pushNotif
}

private PayloadAndHeaders applyInterceptors(String methodName, Object payload,
AgentCard agentCard, ClientCallContext clientCallContext) {
AgentCard agentCard, @Nullable ClientCallContext clientCallContext) {
PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload,
clientCallContext != null ? clientCallContext.getHeaders() : null);
if (interceptors != null && ! interceptors.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import java.util.function.Function;

import org.jspecify.annotations.Nullable;

public class GrpcTransportConfigBuilder extends ClientTransportConfigBuilder<GrpcTransportConfig, GrpcTransportConfigBuilder> {

private Function<String, Channel> channelFactory;
private @Nullable Function<String, Channel> channelFactory;

public GrpcTransportConfigBuilder channelFactory(Function<String, Channel> channelFactory) {
Assert.checkNotNullParam("channelFactory", channelFactory);
Expand All @@ -20,6 +22,9 @@ public GrpcTransportConfigBuilder channelFactory(Function<String, Channel> chann

@Override
public GrpcTransportConfig build() {
if (channelFactory == null) {
throw new IllegalStateException("channelFactory must be set");
}
GrpcTransportConfig config = new GrpcTransportConfig(channelFactory);
config.setInterceptors(interceptors);
return config;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@NullMarked
package io.a2a.client.transport.grpc;

import org.jspecify.annotations.NullMarked;
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class PayloadAndHeaders {
private final @Nullable Object payload;
private final Map<String, String> headers;

public PayloadAndHeaders(@Nullable Object payload, Map<String, String> headers) {
public PayloadAndHeaders(@Nullable Object payload, @Nullable Map<String, String> headers) {
this.payload = payload;
this.headers = headers == null ? Collections.emptyMap() : new HashMap<>(headers);
}
Expand Down