Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 113 additions & 20 deletions gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

// todo decide if thread-safe or not, once responses are stored
import static java.util.Collections.emptyList;


/**
* Entry point for all interactions with Gemini API.
* Note that some methods store state (e.g. {@link #generateContent(GenerativeModel)} or ${@link #generateContentStream(GenerativeModel)}).
* Call the {@link #close()} method to clean up the state.
* This class is thread safe.
*/
public class GenAi {
public class GenAi implements AutoCloseable {

private static final String STREAM_LINE_PREFIX = "data: ";
private static final int STREAM_LINE_PREFIX_LENGTH = STREAM_LINE_PREFIX.length();
Expand All @@ -27,6 +35,7 @@ public class GenAi {
private final String apiKey;
private final HttpClient client;
private final JsonParser jsonParser;
private final Map<UUID, GenerateContentResponse> responseById = new ConcurrentHashMap<>();

public GenAi(
String apiKey,
Expand Down Expand Up @@ -95,19 +104,47 @@ public Model getModel(String model) {
});
}

/**
* Get the usage metadata of a {@link GeneratedContent#id()}.
*
* @param id of the corresponding {@link GeneratedContent}
* @return the corresponding metadata, or an empty optional
*/
public Optional<UsageMetadata> usageMetadata(UUID id) {
return Optional.ofNullable(responseById.get(id))
.map(GenerateContentResponse::usageMetadata);
}

/**
* Get the safety ratings of a {@link GeneratedContent#id()}.
*
* @param id of the corresponding {@link GeneratedContent}
* @return the corresponding safety ratings, or an empty optional
*/
public List<SafetyRating> safetyRatings(UUID id) {
GenerateContentResponse response = responseById.get(id);
if (response == null) {
return emptyList();
}
return response.candidates().stream()
.flatMap(candidate -> candidate.safetyRatings().stream())
.toList();
}

/**
* Generates a response from Gemini API based on the given {@code model}. The response is streamed in chunks of text. The
* stream items are delivered as they arrive.
* Once the call has been completed, metadata and safety ratings can be obtained by calling
* {@link #usageMetadata(UUID)} or {@link #safetyRatings(UUID)}. If those methods are called while the stream is still
* active, the last available statistics are returned.
*
* @param model with the necessary information for Gemini API to generate content
* @return A live stream of the response, as it arrives
* @see #generateContent(GenerativeModel) which returns the whole response at once (asynchronously)
*/
public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
// todo, keep responses in the state.
// add up the usageMetadata
// store the safety ratings
return execute(() -> {
UUID uuid = UUID.randomUUID();
HttpRequest request = HttpRequest.newBuilder()
.POST(HttpRequest.BodyPublishers.ofString(
jsonParser.toJson(convert(model))
Expand All @@ -124,7 +161,9 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
.map(line -> {
try {
var gcr = jsonParser.fromJson(line.substring(STREAM_LINE_PREFIX_LENGTH), GenerateContentResponse.class);
return new GeneratedContent(gcr.candidates().get(0).content().parts().get(0).text());
// each element can just replace the previous one
this.responseById.put(uuid, gcr);
return new GeneratedContent(uuid, gcr.candidates().get(0).content().parts().get(0).text());
} catch (Exception e) {
throw new RuntimeException("Unexpected line:\n" + line, e);
}
Expand All @@ -134,6 +173,8 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {

/**
* Generates a response from Gemini API based on the given {@code model}.
* Once the call has been completed, metadata and safety ratings can be obtained by calling
* {@link #usageMetadata(UUID)} or {@link #safetyRatings(UUID)}
*
* @param model with the necessary information for Gemini API to generate content
* @return a {@link CompletableFuture} which completes once the response from Gemini API has arrived. The {@link CompletableFuture}
Expand All @@ -142,6 +183,7 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
*/
public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model) {
return execute(() -> {
UUID uuid = UUID.randomUUID();
CompletableFuture<HttpResponse<String>> response = client.sendAsync(
HttpRequest.newBuilder()
.POST(HttpRequest.BodyPublishers.ofString(
Expand All @@ -156,7 +198,8 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
.thenApply(body -> {
try {
var gcr = jsonParser.fromJson(body, GenerateContentResponse.class);
return new GeneratedContent(gcr.candidates().get(0).content().parts().get(0).text());
responseById.put(uuid, gcr);
return new GeneratedContent(uuid, gcr.candidates().get(0).content().parts().get(0).text());
} catch (Exception e) {
throw new RuntimeException("Unexpected body:\n" + body, e);
}
Expand Down Expand Up @@ -230,14 +273,77 @@ private <T> T execute(ThrowingSupplier<T> supplier) {
}
}

/**
* Clears the internal state.
*/
@Override
public void close() {
responseById.clear();
}

/**
* Content generated by Gemini API.
*
* @param id the id of the request, for subsequent queries regarding metadata of the query
*/
public record GeneratedContent(
UUID id,
String text
) {
}

/**
* Usage metadata for a given request.
*
* @param promptTokenCount Number of tokens in the prompt.
* @param candidatesTokenCount Total number of tokens for the generated response.
* @param totalTokenCount Total token count for the generation request (prompt + candidates).
*/
public record UsageMetadata(
int promptTokenCount,
int candidatesTokenCount,
int totalTokenCount
) {
}

/**
* Safety rating for a given response.
*
* @param category The category for this rating. see {@link swiss.ameri.gemini.api.SafetySetting.HarmCategory}
* @param probability The probability of harm for this content. see {@link swiss.ameri.gemini.api.SafetySetting.HarmProbability}
*/
public record SafetyRating(
String category,
String probability
) {

/**
* Convert the safety rating to a typed safety rating.
* Might crash if Gemini API changes, and an enum value is missing.
*
* @return the TypedSafetyRating
*/
public TypedSafetyRating toTypedSafetyRating() {
return new TypedSafetyRating(
SafetySetting.HarmCategory.valueOf(category()),
SafetySetting.HarmProbability.valueOf(probability())
);
}

/**
* Typed values. This is done separately, since enum values might be missing compared to Gemini API
*
* @param harmCategory of this rating
* @param probability of this rating
*/
public record TypedSafetyRating(
SafetySetting.HarmCategory harmCategory,
SafetySetting.HarmProbability probability
) {
}

}

private record GenerateContentResponse(
UsageMetadata usageMetadata,
List<ResponseCandidate> candidates
Expand All @@ -252,19 +358,6 @@ private record ResponseCandidate(
) {
}

private record SafetyRating(
String category,
String probability
) {
}

private record UsageMetadata(
int promptTokenCount,
int candidatesTokenCount,
int totalTokenCount
) {
}

private record GenerateContentRequest(
List<GenerationContent> contents,
List<SafetySetting> safetySettings,
Expand Down
33 changes: 33 additions & 0 deletions gemini-api/src/main/java/swiss/ameri/gemini/api/SafetySetting.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,37 @@ public enum HarmBlockThreshold {
BLOCK_NONE
}

/**
* The probability that a piece of content is harmful.
* The classification system gives the probability of the content being unsafe.
* This does not indicate the severity of harm for a piece of content.
*/
public enum HarmProbability {

/**
* Probability is unspecified.
*/
HARM_PROBABILITY_UNSPECIFIED,

/**
* Content has a negligible chance of being unsafe.
*/
NEGLIGIBLE,

/**
* Content has a low chance of being unsafe.
*/
LOW,

/**
* Content has a medium chance of being unsafe.
*/
MEDIUM,

/**
* Content has a high chance of being unsafe.
*/
HIGH
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,17 @@ public static void main(String[] args) throws Exception {
JsonParser parser = new GsonJsonParser();
String apiKey = args[0];

GenAi genAi = new GenAi(
apiKey,
parser
);
try (var genAi = new GenAi(apiKey, parser)) {
// each method represents an example usage
listModels(genAi);
getModel(genAi);
generateContent(genAi);
generateContentStream(genAi);
multiChatTurn(genAi);
textAndImage(genAi);
}


// each method represents an example usage
listModels(genAi);
getModel(genAi);
generateContent(genAi);
generateContentStream(genAi);
multiChatTurn(genAi);
textAndImage(genAi);
}

private static void multiChatTurn(GenAi genAi) {
Expand All @@ -66,17 +65,27 @@ private static void multiChatTurn(GenAi genAi) {
}

private static void generateContentStream(GenAi genAi) {
System.out.println("----- Generate content (streaming)");
System.out.println("----- Generate content (streaming) -- with usage meta data");
var model = createStoryModel();
genAi.generateContentStream(model)
.forEach(System.out::println);
.forEach(x -> {
System.out.println(x);
// note that the usage metadata is updated as it arrives
System.out.println(genAi.usageMetadata(x.id()));
System.out.println(genAi.safetyRatings(x.id()));
});
}

private static void generateContent(GenAi genAi) throws InterruptedException, ExecutionException, TimeoutException {
var model = createStoryModel();
System.out.println("----- Generate content (blocking)");
genAi.generateContent(model)
.thenAccept(System.out::println)
.thenAccept(gcr -> {
System.out.println(gcr);
System.out.println("----- Generate content (blocking) usage meta data & safety ratings");
System.out.println(genAi.usageMetadata(gcr.id()));
System.out.println(genAi.safetyRatings(gcr.id()).stream().map(GenAi.SafetyRating::toTypedSafetyRating).toList());
})
.get(20, TimeUnit.SECONDS);
}

Expand Down