|
7 | 7 | type CooldownReason, |
8 | 8 | type RateLimitStateV3, |
9 | 9 | findMatchingAccountIndex, |
| 10 | + withAccountStorageTransaction, |
10 | 11 | } from "./storage.js"; |
11 | 12 | import type { AccountIdSource, OAuthAuthDetails } from "./types.js"; |
12 | 13 | import { MODEL_FAMILIES, type ModelFamily } from "./prompts/codex.js"; |
@@ -81,6 +82,81 @@ function initFamilyState(defaultValue: number): Record<ModelFamily, number> { |
81 | 82 | ) as Record<ModelFamily, number>; |
82 | 83 | } |
83 | 84 |
|
| 85 | +type AccountIdentityCandidate = Pick< |
| 86 | + ManagedAccount, |
| 87 | + "accountId" | "email" | "refreshToken" |
| 88 | +> & { |
| 89 | + index?: number; |
| 90 | +}; |
| 91 | + |
| 92 | +function getAuthIdentityCandidate( |
| 93 | + auth: OAuthAuthDetails | undefined, |
| 94 | +): AccountIdentityCandidate { |
| 95 | + const accountId = extractAccountId(auth?.access)?.trim() || undefined; |
| 96 | + const email = sanitizeEmail(extractAccountEmail(auth?.access)); |
| 97 | + return { |
| 98 | + accountId, |
| 99 | + email, |
| 100 | + refreshToken: auth?.refresh, |
| 101 | + }; |
| 102 | +} |
| 103 | + |
| 104 | +function buildAccountIdentityCandidates( |
| 105 | + source: AccountIdentityCandidate, |
| 106 | + auth?: OAuthAuthDetails, |
| 107 | +): AccountIdentityCandidate[] { |
| 108 | + const derived = getAuthIdentityCandidate(auth); |
| 109 | + const candidates: AccountIdentityCandidate[] = []; |
| 110 | + const seen = new Set<string>(); |
| 111 | + |
| 112 | + const pushCandidate = (candidate: AccountIdentityCandidate): void => { |
| 113 | + const key = `${candidate.accountId ?? ""}|${candidate.email ?? ""}|${candidate.refreshToken ?? ""}`; |
| 114 | + if (seen.has(key)) return; |
| 115 | + seen.add(key); |
| 116 | + candidates.push(candidate); |
| 117 | + }; |
| 118 | + |
| 119 | + pushCandidate(source); |
| 120 | + pushCandidate({ |
| 121 | + accountId: source.accountId ?? derived.accountId, |
| 122 | + email: source.email ?? derived.email, |
| 123 | + refreshToken: source.refreshToken, |
| 124 | + index: source.index, |
| 125 | + }); |
| 126 | + pushCandidate({ |
| 127 | + accountId: derived.accountId ?? source.accountId, |
| 128 | + email: derived.email ?? source.email, |
| 129 | + refreshToken: source.refreshToken, |
| 130 | + index: source.index, |
| 131 | + }); |
| 132 | + pushCandidate({ |
| 133 | + accountId: derived.accountId ?? source.accountId, |
| 134 | + email: derived.email ?? source.email, |
| 135 | + refreshToken: derived.refreshToken ?? source.refreshToken, |
| 136 | + index: source.index, |
| 137 | + }); |
| 138 | + |
| 139 | + return candidates; |
| 140 | +} |
| 141 | + |
| 142 | +function findAccountIndexByIdentity< |
| 143 | + T extends Pick<AccountIdentityCandidate, "accountId" | "email" | "refreshToken">, |
| 144 | +>( |
| 145 | + accounts: readonly T[], |
| 146 | + source: AccountIdentityCandidate, |
| 147 | + auth?: OAuthAuthDetails, |
| 148 | +): number | undefined { |
| 149 | + for (const candidate of buildAccountIdentityCandidates(source, auth)) { |
| 150 | + const matchIndex = findMatchingAccountIndex(accounts, candidate, { |
| 151 | + allowUniqueAccountIdFallbackWithoutEmail: true, |
| 152 | + }); |
| 153 | + if (matchIndex !== undefined) { |
| 154 | + return matchIndex; |
| 155 | + } |
| 156 | + } |
| 157 | + return undefined; |
| 158 | +} |
| 159 | + |
84 | 160 | export interface Workspace { |
85 | 161 | id: string; |
86 | 162 | name?: string; |
@@ -642,6 +718,17 @@ export class AccountManager { |
642 | 718 | account.consecutiveAuthFailures = 0; |
643 | 719 | } |
644 | 720 |
|
| 721 | + getAccountByIdentity( |
| 722 | + candidate: AccountIdentityCandidate, |
| 723 | + auth?: OAuthAuthDetails, |
| 724 | + ): ManagedAccount | null { |
| 725 | + const index = findAccountIndexByIdentity(this.accounts, candidate, auth); |
| 726 | + if (index === undefined) { |
| 727 | + return null; |
| 728 | + } |
| 729 | + return this.accounts[index] ?? null; |
| 730 | + } |
| 731 | + |
645 | 732 | shouldShowAccountToast(accountIndex: number, debounceMs = 30000): boolean { |
646 | 733 | const now = nowMs(); |
647 | 734 | if (accountIndex === this.lastToastAccountIndex && now - this.lastToastTime < debounceMs) { |
@@ -670,6 +757,112 @@ export class AccountManager { |
670 | 757 | account.email = sanitizeEmail(extractAccountEmail(auth.access)) ?? account.email; |
671 | 758 | } |
672 | 759 |
|
| 760 | + private buildStorageSnapshot(): AccountStorageV3 { |
| 761 | + const activeIndexByFamily: Partial<Record<ModelFamily, number>> = {}; |
| 762 | + for (const family of MODEL_FAMILIES) { |
| 763 | + const raw = this.currentAccountIndexByFamily[family]; |
| 764 | + activeIndexByFamily[family] = clampNonNegativeInt(raw, 0); |
| 765 | + } |
| 766 | + |
| 767 | + const activeIndex = clampNonNegativeInt(activeIndexByFamily.codex, 0); |
| 768 | + |
| 769 | + return { |
| 770 | + version: 3, |
| 771 | + accounts: this.accounts.map((account) => ({ |
| 772 | + accountId: account.accountId, |
| 773 | + accountIdSource: account.accountIdSource, |
| 774 | + accountLabel: account.accountLabel, |
| 775 | + email: account.email, |
| 776 | + refreshToken: account.refreshToken, |
| 777 | + accessToken: account.access, |
| 778 | + expiresAt: account.expires, |
| 779 | + enabled: account.enabled === false ? false : undefined, |
| 780 | + addedAt: account.addedAt, |
| 781 | + lastUsed: account.lastUsed, |
| 782 | + lastSwitchReason: account.lastSwitchReason, |
| 783 | + rateLimitResetTimes: |
| 784 | + Object.keys(account.rateLimitResetTimes).length > 0 ? account.rateLimitResetTimes : undefined, |
| 785 | + coolingDownUntil: account.coolingDownUntil, |
| 786 | + cooldownReason: account.cooldownReason, |
| 787 | + workspaces: account.workspaces, |
| 788 | + currentWorkspaceIndex: account.currentWorkspaceIndex, |
| 789 | + })), |
| 790 | + activeIndex, |
| 791 | + activeIndexByFamily, |
| 792 | + }; |
| 793 | + } |
| 794 | + |
| 795 | + async commitRefreshedAuth( |
| 796 | + source: Pick< |
| 797 | + ManagedAccount, |
| 798 | + "index" | "accountId" | "email" | "refreshToken" |
| 799 | + >, |
| 800 | + auth: OAuthAuthDetails, |
| 801 | + ): Promise<ManagedAccount | null> { |
| 802 | + const nextAccountId = extractAccountId(auth.access)?.trim() || undefined; |
| 803 | + const nextEmail = sanitizeEmail(extractAccountEmail(auth.access)); |
| 804 | + |
| 805 | + await withAccountStorageTransaction(async (current, persist) => { |
| 806 | + const nextStorage = structuredClone( |
| 807 | + current ?? this.buildStorageSnapshot(), |
| 808 | + ) as AccountStorageV3; |
| 809 | + const storageIndex = findAccountIndexByIdentity( |
| 810 | + nextStorage.accounts, |
| 811 | + source, |
| 812 | + auth, |
| 813 | + ); |
| 814 | + if (storageIndex === undefined) { |
| 815 | + log.warn("Unable to resolve refreshed account for persistence", { |
| 816 | + sourceIndex: source.index, |
| 817 | + email: source.email, |
| 818 | + }); |
| 819 | + return; |
| 820 | + } |
| 821 | + |
| 822 | + const storedAccount = nextStorage.accounts[storageIndex]; |
| 823 | + if (!storedAccount) { |
| 824 | + return; |
| 825 | + } |
| 826 | + |
| 827 | + storedAccount.refreshToken = auth.refresh; |
| 828 | + storedAccount.accessToken = auth.access; |
| 829 | + storedAccount.expiresAt = auth.expires; |
| 830 | + if ( |
| 831 | + nextAccountId && |
| 832 | + shouldUpdateAccountIdFromToken( |
| 833 | + storedAccount.accountIdSource, |
| 834 | + storedAccount.accountId, |
| 835 | + ) |
| 836 | + ) { |
| 837 | + storedAccount.accountId = nextAccountId; |
| 838 | + storedAccount.accountIdSource = "token"; |
| 839 | + } |
| 840 | + if (nextEmail) { |
| 841 | + storedAccount.email = nextEmail; |
| 842 | + } |
| 843 | + storedAccount.enabled = undefined; |
| 844 | + delete storedAccount.coolingDownUntil; |
| 845 | + delete storedAccount.cooldownReason; |
| 846 | + |
| 847 | + await persist(nextStorage); |
| 848 | + }); |
| 849 | + |
| 850 | + const liveAccount = this.getAccountByIdentity(source, auth); |
| 851 | + if (!liveAccount) { |
| 852 | + log.warn("Unable to resolve refreshed live account after persistence", { |
| 853 | + sourceIndex: source.index, |
| 854 | + email: source.email, |
| 855 | + }); |
| 856 | + return null; |
| 857 | + } |
| 858 | + |
| 859 | + this.updateFromAuth(liveAccount, auth); |
| 860 | + liveAccount.enabled = true; |
| 861 | + this.clearAccountCooldown(liveAccount); |
| 862 | + this.clearAuthFailures(liveAccount); |
| 863 | + return liveAccount; |
| 864 | + } |
| 865 | + |
673 | 866 | toAuthDetails(account: ManagedAccount): Auth { |
674 | 867 | return { |
675 | 868 | type: "oauth", |
@@ -780,40 +973,7 @@ export class AccountManager { |
780 | 973 | } |
781 | 974 |
|
782 | 975 | async saveToDisk(): Promise<void> { |
783 | | - const activeIndexByFamily: Partial<Record<ModelFamily, number>> = {}; |
784 | | - for (const family of MODEL_FAMILIES) { |
785 | | - const raw = this.currentAccountIndexByFamily[family]; |
786 | | - activeIndexByFamily[family] = clampNonNegativeInt(raw, 0); |
787 | | - } |
788 | | - |
789 | | - const activeIndex = clampNonNegativeInt(activeIndexByFamily.codex, 0); |
790 | | - |
791 | | - const storage: AccountStorageV3 = { |
792 | | - version: 3, |
793 | | - accounts: this.accounts.map((account) => ({ |
794 | | - accountId: account.accountId, |
795 | | - accountIdSource: account.accountIdSource, |
796 | | - accountLabel: account.accountLabel, |
797 | | - email: account.email, |
798 | | - refreshToken: account.refreshToken, |
799 | | - accessToken: account.access, |
800 | | - expiresAt: account.expires, |
801 | | - enabled: account.enabled === false ? false : undefined, |
802 | | - addedAt: account.addedAt, |
803 | | - lastUsed: account.lastUsed, |
804 | | - lastSwitchReason: account.lastSwitchReason, |
805 | | - rateLimitResetTimes: |
806 | | - Object.keys(account.rateLimitResetTimes).length > 0 ? account.rateLimitResetTimes : undefined, |
807 | | - coolingDownUntil: account.coolingDownUntil, |
808 | | - cooldownReason: account.cooldownReason, |
809 | | - workspaces: account.workspaces, |
810 | | - currentWorkspaceIndex: account.currentWorkspaceIndex, |
811 | | - })), |
812 | | - activeIndex, |
813 | | - activeIndexByFamily, |
814 | | - }; |
815 | | - |
816 | | - await saveAccounts(storage); |
| 976 | + await saveAccounts(this.buildStorageSnapshot()); |
817 | 977 | } |
818 | 978 |
|
819 | 979 | saveToDiskDebounced(delayMs = 500): void { |
|
0 commit comments