|
2 | 2 |
|
3 | 3 | import com.coze.loop.exception.AuthException; |
4 | 4 | import com.coze.loop.exception.ErrorCode; |
| 5 | +import com.coze.loop.http.HttpClient; |
| 6 | +import com.coze.loop.internal.JsonUtils; |
5 | 7 | import com.coze.loop.internal.ValidationUtils; |
| 8 | +import com.fasterxml.jackson.annotation.JsonProperty; |
6 | 9 | import io.jsonwebtoken.Jwts; |
7 | 10 | import io.jsonwebtoken.SignatureAlgorithm; |
| 11 | +import okhttp3.MediaType; |
| 12 | +import okhttp3.Request; |
| 13 | +import okhttp3.RequestBody; |
8 | 14 | import org.slf4j.Logger; |
9 | 15 | import org.slf4j.LoggerFactory; |
10 | 16 |
|
|
13 | 19 | import java.security.spec.PKCS8EncodedKeySpec; |
14 | 20 | import java.util.Base64; |
15 | 21 | import java.util.Date; |
| 22 | +import java.util.HashMap; |
| 23 | +import java.util.Map; |
| 24 | +import java.util.UUID; |
16 | 25 | import java.util.concurrent.locks.ReentrantReadWriteLock; |
17 | 26 |
|
18 | 27 | /** |
|
22 | 31 | public class JWTOAuthAuth implements Auth { |
23 | 32 | private static final Logger logger = LoggerFactory.getLogger(JWTOAuthAuth.class); |
24 | 33 | private static final String AUTH_TYPE = "Bearer"; |
25 | | - private static final long TOKEN_EXPIRY_MINUTES = 55; // 55 minutes, JWT typically expires in 1 hour |
26 | | - private static final long REFRESH_BUFFER_MINUTES = 5; // Refresh 5 minutes before expiry |
| 34 | + private static final String GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer"; |
| 35 | + private static final String TOKEN_PATH = "/api/permission/oauth2/token"; |
| 36 | + private static final long REFRESH_BUFFER_MS = 5 * 60 * 1000; // 5 minutes |
| 37 | + private static final long DEFAULT_JWT_TTL_MS = 15 * 60 * 1000; // 15 minutes |
27 | 38 |
|
28 | 39 | private final String clientId; |
29 | 40 | private final PrivateKey privateKey; |
30 | 41 | private final String publicKeyId; |
| 42 | + private String baseUrl; |
| 43 | + private HttpClient httpClient; |
31 | 44 |
|
32 | 45 | private volatile String currentToken; |
33 | 46 | private volatile long tokenExpiryTime; |
34 | 47 |
|
35 | 48 | private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); |
36 | 49 |
|
37 | 50 | /** |
38 | | - * Create a JWTOAuthAuth instance. |
| 51 | + * Create a JWTOAuthAuth instance with default base URL. |
39 | 52 | * |
40 | 53 | * @param clientId the client ID |
41 | 54 | * @param privateKeyPem the private key in PEM format (base64 encoded PKCS8) |
42 | 55 | * @param publicKeyId the public key ID |
43 | 56 | */ |
44 | 57 | public JWTOAuthAuth(String clientId, String privateKeyPem, String publicKeyId) { |
| 58 | + this(clientId, privateKeyPem, publicKeyId, "https://api.coze.cn"); |
| 59 | + } |
| 60 | + |
| 61 | + /** |
| 62 | + * Create a JWTOAuthAuth instance with custom base URL. |
| 63 | + * |
| 64 | + * @param clientId the client ID |
| 65 | + * @param privateKeyPem the private key in PEM format (base64 encoded PKCS8) |
| 66 | + * @param publicKeyId the public key ID |
| 67 | + * @param baseUrl the base URL for token refresh |
| 68 | + */ |
| 69 | + public JWTOAuthAuth(String clientId, String privateKeyPem, String publicKeyId, String baseUrl) { |
45 | 70 | ValidationUtils.requireNonEmpty(clientId, "clientId"); |
46 | 71 | ValidationUtils.requireNonEmpty(privateKeyPem, "privateKeyPem"); |
47 | 72 | ValidationUtils.requireNonEmpty(publicKeyId, "publicKeyId"); |
48 | 73 |
|
49 | 74 | this.clientId = clientId; |
50 | 75 | this.publicKeyId = publicKeyId; |
51 | 76 | this.privateKey = parsePrivateKey(privateKeyPem); |
52 | | - |
53 | | - // Generate initial token |
54 | | - refreshToken(); |
| 77 | + this.baseUrl = baseUrl != null ? baseUrl : "https://api.coze.cn"; |
| 78 | + this.httpClient = new HttpClient(null); |
| 79 | + } |
| 80 | + |
| 81 | + /** |
| 82 | + * Set the base URL for token refresh. |
| 83 | + * |
| 84 | + * @param baseUrl the base URL |
| 85 | + */ |
| 86 | + public void setBaseUrl(String baseUrl) { |
| 87 | + this.baseUrl = baseUrl; |
| 88 | + } |
| 89 | + |
| 90 | + /** |
| 91 | + * Set the HttpClient for token refresh. |
| 92 | + * |
| 93 | + * @param httpClient the HttpClient |
| 94 | + */ |
| 95 | + public void setHttpClient(HttpClient httpClient) { |
| 96 | + this.httpClient = httpClient; |
55 | 97 | } |
56 | 98 |
|
57 | 99 | @Override |
@@ -91,35 +133,65 @@ public String getType() { |
91 | 133 | * Check if token should be refreshed. |
92 | 134 | */ |
93 | 135 | private boolean shouldRefreshToken(long currentTime) { |
94 | | - return currentTime >= (tokenExpiryTime - REFRESH_BUFFER_MINUTES * 60 * 1000); |
| 136 | + return currentToken == null || currentTime >= (tokenExpiryTime - REFRESH_BUFFER_MS); |
95 | 137 | } |
96 | 138 |
|
97 | 139 | /** |
98 | | - * Refresh the JWT token. |
| 140 | + * Refresh the OAuth token by exchanging a JWT for an access token. |
99 | 141 | * Must be called with write lock held. |
100 | 142 | */ |
101 | 143 | private void refreshToken() { |
102 | 144 | try { |
103 | 145 | long currentTime = System.currentTimeMillis(); |
104 | | - long expiryTime = currentTime + TOKEN_EXPIRY_MINUTES * 60 * 1000; |
| 146 | + String host = new java.net.URL(baseUrl).getHost(); |
105 | 147 |
|
| 148 | + // 1. Generate JWT |
106 | 149 | String jwt = Jwts.builder() |
107 | 150 | .setIssuer(clientId) |
108 | | - .setAudience("coze") |
| 151 | + .setAudience(host) |
109 | 152 | .setIssuedAt(new Date(currentTime)) |
110 | | - .setExpiration(new Date(expiryTime)) |
| 153 | + .setExpiration(new Date(currentTime + DEFAULT_JWT_TTL_MS)) |
111 | 154 | .setHeaderParam("kid", publicKeyId) |
| 155 | + .setHeaderParam("typ", "JWT") |
| 156 | + .setHeaderParam("alg", "RS256") |
| 157 | + .setId(UUID.randomUUID().toString()) |
112 | 158 | .signWith(privateKey, SignatureAlgorithm.RS256) |
113 | 159 | .compact(); |
114 | 160 |
|
115 | | - this.currentToken = jwt; |
116 | | - this.tokenExpiryTime = expiryTime; |
| 161 | + // 2. Exchange JWT for Access Token using HttpClient |
| 162 | + Map<String, Object> body = new HashMap<>(); |
| 163 | + body.put("client_id", clientId); |
| 164 | + body.put("grant_type", GRANT_TYPE_JWT); |
| 165 | + |
| 166 | + Map<String, String> headers = new HashMap<>(); |
| 167 | + headers.put("Authorization", "Bearer " + jwt); |
| 168 | + |
| 169 | + String responseJson = httpClient.post(baseUrl + TOKEN_PATH, body, headers); |
| 170 | + TokenResponse resp = JsonUtils.fromJson(responseJson, TokenResponse.class); |
117 | 171 |
|
118 | | - logger.debug("JWT token refreshed, expires at: {}", new Date(expiryTime)); |
| 172 | + if (resp == null || resp.accessToken == null) { |
| 173 | + throw new AuthException(ErrorCode.AUTH_FAILED, "Invalid response from token endpoint"); |
| 174 | + } |
| 175 | + |
| 176 | + this.currentToken = resp.accessToken; |
| 177 | + this.tokenExpiryTime = System.currentTimeMillis() + resp.expiresIn * 1000; |
| 178 | + |
| 179 | + logger.debug("OAuth token refreshed, expires at: {}", new Date(this.tokenExpiryTime)); |
119 | 180 | } catch (Exception e) { |
120 | | - throw new AuthException(ErrorCode.AUTH_FAILED, "Failed to generate JWT token", e); |
| 181 | + throw new AuthException(ErrorCode.AUTH_FAILED, "Failed to refresh OAuth token", e); |
121 | 182 | } |
122 | 183 | } |
| 184 | + |
| 185 | + private static class TokenResponse { |
| 186 | + @JsonProperty("access_token") |
| 187 | + private String accessToken; |
| 188 | + |
| 189 | + @JsonProperty("expires_in") |
| 190 | + private long expiresIn; |
| 191 | + |
| 192 | + @JsonProperty("refresh_token") |
| 193 | + private String refreshToken; |
| 194 | + } |
123 | 195 |
|
124 | 196 | /** |
125 | 197 | * Parse private key from PEM format. |
|
0 commit comments