Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions client/base/src/main/java/io/a2a/client/ClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ private ClientTransport buildClientTransport() throws A2AClientException {

// Get the transport provider associated to the protocol
ClientTransportProvider clientTransportProvider = transportProviderRegistry.get(agentInterface.transport());
if (clientTransportProvider == null) {
throw new A2AClientException("No client available for " + agentInterface.transport());
}
Class<? extends ClientTransport> transportProtocolClass = clientTransportProvider.getTransportProtocolClass();

// Retrieve the configuration associated to the preferred transport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,20 @@ public static A2AClientException mapRestError(String body, int code) {
}

public static A2AClientException mapRestError(String className, String errorMessage, int code) {
switch (className) {
case "io.a2a.spec.TaskNotFoundError":
return new A2AClientException(errorMessage, new TaskNotFoundError());
case "io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError":
return new A2AClientException(errorMessage, new AuthenticatedExtendedCardNotConfiguredError());
case "io.a2a.spec.ContentTypeNotSupportedError":
return new A2AClientException(errorMessage, new ContentTypeNotSupportedError(null, null, errorMessage));
case "io.a2a.spec.InternalError":
return new A2AClientException(errorMessage, new InternalError(errorMessage));
case "io.a2a.spec.InvalidAgentResponseError":
return new A2AClientException(errorMessage, new InvalidAgentResponseError(null, null, errorMessage));
case "io.a2a.spec.InvalidParamsError":
return new A2AClientException(errorMessage, new InvalidParamsError());
case "io.a2a.spec.InvalidRequestError":
return new A2AClientException(errorMessage, new InvalidRequestError());
case "io.a2a.spec.JSONParseError":
return new A2AClientException(errorMessage, new JSONParseError());
case "io.a2a.spec.MethodNotFoundError":
return new A2AClientException(errorMessage, new MethodNotFoundError());
case "io.a2a.spec.PushNotificationNotSupportedError":
return new A2AClientException(errorMessage, new PushNotificationNotSupportedError());
case "io.a2a.spec.TaskNotCancelableError":
return new A2AClientException(errorMessage, new TaskNotCancelableError());
case "io.a2a.spec.UnsupportedOperationError":
return new A2AClientException(errorMessage, new UnsupportedOperationError());
default:
return new A2AClientException(errorMessage);
}
return switch (className) {
case "io.a2a.spec.TaskNotFoundError" -> new A2AClientException(errorMessage, new TaskNotFoundError());
case "io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError" -> new A2AClientException(errorMessage, new AuthenticatedExtendedCardNotConfiguredError());
case "io.a2a.spec.ContentTypeNotSupportedError" -> new A2AClientException(errorMessage, new ContentTypeNotSupportedError(null, null, errorMessage));
case "io.a2a.spec.InternalError" -> new A2AClientException(errorMessage, new InternalError(errorMessage));
case "io.a2a.spec.InvalidAgentResponseError" -> new A2AClientException(errorMessage, new InvalidAgentResponseError(null, null, errorMessage));
case "io.a2a.spec.InvalidParamsError" -> new A2AClientException(errorMessage, new InvalidParamsError());
case "io.a2a.spec.InvalidRequestError" -> new A2AClientException(errorMessage, new InvalidRequestError());
case "io.a2a.spec.JSONParseError" -> new A2AClientException(errorMessage, new JSONParseError());
case "io.a2a.spec.MethodNotFoundError" -> new A2AClientException(errorMessage, new MethodNotFoundError());
case "io.a2a.spec.PushNotificationNotSupportedError" -> new A2AClientException(errorMessage, new PushNotificationNotSupportedError());
case "io.a2a.spec.TaskNotCancelableError" -> new A2AClientException(errorMessage, new TaskNotCancelableError());
case "io.a2a.spec.UnsupportedOperationError" -> new A2AClientException(errorMessage, new UnsupportedOperationError());
default -> new A2AClientException(errorMessage);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,40 +38,38 @@
import io.a2a.spec.SetTaskPushNotificationConfigRequest;
import io.a2a.util.Utils;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;

public class RestTransport implements ClientTransport {

private static final Logger log = Logger.getLogger(RestTransport.class.getName());
private final A2AHttpClient httpClient;
private final String agentUrl;
private final List<ClientCallInterceptor> interceptors;
private @Nullable final List<ClientCallInterceptor> interceptors;
private AgentCard agentCard;
private boolean needsExtendedCard = false;

public RestTransport(String agentUrl) {
this(null, null, agentUrl, null);
}

public RestTransport(AgentCard agentCard) {
this(null, agentCard, agentCard.url(), null);
}

public RestTransport(A2AHttpClient httpClient, AgentCard agentCard,
String agentUrl, List<ClientCallInterceptor> interceptors) {
public RestTransport(@Nullable A2AHttpClient httpClient, AgentCard agentCard,
String agentUrl, @Nullable List<ClientCallInterceptor> interceptors) {
this.httpClient = httpClient == null ? new JdkA2AHttpClient() : httpClient;
this.agentCard = agentCard;
this.agentUrl = agentUrl.endsWith("/") ? agentUrl.substring(0, agentUrl.length() - 1) : agentUrl;
this.interceptors = interceptors;
}

@Override
public EventKind sendMessage(MessageSendParams messageSendParams, ClientCallContext context) throws A2AClientException {
public EventKind sendMessage(MessageSendParams messageSendParams, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("messageSendParams", messageSendParams);
io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams));
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, builder, agentCard, context);
Expand All @@ -94,7 +92,7 @@ public EventKind sendMessage(MessageSendParams messageSendParams, ClientCallCont
}

@Override
public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer<StreamingEventKind> eventConsumer, Consumer<Throwable> errorConsumer, ClientCallContext context) throws A2AClientException {
public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer<StreamingEventKind> eventConsumer, Consumer<Throwable> errorConsumer, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", messageSendParams);
checkNotNullParam("eventConsumer", eventConsumer);
checkNotNullParam("messageSendParams", messageSendParams);
Expand All @@ -119,7 +117,7 @@ public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer<S
}

@Override
public Task getTask(TaskQueryParams taskQueryParams, ClientCallContext context) throws A2AClientException {
public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("taskQueryParams", taskQueryParams);
GetTaskRequest.Builder builder = GetTaskRequest.newBuilder();
builder.setName("tasks/" + taskQueryParams.id());
Expand Down Expand Up @@ -154,7 +152,7 @@ public Task getTask(TaskQueryParams taskQueryParams, ClientCallContext context)
}

@Override
public Task cancelTask(TaskIdParams taskIdParams, ClientCallContext context) throws A2AClientException {
public Task cancelTask(TaskIdParams taskIdParams, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("taskIdParams", taskIdParams);
CancelTaskRequest.Builder builder = CancelTaskRequest.newBuilder();
builder.setName("tasks/" + taskIdParams.id());
Expand All @@ -173,7 +171,7 @@ public Task cancelTask(TaskIdParams taskIdParams, ClientCallContext context) thr
}

@Override
public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, ClientCallContext context) throws A2AClientException {
public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
CreateTaskPushNotificationConfigRequest.Builder builder = CreateTaskPushNotificationConfigRequest.newBuilder();
builder.setConfig(ProtoUtils.ToProto.taskPushNotificationConfig(request))
Expand All @@ -195,7 +193,7 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
}

@Override
public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException {
public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
GetTaskPushNotificationConfigRequest.Builder builder = GetTaskPushNotificationConfigRequest.newBuilder();
builder.setName(String.format("/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()));
Expand Down Expand Up @@ -225,7 +223,7 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu
}

@Override
public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(ListTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException {
public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(ListTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
ListTaskPushNotificationConfigRequest.Builder builder = ListTaskPushNotificationConfigRequest.newBuilder();
builder.setParent(String.format("/tasks/%1s/pushNotificationConfigs", request.id()));
Expand Down Expand Up @@ -255,7 +253,7 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(L
}

@Override
public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, ClientCallContext context) throws A2AClientException {
public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
io.a2a.grpc.DeleteTaskPushNotificationConfigRequestOrBuilder builder = io.a2a.grpc.DeleteTaskPushNotificationConfigRequest.newBuilder();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD, builder,
Expand All @@ -281,7 +279,7 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC

@Override
public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> eventConsumer,
Consumer<Throwable> errorConsumer, ClientCallContext context) throws A2AClientException {
Consumer<Throwable> errorConsumer, @Nullable ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);
io.a2a.grpc.TaskSubscriptionRequest.Builder builder = io.a2a.grpc.TaskSubscriptionRequest.newBuilder();
builder.setName("tasks/" + request.id());
Expand All @@ -306,7 +304,7 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
}

@Override
public AgentCard getAgentCard(ClientCallContext context) throws A2AClientException {
public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException {
A2ACardResolver resolver;
try {
if (agentCard == null) {
Expand Down Expand Up @@ -346,8 +344,8 @@ public void close() {
// no-op
}

private PayloadAndHeaders applyInterceptors(String methodName, MessageOrBuilder payload,
AgentCard agentCard, ClientCallContext clientCallContext) {
private PayloadAndHeaders applyInterceptors(String methodName, @Nullable MessageOrBuilder payload,
AgentCard agentCard, @Nullable ClientCallContext clientCallContext) {
PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, getHttpHeaders(clientCallContext));
if (interceptors != null && !interceptors.isEmpty()) {
for (ClientCallInterceptor interceptor : interceptors) {
Expand Down Expand Up @@ -383,7 +381,7 @@ private A2AHttpClient.PostBuilder createPostBuilder(String url, PayloadAndHeader
return postBuilder;
}

private Map<String, String> getHttpHeaders(ClientCallContext context) {
return context != null ? context.getHeaders() : null;
private Map<String, String> getHttpHeaders(@Nullable ClientCallContext context) {
return context != null ? context.getHeaders() : Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import io.a2a.client.http.A2AHttpClient;
import io.a2a.client.transport.spi.ClientTransportConfig;
import org.jspecify.annotations.Nullable;

public class RestTransportConfig extends ClientTransportConfig<RestTransport> {

private final A2AHttpClient httpClient;
private final @Nullable A2AHttpClient httpClient;

public RestTransportConfig() {
this.httpClient = null;
Expand All @@ -15,7 +16,7 @@ public RestTransportConfig(A2AHttpClient httpClient) {
this.httpClient = httpClient;
}

public A2AHttpClient getHttpClient() {
public @Nullable A2AHttpClient getHttpClient() {
return httpClient;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import io.a2a.client.http.A2AHttpClient;
import io.a2a.client.http.JdkA2AHttpClient;
import io.a2a.client.transport.spi.ClientTransportConfigBuilder;
import org.jspecify.annotations.Nullable;

public class RestTransportConfigBuilder extends ClientTransportConfigBuilder<RestTransportConfig, RestTransportConfigBuilder> {

private A2AHttpClient httpClient;
private @Nullable A2AHttpClient httpClient;

public RestTransportConfigBuilder httpClient(A2AHttpClient httpClient) {
this.httpClient = httpClient;

return this;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@NullMarked
package io.a2a.client.transport.rest;

import org.jspecify.annotations.NullMarked;

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.a2a.grpc.StreamResponse;
import io.a2a.grpc.utils.ProtoUtils;
import io.a2a.spec.StreamingEventKind;
import org.jspecify.annotations.Nullable;

public class RestSSEEventListener {

Expand All @@ -28,43 +29,42 @@ public RestSSEEventListener(Consumer<StreamingEventKind> eventHandler,
this.errorHandler = errorHandler;
}

public void onMessage(String message, Future<Void> completableFuture) {
public void onMessage(String message, @Nullable Future<Void> completableFuture) {
try {
System.out.println("Streaming message received: " + message);
log.fine("Streaming message received: " + message);
io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder();
JsonFormat.parser().merge(message, builder);
handleMessage(builder.build(), completableFuture);
handleMessage(builder.build());
} catch (InvalidProtocolBufferException e) {
errorHandler.accept(RestErrorMapper.mapRestError(message, 500));
}
}

public void onError(Throwable throwable, Future<Void> future) {
public void onError(Throwable throwable, @Nullable Future<Void> future) {
if (errorHandler != null) {
errorHandler.accept(throwable);
}
future.cancel(true); // close SSE channel
if (future != null) {
future.cancel(true); // close SSE channel
}
}

private void handleMessage(StreamResponse response, Future<Void> future) {
private void handleMessage(StreamResponse response) {
StreamingEventKind event;
switch (response.getPayloadCase()) {
case MSG:
case MSG ->
event = ProtoUtils.FromProto.message(response.getMsg());
break;
case TASK:
case TASK ->
event = ProtoUtils.FromProto.task(response.getTask());
break;
case STATUS_UPDATE:
case STATUS_UPDATE ->
event = ProtoUtils.FromProto.taskStatusUpdateEvent(response.getStatusUpdate());
break;
case ARTIFACT_UPDATE:
case ARTIFACT_UPDATE ->
event = ProtoUtils.FromProto.taskArtifactUpdateEvent(response.getArtifactUpdate());
break;
default:
default -> {
log.warning("Invalid stream response " + response.getPayloadCase());
errorHandler.accept(new IllegalStateException("Invalid stream response from server: " + response.getPayloadCase()));
return;
}
}
eventHandler.accept(event);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@NullMarked
package io.a2a.client.transport.rest.sse;

import org.jspecify.annotations.NullMarked;

Loading