Skip to content

Commit 3e0d68d

Browse files
authored
Merge pull request auth0#167 from auth0/add-keyprovider-kid
Refactor KeyProvider to receive the "Key Id"
2 parents 5c9bcc5 + 11e96ad commit 3e0d68d

25 files changed

+1023
-362
lines changed

README.md

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,65 @@ The library implements JWT Verification and Signing using the following algorith
4646

4747
## Usage
4848

49+
### Pick the Algorithm
50+
51+
The Algorithm defines how a token is signed and verified. It can be instantiated with the raw value of the secret in the case of HMAC algorithms, or the key pairs or `KeyProvider` in the case of RSA and ECDSA algorithms. Once created, the instance is reusable for token signing and verification operations.
52+
53+
#### Using static secrets or keys:
54+
55+
```java
56+
//HMAC
57+
Algorithm algorithmHS = Algorithm.HMAC256("secret");
58+
59+
//RSA
60+
RSAPublicKey publicKey = //Get the key instance
61+
RSAPrivateKey privateKey = //Get the key instance
62+
Algorithm algorithmRS = Algorithm.RSA256(publicKey, privateKey);
63+
```
64+
65+
#### Using a KeyProvider:
66+
67+
By using a `KeyProvider` you can change in runtime the key used either to verify the token signature or to sign a new token for RSA or ECDSA algorithms. This is achieved by implementing either `RSAKeyProvider` or `ECDSAKeyProvider` methods:
68+
69+
- `getPublicKeyById(String kid)`: Its called during token signature verification and it should return the key used to verify the token. If key rotation is being used, e.g. [JWK](https://tools.ietf.org/html/rfc7517) it can fetch the correct rotation key using the id. (Or just return the same key all the time).
70+
- `getPrivateKey()`: Its called during token signing and it should return the key that will be used to sign the JWT.
71+
- `getPrivateKeyId()`: Its called during token signing and it should return the id of the key that identifies the one returned by `getPrivateKey()`. This value is preferred over the one set in the `JWTCreator.Builder#withKeyId(String)` method. If you don't need to set a `kid` value avoid instantiating an Algorithm using a `KeyProvider`.
72+
73+
74+
The following snippet uses example classes showing how this would work:
75+
76+
77+
```java
78+
final JwkStore jwkStore = new JwkStore("{JWKS_FILE_HOST}");
79+
final RSAPrivateKey privateKey = //Get the key instance
80+
final String privateKeyId = //Create an Id for the above key
81+
82+
RSAKeyProvider keyProvider = new RSAKeyProvider() {
83+
@Override
84+
public RSAPublicKey getPublicKeyById(String kid) {
85+
//Received 'kid' value might be null if it wasn't defined in the Token's header
86+
RSAPublicKey publicKey = jwkStore.get(kid);
87+
return (RSAPublicKey) publicKey;
88+
}
89+
90+
@Override
91+
public RSAPrivateKey getPrivateKey() {
92+
return privateKey;
93+
}
94+
95+
@Override
96+
public String getPrivateKeyId() {
97+
return privateKeyId;
98+
}
99+
};
100+
101+
Algorithm algorithm = Algorithm.RSA256(keyProvider);
102+
//Use the Algorithm to create and verify JWTs.
103+
```
104+
105+
> For simple key rotation using JWKs try the [jwks-rsa-java](https://github.com/auth0/jwks-rsa-java) library.
106+
107+
49108
### Create and Sign a Token
50109

51110
You'll first need to create a `JWTCreator` instance by calling `JWT.create()`. Use the builder to define the custom Claims your token needs to have. Finally to get the String token call `sign()` and pass the `Algorithm` instance.
@@ -220,7 +279,7 @@ When creating a Token with the `JWT.create()` you can specify header Claims by c
220279

221280
```java
222281
Map<String, Object> headerClaims = new HashMap();
223-
headerclaims.put("owner", "auth0");
282+
headerClaims.put("owner", "auth0");
224283
String token = JWT.create()
225284
.withHeader(headerClaims)
226285
.sign(algorithm);

lib/src/main/java/com/auth0/jwt/JWT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import com.auth0.jwt.interfaces.Verification;
77

88
@SuppressWarnings("WeakerAccess")
9-
public abstract class JWT implements DecodedJWT {
9+
public abstract class JWT {
1010

1111
/**
1212
* Decode a given Json Web Token.

lib/src/main/java/com/auth0/jwt/JWTCreator.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public Builder withHeader(Map<String, Object> headerClaims) {
7777

7878
/**
7979
* Add a specific Key Id ("kid") claim to the Header.
80+
* If the {@link Algorithm} used to sign this token was instantiated with a KeyProvider, the 'kid' value will be taken from that provider and this one will be ignored.
8081
*
8182
* @param keyId the Key Id value.
8283
* @return this same Builder instance.
@@ -303,6 +304,10 @@ public String sign(Algorithm algorithm) throws IllegalArgumentException, JWTCrea
303304
}
304305
headerClaims.put(PublicClaims.ALGORITHM, algorithm.getName());
305306
headerClaims.put(PublicClaims.TYPE, "JWT");
307+
String signingKeyId = algorithm.getSigningKeyId();
308+
if (signingKeyId != null) {
309+
withKeyId(signingKeyId);
310+
}
306311
return new JWTCreator(algorithm, headerClaims, payloadClaims).sign();
307312
}
308313

@@ -322,8 +327,8 @@ private void addClaim(String name, Object value) {
322327
}
323328

324329
private String sign() throws SignatureGenerationException {
325-
String header = Base64.encodeBase64URLSafeString((headerJson.getBytes(StandardCharsets.UTF_8)));
326-
String payload = Base64.encodeBase64URLSafeString((payloadJson.getBytes(StandardCharsets.UTF_8)));
330+
String header = Base64.encodeBase64URLSafeString(headerJson.getBytes(StandardCharsets.UTF_8));
331+
String payload = Base64.encodeBase64URLSafeString(payloadJson.getBytes(StandardCharsets.UTF_8));
327332
String content = String.format("%s.%s", header, payload);
328333

329334
byte[] signatureBytes = algorithm.sign(content.getBytes(StandardCharsets.UTF_8));

lib/src/main/java/com/auth0/jwt/JWTDecoder.java

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.auth0.jwt.exceptions.JWTDecodeException;
44
import com.auth0.jwt.impl.JWTParser;
55
import com.auth0.jwt.interfaces.Claim;
6+
import com.auth0.jwt.interfaces.DecodedJWT;
67
import com.auth0.jwt.interfaces.Header;
78
import com.auth0.jwt.interfaces.Payload;
89
import org.apache.commons.codec.binary.Base64;
@@ -16,20 +17,14 @@
1617
* The JWTDecoder class holds the decode method to parse a given JWT token into it's JWT representation.
1718
*/
1819
@SuppressWarnings("WeakerAccess")
19-
final class JWTDecoder extends JWT {
20+
final class JWTDecoder implements DecodedJWT {
2021

21-
private final String token;
22-
private Header header;
23-
private Payload payload;
24-
private String signature;
22+
private final String[] parts;
23+
private final Header header;
24+
private final Payload payload;
2525

2626
JWTDecoder(String jwt) throws JWTDecodeException {
27-
this.token = jwt;
28-
parseToken(jwt);
29-
}
30-
31-
private void parseToken(String token) throws JWTDecodeException {
32-
final String[] parts = TokenUtils.splitToken(token);
27+
parts = TokenUtils.splitToken(jwt);
3328
final JWTParser converter = new JWTParser();
3429
String headerJson;
3530
String payloadJson;
@@ -41,7 +36,6 @@ private void parseToken(String token) throws JWTDecodeException {
4136
}
4237
header = converter.parseHeader(headerJson);
4338
payload = converter.parsePayload(payloadJson);
44-
signature = parts[2];
4539
}
4640

4741
@Override
@@ -114,13 +108,23 @@ public Map<String, Claim> getClaims() {
114108
return payload.getClaims();
115109
}
116110

111+
@Override
112+
public String getHeader() {
113+
return parts[0];
114+
}
115+
116+
@Override
117+
public String getPayload() {
118+
return parts[1];
119+
}
120+
117121
@Override
118122
public String getSignature() {
119-
return signature;
123+
return parts[2];
120124
}
121125

122126
@Override
123127
public String getToken() {
124-
return token;
128+
return String.format("%s.%s.%s", parts[0], parts[1], parts[2]);
125129
}
126130
}

lib/src/main/java/com/auth0/jwt/JWTVerifier.java

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
package com.auth0.jwt;
22

33
import com.auth0.jwt.algorithms.Algorithm;
4-
import com.auth0.jwt.exceptions.AlgorithmMismatchException;
5-
import com.auth0.jwt.exceptions.InvalidClaimException;
6-
import com.auth0.jwt.exceptions.JWTVerificationException;
7-
import com.auth0.jwt.exceptions.SignatureVerificationException;
8-
import com.auth0.jwt.exceptions.TokenExpiredException;
4+
import com.auth0.jwt.exceptions.*;
95
import com.auth0.jwt.impl.PublicClaims;
106
import com.auth0.jwt.interfaces.Claim;
117
import com.auth0.jwt.interfaces.Clock;
128
import com.auth0.jwt.interfaces.DecodedJWT;
139
import com.auth0.jwt.interfaces.Verification;
14-
import org.apache.commons.codec.binary.Base64;
1510

16-
import java.nio.charset.StandardCharsets;
1711
import java.util.*;
1812

1913
/**
@@ -349,29 +343,26 @@ private void requireClaim(String name, Object value) {
349343
*
350344
* @param token to verify.
351345
* @return a verified and decoded JWT.
352-
* @throws JWTVerificationException if any of the required contents inside the JWT is invalid.
346+
* @throws AlgorithmMismatchException if the algorithm stated in the token's header it's not equal to the one defined in the {@link JWTVerifier}.
347+
* @throws SignatureVerificationException if the signature is invalid.
348+
* @throws TokenExpiredException if the token has expired.
349+
* @throws InvalidClaimException if a claim contained a different value than the expected one.
353350
*/
354351
public DecodedJWT verify(String token) throws JWTVerificationException {
355-
DecodedJWT jwt = JWTDecoder.decode(token);
352+
DecodedJWT jwt = JWT.decode(token);
356353
verifyAlgorithm(jwt, algorithm);
357-
verifySignature(TokenUtils.splitToken(token));
354+
algorithm.verify(jwt);
358355
verifyClaims(jwt, claims);
359356
return jwt;
360357
}
361358

362-
private void verifySignature(String[] parts) throws SignatureVerificationException {
363-
byte[] content = String.format("%s.%s", parts[0], parts[1]).getBytes(StandardCharsets.UTF_8);
364-
byte[] signature = Base64.decodeBase64(parts[2]);
365-
algorithm.verify(content, signature);
366-
}
367-
368359
private void verifyAlgorithm(DecodedJWT jwt, Algorithm expectedAlgorithm) throws AlgorithmMismatchException {
369360
if (!expectedAlgorithm.getName().equals(jwt.getAlgorithm())) {
370361
throw new AlgorithmMismatchException("The provided Algorithm doesn't match the one defined in the JWT's Header.");
371362
}
372363
}
373364

374-
private void verifyClaims(DecodedJWT jwt, Map<String, Object> claims) {
365+
private void verifyClaims(DecodedJWT jwt, Map<String, Object> claims) throws TokenExpiredException, InvalidClaimException {
375366
for (Map.Entry<String, Object> entry : claims.entrySet()) {
376367
switch (entry.getKey()) {
377368
case PublicClaims.AUDIENCE:
@@ -435,31 +426,28 @@ private void assertValidStringClaim(String claimName, String value, String expec
435426
}
436427

437428
private void assertValidDateClaim(Date date, long leeway, boolean shouldBeFuture) {
438-
Date today = clock.getToday();
439-
today.setTime((long) Math.floor((today.getTime() / 1000) * 1000)); // truncate
440-
// millis
441-
if (shouldBeFuture) {
442-
assertDateIsFuture(date, leeway, today);
443-
} else {
444-
assertDateIsPast(date, leeway, today);
445-
}
446-
}
447-
448-
private void assertDateIsFuture(Date date, long leeway, Date today) {
449-
450-
today.setTime(today.getTime() - leeway * 1000);
451-
if (date != null && today.after(date)) {
452-
throw new TokenExpiredException(String.format("The Token has expired on %s.", date));
453-
}
454-
}
455-
456-
private void assertDateIsPast(Date date, long leeway, Date today) {
457-
today.setTime(today.getTime() + leeway * 1000);
458-
if(date!=null && today.before(date)) {
459-
throw new InvalidClaimException(String.format("The Token can't be used before %s.", date));
460-
}
461-
462-
}
429+
Date today = clock.getToday();
430+
today.setTime((long) Math.floor((today.getTime() / 1000) * 1000)); // truncate millis
431+
if (shouldBeFuture) {
432+
assertDateIsFuture(date, leeway, today);
433+
} else {
434+
assertDateIsPast(date, leeway, today);
435+
}
436+
}
437+
438+
private void assertDateIsFuture(Date date, long leeway, Date today) {
439+
today.setTime(today.getTime() - leeway * 1000);
440+
if (date != null && today.after(date)) {
441+
throw new TokenExpiredException(String.format("The Token has expired on %s.", date));
442+
}
443+
}
444+
445+
private void assertDateIsPast(Date date, long leeway, Date today) {
446+
today.setTime(today.getTime() + leeway * 1000);
447+
if (date != null && today.before(date)) {
448+
throw new InvalidClaimException(String.format("The Token can't be used before %s.", date));
449+
}
450+
}
463451

464452
private void assertValidAudienceClaim(List<String> audience, List<String> value) {
465453
if (audience == null || !audience.containsAll(value)) {

lib/src/main/java/com/auth0/jwt/algorithms/Algorithm.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import com.auth0.jwt.exceptions.SignatureGenerationException;
44
import com.auth0.jwt.exceptions.SignatureVerificationException;
5-
import com.auth0.jwt.interfaces.ECKeyProvider;
5+
import com.auth0.jwt.interfaces.DecodedJWT;
6+
import com.auth0.jwt.interfaces.ECDSAKeyProvider;
67
import com.auth0.jwt.interfaces.RSAKeyProvider;
78

89
import java.io.UnsupportedEncodingException;
@@ -207,7 +208,7 @@ public static Algorithm HMAC512(byte[] secret) throws IllegalArgumentException {
207208
* @return a valid ECDSA256 Algorithm.
208209
* @throws IllegalArgumentException if the Key Provider is null.
209210
*/
210-
public static Algorithm ECDSA256(ECKeyProvider keyProvider) throws IllegalArgumentException {
211+
public static Algorithm ECDSA256(ECDSAKeyProvider keyProvider) throws IllegalArgumentException {
211212
return new ECDSAAlgorithm("ES256", "SHA256withECDSA", 32, keyProvider);
212213
}
213214

@@ -229,7 +230,7 @@ public static Algorithm ECDSA256(ECPublicKey publicKey, ECPrivateKey privateKey)
229230
* @param key the key to use in the verify or signing instance.
230231
* @return a valid ECDSA256 Algorithm.
231232
* @throws IllegalArgumentException if the provided Key is null.
232-
* @deprecated use {@link #ECDSA256(ECPublicKey, ECPrivateKey)} or {@link #ECDSA256(ECKeyProvider)}
233+
* @deprecated use {@link #ECDSA256(ECPublicKey, ECPrivateKey)} or {@link #ECDSA256(ECDSAKeyProvider)}
233234
*/
234235
@Deprecated
235236
public static Algorithm ECDSA256(ECKey key) throws IllegalArgumentException {
@@ -245,7 +246,7 @@ public static Algorithm ECDSA256(ECKey key) throws IllegalArgumentException {
245246
* @return a valid ECDSA384 Algorithm.
246247
* @throws IllegalArgumentException if the Key Provider is null.
247248
*/
248-
public static Algorithm ECDSA384(ECKeyProvider keyProvider) throws IllegalArgumentException {
249+
public static Algorithm ECDSA384(ECDSAKeyProvider keyProvider) throws IllegalArgumentException {
249250
return new ECDSAAlgorithm("ES384", "SHA384withECDSA", 48, keyProvider);
250251
}
251252

@@ -267,7 +268,7 @@ public static Algorithm ECDSA384(ECPublicKey publicKey, ECPrivateKey privateKey)
267268
* @param key the key to use in the verify or signing instance.
268269
* @return a valid ECDSA384 Algorithm.
269270
* @throws IllegalArgumentException if the provided Key is null.
270-
* @deprecated use {@link #ECDSA384(ECPublicKey, ECPrivateKey)} or {@link #ECDSA384(ECKeyProvider)}
271+
* @deprecated use {@link #ECDSA384(ECPublicKey, ECPrivateKey)} or {@link #ECDSA384(ECDSAKeyProvider)}
271272
*/
272273
@Deprecated
273274
public static Algorithm ECDSA384(ECKey key) throws IllegalArgumentException {
@@ -283,7 +284,7 @@ public static Algorithm ECDSA384(ECKey key) throws IllegalArgumentException {
283284
* @return a valid ECDSA512 Algorithm.
284285
* @throws IllegalArgumentException if the Key Provider is null.
285286
*/
286-
public static Algorithm ECDSA512(ECKeyProvider keyProvider) throws IllegalArgumentException {
287+
public static Algorithm ECDSA512(ECDSAKeyProvider keyProvider) throws IllegalArgumentException {
287288
return new ECDSAAlgorithm("ES512", "SHA512withECDSA", 66, keyProvider);
288289
}
289290

@@ -305,7 +306,7 @@ public static Algorithm ECDSA512(ECPublicKey publicKey, ECPrivateKey privateKey)
305306
* @param key the key to use in the verify or signing instance.
306307
* @return a valid ECDSA512 Algorithm.
307308
* @throws IllegalArgumentException if the provided Key is null.
308-
* @deprecated use {@link #ECDSA512(ECPublicKey, ECPrivateKey)} or {@link #ECDSA512(ECKeyProvider)}
309+
* @deprecated use {@link #ECDSA512(ECPublicKey, ECPrivateKey)} or {@link #ECDSA512(ECDSAKeyProvider)}
309310
*/
310311
@Deprecated
311312
public static Algorithm ECDSA512(ECKey key) throws IllegalArgumentException {
@@ -324,6 +325,15 @@ protected Algorithm(String name, String description) {
324325
this.description = description;
325326
}
326327

328+
/**
329+
* Getter for the Id of the Private Key used to sign the tokens. This is usually specified as the `kid` claim in the Header.
330+
*
331+
* @return the Key Id that identifies the Signing Key or null if it's not specified.
332+
*/
333+
public String getSigningKeyId() {
334+
return null;
335+
}
336+
327337
/**
328338
* Getter for the name of this Algorithm, as defined in the JWT Standard. i.e. "HS256"
329339
*
@@ -348,13 +358,12 @@ public String toString() {
348358
}
349359

350360
/**
351-
* Verify the given content using this Algorithm instance.
361+
* Verify the given token using this Algorithm instance.
352362
*
353-
* @param contentBytes an array of bytes representing the base64 encoded content to be verified against the signature.
354-
* @param signatureBytes an array of bytes representing the base64 encoded signature to compare the content against.
363+
* @param jwt the already decoded JWT that it's going to be verified.
355364
* @throws SignatureVerificationException if the Token's Signature is invalid, meaning that it doesn't match the signatureBytes, or if the Key is invalid.
356365
*/
357-
public abstract void verify(byte[] contentBytes, byte[] signatureBytes) throws SignatureVerificationException;
366+
public abstract void verify(DecodedJWT jwt) throws SignatureVerificationException;
358367

359368
/**
360369
* Sign the given content using this Algorithm instance.

0 commit comments

Comments
 (0)