2929import static org .mockito .Mockito .when ;
3030
3131import com .google .api .core .ApiFutures ;
32+ import com .google .cloud .NoCredentials ;
33+ import com .google .cloud .grpc .GrpcTransportOptions .ExecutorFactory ;
3234import com .google .cloud .spanner .SessionClient .SessionConsumer ;
3335import java .io .PrintWriter ;
3436import java .io .StringWriter ;
3739import java .time .Duration ;
3840import java .time .Instant ;
3941import java .util .Map ;
42+ import java .util .concurrent .Executors ;
43+ import java .util .concurrent .ScheduledExecutorService ;
44+ import java .util .concurrent .TimeUnit ;
4045import java .util .concurrent .atomic .AtomicInteger ;
4146import java .util .concurrent .atomic .AtomicReference ;
4247import org .junit .After ;
@@ -309,68 +314,58 @@ public void testGrpcGcpSingleUseDoesNotReserveBitsetChannelHint() throws Excepti
309314
310315 @ Test
311316 public void testCloseRemovesChannelUsageEntryWhenLastClientCloses () throws Exception {
312- SessionClient sessionClient = createSessionClient ();
317+ try (SpannerImpl spanner = createTestSpanner ();
318+ SessionClient sessionClient = createSessionClient (spanner )) {
319+ MultiplexedSessionDatabaseClient client =
320+ new MultiplexedSessionDatabaseClient (sessionClient , Clock .systemUTC ());
313321
314- MultiplexedSessionDatabaseClient client =
315- new MultiplexedSessionDatabaseClient (sessionClient , Clock .systemUTC ());
316-
317- assertEquals (1 , getChannelUsage ().size ());
322+ assertEquals (1 , getChannelUsage ().size ());
318323
319- client .close ();
324+ client .close ();
320325
321- assertEquals (0 , getChannelUsage ().size ());
326+ assertEquals (0 , getChannelUsage ().size ());
327+ }
322328 }
323329
324330 @ Test
325331 public void testCloseKeepsChannelUsageEntryWhileAnotherClientIsUsingSameSpanner ()
326332 throws Exception {
327- SpannerImpl spanner = mock (SpannerImpl .class );
328- SessionClient firstSessionClient = createSessionClient (spanner );
329- SessionClient secondSessionClient = createSessionClient (spanner );
333+ try (SpannerImpl spanner = createTestSpanner ();
334+ SessionClient firstSessionClient = createSessionClient (spanner );
335+ SessionClient secondSessionClient = createSessionClient (spanner )) {
336+ MultiplexedSessionDatabaseClient firstClient =
337+ new MultiplexedSessionDatabaseClient (firstSessionClient , Clock .systemUTC ());
338+ MultiplexedSessionDatabaseClient secondClient =
339+ new MultiplexedSessionDatabaseClient (secondSessionClient , Clock .systemUTC ());
330340
331- MultiplexedSessionDatabaseClient firstClient =
332- new MultiplexedSessionDatabaseClient (firstSessionClient , Clock .systemUTC ());
333- MultiplexedSessionDatabaseClient secondClient =
334- new MultiplexedSessionDatabaseClient (secondSessionClient , Clock .systemUTC ());
341+ assertEquals (1 , getChannelUsage ().size ());
335342
336- assertEquals (1 , getChannelUsage ().size ());
343+ firstClient .close ();
344+ assertEquals (1 , getChannelUsage ().size ());
337345
338- firstClient .close ();
339- assertEquals (1 , getChannelUsage ().size ());
340-
341- secondClient .close ();
342- assertEquals (0 , getChannelUsage ().size ());
343- }
344-
345- private SessionClient createSessionClient () {
346- return createSessionClient (mock (SpannerImpl .class ));
346+ secondClient .close ();
347+ assertEquals (0 , getChannelUsage ().size ());
348+ }
347349 }
348350
349351 private SessionClient createSessionClient (SpannerImpl spanner ) {
350- SessionClient sessionClient = mock (SessionClient .class );
351- SpannerOptions spannerOptions = mock (SpannerOptions .class );
352- SessionPoolOptions sessionPoolOptions = mock (SessionPoolOptions .class );
352+ return new FailingMultiplexedSessionClient (spanner );
353+ }
353354
354- when (sessionClient .getSpanner ()).thenReturn (spanner );
355- when (spanner .getOptions ()).thenReturn (spannerOptions );
356- when (spannerOptions .getNumChannels ()).thenReturn (4 );
357- when (spannerOptions .getSessionPoolOptions ()).thenReturn (sessionPoolOptions );
358- when (sessionPoolOptions .getMultiplexedSessionMaintenanceDuration ())
359- .thenReturn (Duration .ofDays (7 ));
360- when (sessionPoolOptions .getWaitForMinSessions ()).thenReturn (Duration .ZERO );
361- doAnswer (
362- (Answer <?>)
363- invocationOnMock -> {
364- SessionConsumer consumer = invocationOnMock .getArgument (0 );
365- consumer .onSessionCreateFailure (
366- SpannerExceptionFactory .newSpannerException (
367- ErrorCode .UNAUTHENTICATED , "test" ),
368- 1 );
369- return null ;
370- })
371- .when (sessionClient )
372- .asyncCreateMultiplexedSession (any (SessionConsumer .class ));
373- return sessionClient ;
355+ private SpannerImpl createTestSpanner () {
356+ SessionPoolOptions sessionPoolOptions =
357+ SessionPoolOptions .newBuilder ()
358+ .setMultiplexedSessionMaintenanceDuration (Duration .ofDays (7 ))
359+ .setWaitForMinSessionsDuration (Duration .ZERO )
360+ .build ();
361+ SpannerOptions options =
362+ SpannerOptions .newBuilder ()
363+ .setProjectId ("test-project" )
364+ .setCredentials (NoCredentials .getInstance ())
365+ .setNumChannels (4 )
366+ .setSessionPoolOption (sessionPoolOptions )
367+ .build ();
368+ return new SpannerImpl (options );
374369 }
375370
376371 @ SuppressWarnings ("unchecked" )
@@ -391,4 +386,37 @@ private boolean isJava8() {
391386 private boolean isWindows () {
392387 return System .getProperty ("os.name" ).toLowerCase ().contains ("windows" );
393388 }
389+
390+ private static final class TestExecutorFactory
391+ implements ExecutorFactory <ScheduledExecutorService > {
392+ @ Override
393+ public ScheduledExecutorService get () {
394+ return Executors .newSingleThreadScheduledExecutor ();
395+ }
396+
397+ @ Override
398+ public void release (ScheduledExecutorService executor ) {
399+ executor .shutdown ();
400+ try {
401+ executor .awaitTermination (10L , TimeUnit .SECONDS );
402+ } catch (InterruptedException e ) {
403+ throw new RuntimeException (e );
404+ }
405+ }
406+ }
407+
408+ private static final class FailingMultiplexedSessionClient extends SessionClient {
409+ private static final DatabaseId TEST_DATABASE_ID =
410+ DatabaseId .of ("test-project" , "test-instance" , "test-database" );
411+
412+ private FailingMultiplexedSessionClient (SpannerImpl spanner ) {
413+ super (spanner , TEST_DATABASE_ID , new TestExecutorFactory ());
414+ }
415+
416+ @ Override
417+ void asyncCreateMultiplexedSession (SessionConsumer consumer ) {
418+ consumer .onSessionCreateFailure (
419+ SpannerExceptionFactory .newSpannerException (ErrorCode .UNAUTHENTICATED , "test" ), 1 );
420+ }
421+ }
394422}
0 commit comments