diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 1d722b421c..49390c64fe 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,8 +1,10 @@ import { ControllerMessenger } from '@metamask/base-controller'; import { toHex } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import BN from 'bn.js'; import { flushPromises } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { AllowedActions, AllowedEvents, @@ -31,19 +33,63 @@ function getMessenger( ): TokenBalancesControllerMessenger { return controllerMessenger.getRestricted({ name: controllerName, - allowedActions: ['PreferencesController:getState'], + allowedActions: ['AccountsController:getSelectedAccount'], allowedEvents: ['TokensController:stateChange'], }); } -describe('TokenBalancesController', () => { - let controllerMessenger: ControllerMessenger; - let messenger: TokenBalancesControllerMessenger; +const setupController = ({ + config, + mock, +}: { + config?: Partial[0]>; + mock: { + getBalanceOf?: BN; + selectedAccount: InternalAccount; + }; +}): { + controller: TokenBalancesController; + messenger: TokenBalancesControllerMessenger; + mockSelectedAccount: jest.Mock; + mockGetERC20BalanceOf: jest.Mock; + triggerTokensStateChange: (state: TokensControllerState) => Promise; +} => { + const controllerMessenger = new ControllerMessenger< + AllowedActions, + AllowedEvents + >(); + const messenger = getMessenger(controllerMessenger); + + const mockSelectedAccount = jest.fn().mockReturnValue(mock.selectedAccount); + const mockGetERC20BalanceOf = jest.fn().mockReturnValue(mock.getBalanceOf); + + controllerMessenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + mockSelectedAccount, + ); + + const controller = new TokenBalancesController({ + getERC20BalanceOf: mockGetERC20BalanceOf, + messenger, + ...config, + }); + + const triggerTokensStateChange = async (state: TokensControllerState) => { + controllerMessenger.publish('TokensController:stateChange', state, []); + }; + return { + controller, + messenger, + mockSelectedAccount, + mockGetERC20BalanceOf, + triggerTokensStateChange, + }; +}; + +describe('TokenBalancesController', () => { beforeEach(() => { jest.useFakeTimers(); - controllerMessenger = new ControllerMessenger(); - messenger = getMessenger(controllerMessenger); }); afterEach(() => { @@ -51,23 +97,16 @@ describe('TokenBalancesController', () => { }); it('should set default state', () => { - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - getERC20BalanceOf: jest.fn(), - messenger, + const { controller } = setupController({ + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + }, }); expect(controller.state).toStrictEqual({ contractBalances: {} }); }); it('should poll and update balances in the right interval', async () => { - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); const updateBalancesSpy = jest.spyOn( TokenBalancesController.prototype, 'updateBalances', @@ -76,7 +115,7 @@ describe('TokenBalancesController', () => { new TokenBalancesController({ interval: 10, getERC20BalanceOf: jest.fn(), - messenger, + messenger: getMessenger(new ControllerMessenger()), }); await flushPromises(); @@ -90,16 +129,16 @@ describe('TokenBalancesController', () => { it('should update balances if enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - disabled: false, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + getBalanceOf: new BN(1), + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + }, }); await controller.updateBalances(); @@ -111,16 +150,16 @@ describe('TokenBalancesController', () => { it('should not update balances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - disabled: true, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); await controller.updateBalances(); @@ -130,16 +169,16 @@ describe('TokenBalancesController', () => { it('should update balances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - disabled: true, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); await controller.updateBalances(); @@ -156,16 +195,16 @@ describe('TokenBalancesController', () => { it('should not update balances if controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - disabled: false, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); await controller.updateBalances(); @@ -184,20 +223,17 @@ describe('TokenBalancesController', () => { it('should update balances if tokens change and controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - disabled: true, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; await controller.updateBalances(); @@ -222,20 +258,17 @@ describe('TokenBalancesController', () => { it('should not update balances if tokens change and controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - disabled: false, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; await controller.updateBalances(); @@ -261,14 +294,14 @@ describe('TokenBalancesController', () => { }); it('should clear previous interval', async () => { - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - interval: 1337, - getERC20BalanceOf: jest.fn(), - messenger, + const { controller } = setupController({ + config: { + interval: 1337, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); const mockClearTimeout = jest.spyOn(global, 'clearTimeout'); @@ -291,15 +324,17 @@ describe('TokenBalancesController', () => { aggregators: [], }, ]; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress }), - ); - const controller = new TokenBalancesController({ - interval: 1337, - tokens, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + interval: 1337, + tokens, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address: selectedAddress, + }), + getBalanceOf: new BN(1), + }, }); expect(controller.state.contractBalances).toStrictEqual({}); @@ -314,9 +349,6 @@ describe('TokenBalancesController', () => { it('should handle `getERC20BalanceOf` error case', async () => { const errorMsg = 'Failed to get balance'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const getERC20BalanceOfStub = jest - .fn() - .mockReturnValue(Promise.reject(new Error(errorMsg))); const tokens: Token[] = [ { address, @@ -326,17 +358,21 @@ describe('TokenBalancesController', () => { }, ]; - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({}), - ); - const controller = new TokenBalancesController({ - interval: 1337, - tokens, - getERC20BalanceOf: getERC20BalanceOfStub, - messenger, + const { controller, mockGetERC20BalanceOf } = setupController({ + config: { + interval: 1337, + tokens, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address, + }), + }, }); + // @ts-expect-error Testing error case + mockGetERC20BalanceOf.mockReturnValueOnce(new Error(errorMsg)); + expect(controller.state.contractBalances).toStrictEqual({}); await controller.updateBalances(); @@ -344,8 +380,7 @@ describe('TokenBalancesController', () => { expect(tokens[0].hasBalanceError).toBe(true); expect(controller.state.contractBalances[address]).toBe(toHex(0)); - getERC20BalanceOfStub.mockReturnValue(new BN(1)); - + mockGetERC20BalanceOf.mockReturnValueOnce(new BN(1)); await controller.updateBalances(); expect(tokens[0].hasBalanceError).toBe(false); @@ -354,18 +389,18 @@ describe('TokenBalancesController', () => { }); it('should update balances when tokens change', async () => { - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - getERC20BalanceOf: jest.fn(), - interval: 1337, - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + interval: 1337, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address: '0x1234', + }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); await triggerTokensStateChange({ @@ -383,18 +418,18 @@ describe('TokenBalancesController', () => { }); it('should update token balances when detected tokens are added', async () => { - controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), - ); - const controller = new TokenBalancesController({ - interval: 1337, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + interval: 1337, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address: '0x1234', + }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; + expect(controller.state.contractBalances).toStrictEqual({}); await triggerTokensStateChange({ diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 323544f813..a1b58a4340 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,3 +1,4 @@ +import { type AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; import { type RestrictedControllerMessenger, type ControllerGetStateAction, @@ -5,7 +6,6 @@ import { BaseController, } from '@metamask/base-controller'; import { safelyExecute, toHex } from '@metamask/controller-utils'; -import type { PreferencesControllerGetStateAction } from '@metamask/preferences-controller'; import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; @@ -56,7 +56,7 @@ export type TokenBalancesControllerGetStateAction = ControllerGetStateAction< export type TokenBalancesControllerActions = TokenBalancesControllerGetStateAction; -export type AllowedActions = PreferencesControllerGetStateAction; +export type AllowedActions = AccountsControllerGetSelectedAccountAction; export type TokenBalancesControllerStateChangeEvent = ControllerStateChangeEvent< @@ -201,16 +201,18 @@ export class TokenBalancesController extends BaseController< if (this.#disabled) { return; } - - const { selectedAddress } = this.messagingSystem.call( - 'PreferencesController:getState', + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', ); const newContractBalances: ContractBalances = {}; for (const token of this.#tokens) { const { address } = token; try { - const balance = await this.#getERC20BalanceOf(address, selectedAddress); + const balance = await this.#getERC20BalanceOf( + address, + selectedInternalAccount.address, + ); newContractBalances[address] = toHex(balance); token.hasBalanceError = false; } catch (error) { diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 1012531be9..2e23ceb5fd 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -27,6 +27,7 @@ import nock from 'nock'; import * as sinon from 'sinon'; import { advanceTime } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { formatAggregatorNames } from './assetsUtil'; import { TOKEN_END_POINT_API } from './token-service'; import type { @@ -144,6 +145,7 @@ function buildTokenDetectionControllerMessenger( return controllerMessenger.getRestricted({ name: controllerName, allowedActions: [ + 'AccountsController:getAccount', 'AccountsController:getSelectedAccount', 'KeyringController:getState', 'NetworkController:getNetworkClientById', @@ -155,7 +157,7 @@ function buildTokenDetectionControllerMessenger( 'PreferencesController:getState', ], allowedEvents: [ - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', 'KeyringController:lock', 'KeyringController:unlock', 'NetworkController:networkDidChange', @@ -166,6 +168,8 @@ function buildTokenDetectionControllerMessenger( } describe('TokenDetectionController', () => { + const defaultSelectedAccount = createMockInternalAccount(); + beforeEach(async () => { nock(TOKEN_END_POINT_API) .get(getTokensPath(ChainId.mainnet)) @@ -207,6 +211,10 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, + options: {}, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, }, async ({ controller }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); @@ -225,6 +233,10 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, + options: {}, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, }, async ({ controller, triggerKeyringUnlock }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); @@ -259,16 +271,24 @@ describe('TokenDetectionController', () => { }); it('should poll and detect tokens on interval while on supported networks', async () => { - await withController(async ({ controller }) => { - const mockTokens = sinon.stub(controller, 'detectTokens'); - controller.setIntervalLength(10); + await withController( + { + options: {}, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, + }, + async ({ controller }) => { + const mockTokens = sinon.stub(controller, 'detectTokens'); + controller.setIntervalLength(10); - await controller.start(); + await controller.start(); - expect(mockTokens.calledOnce).toBe(true); - await advanceTime({ clock, duration: 15 }); - expect(mockTokens.calledTwice).toBe(true); - }); + expect(mockTokens.calledOnce).toBe(true); + await advanceTime({ clock, duration: 15 }); + expect(mockTokens.calledTwice).toBe(true); + }, + ); }); it('should not autodetect while not on supported networks', async () => { @@ -280,6 +300,9 @@ describe('TokenDetectionController', () => { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, }, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, }, async ({ controller, mockNetworkState }) => { mockNetworkState({ @@ -297,12 +320,17 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { @@ -333,7 +361,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -344,12 +372,17 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -397,7 +430,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: '0x89', - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -409,14 +442,19 @@ describe('TokenDetectionController', () => { [sampleTokenA.address]: new BN(1), [sampleTokenB.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); const interval = 100; await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, interval, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { @@ -459,7 +497,7 @@ describe('TokenDetectionController', () => { [sampleTokenA, sampleTokenB], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -470,12 +508,17 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -526,6 +569,9 @@ describe('TokenDetectionController', () => { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, }, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { mockTokenListGetState({ @@ -573,19 +619,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerSelectedAccountChange, callActionSpy, @@ -610,9 +661,8 @@ describe('TokenDetectionController', () => { }, }); - triggerSelectedAccountChange({ - address: secondSelectedAddress, - } as InternalAccount); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).toHaveBeenCalledWith( @@ -620,7 +670,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, + selectedAddress: secondSelectedAccount.address, }, ); }, @@ -631,13 +681,17 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ @@ -666,7 +720,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: selectedAddress, + address: selectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -682,16 +736,20 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, isKeyringUnlocked: false, }, @@ -721,7 +779,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: secondSelectedAddress, + address: secondSelectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -739,16 +797,20 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ @@ -777,7 +839,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: secondSelectedAddress, + address: secondSelectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -805,21 +867,27 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { mockTokenListGetState({ @@ -844,17 +912,18 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).toHaveBeenCalledWith( + expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addDetectedTokens', [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, + selectedAddress: secondSelectedAccount.address, }, ); }, @@ -865,20 +934,26 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -901,14 +976,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -918,7 +991,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -929,23 +1002,30 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ + mockGetAccount, mockTokenListGetState, + triggerSelectedAccountChange, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -963,9 +1043,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: false, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -979,13 +1060,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1010,7 +1096,6 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1021,113 +1106,124 @@ describe('TokenDetectionController', () => { }, ); }); + }); - describe('when keyring is locked', () => { - it('should not detect new tokens after switching between accounts', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, - }, - isKeyringUnlocked: false, + describe('when keyring is locked', () => { + it('should not detect new tokens after switching between accounts', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, }, - async ({ - mockTokenListGetState, - triggerPreferencesStateChange, - callActionSpy, - }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokenList: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, + mocks: { + getSelectedAccount: firstSelectedAccount, + getAccount: firstSelectedAccount, + }, + isKeyringUnlocked: false, + }, + async ({ + mockGetAccount, + mockTokenListGetState, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + callActionSpy, + }) => { + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokenList: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, }, - }); + }, + }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, - useTokenDetection: true, - }); - await advanceTime({ clock, duration: 1 }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: true, + }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); + await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).not.toHaveBeenCalledWith( - 'TokensController:addDetectedTokens', - ); - }, - ); - }); + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); + }, + ); + }); - it('should not detect new tokens after enabling token detection', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, - }, - isKeyringUnlocked: false, + it('should not detect new tokens after enabling token detection', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, }, - async ({ - mockTokenListGetState, - triggerPreferencesStateChange, - callActionSpy, - }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokenList: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, + isKeyringUnlocked: false, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + callActionSpy, + }) => { + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokenList: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, }, - }); + }, + }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, - useTokenDetection: false, - }); - await advanceTime({ clock, duration: 1 }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: false, + }); + await advanceTime({ clock, duration: 1 }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, - useTokenDetection: true, - }); - await advanceTime({ clock, duration: 1 }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: true, + }); + await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).not.toHaveBeenCalledWith( - 'TokensController:addDetectedTokens', - ); - }, - ); - }); + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); + }, + ); }); }); @@ -1136,21 +1232,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + }, + mocks: { + getAccount: firstSelectedAccount, + getSelectedAccount: firstSelectedAccount, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { mockTokenListGetState({ @@ -1170,9 +1273,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -1186,13 +1290,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1217,14 +1326,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1253,13 +1360,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1298,7 +1410,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: '0x89', - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1309,13 +1421,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1360,13 +1477,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1407,15 +1529,20 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, }, isKeyringUnlocked: false, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, }, async ({ mockTokenListGetState, @@ -1457,13 +1584,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1516,13 +1648,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ @@ -1561,7 +1698,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1572,13 +1709,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ @@ -1607,15 +1749,20 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, }, isKeyringUnlocked: false, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ mockTokenListGetState, @@ -1655,13 +1802,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ @@ -1711,13 +1863,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ controller, mockTokenListGetState }) => { @@ -1777,13 +1934,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ @@ -1802,7 +1964,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ networkClientId: NetworkType.goerli, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addDetectedTokens', @@ -1821,13 +1983,18 @@ describe('TokenDetectionController', () => { {}, ), ); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ @@ -1841,7 +2008,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addDetectedTokens', @@ -1854,7 +2021,7 @@ describe('TokenDetectionController', () => { }; }), { - selectedAddress, + selectedAddress: selectedAccount.address, chainId: ChainId.mainnet, }, ); @@ -1866,13 +2033,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { @@ -1898,7 +2070,7 @@ describe('TokenDetectionController', () => { await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenCalledWith( @@ -1906,7 +2078,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1917,7 +2089,9 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); const mockTrackMetaMetricsEvent = jest.fn(); await withController( @@ -1926,7 +2100,10 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ controller, mockTokenListGetState }) => { @@ -1952,7 +2129,7 @@ describe('TokenDetectionController', () => { await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(mockTrackMetaMetricsEvent).toHaveBeenCalledWith({ @@ -1971,6 +2148,85 @@ describe('TokenDetectionController', () => { }, ); }); + + it('does not trigger `TokensController:addDetectedTokens` action when selectedAccount is not found', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + + const mockTrackMetaMetricsEvent = jest.fn(); + + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + trackMetaMetricsEvent: mockTrackMetaMetricsEvent, + }, + }, + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + // @ts-expect-error forcing an undefined value + mockGetAccount(undefined); + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokensChainsCache: { + '0x1': { + timestamp: 0, + data: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, + }, + }, + }, + }, + }); + + await controller.detectTokens({ + networkClientId: NetworkType.mainnet, + }); + + expect(callActionSpy).toHaveBeenLastCalledWith( + 'TokensController:addDetectedTokens', + [ + { + address: '0x514910771AF9Ca656af840dff83E8264EcF986CA', + aggregators: [ + 'Paraswap', + 'PMM', + 'AirswapLight', + '0x', + 'Bancor', + 'CoinGecko', + 'Zapper', + 'Kleros', + 'Zerion', + 'CMC', + '1inch', + ], + decimals: 18, + image: + 'https://static.cx.metamask.io/api/v1/tokenIcons/1/0x514910771af9ca656af840dff83e8264ecf986ca.png', + isERC721: false, + name: 'Chainlink', + symbol: 'LINK', + }, + ], + { chainId: '0x1', selectedAddress: '' }, + ); + }, + ); + }); }); }); @@ -1990,6 +2246,7 @@ function getTokensPath(chainId: Hex) { type WithControllerCallback = ({ controller, + mockGetAccount, mockGetSelectedAccount, mockKeyringGetState, mockTokensGetState, @@ -2007,6 +2264,7 @@ type WithControllerCallback = ({ triggerNetworkDidChange, }: { controller: TokenDetectionController; + mockGetAccount: (internalAccount: InternalAccount) => void; mockGetSelectedAccount: (address: string) => void; mockKeyringGetState: (state: KeyringControllerState) => void; mockTokensGetState: (state: TokensControllerState) => void; @@ -2034,6 +2292,10 @@ type WithControllerOptions = { options?: Partial[0]>; isKeyringUnlocked?: boolean; messenger?: ControllerMessenger; + mocks?: { + getAccount?: InternalAccount; + getSelectedAccount?: InternalAccount; + }; }; type WithControllerArgs = @@ -2053,16 +2315,25 @@ async function withController( ...args: WithControllerArgs ): Promise { const [{ ...rest }, fn] = args.length === 2 ? args : [{}, args[0]]; - const { options, isKeyringUnlocked, messenger } = rest; + const { options, isKeyringUnlocked, messenger, mocks } = rest; const controllerMessenger = messenger ?? new ControllerMessenger(); + const mockGetAccount = jest.fn(); + controllerMessenger.registerActionHandler( + 'AccountsController:getAccount', + mockGetAccount.mockReturnValue( + mocks?.getAccount ?? createMockInternalAccount({ address: '0x1' }), + ), + ); + const mockGetSelectedAccount = jest.fn(); controllerMessenger.registerActionHandler( 'AccountsController:getSelectedAccount', - mockGetSelectedAccount.mockReturnValue({ - address: '0x1', - } as InternalAccount), + mockGetSelectedAccount.mockReturnValue( + mocks?.getSelectedAccount ?? + createMockInternalAccount({ address: '0x1' }), + ), ); const mockKeyringState = jest.fn(); controllerMessenger.registerActionHandler( @@ -2140,6 +2411,9 @@ async function withController( try { return await fn({ controller, + mockGetAccount: (internalAccount: InternalAccount) => { + mockGetAccount.mockReturnValue(internalAccount); + }, mockGetSelectedAccount: (address: string) => { mockGetSelectedAccount.mockReturnValue({ address } as InternalAccount); }, @@ -2195,7 +2469,7 @@ async function withController( }, triggerSelectedAccountChange: (account: InternalAccount) => { controllerMessenger.publish( - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', account, ); }, diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index e432999867..64572b9a43 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -1,6 +1,7 @@ import type { AccountsControllerGetSelectedAccountAction, - AccountsControllerSelectedAccountChangeEvent, + AccountsControllerGetAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, } from '@metamask/accounts-controller'; import type { RestrictedControllerMessenger, @@ -105,6 +106,7 @@ export type TokenDetectionControllerActions = export type AllowedActions = | AccountsControllerGetSelectedAccountAction + | AccountsControllerGetAccountAction | NetworkControllerGetNetworkClientByIdAction | NetworkControllerGetNetworkConfigurationByNetworkClientId | NetworkControllerGetStateAction @@ -121,7 +123,7 @@ export type TokenDetectionControllerEvents = TokenDetectionControllerStateChangeEvent; export type AllowedEvents = - | AccountsControllerSelectedAccountChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent | NetworkControllerNetworkDidChangeEvent | TokenListStateChange | KeyringControllerLockEvent @@ -153,7 +155,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< > { #intervalId?: ReturnType; - #selectedAddress: string; + #selectedAccountId: string; #networkClientId: NetworkClientId; @@ -190,19 +192,16 @@ export class TokenDetectionController extends StaticIntervalPollingController< * @param options.messenger - The controller messaging system. * @param options.disabled - If set to true, all network requests are blocked. * @param options.interval - Polling interval used to fetch new token rates - * @param options.selectedAddress - Vault selected address * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. * @param options.trackMetaMetricsEvent - Sets options for MetaMetrics event tracking. */ constructor({ - selectedAddress, interval = DEFAULT_INTERVAL, disabled = true, getBalancesInSingleCall, trackMetaMetricsEvent, messenger, }: { - selectedAddress?: string; interval?: number; disabled?: boolean; getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; @@ -231,10 +230,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.#disabled = disabled; this.setIntervalLength(interval); - this.#selectedAddress = - selectedAddress ?? - this.messagingSystem.call('AccountsController:getSelectedAccount') - .address; + this.#selectedAccountId = this.#getSelectedAccount().id; const { chainId, networkClientId } = this.#getCorrectChainIdAndNetworkClientId(); @@ -291,34 +287,32 @@ export class TokenDetectionController extends StaticIntervalPollingController< 'PreferencesController:stateChange', // TODO: Either fix this lint violation or explain why it's necessary to ignore. // eslint-disable-next-line @typescript-eslint/no-misused-promises - async ({ selectedAddress: newSelectedAddress, useTokenDetection }) => { - const isSelectedAddressChanged = - this.#selectedAddress !== newSelectedAddress; + async ({ useTokenDetection }) => { + const selectedAccount = this.#getSelectedAccount(); const isDetectionChangedFromPreferences = this.#isDetectionEnabledFromPreferences !== useTokenDetection; - this.#selectedAddress = newSelectedAddress; this.#isDetectionEnabledFromPreferences = useTokenDetection; - if (isSelectedAddressChanged || isDetectionChangedFromPreferences) { + if (isDetectionChangedFromPreferences) { await this.#restartTokenDetection({ - selectedAddress: this.#selectedAddress, + selectedAddress: selectedAccount.address, }); } }, ); this.messagingSystem.subscribe( - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', // TODO: Either fix this lint violation or explain why it's necessary to ignore. // eslint-disable-next-line @typescript-eslint/no-misused-promises - async ({ address: newSelectedAddress }) => { - const isSelectedAddressChanged = - this.#selectedAddress !== newSelectedAddress; - if (isSelectedAddressChanged) { - this.#selectedAddress = newSelectedAddress; + async (selectedAccount) => { + const isSelectedAccountIdChanged = + this.#selectedAccountId !== selectedAccount.id; + if (isSelectedAccountIdChanged) { + this.#selectedAccountId = selectedAccount.id; await this.#restartTokenDetection({ - selectedAddress: this.#selectedAddress, + selectedAddress: selectedAccount.address, }); } }, @@ -493,7 +487,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< } const addressAgainstWhichToDetect = - selectedAddress ?? this.#selectedAddress; + selectedAddress ?? this.#getSelectedAddress(); const { chainId, networkClientId: selectedNetworkClientId } = this.#getCorrectChainIdAndNetworkClientId(networkClientId); const chainIdAgainstWhichToDetect = chainId; @@ -637,6 +631,19 @@ export class TokenDetectionController extends StaticIntervalPollingController< } }); } + + #getSelectedAccount() { + return this.messagingSystem.call('AccountsController:getSelectedAccount'); + } + + #getSelectedAddress() { + // If the address is not defined (or empty), we fallback to the currently selected account's address + const account = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + return account?.address || ''; + } } export default TokenDetectionController; diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index df8833ff4e..3f8404ae36 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1,3 +1,4 @@ +import { createMockInternalAccount } from '@metamask/accounts-controller/src/tests/mocks'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import { ControllerMessenger } from '@metamask/base-controller'; import { @@ -7,16 +8,13 @@ import { toChecksumHexAddress, toHex, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkState, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; import type { NetworkClientConfiguration } from '@metamask/network-controller/src/types'; -import { - getDefaultPreferencesState, - type PreferencesState, -} from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { add0x } from '@metamask/utils'; import assert from 'assert'; @@ -45,6 +43,9 @@ import { getDefaultTokensState } from './TokensController'; import type { TokensControllerState } from './TokensController'; const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; +const defaultSelectedAccount = createMockInternalAccount({ + address: defaultSelectedAddress, +}); const mockTokenAddress = '0x0000000000000000000000000000000000000010'; const defaultSelectedNetworkClientId = 'AAAA-BBBB-CCCC-DDDD'; @@ -68,12 +69,13 @@ function buildTokenRatesControllerMessenger( 'TokensController:getState', 'NetworkController:getNetworkClientById', 'NetworkController:getState', - 'PreferencesController:getState', + 'AccountsController:getAccount', + 'AccountsController:getSelectedAccount', ], allowedEvents: [ - 'PreferencesController:stateChange', 'TokensController:stateChange', 'NetworkController:stateChange', + 'AccountsController:selectedEvmAccountChange', ], }); } @@ -992,6 +994,9 @@ describe('TokenRatesController', () => { it('should update exchange rates when selected address changes', async () => { const alternateSelectedAddress = '0x0000000000000000000000000000000000000002'; + const alternateSelectedAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); await withController( { options: { @@ -1018,69 +1023,26 @@ describe('TokenRatesController', () => { }, }, }, - async ({ controller, triggerPreferencesStateChange }) => { + async ({ controller, triggerSelectedAccountChange }) => { await controller.start(); const updateExchangeRatesSpy = jest .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: alternateSelectedAddress, - }); + triggerSelectedAccountChange(alternateSelectedAccount); expect(updateExchangeRatesSpy).toHaveBeenCalledTimes(1); }, ); }); - - it('should not update exchange rates when preferences state changes without selected address changing', async () => { - await withController( - { - options: { - interval: 100, - }, - mockTokensControllerState: { - allTokens: { - '0x1': { - [defaultSelectedAddress]: [ - { - address: '0x02', - decimals: 0, - symbol: '', - aggregators: [], - }, - { - address: '0x03', - decimals: 0, - symbol: '', - aggregators: [], - }, - ], - }, - }, - }, - }, - async ({ controller, triggerPreferencesStateChange }) => { - await controller.start(); - const updateExchangeRatesSpy = jest - .spyOn(controller, 'updateExchangeRates') - .mockResolvedValue(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: defaultSelectedAddress, - openSeaEnabled: false, - }); - - expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); - }, - ); - }); }); describe('when polling is inactive', () => { - it('should not update exchange rates when selected address changes', async () => { + it('does not update exchange rates when selected account changes', async () => { const alternateSelectedAddress = '0x0000000000000000000000000000000000000002'; + const alternateSelectedAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); await withController( { options: { @@ -1107,14 +1069,11 @@ describe('TokenRatesController', () => { }, }, }, - async ({ controller, triggerPreferencesStateChange }) => { + async ({ controller, triggerSelectedAccountChange }) => { const updateExchangeRatesSpy = jest .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: alternateSelectedAddress, - }); + triggerSelectedAccountChange(alternateSelectedAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }, @@ -2303,12 +2262,12 @@ describe('TokenRatesController', () => { */ type WithControllerCallback = ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, triggerTokensStateChange, triggerNetworkStateChange, }: { controller: TokenRatesController; - triggerPreferencesStateChange: (state: PreferencesState) => void; + triggerSelectedAccountChange: (state: InternalAccount) => void; triggerTokensStateChange: (state: TokensControllerState) => void; triggerNetworkStateChange: (state: NetworkState) => void; }) => Promise | ReturnValue; @@ -2377,13 +2336,16 @@ async function withController( }), ); - const mockPreferencesState = jest.fn(); + const mockGetSelectedAccount = jest.fn(); controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - mockPreferencesState.mockReturnValue({ - ...getDefaultPreferencesState(), - selectedAddress: defaultSelectedAddress, - }), + 'AccountsController:getSelectedAccount', + mockGetSelectedAccount.mockReturnValue(defaultSelectedAccount), + ); + + const mockGetAccount = jest.fn(); + controllerMessenger.registerActionHandler( + 'AccountsController:getAccount', + mockGetAccount.mockReturnValue(defaultSelectedAccount), ); const controller = new TokenRatesController({ @@ -2394,13 +2356,13 @@ async function withController( try { return await fn({ controller, - triggerPreferencesStateChange: (state: PreferencesState) => { + triggerSelectedAccountChange: (account: InternalAccount) => { controllerMessenger.publish( - 'PreferencesController:stateChange', - state, - [], + 'AccountsController:selectedEvmAccountChange', + account, ); }, + triggerTokensStateChange: (state: TokensControllerState) => { controllerMessenger.publish('TokensController:stateChange', state, []); }, diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index c7380171fb..534ca176bf 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -1,3 +1,8 @@ +import type { + AccountsControllerGetAccountAction, + AccountsControllerGetSelectedAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; import type { ControllerGetStateAction, ControllerStateChangeEvent, @@ -9,6 +14,7 @@ import { FALL_BACK_VS_CURRENCY, toHex, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkControllerGetNetworkClientByIdAction, @@ -16,10 +22,6 @@ import type { NetworkControllerStateChangeEvent, } from '@metamask/network-controller'; import { StaticIntervalPollingController } from '@metamask/polling-controller'; -import type { - PreferencesControllerGetStateAction, - PreferencesControllerStateChangeEvent, -} from '@metamask/preferences-controller'; import { createDeferredPromise, type Hex } from '@metamask/utils'; import { isEqual } from 'lodash'; @@ -103,15 +105,16 @@ export type AllowedActions = | TokensControllerGetStateAction | NetworkControllerGetNetworkClientByIdAction | NetworkControllerGetStateAction - | PreferencesControllerGetStateAction; + | AccountsControllerGetAccountAction + | AccountsControllerGetSelectedAccountAction; /** * The external events available to the {@link TokenRatesController}. */ export type AllowedEvents = - | PreferencesControllerStateChangeEvent | TokensControllerStateChangeEvent - | NetworkControllerStateChangeEvent; + | NetworkControllerStateChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent; /** * The name of the {@link TokenRatesController}. @@ -235,7 +238,7 @@ export class TokenRatesController extends StaticIntervalPollingController< #inProcessExchangeRateUpdates: Record<`${Hex}:${string}`, Promise> = {}; - #selectedAddress: string; + #selectedAccountId: string; #disabled: boolean; @@ -289,36 +292,17 @@ export class TokenRatesController extends StaticIntervalPollingController< this.#chainId = currentChainId; this.#ticker = currentTicker; - this.#selectedAddress = this.#getSelectedAddress(); + this.#selectedAccountId = this.#getSelectedAccount().id; const { allTokens, allDetectedTokens } = this.#getTokensControllerState(); this.#allTokens = allTokens; this.#allDetectedTokens = allDetectedTokens; - this.#subscribeToPreferencesStateChange(); - this.#subscribeToTokensStateChange(); this.#subscribeToNetworkStateChange(); - } - #subscribeToPreferencesStateChange() { - this.messagingSystem.subscribe( - 'PreferencesController:stateChange', - // TODO: Either fix this lint violation or explain why it's necessary to ignore. - // eslint-disable-next-line @typescript-eslint/no-misused-promises - async (selectedAddress: string) => { - if (this.#selectedAddress !== selectedAddress) { - this.#selectedAddress = selectedAddress; - if (this.#pollState === PollState.Active) { - await this.updateExchangeRates(); - } - } - }, - ({ selectedAddress }) => { - return selectedAddress; - }, - ); + this.#subscribeToAccountChange(); } #subscribeToTokensStateChange() { @@ -372,6 +356,22 @@ export class TokenRatesController extends StaticIntervalPollingController< ); } + #subscribeToAccountChange() { + this.messagingSystem.subscribe( + 'AccountsController:selectedEvmAccountChange', + // TODO: Either fix this lint violation or explain why it's necessary to ignore. + // eslint-disable-next-line @typescript-eslint/no-misused-promises + async (selectedAccount) => { + if (this.#selectedAccountId !== selectedAccount.id) { + this.#selectedAccountId = selectedAccount.id; + if (this.#pollState === PollState.Active) { + await this.updateExchangeRates(); + } + } + }, + ); + } + /** * Get the user's tokens for the given chain. * @@ -379,9 +379,14 @@ export class TokenRatesController extends StaticIntervalPollingController< * @returns The list of tokens addresses for the current chain */ #getTokenAddresses(chainId: Hex): Hex[] { - const tokens = this.#allTokens[chainId]?.[this.#selectedAddress] || []; + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + const selectedAddress = selectedAccount?.address ?? ''; + const tokens = this.#allTokens[chainId]?.[selectedAddress] || []; const detectedTokens = - this.#allDetectedTokens[chainId]?.[this.#selectedAddress] || []; + this.#allDetectedTokens[chainId]?.[selectedAddress] || []; return [ ...new Set( @@ -423,12 +428,12 @@ export class TokenRatesController extends StaticIntervalPollingController< this.#pollState = PollState.Inactive; } - #getSelectedAddress(): string { - const { selectedAddress } = this.messagingSystem.call( - 'PreferencesController:getState', + #getSelectedAccount(): InternalAccount { + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', ); - return selectedAddress; + return selectedAccount; } #getChainIdAndTicker(): { diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 2d8fb47cb1..38692b6404 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -13,18 +13,18 @@ import { convertHexToDecimal, InfuraNetworkType, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientConfiguration, NetworkClientId, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; -import { getDefaultPreferencesState } from '@metamask/preferences-controller'; import nock from 'nock'; import * as sinon from 'sinon'; import { v1 as uuidV1 } from 'uuid'; import { FakeProvider } from '../../../tests/fake-provider'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { ExtractAvailableAction, ExtractAvailableEvent, @@ -39,12 +39,17 @@ import { TOKEN_END_POINT_API } from './token-service'; import type { Token } from './TokenRatesController'; import { TokensController } from './TokensController'; import type { + AllowedActions, + AllowedEvents, TokensControllerMessenger, TokensControllerState, } from './TokensController'; jest.mock('@ethersproject/contracts'); -jest.mock('uuid'); +jest.mock('uuid', () => ({ + ...jest.requireActual('uuid'), + v1: jest.fn(), +})); jest.mock('./Standards/ERC20Standard'); jest.mock('./Standards/NftStandards/ERC1155/ERC1155Standard'); @@ -58,6 +63,10 @@ const uuidV1Mock = jest.mocked(uuidV1); const ERC20StandardMock = jest.mocked(ERC20Standard); const ERC1155StandardMock = jest.mocked(ERC1155Standard); +const defaultMockInternalAccount = createMockInternalAccount({ + address: '0x1', +}); + describe('TokensController', () => { beforeEach(() => { uuidV1Mock.mockReturnValue('9b1deb4d-3b7d-4bad-9bdd-2b0d7b3dcb6d'); @@ -265,33 +274,36 @@ describe('TokensController', () => { }); it('should add token by selected address', async () => { + const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); + const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); await withController( - async ({ controller, triggerPreferencesStateChange }) => { + { + mocks: { + getAccount: firstAccount, + getSelectedAccount: firstAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - const firstAddress = '0x123'; - const secondAddress = '0x321'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x01', symbol: 'bar', decimals: 2, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondAddress, - }); + triggerSelectedAccountChange(secondAccount); expect(controller.state.tokens).toHaveLength(0); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x01', decimals: 2, @@ -407,26 +419,33 @@ describe('TokensController', () => { }); it('should remove token by selected address', async () => { + const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); + const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); await withController( - async ({ controller, triggerPreferencesStateChange }) => { + { + mocks: { + getAccount: firstAccount, + getSelectedAccount: firstAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - const firstAddress = '0x123'; - const secondAddress = '0x321'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + + triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x02', symbol: 'baz', decimals: 2, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondAddress, - }); + triggerSelectedAccountChange(secondAccount); await controller.addToken({ address: '0x01', symbol: 'bar', @@ -436,10 +455,7 @@ describe('TokensController', () => { controller.ignoreTokens(['0x01']); expect(controller.state.tokens).toHaveLength(0); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x02', decimals: 2, @@ -519,17 +535,19 @@ describe('TokensController', () => { }); it('should remove a token from the ignoredTokens/allIgnoredTokens lists if re-added as part of a bulk addTokens add', async () => { + const selectedAddress = '0x0001'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( - async ({ - controller, - triggerPreferencesStateChange, - changeNetwork, - }) => { - const selectedAddress = '0x0001'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, - }); + { + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange, changeNetwork }) => { + triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -566,17 +584,19 @@ describe('TokensController', () => { }); it('should be able to clear the ignoredTokens list', async () => { + const selectedAddress = '0x0001'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( - async ({ - controller, - triggerPreferencesStateChange, - changeNetwork, - }) => { - const selectedAddress = '0x0001'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, - }); + { + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange, changeNetwork }) => { + triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -603,18 +623,23 @@ describe('TokensController', () => { }); it('should ignore tokens by [chainID][accountAddress]', async () => { + const selectedAddress1 = '0x0001'; + const selectedAccount1 = createMockInternalAccount({ + address: selectedAddress1, + }); + const selectedAddress2 = '0x0002'; + const selectedAccount2 = createMockInternalAccount({ + address: selectedAddress2, + }); await withController( - async ({ - controller, - triggerPreferencesStateChange, - changeNetwork, - }) => { - const selectedAddress1 = '0x0001'; - const selectedAddress2 = '0x0002'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: selectedAddress1, - }); + { + mocks: { + getSelectedAccount: selectedAccount1, + getAccount: selectedAccount1, + }, + }, + async ({ controller, triggerSelectedAccountChange, changeNetwork }) => { + triggerSelectedAccountChange(selectedAccount1); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -638,10 +663,7 @@ describe('TokensController', () => { controller.ignoreTokens(['0x02']); expect(controller.state.ignoredTokens).toStrictEqual(['0x02']); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: selectedAddress2, - }); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.ignoredTokens).toHaveLength(0); await controller.addToken({ @@ -889,7 +911,9 @@ describe('TokensController', () => { symbol: 'LINK', decimals: 18, }); - changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); + changeNetwork({ + selectedNetworkClientId: InfuraNetworkType.goerli, + }); await expect(addTokenPromise).rejects.toThrow( 'TokensController Error: Switched networks while adding token', @@ -969,12 +993,17 @@ describe('TokensController', () => { }); it('should add tokens to the correct chainId/selectedAddress on which they were detected even if its not the currently configured chainId/selectedAddress', async () => { + const CONFIGURED_ADDRESS = '0xConfiguredAddress'; + const configuredAccount = createMockInternalAccount({ + address: CONFIGURED_ADDRESS, + }); await withController( - async ({ - controller, - changeNetwork, - triggerPreferencesStateChange, - }) => { + { + mocks: { + getAccount: configuredAccount, + }, + }, + async ({ controller, changeNetwork, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -982,14 +1011,11 @@ describe('TokensController', () => { // The currently configured chain + address const CONFIGURED_CHAIN = ChainId.sepolia; const CONFIGURED_NETWORK_CLIENT_ID = InfuraNetworkType.sepolia; - const CONFIGURED_ADDRESS = '0xConfiguredAddress'; + changeNetwork({ selectedNetworkClientId: CONFIGURED_NETWORK_CLIENT_ID, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: CONFIGURED_ADDRESS, - }); + triggerSelectedAccountChange(configuredAccount); // A different chain + address const OTHER_CHAIN = '0xOtherChainId'; @@ -1572,7 +1598,6 @@ describe('TokensController', () => { buildMockEthersERC721Contract({ supportsInterface: false }), ); uuidV1Mock.mockReturnValue(requestId); - await controller.watchAsset({ asset, type: 'ERC20' }); expect(controller.state.tokens).toHaveLength(1); @@ -1723,7 +1748,6 @@ describe('TokensController', () => { buildMockEthersERC721Contract({ supportsInterface: false }), ); uuidV1Mock.mockReturnValue(requestId); - await expect( controller.watchAsset({ asset, type: 'ERC20' }), ).rejects.toThrow(errorMessage); @@ -1845,15 +1869,23 @@ describe('TokensController', () => { describe('when PreferencesController:stateChange is published', () => { it('should update tokens list when set address changes', async () => { + const selectedAccount = createMockInternalAccount({ address: '0x1' }); + const selectedAccount2 = createMockInternalAccount({ + address: '0x2', + }); await withController( - async ({ controller, triggerPreferencesStateChange }) => { + { + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x1', - }); + + triggerSelectedAccountChange(selectedAccount); await controller.addToken({ address: '0x01', symbol: 'A', @@ -1864,10 +1896,7 @@ describe('TokensController', () => { symbol: 'B', decimals: 5, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x2', - }); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([]); await controller.addToken({ @@ -1875,10 +1904,7 @@ describe('TokensController', () => { symbol: 'C', decimals: 6, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x1', - }); + triggerSelectedAccountChange(selectedAccount); expect(controller.state.tokens).toStrictEqual([ { address: '0x01', @@ -1902,10 +1928,7 @@ describe('TokensController', () => { }, ]); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x2', - }); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([ { address: '0x03', @@ -2012,6 +2035,9 @@ describe('TokensController', () => { describe('Clearing nested lists', () => { it('should clear nest allTokens under chain ID and selected address when an added token is ignored', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2027,7 +2053,9 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ controller }) => { @@ -2043,6 +2071,9 @@ describe('TokensController', () => { it('should clear nest allIgnoredTokens under chain ID and selected address when an ignored token is re-added', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2058,7 +2089,9 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ controller }) => { @@ -2075,6 +2108,9 @@ describe('TokensController', () => { it('should clear nest allDetectedTokens under chain ID and selected address when an detected token is added to tokens list', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2090,7 +2126,9 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ controller }) => { @@ -2160,6 +2198,117 @@ describe('TokensController', () => { }); }); }); + + describe('when selectedAccountId is not set or account not found', () => { + describe('detectTokens', () => { + it('updates the token states to empty arrays if the selectedAccountId account is undefined', async () => { + await withController(async ({ controller, changeNetwork }) => { + ContractMock.mockReturnValue( + buildMockEthersERC721Contract({ supportsInterface: false }), + ); + changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); + + expect(controller.state.tokens).toStrictEqual([]); + expect(controller.state.ignoredTokens).toStrictEqual([]); + expect(controller.state.detectedTokens).toStrictEqual([]); + }); + }); + }); + + describe('addToken', () => { + it('handles undefined selected account', async () => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(undefined); + const contractAddresses = Object.keys(contractMaps); + const erc721ContractAddresses = contractAddresses.filter( + (contractAddress) => contractMaps[contractAddress].erc721 === true, + ); + const address = erc721ContractAddresses[0]; + const { symbol, decimals } = contractMaps[address]; + + await controller.addToken({ address, symbol, decimals }); + + expect(controller.state.tokens).toStrictEqual([ + { + address, + aggregators: [], + decimals, + image: + 'https://static.cx.metamask.io/api/v1/tokenIcons/1/0x9c8ff314c9bc7f6e59a9d9225fb22946427edc03.png', + isERC721: true, + name: undefined, + symbol, + }, + ]); + }); + }); + }); + + describe('addDetectedTokens', () => { + it('handles an undefined selected account', async () => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(undefined); + const mockToken = { + address: '0x01', + symbol: 'barA', + decimals: 2, + aggregators: [], + }; + await controller.addDetectedTokens([mockToken]); + expect(controller.state.detectedTokens[0]).toStrictEqual({ + ...mockToken, + image: undefined, + isERC721: undefined, + name: undefined, + }); + }); + }); + }); + + describe('watchAsset', () => { + it('handles undefined selected account', async () => { + await withController( + async ({ controller, approvalController, getAccountHandler }) => { + const requestId = '12345'; + const addAndShowApprovalRequestSpy = jest + .spyOn(approvalController, 'addAndShowApprovalRequest') + .mockResolvedValue(undefined); + const asset = buildToken(); + ContractMock.mockReturnValue( + buildMockEthersERC721Contract({ supportsInterface: false }), + ); + uuidV1Mock.mockReturnValue(requestId); + getAccountHandler.mockReturnValue(undefined); + await controller.watchAsset({ asset, type: 'ERC20' }); + + expect(controller.state.tokens).toHaveLength(1); + expect(controller.state.tokens).toStrictEqual([ + { + address: '0x000000000000000000000000000000000000dEaD', + aggregators: [], + decimals: 12, + image: 'image', + isERC721: false, + name: undefined, + symbol: 'TOKEN', + }, + ]); + expect(addAndShowApprovalRequestSpy).toHaveBeenCalledTimes(1); + expect(addAndShowApprovalRequestSpy).toHaveBeenCalledWith({ + id: requestId, + origin: ORIGIN_METAMASK, + type: ApprovalType.WatchAsset, + requestData: { + id: requestId, + interactingAddress: '', // this is the default value if account is not found + asset, + }, + }); + }, + ); + }); + }); + }); }); type WithControllerCallback = ({ @@ -2167,7 +2316,7 @@ type WithControllerCallback = ({ changeNetwork, messenger, approvalController, - triggerPreferencesStateChange, + triggerSelectedAccountChange, }: { controller: TokensController; changeNetwork: (networkControllerState: { @@ -2175,9 +2324,16 @@ type WithControllerCallback = ({ }) => void; messenger: UnrestrictedMessenger; approvalController: ApprovalController; - triggerPreferencesStateChange: (state: PreferencesState) => void; + triggerSelectedAccountChange: (internalAccount: InternalAccount) => void; + getAccountHandler: jest.Mock; + getSelectedAccountHandler: jest.Mock; }) => Promise | ReturnValue; +type WithControllerMockArgs = { + getAccount?: InternalAccount; + getSelectedAccount?: InternalAccount; +}; + type WithControllerArgs = | [WithControllerCallback] | [ @@ -2187,6 +2343,7 @@ type WithControllerArgs = NetworkClientId, NetworkClientConfiguration >; + mocks?: WithControllerMockArgs; }, WithControllerCallback, ]; @@ -2201,17 +2358,22 @@ type WithControllerArgs = * @param args.mockNetworkClientConfigurationsByNetworkClientId - Used to construct * mock versions of network clients and ultimately mock the * `NetworkController:getNetworkClientById` action. + * @param args.mocks - Move values for actions to be mocked. * @returns A collection of test controllers and mocks. */ async function withController( ...args: WithControllerArgs ): Promise { const [ - { options = {}, mockNetworkClientConfigurationsByNetworkClientId = {} }, + { + options = {}, + mockNetworkClientConfigurationsByNetworkClientId = {}, + mocks = {} as WithControllerMockArgs, + }, fn, ] = args.length === 2 ? args : [{}, args[0]]; - const messenger: UnrestrictedMessenger = new ControllerMessenger(); + const messenger = new ControllerMessenger(); const approvalControllerMessenger = messenger.getRestricted({ name: 'ApprovalController', @@ -2229,16 +2391,34 @@ async function withController( allowedActions: [ 'ApprovalController:addRequest', 'NetworkController:getNetworkClientById', + 'AccountsController:getAccount', + 'AccountsController:getSelectedAccount', ], allowedEvents: [ 'NetworkController:networkDidChange', - 'PreferencesController:stateChange', + 'AccountsController:selectedEvmAccountChange', 'TokenListController:stateChange', ], }); + + const getAccountHandler = jest.fn(); + messenger.registerActionHandler( + 'AccountsController:getAccount', + getAccountHandler.mockReturnValue( + mocks?.getAccount ?? defaultMockInternalAccount, + ), + ); + + const getSelectedAccountHandler = jest.fn(); + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + getSelectedAccountHandler.mockReturnValue( + mocks?.getSelectedAccount ?? defaultMockInternalAccount, + ), + ); + const controller = new TokensController({ chainId: ChainId.mainnet, - selectedAddress: '0x1', // The tests assume that this is set, but they shouldn't make that // assumption. But we have to do this due to a bug in TokensController // where the provider can possibly be `undefined` if `networkClientId` is @@ -2248,8 +2428,12 @@ async function withController( ...options, }); - const triggerPreferencesStateChange = (state: PreferencesState) => { - messenger.publish('PreferencesController:stateChange', state, []); + const triggerSelectedAccountChange = (internalAccount: InternalAccount) => { + getAccountHandler.mockReturnValue(internalAccount); + messenger.publish( + 'AccountsController:selectedEvmAccountChange', + internalAccount, + ); }; const changeNetwork = ({ @@ -2276,7 +2460,9 @@ async function withController( changeNetwork, messenger, approvalController, - triggerPreferencesStateChange, + triggerSelectedAccountChange, + getAccountHandler, + getSelectedAccountHandler, }); } diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index b90df2f781..9345f8b334 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -1,5 +1,10 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; +import type { + AccountsControllerGetAccountAction, + AccountsControllerGetSelectedAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import type { RestrictedControllerMessenger, @@ -19,6 +24,7 @@ import { isValidHexAddress, safelyExecute, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import { abiERC721 } from '@metamask/metamask-eth-abis'; import type { NetworkClientId, @@ -27,10 +33,6 @@ import type { NetworkState, Provider, } from '@metamask/network-controller'; -import type { - PreferencesControllerStateChangeEvent, - PreferencesState, -} from '@metamask/preferences-controller'; import { rpcErrors } from '@metamask/rpc-errors'; import type { Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; @@ -136,7 +138,9 @@ export type TokensControllerAddDetectedTokensAction = { */ export type AllowedActions = | AddApprovalRequest - | NetworkControllerGetNetworkClientByIdAction; + | NetworkControllerGetNetworkClientByIdAction + | AccountsControllerGetAccountAction + | AccountsControllerGetSelectedAccountAction; export type TokensControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -147,8 +151,8 @@ export type TokensControllerEvents = TokensControllerStateChangeEvent; export type AllowedEvents = | NetworkControllerNetworkDidChangeEvent - | PreferencesControllerStateChangeEvent - | TokenListStateChange; + | TokenListStateChange + | AccountsControllerSelectedEvmAccountChangeEvent; /** * The messenger of the {@link TokensController}. @@ -184,7 +188,7 @@ export class TokensController extends BaseController< #chainId: Hex; - #selectedAddress: string; + #selectedAccountId: string; #provider: Provider | undefined; @@ -194,20 +198,17 @@ export class TokensController extends BaseController< * Tokens controller options * @param options - Constructor options. * @param options.chainId - The chain ID of the current network. - * @param options.selectedAddress - Vault selected address * @param options.provider - Network provider. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messenger. */ constructor({ chainId: initialChainId, - selectedAddress, provider, state, messenger, }: { chainId: Hex; - selectedAddress: string; provider: Provider | undefined; state?: Partial; messenger: TokensControllerMessenger; @@ -226,7 +227,7 @@ export class TokensController extends BaseController< this.#provider = provider; - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = this.#getSelectedAccount().id; this.#abortController = new AbortController(); @@ -236,8 +237,8 @@ export class TokensController extends BaseController< ); this.messagingSystem.subscribe( - 'PreferencesController:stateChange', - this.#onPreferenceControllerStateChange.bind(this), + 'AccountsController:selectedEvmAccountChange', + this.#onSelectedAccountChange.bind(this), ); this.messagingSystem.subscribe( @@ -273,29 +274,28 @@ export class TokensController extends BaseController< this.#abortController.abort(); this.#abortController = new AbortController(); this.#chainId = chainId; + const selectedAddress = this.#getSelectedAddress(); this.update((state) => { - state.tokens = allTokens[chainId]?.[this.#selectedAddress] || []; - state.ignoredTokens = - allIgnoredTokens[chainId]?.[this.#selectedAddress] || []; + state.tokens = allTokens[chainId]?.[selectedAddress] || []; + state.ignoredTokens = allIgnoredTokens[chainId]?.[selectedAddress] || []; state.detectedTokens = - allDetectedTokens[chainId]?.[this.#selectedAddress] || []; + allDetectedTokens[chainId]?.[selectedAddress] || []; }); } /** - * Handles the state change of the preference controller. - * @param preferencesState - The new state of the preference controller. - * @param preferencesState.selectedAddress - The current selected address of the preference controller. + * Handles the selected account change in the accounts controller. + * @param selectedAccount - The new selected account */ - #onPreferenceControllerStateChange({ selectedAddress }: PreferencesState) { + #onSelectedAccountChange(selectedAccount: InternalAccount) { const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = selectedAccount.id; this.update((state) => { - state.tokens = allTokens[this.#chainId]?.[selectedAddress] ?? []; + state.tokens = allTokens[this.#chainId]?.[selectedAccount.address] ?? []; state.ignoredTokens = - allIgnoredTokens[this.#chainId]?.[selectedAddress] ?? []; + allIgnoredTokens[this.#chainId]?.[selectedAccount.address] ?? []; state.detectedTokens = - allDetectedTokens[this.#chainId]?.[selectedAddress] ?? []; + allDetectedTokens[this.#chainId]?.[selectedAccount.address] ?? []; }); } @@ -357,7 +357,6 @@ export class TokensController extends BaseController< networkClientId?: NetworkClientId; }): Promise { const chainId = this.#chainId; - const selectedAddress = this.#selectedAddress; const releaseLock = await this.#mutex.acquire(); const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; let currentChainId = chainId; @@ -368,9 +367,10 @@ export class TokensController extends BaseController< ).configuration.chainId; } - const accountAddress = interactingAddress || selectedAddress; - const isInteractingWithWalletAccount = accountAddress === selectedAddress; - + const accountAddress = + this.#getAddressOrSelectedAddress(interactingAddress); + const isInteractingWithWalletAccount = + this.#isInteractingWithWallet(accountAddress); try { address = toChecksumHexAddress(address); const tokens = allTokens[currentChainId]?.[accountAddress] || []; @@ -578,10 +578,10 @@ export class TokensController extends BaseController< ) { const releaseLock = await this.#mutex.acquire(); - // Get existing tokens for the chain + account const chainId = detectionDetails?.chainId ?? this.#chainId; + // Previously selectedAddress could be an empty string. This is to preserve the behaviour const accountAddress = - detectionDetails?.selectedAddress ?? this.#selectedAddress; + detectionDetails?.selectedAddress ?? this.#getSelectedAddress(); const { allTokens, allDetectedTokens, allIgnoredTokens } = this.state; let newTokens = [...(allTokens?.[chainId]?.[accountAddress] ?? [])]; @@ -648,9 +648,11 @@ export class TokensController extends BaseController< // We may be detecting tokens on a different chain/account pair than are currently configured. // Re-point `tokens` and `detectedTokens` to keep them referencing the current chain/account. - newTokens = newAllTokens?.[this.#chainId]?.[this.#selectedAddress] || []; + const selectedAddress = this.#getSelectedAddress(); + + newTokens = newAllTokens?.[this.#chainId]?.[selectedAddress] || []; newDetectedTokens = - newAllDetectedTokens?.[this.#chainId]?.[this.#selectedAddress] || []; + newAllDetectedTokens?.[this.#chainId]?.[selectedAddress] || []; this.update((state) => { state.tokens = newTokens; @@ -806,6 +808,9 @@ export class TokensController extends BaseController< throw rpcErrors.invalidParams(`Invalid address "${asset.address}"`); } + const selectedAddress = + this.#getAddressOrSelectedAddress(interactingAddress); + // Validate contract if (await this.#detectIsERC721(asset.address, networkClientId)) { @@ -906,7 +911,7 @@ export class TokensController extends BaseController< id: this.#generateRandomId(), time: Date.now(), type, - interactingAddress: interactingAddress || this.#selectedAddress, + interactingAddress: selectedAddress, }; await this.#requestApproval(suggestedAssetMeta); @@ -951,7 +956,9 @@ export class TokensController extends BaseController< } = params; const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - const userAddressToAddTokens = interactingAddress ?? this.#selectedAddress; + const userAddressToAddTokens = + this.#getAddressOrSelectedAddress(interactingAddress); + const chainIdToAddTokens = interactingChainId ?? this.#chainId; let newAllTokens = allTokens; @@ -1013,6 +1020,20 @@ export class TokensController extends BaseController< return { newAllTokens, newAllIgnoredTokens, newAllDetectedTokens }; } + #getAddressOrSelectedAddress(address: string | undefined): string { + if (address) { + return address; + } + + return this.#getSelectedAddress(); + } + + #isInteractingWithWallet(address: string | undefined) { + const selectedAddress = this.#getSelectedAddress(); + + return selectedAddress === address; + } + /** * Removes all tokens from the ignored list. */ @@ -1044,6 +1065,19 @@ export class TokensController extends BaseController< true, ); } + + #getSelectedAccount() { + return this.messagingSystem.call('AccountsController:getSelectedAccount'); + } + + #getSelectedAddress() { + // If the address is not defined (or empty), we fallback to the currently selected account's address + const account = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + return account?.address || ''; + } } export default TokensController;