@@ -745,6 +745,150 @@ struct FeatherPostgresDatabaseTestSuite {
745745 }
746746 }
747747
748+ @Test
749+ func concurrentTransactionUpdates( ) async throws {
750+ try await runUsingTestDatabaseClient { database in
751+ let suffix = randomTableSuffix ( )
752+ let table = " sessions_ \( suffix) "
753+ let sessionID = " session_ \( suffix) "
754+
755+ enum TestError : Error {
756+ case missingRow
757+ }
758+
759+ try await database. execute (
760+ query: #"""
761+ DROP TABLE IF EXISTS " \#( unescaped: table) " CASCADE;
762+ """#
763+ )
764+ try await database. execute (
765+ query: #"""
766+ CREATE TABLE " \#( unescaped: table) " (
767+ "id" TEXT NOT NULL PRIMARY KEY,
768+ "access_token" TEXT NOT NULL,
769+ "access_expires_at" TIMESTAMPTZ NOT NULL,
770+ "refresh_token" TEXT NOT NULL,
771+ "refresh_count" INTEGER NOT NULL DEFAULT 0
772+ );
773+ """#
774+ )
775+
776+ // set an expired token
777+ try await database. execute (
778+ query: #"""
779+ INSERT INTO " \#( unescaped: table) "
780+ ("id", "access_token", "access_expires_at", "refresh_token", "refresh_count")
781+ VALUES
782+ (
783+ \#( sessionID) ,
784+ 'stale',
785+ NOW() - INTERVAL '5 minutes',
786+ 'refresh',
787+ 0
788+ );
789+ """#
790+ )
791+
792+ func getValidAccessToken( sessionID: String ) async throws -> String {
793+ try await database. transaction { connection in
794+ let result = try await connection. execute (
795+ query: #"""
796+ SELECT
797+ "access_token",
798+ "refresh_count",
799+ "access_expires_at" > NOW() + INTERVAL '60 seconds' AS "is_valid"
800+ FROM " \#( unescaped: table) "
801+ WHERE "id" = \#( sessionID)
802+ FOR UPDATE;
803+ """#
804+ )
805+ let rows = try await result. collect ( )
806+
807+ guard let row = rows. first else {
808+ throw TestError . missingRow
809+ }
810+
811+ let isValid = try row. decode (
812+ column: " is_valid " ,
813+ as: Bool . self
814+ )
815+ if isValid {
816+ // token was valid, must be called X times
817+ return try row. decode (
818+ column: " access_token " ,
819+ as: String . self
820+ )
821+ }
822+
823+ // refresh, this branch can only be called 1 time
824+ let refreshCount = try row. decode (
825+ column: " refresh_count " ,
826+ as: Int . self
827+ )
828+ let newRefreshCount = refreshCount + 1
829+ let newToken = " token_ \( newRefreshCount) "
830+
831+ try await Task . sleep ( for: . milliseconds( 40 ) )
832+
833+ _ = try await connection. execute (
834+ query: #"""
835+ UPDATE " \#( unescaped: table) "
836+ SET
837+ "access_token" = \#( newToken) ,
838+ "access_expires_at" = NOW() + INTERVAL '10 minutes',
839+ "refresh_count" = \#( newRefreshCount)
840+ WHERE "id" = \#( sessionID) ;
841+ """#
842+ )
843+
844+ return newToken
845+ }
846+ }
847+
848+ let workerCount = 80
849+ var tokens : [ String ] = [ ]
850+ try await withThrowingTaskGroup ( of: String . self) { group in
851+ for _ in 0 ..< workerCount {
852+ group. addTask {
853+ try await getValidAccessToken ( sessionID: sessionID)
854+ }
855+ }
856+ for try await token in group {
857+ tokens. append ( token)
858+ }
859+ }
860+
861+ #expect( Set ( tokens) . count == 1 )
862+
863+ let result =
864+ try await database. execute (
865+ query: #"""
866+ SELECT
867+ "access_token",
868+ "refresh_count",
869+ "access_expires_at" > NOW() AS "is_valid"
870+ FROM " \#( unescaped: table) "
871+ WHERE "id" = \#( sessionID) ;
872+ """#
873+ )
874+ . collect ( )
875+
876+ #expect( result. count == 1 )
877+ #expect(
878+ try result [ 0 ] . decode ( column: " refresh_count " , as: Int . self)
879+ == 1
880+ )
881+ #expect(
882+ try result [ 0 ] . decode ( column: " access_token " , as: String . self)
883+ == " token_1 "
884+ )
885+ #expect(
886+ try result [ 0 ] . decode ( column: " is_valid " , as: Bool . self)
887+ == true
888+ )
889+ }
890+ }
891+
748892 @Test
749893 func doubleRoundTrip( ) async throws {
750894 try await runUsingTestDatabaseClient { database in
0 commit comments