Skip to content

Commit e18cc20

Browse files
authored
fix: unbind manual affinity keys after terminal calls (#232)
* fix: unbind manual affinity keys after terminal calls * run fomat
1 parent ad0b4ee commit e18cc20

3 files changed

Lines changed: 189 additions & 2 deletions

File tree

grpc-gcp/src/main/java/com/google/cloud/grpc/GcpClientCall.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io.grpc.MethodDescriptor;
2828
import io.grpc.Status;
2929
import java.util.ArrayDeque;
30+
import java.util.Collections;
3031
import java.util.List;
3132
import java.util.Queue;
3233
import java.util.concurrent.atomic.AtomicBoolean;
@@ -222,17 +223,24 @@ public void onMessage(RespT message) {
222223
*/
223224
public static class SimpleGcpClientCall<ReqT, RespT> extends ForwardingClientCall<ReqT, RespT> {
224225

226+
private final GcpManagedChannel delegateChannel;
225227
private final GcpManagedChannel.ChannelRef channelRef;
226228
private final ClientCall<ReqT, RespT> delegateCall;
229+
@Nullable private final String affinityKey;
230+
private final boolean unbindOnComplete;
227231
private long startNanos = 0;
228232

229233
private final AtomicBoolean decremented = new AtomicBoolean(false);
230234

231235
protected SimpleGcpClientCall(
236+
GcpManagedChannel delegateChannel,
232237
GcpManagedChannel.ChannelRef channelRef,
233238
MethodDescriptor<ReqT, RespT> methodDescriptor,
234239
CallOptions callOptions) {
240+
this.delegateChannel = delegateChannel;
235241
this.channelRef = channelRef;
242+
this.affinityKey = callOptions.getOption(GcpManagedChannel.AFFINITY_KEY);
243+
this.unbindOnComplete = callOptions.getOption(GcpManagedChannel.UNBIND_AFFINITY_KEY);
236244
// Set the actual channel ID in callOptions so downstream interceptors can access it.
237245
CallOptions callOptionsWithChannelId =
238246
callOptions.withOption(GcpManagedChannel.CHANNEL_ID_KEY, channelRef.getId());
@@ -257,6 +265,12 @@ public void onClose(Status status, Metadata trailers) {
257265
if (!decremented.getAndSet(true)) {
258266
channelRef.activeStreamsCountDecr(startNanos, status, false);
259267
}
268+
// Unbind the affinity key when the caller explicitly requests it
269+
// (e.g., on terminal RPCs like Commit or Rollback) to prevent
270+
// unbounded growth of the affinity map.
271+
if (unbindOnComplete && affinityKey != null) {
272+
delegateChannel.unbind(Collections.singletonList(affinityKey));
273+
}
260274
super.onClose(status, trailers);
261275
}
262276

@@ -276,6 +290,10 @@ public void cancel(String message, Throwable cause) {
276290
if (!decremented.getAndSet(true)) {
277291
channelRef.activeStreamsCountDecr(startNanos, Status.CANCELLED, true);
278292
}
293+
// Always unbind on cancel — the transaction is being abandoned.
294+
if (affinityKey != null) {
295+
delegateChannel.unbind(Collections.singletonList(affinityKey));
296+
}
279297
delegateCall.cancel(message, cause);
280298
}
281299
}

grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ public class GcpManagedChannel extends ManagedChannel {
9494
public static final Context.Key<String> AFFINITY_CTX_KEY = Context.key("AffinityKey");
9595
public static final CallOptions.Key<String> AFFINITY_KEY = CallOptions.Key.create("AffinityKey");
9696

97+
/** When set to true, the affinity key will be unbound after the call completes. */
98+
public static final CallOptions.Key<Boolean> UNBIND_AFFINITY_KEY =
99+
CallOptions.Key.createWithDefault("UnbindAffinityKey", false);
100+
97101
/**
98102
* CallOptions key that will be set by grpc-gcp with the actual channel ID used for the call. This
99103
* can be read by downstream interceptors to get the real channel ID after channel selection.
@@ -1848,7 +1852,7 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
18481852
logger.finest(log("Channel affinity is disabled via context or call options."));
18491853
}
18501854
return new GcpClientCall.SimpleGcpClientCall<>(
1851-
getChannelRef(null), methodDescriptor, callOptions);
1855+
this, getChannelRef(null), methodDescriptor, callOptions);
18521856
}
18531857

18541858
AffinityConfig affinity = methodToAffinity.get(methodDescriptor.getFullMethodName());
@@ -1858,7 +1862,7 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
18581862
}
18591863

18601864
return new GcpClientCall.SimpleGcpClientCall<>(
1861-
getChannelRef(key), methodDescriptor, callOptions);
1865+
this, getChannelRef(key), methodDescriptor, callOptions);
18621866
}
18631867

18641868
@Nullable
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.grpc;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static org.mockito.ArgumentMatchers.any;
21+
import static org.mockito.ArgumentMatchers.anyBoolean;
22+
import static org.mockito.ArgumentMatchers.eq;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
25+
26+
import io.grpc.CallOptions;
27+
import io.grpc.ClientCall;
28+
import io.grpc.ConnectivityState;
29+
import io.grpc.ManagedChannel;
30+
import io.grpc.ManagedChannelBuilder;
31+
import io.grpc.Metadata;
32+
import io.grpc.MethodDescriptor;
33+
import io.grpc.Status;
34+
import java.io.InputStream;
35+
import java.util.Collections;
36+
import org.junit.After;
37+
import org.junit.Before;
38+
import org.junit.Test;
39+
import org.junit.runner.RunWith;
40+
import org.mockito.ArgumentCaptor;
41+
import org.mockito.Mock;
42+
import org.mockito.junit.MockitoJUnitRunner;
43+
44+
@RunWith(MockitoJUnitRunner.class)
45+
public final class GcpClientCallTest {
46+
47+
private static final class FakeMarshaller<T> implements MethodDescriptor.Marshaller<T> {
48+
@Override
49+
public InputStream stream(T value) {
50+
return null;
51+
}
52+
53+
@Override
54+
public T parse(InputStream stream) {
55+
return null;
56+
}
57+
}
58+
59+
private static final MethodDescriptor<String, String> METHOD_DESCRIPTOR =
60+
MethodDescriptor.<String, String>newBuilder()
61+
.setType(MethodDescriptor.MethodType.UNARY)
62+
.setFullMethodName("test/method")
63+
.setRequestMarshaller(new FakeMarshaller<>())
64+
.setResponseMarshaller(new FakeMarshaller<>())
65+
.build();
66+
67+
@Mock private ManagedChannel delegateChannel;
68+
@Mock private ClientCall<String, String> delegateCall;
69+
70+
private GcpManagedChannel gcpChannel;
71+
private GcpManagedChannel.ChannelRef channelRef;
72+
73+
@Before
74+
public void setUp() {
75+
ManagedChannelBuilder<?> builder = ManagedChannelBuilder.forAddress("localhost", 443);
76+
gcpChannel = (GcpManagedChannel) GcpManagedChannelBuilder.forDelegateBuilder(builder).build();
77+
78+
when(delegateChannel.getState(anyBoolean())).thenReturn(ConnectivityState.IDLE);
79+
when(delegateChannel.newCall(eq(METHOD_DESCRIPTOR), any(CallOptions.class)))
80+
.thenReturn(delegateCall);
81+
82+
channelRef = gcpChannel.new ChannelRef(delegateChannel);
83+
}
84+
85+
@After
86+
public void tearDown() {
87+
gcpChannel.shutdownNow();
88+
}
89+
90+
@SuppressWarnings("unchecked")
91+
@Test
92+
public void simpleCallUnbindsAffinityKeyOnCloseWhenRequested() {
93+
String affinityKey = "txn-1";
94+
gcpChannel.bind(channelRef, Collections.singletonList(affinityKey));
95+
96+
GcpClientCall.SimpleGcpClientCall<String, String> call =
97+
new GcpClientCall.SimpleGcpClientCall<>(
98+
gcpChannel,
99+
channelRef,
100+
METHOD_DESCRIPTOR,
101+
CallOptions.DEFAULT
102+
.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey)
103+
.withOption(GcpManagedChannel.UNBIND_AFFINITY_KEY, true));
104+
105+
call.start(new ClientCall.Listener<String>() {}, new Metadata());
106+
107+
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor =
108+
(ArgumentCaptor<ClientCall.Listener<String>>)
109+
(ArgumentCaptor<?>) ArgumentCaptor.forClass(ClientCall.Listener.class);
110+
verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class));
111+
112+
assertThat(gcpChannel.affinityKeyToChannelRef).containsKey(affinityKey);
113+
114+
listenerCaptor.getValue().onClose(Status.OK, new Metadata());
115+
116+
assertThat(gcpChannel.affinityKeyToChannelRef).doesNotContainKey(affinityKey);
117+
assertThat(channelRef.getAffinityCount()).isEqualTo(0);
118+
}
119+
120+
@SuppressWarnings("unchecked")
121+
@Test
122+
public void simpleCallKeepsAffinityKeyOnCloseWhenUnbindNotRequested() {
123+
String affinityKey = "txn-2";
124+
gcpChannel.bind(channelRef, Collections.singletonList(affinityKey));
125+
126+
GcpClientCall.SimpleGcpClientCall<String, String> call =
127+
new GcpClientCall.SimpleGcpClientCall<>(
128+
gcpChannel,
129+
channelRef,
130+
METHOD_DESCRIPTOR,
131+
CallOptions.DEFAULT.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey));
132+
133+
call.start(new ClientCall.Listener<String>() {}, new Metadata());
134+
135+
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor =
136+
(ArgumentCaptor<ClientCall.Listener<String>>)
137+
(ArgumentCaptor<?>) ArgumentCaptor.forClass(ClientCall.Listener.class);
138+
verify(delegateCall).start(listenerCaptor.capture(), any(Metadata.class));
139+
140+
listenerCaptor.getValue().onClose(Status.OK, new Metadata());
141+
142+
assertThat(gcpChannel.affinityKeyToChannelRef).containsEntry(affinityKey, channelRef);
143+
assertThat(channelRef.getAffinityCount()).isEqualTo(1);
144+
}
145+
146+
@Test
147+
public void simpleCallUnbindsAffinityKeyOnCancel() {
148+
String affinityKey = "txn-3";
149+
gcpChannel.bind(channelRef, Collections.singletonList(affinityKey));
150+
151+
GcpClientCall.SimpleGcpClientCall<String, String> call =
152+
new GcpClientCall.SimpleGcpClientCall<>(
153+
gcpChannel,
154+
channelRef,
155+
METHOD_DESCRIPTOR,
156+
CallOptions.DEFAULT.withOption(GcpManagedChannel.AFFINITY_KEY, affinityKey));
157+
158+
call.start(new ClientCall.Listener<String>() {}, new Metadata());
159+
call.cancel("cancelled", null);
160+
161+
assertThat(gcpChannel.affinityKeyToChannelRef).doesNotContainKey(affinityKey);
162+
assertThat(channelRef.getAffinityCount()).isEqualTo(0);
163+
verify(delegateCall).cancel("cancelled", null);
164+
}
165+
}

0 commit comments

Comments
 (0)