From 1e319a0f0632ee0b346b545ede90bea928d455f2 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Wed, 5 Jun 2024 20:51:15 +0800 Subject: [PATCH 01/22] refactor: selectedAddress to selectedAccountId --- .../src/TokenBalancesController.test.ts | 81 ++- .../src/TokenBalancesController.ts | 14 +- .../src/TokenDetectionController.test.ts | 420 +++++++++---- .../src/TokenDetectionController.ts | 67 ++- .../src/TokenRatesController.test.ts | 555 ++++++++++++------ .../src/TokenRatesController.ts | 39 +- .../src/TokensController.test.ts | 246 +++++--- .../src/TokensController.ts | 105 ++-- 8 files changed, 1000 insertions(+), 527 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 1d722b421c..5ac19788a4 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -3,6 +3,7 @@ import { toHex } from '@metamask/controller-utils'; import BN from 'bn.js'; import { flushPromises } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { AllowedActions, AllowedEvents, @@ -31,7 +32,7 @@ function getMessenger( ): TokenBalancesControllerMessenger { return controllerMessenger.getRestricted({ name: controllerName, - allowedActions: ['PreferencesController:getState'], + allowedActions: ['AccountsController:getSelectedAccount'], allowedEvents: ['TokensController:stateChange'], }); } @@ -52,8 +53,10 @@ describe('TokenBalancesController', () => { it('should set default state', () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ getERC20BalanceOf: jest.fn(), @@ -65,8 +68,10 @@ describe('TokenBalancesController', () => { it('should poll and update balances in the right interval', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const updateBalancesSpy = jest.spyOn( TokenBalancesController.prototype, @@ -91,8 +96,10 @@ describe('TokenBalancesController', () => { it('should update balances if enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: false, @@ -112,8 +119,10 @@ describe('TokenBalancesController', () => { it('should not update balances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: true, @@ -131,8 +140,10 @@ describe('TokenBalancesController', () => { it('should update balances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: true, @@ -157,8 +168,10 @@ describe('TokenBalancesController', () => { it('should not update balances if controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: false, @@ -185,8 +198,10 @@ describe('TokenBalancesController', () => { it('should update balances if tokens change and controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: true, @@ -223,8 +238,10 @@ describe('TokenBalancesController', () => { it('should not update balances if tokens change and controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: false, @@ -262,8 +279,10 @@ describe('TokenBalancesController', () => { it('should clear previous interval', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ interval: 1337, @@ -292,8 +311,12 @@ describe('TokenBalancesController', () => { }, ]; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue( + createMockInternalAccount({ address: selectedAddress }), + ), ); const controller = new TokenBalancesController({ interval: 1337, @@ -327,8 +350,8 @@ describe('TokenBalancesController', () => { ]; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({}), + 'AccountsController:getSelectedAccount', + jest.fn().mockReturnValue(createMockInternalAccount({ address })), ); const controller = new TokenBalancesController({ interval: 1337, @@ -355,8 +378,10 @@ describe('TokenBalancesController', () => { it('should update balances when tokens change', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ getERC20BalanceOf: jest.fn(), @@ -384,8 +409,10 @@ describe('TokenBalancesController', () => { it('should update token balances when detected tokens are added', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ interval: 1337, diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 280793ef68..1251b4101c 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,3 +1,4 @@ +import { type AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; import { type RestrictedControllerMessenger, type ControllerGetStateAction, @@ -5,7 +6,6 @@ import { BaseController, } from '@metamask/base-controller'; import { safelyExecute, toHex } from '@metamask/controller-utils'; -import type { PreferencesControllerGetStateAction } from '@metamask/preferences-controller'; import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; @@ -56,7 +56,7 @@ export type TokenBalancesControllerGetStateAction = ControllerGetStateAction< export type TokenBalancesControllerActions = TokenBalancesControllerGetStateAction; -export type AllowedActions = PreferencesControllerGetStateAction; +export type AllowedActions = AccountsControllerGetSelectedAccountAction; export type TokenBalancesControllerStateChangeEvent = ControllerStateChangeEvent< @@ -195,16 +195,18 @@ export class TokenBalancesController extends BaseController< if (this.#disabled) { return; } - - const { selectedAddress } = this.messagingSystem.call( - 'PreferencesController:getState', + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', ); const newContractBalances: ContractBalances = {}; for (const token of this.#tokens) { const { address } = token; try { - const balance = await this.#getERC20BalanceOf(address, selectedAddress); + const balance = await this.#getERC20BalanceOf( + address, + selectedInternalAccount.address, + ); newContractBalances[address] = toHex(balance); token.hasBalanceError = false; } catch (error) { diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 7b040fefee..a21c05ff46 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -27,6 +27,7 @@ import nock from 'nock'; import * as sinon from 'sinon'; import { advanceTime } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { formatAggregatorNames } from './assetsUtil'; import { TOKEN_END_POINT_API } from './token-service'; import type { @@ -144,6 +145,7 @@ function buildTokenDetectionControllerMessenger( return controllerMessenger.getRestricted({ name: controllerName, allowedActions: [ + 'AccountsController:getAccount', 'AccountsController:getSelectedAccount', 'KeyringController:getState', 'NetworkController:getNetworkClientById', @@ -155,7 +157,7 @@ function buildTokenDetectionControllerMessenger( 'PreferencesController:getState', ], allowedEvents: [ - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', 'KeyringController:lock', 'KeyringController:unlock', 'NetworkController:networkDidChange', @@ -166,6 +168,8 @@ function buildTokenDetectionControllerMessenger( } describe('TokenDetectionController', () => { + const defaultSelectedAccount = createMockInternalAccount(); + beforeEach(async () => { nock(TOKEN_END_POINT_API) .get(getTokensPath(ChainId.mainnet)) @@ -203,6 +207,7 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, + options: { selectedAccountId: defaultSelectedAccount.id }, }, async ({ controller }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); @@ -221,8 +226,12 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, + options: { + selectedAccountId: defaultSelectedAccount.id, + }, }, - async ({ controller, triggerKeyringUnlock }) => { + async ({ controller, mockGetAccount, triggerKeyringUnlock }) => { + mockGetAccount(defaultSelectedAccount); const mockTokens = sinon.stub(controller, 'detectTokens'); await controller.start(); @@ -255,16 +264,24 @@ describe('TokenDetectionController', () => { }); it('should poll and detect tokens on interval while on supported networks', async () => { - await withController(async ({ controller }) => { - const mockTokens = sinon.stub(controller, 'detectTokens'); - controller.setIntervalLength(10); + await withController( + { + options: { + selectedAccountId: defaultSelectedAccount.id, + }, + }, + async ({ controller, mockGetAccount }) => { + mockGetAccount(defaultSelectedAccount); + const mockTokens = sinon.stub(controller, 'detectTokens'); + controller.setIntervalLength(10); - await controller.start(); + await controller.start(); - expect(mockTokens.calledOnce).toBe(true); - await advanceTime({ clock, duration: 15 }); - expect(mockTokens.calledTwice).toBe(true); - }); + expect(mockTokens.calledOnce).toBe(true); + await advanceTime({ clock, duration: 15 }); + expect(mockTokens.calledTwice).toBe(true); + }, + ); }); it('should not autodetect while not on supported networks', async () => { @@ -275,9 +292,11 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + selectedAccountId: defaultSelectedAccount.id, }, }, - async ({ controller, mockNetworkState }) => { + async ({ controller, mockGetAccount, mockNetworkState }) => { + mockGetAccount(defaultSelectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: NetworkType.goerli, @@ -293,15 +312,23 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -329,7 +356,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -340,21 +367,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, + mockGetAccount, mockTokenListGetState, mockNetworkState, mockGetNetworkClientById, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: 'polygon', @@ -393,7 +424,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: '0x89', - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -405,17 +436,25 @@ describe('TokenDetectionController', () => { [sampleTokenA.address]: new BN(1), [sampleTokenB.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); const interval = 100; await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, interval, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -455,7 +494,7 @@ describe('TokenDetectionController', () => { [sampleTokenA, sampleTokenB], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -466,20 +505,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, + mockGetAccount, mockTokensGetState, mockTokenListGetState, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokensGetState({ ...getDefaultTokensState(), ignoredTokens: [sampleTokenA.address], @@ -521,9 +564,16 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + selectedAccountId: defaultSelectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(defaultSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -569,23 +619,27 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -606,9 +660,8 @@ describe('TokenDetectionController', () => { }, }); - triggerSelectedAccountChange({ - address: secondSelectedAddress, - } as InternalAccount); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).toHaveBeenCalledWith( @@ -616,7 +669,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, + selectedAddress: secondSelectedAccount.address, }, ); }, @@ -627,13 +680,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ @@ -662,7 +717,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: selectedAddress, + address: selectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -678,16 +733,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, isKeyringUnlocked: false, }, @@ -717,7 +774,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: secondSelectedAddress, + address: secondSelectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -735,16 +792,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ @@ -773,7 +832,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: secondSelectedAddress, + address: secondSelectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -801,23 +860,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -840,17 +904,18 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).toHaveBeenCalledWith( + expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addDetectedTokens', [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, + selectedAddress: secondSelectedAccount.address, }, ); }, @@ -861,20 +926,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -897,14 +966,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -914,7 +981,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -925,23 +992,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, + triggerSelectedAccountChange, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -959,9 +1031,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: false, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -975,20 +1048,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1006,7 +1083,6 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1023,24 +1099,29 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1058,9 +1139,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -1074,21 +1156,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1106,14 +1192,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1132,23 +1216,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1166,9 +1255,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -1182,20 +1272,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1213,14 +1307,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1249,20 +1341,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1294,7 +1390,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: '0x89', - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1305,20 +1401,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1356,20 +1456,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1403,21 +1507,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1453,20 +1561,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1512,20 +1624,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenList = { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1557,7 +1673,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1568,20 +1684,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: {}, @@ -1603,21 +1723,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1651,20 +1775,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1707,13 +1835,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, mockTokenListGetState }) => { @@ -1773,13 +1903,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ @@ -1787,7 +1919,9 @@ describe('TokenDetectionController', () => { mockNetworkState, triggerPreferencesStateChange, callActionSpy, + mockGetAccount, }) => { + mockGetAccount(selectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: NetworkType.goerli, @@ -1798,7 +1932,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ networkClientId: NetworkType.goerli, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addDetectedTokens', @@ -1817,27 +1951,31 @@ describe('TokenDetectionController', () => { {}, ), ); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, + mockGetAccount, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), useTokenDetection: false, }); await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addDetectedTokens', @@ -1850,7 +1988,7 @@ describe('TokenDetectionController', () => { }; }), { - selectedAddress, + selectedAddress: selectedAccount.address, chainId: ChainId.mainnet, }, ); @@ -1862,16 +2000,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1894,7 +2040,7 @@ describe('TokenDetectionController', () => { await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenCalledWith( @@ -1902,7 +2048,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1913,7 +2059,9 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); const mockTrackMetaMetricsEvent = jest.fn(); await withController( @@ -1922,10 +2070,11 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ controller, mockGetAccount, mockTokenListGetState }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1948,7 +2097,7 @@ describe('TokenDetectionController', () => { await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(mockTrackMetaMetricsEvent).toHaveBeenCalledWith({ @@ -1980,6 +2129,7 @@ function getTokensPath(chainId: Hex) { type WithControllerCallback = ({ controller, + mockGetAccount, mockGetSelectedAccount, mockKeyringGetState, mockTokensGetState, @@ -1997,6 +2147,7 @@ type WithControllerCallback = ({ triggerNetworkDidChange, }: { controller: TokenDetectionController; + mockGetAccount: (internalAccount: InternalAccount) => void; mockGetSelectedAccount: (address: string) => void; mockKeyringGetState: (state: KeyringControllerState) => void; mockTokensGetState: (state: TokensControllerState) => void; @@ -2047,6 +2198,12 @@ async function withController( const controllerMessenger = messenger ?? new ControllerMessenger(); + const mockGetAccount = jest.fn(); + controllerMessenger.registerActionHandler( + 'AccountsController:getAccount', + mockGetAccount, + ); + const mockGetSelectedAccount = jest.fn(); controllerMessenger.registerActionHandler( 'AccountsController:getSelectedAccount', @@ -2130,6 +2287,9 @@ async function withController( try { return await fn({ controller, + mockGetAccount: (internalAccount: InternalAccount) => { + mockGetAccount.mockReturnValue(internalAccount); + }, mockGetSelectedAccount: (address: string) => { mockGetSelectedAccount.mockReturnValue({ address } as InternalAccount); }, @@ -2185,7 +2345,7 @@ async function withController( }, triggerSelectedAccountChange: (account: InternalAccount) => { controllerMessenger.publish( - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', account, ); }, diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index bbebdca480..a4f914fb86 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -1,6 +1,7 @@ import type { AccountsControllerGetSelectedAccountAction, - AccountsControllerSelectedAccountChangeEvent, + AccountsControllerGetAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, } from '@metamask/accounts-controller'; import type { RestrictedControllerMessenger, @@ -105,6 +106,7 @@ export type TokenDetectionControllerActions = export type AllowedActions = | AccountsControllerGetSelectedAccountAction + | AccountsControllerGetAccountAction | NetworkControllerGetNetworkClientByIdAction | NetworkControllerGetNetworkConfigurationByNetworkClientId | NetworkControllerGetStateAction @@ -121,7 +123,7 @@ export type TokenDetectionControllerEvents = TokenDetectionControllerStateChangeEvent; export type AllowedEvents = - | AccountsControllerSelectedAccountChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent | NetworkControllerNetworkDidChangeEvent | TokenListStateChange | KeyringControllerLockEvent @@ -153,7 +155,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< > { #intervalId?: ReturnType; - #selectedAddress: string; + #selectedAccountId: string; #networkClientId: NetworkClientId; @@ -186,19 +188,19 @@ export class TokenDetectionController extends StaticIntervalPollingController< * @param options.messenger - The controller messaging system. * @param options.disabled - If set to true, all network requests are blocked. * @param options.interval - Polling interval used to fetch new token rates - * @param options.selectedAddress - Vault selected address + * @param options.selectedAccountId - Vault selected address * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. * @param options.trackMetaMetricsEvent - Sets options for MetaMetrics event tracking. */ constructor({ - selectedAddress, + selectedAccountId, interval = DEFAULT_INTERVAL, disabled = true, getBalancesInSingleCall, trackMetaMetricsEvent, messenger, }: { - selectedAddress?: string; + selectedAccountId?: string; interval?: number; disabled?: boolean; getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; @@ -223,10 +225,9 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.#disabled = disabled; this.setIntervalLength(interval); - this.#selectedAddress = - selectedAddress ?? - this.messagingSystem.call('AccountsController:getSelectedAccount') - .address; + this.#selectedAccountId = + selectedAccountId ?? + this.messagingSystem.call('AccountsController:getSelectedAccount').id; const { chainId, networkClientId } = this.#getCorrectChainIdAndNetworkClientId(); @@ -277,32 +278,32 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.messagingSystem.subscribe( 'PreferencesController:stateChange', - async ({ selectedAddress: newSelectedAddress, useTokenDetection }) => { - const isSelectedAddressChanged = - this.#selectedAddress !== newSelectedAddress; + async ({ useTokenDetection }) => { + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ); const isDetectionChangedFromPreferences = this.#isDetectionEnabledFromPreferences !== useTokenDetection; - this.#selectedAddress = newSelectedAddress; this.#isDetectionEnabledFromPreferences = useTokenDetection; - if (isSelectedAddressChanged || isDetectionChangedFromPreferences) { + if (isDetectionChangedFromPreferences) { await this.#restartTokenDetection({ - selectedAddress: this.#selectedAddress, + selectedAccountId: selectedAccount.id, }); } }, ); this.messagingSystem.subscribe( - 'AccountsController:selectedAccountChange', - async ({ address: newSelectedAddress }) => { - const isSelectedAddressChanged = - this.#selectedAddress !== newSelectedAddress; - if (isSelectedAddressChanged) { - this.#selectedAddress = newSelectedAddress; + 'AccountsController:selectedEvmAccountChange', + async (internalAccount) => { + const didSelectedAccountIdChanged = + this.#selectedAccountId !== internalAccount.id; + if (didSelectedAccountIdChanged) { + this.#selectedAccountId = internalAccount.id; await this.#restartTokenDetection({ - selectedAddress: this.#selectedAddress, + selectedAccountId: this.#selectedAccountId, }); } }, @@ -436,16 +437,23 @@ export class TokenDetectionController extends StaticIntervalPollingController< * in case of address change or user session initialization. * * @param options - Options for restart token detection. - * @param options.selectedAddress - the selectedAddress against which to detect for token balances + * @param options.selectedAccountId - the id of the InternalAccount against which to detect for token balances * @param options.networkClientId - The ID of the network client to use. */ async #restartTokenDetection({ - selectedAddress, + selectedAccountId, networkClientId, }: { - selectedAddress?: string; + selectedAccountId?: string; networkClientId?: NetworkClientId; } = {}): Promise { + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + selectedAccountId ?? this.#selectedAccountId, + ); + + const selectedAddress = internalAccount?.address || ''; + await this.detectTokens({ networkClientId, selectedAddress, @@ -472,8 +480,13 @@ export class TokenDetectionController extends StaticIntervalPollingController< return; } + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + const addressAgainstWhichToDetect = - selectedAddress ?? this.#selectedAddress; + selectedAddress ?? selectedInternalAccount?.address ?? ''; const { chainId, networkClientId: selectedNetworkClientId } = this.#getCorrectChainIdAndNetworkClientId(networkClientId); const chainIdAgainstWhichToDetect = chainId; diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 78fcb2a051..5cf53a0bae 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -5,13 +5,13 @@ import { toChecksumHexAddress, toHex, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientConfiguration, NetworkClientId, NetworkState, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { add0x } from '@metamask/utils'; import assert from 'assert'; @@ -19,6 +19,7 @@ import nock from 'nock'; import { useFakeTimers } from 'sinon'; import { advanceTime, flushPromises } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { buildCustomNetworkClientConfiguration, buildMockGetNetworkClientById, @@ -37,7 +38,9 @@ import type { } from './TokenRatesController'; import type { TokensControllerState } from './TokensController'; -const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; +const defaultMockInternalAccount = createMockInternalAccount({ + address: '0xA', +}); const mockTokenAddress = '0x0000000000000000000000000000000000000010'; describe('TokenRatesController', () => { @@ -59,10 +62,11 @@ describe('TokenRatesController', () => { it('should set default state', () => { const controller = new TokenRatesController({ getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -75,10 +79,11 @@ describe('TokenRatesController', () => { it('should initialize with the default config', () => { const controller = new TokenRatesController({ getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -91,7 +96,7 @@ describe('TokenRatesController', () => { disabled: false, nativeCurrency: NetworksTicker.mainnet, chainId: '0x1', - selectedAddress: defaultSelectedAddress, + selectedAccountId: defaultMockInternalAccount.id, }); }); @@ -100,10 +105,11 @@ describe('TokenRatesController', () => { new TokenRatesController({ interval: 100, getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -129,18 +135,22 @@ describe('TokenRatesController', () => { describe('when legacy polling is active', () => { it('should update exchange rates when any of the addresses in the "all tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = defaultMockInternalAccount; const tokenAddresses = ['0xE1', '0xE2']; + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -163,7 +173,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -184,19 +194,23 @@ describe('TokenRatesController', () => { it('should update exchange rates when any of the addresses in the "all detected tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); const tokenAddresses = ['0xE1', '0xE2']; + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -219,7 +233,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -239,11 +253,14 @@ describe('TokenRatesController', () => { it('should not update exchange rates if both the "all tokens" or "all detected tokens" are exactly the same', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); const tokensState = { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -259,7 +276,8 @@ describe('TokenRatesController', () => { { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: tokensState, }, @@ -280,10 +298,10 @@ describe('TokenRatesController', () => { it('should not update exchange rates if all of the tokens in "all tokens" just move to "all detected tokens"', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); const tokens = { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -297,11 +315,14 @@ describe('TokenRatesController', () => { { options: { chainId, - selectedAddress, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), }, config: { allTokens: tokens, allDetectedTokens: {}, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -324,17 +345,21 @@ describe('TokenRatesController', () => { it('should not update exchange rates if a new token is added to "all detected tokens" but is already present in "all tokens"', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -357,7 +382,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -369,7 +394,7 @@ describe('TokenRatesController', () => { }, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -389,18 +414,22 @@ describe('TokenRatesController', () => { it('should not update exchange rates if a new token is added to "all tokens" but is already present in "all detected tokens"', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -422,7 +451,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -434,7 +463,7 @@ describe('TokenRatesController', () => { }, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -454,18 +483,22 @@ describe('TokenRatesController', () => { it('should not update exchange rates if none of the addresses in "all tokens" or "all detected tokens" change, even if other parts of the token change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 3, @@ -488,7 +521,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 7, @@ -508,18 +541,24 @@ describe('TokenRatesController', () => { it('should not update exchange rates if none of the addresses in "all tokens" or "all detected tokens" change, when normalized to checksum addresses', async () => { const chainId = '0xC'; - const selectedAddress = '0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'; + const selectedAccount = createMockInternalAccount({ + address: '0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', + }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0x0EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE2', decimals: 3, @@ -542,7 +581,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0x0eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee2', decimals: 7, @@ -562,18 +601,22 @@ describe('TokenRatesController', () => { it('should not update exchange rates if any of the addresses in "all tokens" or "all detected tokens" merely change order', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0xE1', decimals: 0, @@ -602,7 +645,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0xE2', decimals: 0, @@ -630,18 +673,22 @@ describe('TokenRatesController', () => { describe('when legacy polling is inactive', () => { it('should not update exchange rates when any of the addresses in the "all tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); const tokenAddresses = ['0xE1', '0xE2']; await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -663,7 +710,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -683,19 +730,23 @@ describe('TokenRatesController', () => { it('should not update exchange rates when any of the addresses in the "all detected tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); const tokenAddresses = ['0xE1', '0xE2']; await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -717,7 +768,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -763,11 +814,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -801,11 +853,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -839,11 +892,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -875,11 +929,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -911,11 +966,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -951,11 +1007,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -988,11 +1045,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1025,11 +1083,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1060,11 +1119,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1082,37 +1142,41 @@ describe('TokenRatesController', () => { }); }); - describe('PreferencesController::stateChange', () => { + describe('onSelectedAccountChange', () => { let clock: sinon.SinonFakeTimers; - beforeEach(() => { clock = useFakeTimers({ now: Date.now() }); }); - afterEach(() => { clock.restore(); }); describe('when polling is active', () => { it('should update exchange rates when selected address changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => Promise; - const onPreferencesStateChange = jest + const alternateSelectedAddress = + '0x0000000000000000000000000000000000000002'; + const alternativeAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); + + let selectedAccountChangeListener: ( + interalAccount: InternalAccount, + ) => Promise; + const onSelectedAccountChange = jest .fn() .mockImplementation((listener) => { - preferencesStateChangeListener = listener; + selectedAccountChangeListener = listener; }); - const alternateSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const controller = new TokenRatesController( { interval: 100, getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange, + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1134,30 +1198,31 @@ describe('TokenRatesController', () => { .mockResolvedValue(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: alternateSelectedAddress, - }); + await selectedAccountChangeListener!(alternativeAccount); - expect(updateExchangeRatesSpy).toHaveBeenCalledTimes(1); + expect(updateExchangeRatesSpy).toHaveBeenCalled(); }); it('should not update exchange rates when preferences state changes without selected address changing', async () => { // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => Promise; - const onPreferencesStateChange = jest + + let selectedAccountChangeListener: ( + interalAccount: InternalAccount, + ) => Promise; + const onSelectedAccountChange = jest .fn() .mockImplementation((listener) => { - preferencesStateChangeListener = listener; + selectedAccountChangeListener = listener; }); const controller = new TokenRatesController( { interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange, + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1165,7 +1230,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, symbol: '', aggregators: [] }, { address: '0x03', decimals: 0, symbol: '', aggregators: [] }, ], @@ -1179,10 +1244,7 @@ describe('TokenRatesController', () => { .mockResolvedValue(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: defaultSelectedAddress, - exampleConfig: 'exampleValue', - }); + await selectedAccountChangeListener!(defaultMockInternalAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); @@ -1190,24 +1252,29 @@ describe('TokenRatesController', () => { describe('when polling is inactive', () => { it('should not update exchange rates when selected address changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => Promise; - const onPreferencesStateChange = jest + const alternateSelectedAddress = + '0x0000000000000000000000000000000000000002'; + const alternateAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); + let selectedAccountChangeListener: ( + interalAccount: InternalAccount, + ) => Promise; + const onSelectedAccountChange = jest .fn() .mockImplementation((listener) => { - preferencesStateChangeListener = listener; + selectedAccountChangeListener = listener; }); - const alternateSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const controller = new TokenRatesController( { interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange, + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1215,7 +1282,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [alternateSelectedAddress]: [ + [alternateAccount.address]: [ { address: '0x02', decimals: 0, symbol: '', aggregators: [] }, { address: '0x03', decimals: 0, symbol: '', aggregators: [] }, ], @@ -1228,9 +1295,7 @@ describe('TokenRatesController', () => { .mockResolvedValue(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: alternateSelectedAddress, - }); + await selectedAccountChangeListener!(alternateAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); @@ -1257,10 +1322,13 @@ describe('TokenRatesController', () => { { interval, getNetworkClientById: jest.fn(), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService, @@ -1268,7 +1336,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1301,10 +1369,13 @@ describe('TokenRatesController', () => { { interval, getNetworkClientById: jest.fn(), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService, @@ -1312,7 +1383,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1356,8 +1427,8 @@ describe('TokenRatesController', () => { interval, chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1366,12 +1437,15 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1408,8 +1482,8 @@ describe('TokenRatesController', () => { { chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1418,12 +1492,15 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, @@ -1513,8 +1590,8 @@ describe('TokenRatesController', () => { { chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1523,12 +1600,15 @@ describe('TokenRatesController', () => { ticker: 'LOL', }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, @@ -1617,8 +1697,8 @@ describe('TokenRatesController', () => { { chainId: '0x2', ticker: 'ETH', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1627,12 +1707,15 @@ describe('TokenRatesController', () => { ticker: 'LOL', }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, @@ -1674,8 +1757,8 @@ describe('TokenRatesController', () => { interval, chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1684,12 +1767,15 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1728,14 +1814,24 @@ describe('TokenRatesController', () => { ])('%s', (method) => { it('does not update state when disabled', async () => { await withController( - { config: { disabled: true } }, + { + options: { + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + disabled: true, + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { const tokenAddress = '0x0000000000000000000000000000000000000001'; await callUpdateExchangeRatesMethod({ allTokens: { [ChainId.mainnet]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddress, decimals: 18, @@ -1759,51 +1855,65 @@ describe('TokenRatesController', () => { }); it('does not update state if there are no tokens for the given chain and address', async () => { - await withController(async ({ controller, controllerEvents }) => { - const tokenAddress = '0x0000000000000000000000000000000000000001'; - const differentAccount = '0x1000000000000000000000000000000000000000'; + await withController( + { + options: { + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, + async ({ controller, controllerEvents }) => { + const tokenAddress = '0x0000000000000000000000000000000000000001'; + const differentAccount = '0x1000000000000000000000000000000000000000'; - await callUpdateExchangeRatesMethod({ - allTokens: { - // These tokens are for the right chain but wrong account - [ChainId.mainnet]: { - [differentAccount]: [ - { - address: tokenAddress, - decimals: 18, - symbol: 'TST', - aggregators: [], - }, - ], - }, - // These tokens are for the right account but wrong chain - [toHex(2)]: { - [controller.config.selectedAddress]: [ - { - address: tokenAddress, - decimals: 18, - symbol: 'TST', - aggregators: [], - }, - ], + await callUpdateExchangeRatesMethod({ + allTokens: { + // These tokens are for the right chain but wrong account + [ChainId.mainnet]: { + [differentAccount]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], + }, + // These tokens are for the right account but wrong chain + [toHex(2)]: { + [defaultMockInternalAccount.address]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], + }, }, - }, - chainId: ChainId.mainnet, - controller, - controllerEvents, - method, - nativeCurrency: 'ETH', - selectedNetworkClientId: InfuraNetworkType.mainnet, - }); + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, + }); - expect(controller.state).toStrictEqual({ - marketData: { - '0x1': { - '0x0000000000000000000000000000000000000000': { currency: 'ETH' }, + expect(controller.state).toStrictEqual({ + marketData: { + '0x1': { + '0x0000000000000000000000000000000000000000': { + currency: 'ETH', + }, + }, }, - }, - }); - }); + }); + }, + ); }); it('does not update state if the price update fails', async () => { @@ -1813,7 +1923,17 @@ describe('TokenRatesController', () => { .mockRejectedValue(new Error('Failed to fetch')), }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { const tokenAddress = '0x0000000000000000000000000000000000000001'; @@ -1822,7 +1942,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [ChainId.mainnet]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddress, decimals: 18, @@ -1866,13 +1986,19 @@ describe('TokenRatesController', () => { options: { ticker, tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { [chainId]: { - [controller.config.selectedAddress]: tokens, + [defaultMockInternalAccount.address]: tokens, }, }, chainId, @@ -1922,12 +2048,22 @@ describe('TokenRatesController', () => { }), }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { [ChainId.mainnet]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -1994,12 +2130,20 @@ describe('TokenRatesController', () => { }), }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { selectedAccountId: defaultMockInternalAccount.id }, + }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { [toHex(2)]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2088,6 +2232,12 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2097,7 +2247,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [selectedNetworkClientConfiguration.chainId]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2178,6 +2328,12 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2187,7 +2343,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [selectedNetworkClientConfiguration.chainId]: { - [controller.config.selectedAddress]: tokens, + [defaultMockInternalAccount.address]: tokens, }, }, chainId: selectedNetworkClientConfiguration.chainId, @@ -2251,6 +2407,12 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2260,7 +2422,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [selectedNetworkClientConfiguration.chainId]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2319,13 +2481,23 @@ describe('TokenRatesController', () => { fetchTokenPrices: fetchTokenPricesMock, }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { const updateExchangeRates = async () => await callUpdateExchangeRatesMethod({ allTokens: { [toHex(1)]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2382,7 +2554,7 @@ describe('TokenRatesController', () => { */ type ControllerEvents = { networkStateChange: (state: NetworkState) => void; - preferencesStateChange: (state: PreferencesState) => void; + seletedAccountChange: (internalAccount: InternalAccount) => void; tokensStateChange: (state: TokensControllerState) => void; }; @@ -2454,13 +2626,14 @@ async function withController( onNetworkStateChange: (listener) => { controllerEvents.networkStateChange = listener; }, - onPreferencesStateChange: (listener) => { - controllerEvents.preferencesStateChange = listener; + onSelectedAccountChange: (listener) => { + controllerEvents.seletedAccountChange = listener; }, onTokensStateChange: (listener) => { controllerEvents.tokensStateChange = listener; }, - selectedAddress: defaultSelectedAddress, + getInternalAccount: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, ticker: NetworksTicker.mainnet, tokenPricesService: buildMockTokenPricesService(), ...options, diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index 0f0fa9cd32..06e9d3e842 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -5,13 +5,13 @@ import { FALL_BACK_VS_CURRENCY, toHex, } from '@metamask/controller-utils'; +import { type InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkController, NetworkState, } from '@metamask/network-controller'; import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; import { createDeferredPromise, type Hex } from '@metamask/utils'; import { isEqual } from 'lodash'; @@ -59,7 +59,7 @@ export interface TokenRatesConfig extends BaseConfig { interval: number; nativeCurrency: string; chainId: Hex; - selectedAddress: string; + selectedAccountId: string; allTokens: { [chainId: Hex]: { [key: string]: Token[] } }; allDetectedTokens: { [chainId: Hex]: { [key: string]: Token[] } }; threshold: number; @@ -175,6 +175,8 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< private readonly getNetworkClientById: NetworkController['getNetworkClientById']; + private readonly getInternalAccount: (accountId: string) => InternalAccount; + /** * Creates a TokenRatesController instance. * @@ -184,8 +186,9 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< * @param options.getNetworkClientById - Gets the network client with the given id from the NetworkController. * @param options.chainId - The chain ID of the current network. * @param options.ticker - The ticker for the current network. - * @param options.selectedAddress - The current selected address. - * @param options.onPreferencesStateChange - Allows subscribing to preference controller state changes. + * @param options.getInternalAccount - A callback to get an InternalAccount by id. + * @param options.selectedAccountId - The current selected address. + * @param options.onSelectedAccountChange - Allows subscribing to changes of selected account. * @param options.onTokensStateChange - Allows subscribing to token controller state changes. * @param options.onNetworkStateChange - Allows subscribing to network state changes. * @param options.tokenPricesService - An object in charge of retrieving token prices. @@ -199,8 +202,9 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< getNetworkClientById, chainId: initialChainId, ticker: initialTicker, - selectedAddress: initialSelectedAddress, - onPreferencesStateChange, + selectedAccountId, + getInternalAccount, + onSelectedAccountChange, onTokensStateChange, onNetworkStateChange, tokenPricesService, @@ -210,9 +214,10 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< getNetworkClientById: NetworkController['getNetworkClientById']; chainId: Hex; ticker: string; - selectedAddress: string; - onPreferencesStateChange: ( - listener: (preferencesState: PreferencesState) => void, + selectedAccountId: string; + getInternalAccount: (accountId: string) => InternalAccount; + onSelectedAccountChange: ( + listener: (internalAccount: InternalAccount) => void, ) => void; onTokensStateChange: ( listener: (tokensState: TokensControllerState) => void, @@ -232,7 +237,7 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< disabled: false, nativeCurrency: initialTicker, chainId: initialChainId, - selectedAddress: initialSelectedAddress, + selectedAccountId, allTokens: {}, // TODO: initialize these correctly, maybe as part of BaseControllerV2 migration allDetectedTokens: {}, }; @@ -243,15 +248,16 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< this.initialize(); this.setIntervalLength(interval); this.getNetworkClientById = getNetworkClientById; + this.getInternalAccount = getInternalAccount; this.#tokenPricesService = tokenPricesService; if (config?.disabled) { this.configure({ disabled: true }, false, false); } - onPreferencesStateChange(async ({ selectedAddress }) => { - if (this.config.selectedAddress !== selectedAddress) { - this.configure({ selectedAddress }); + onSelectedAccountChange(async (internalAccount) => { + if (this.config.selectedAccountId !== internalAccount.id) { + this.configure({ selectedAccountId: internalAccount.id }); if (this.#pollState === PollState.Active) { await this.updateExchangeRates(); } @@ -298,10 +304,11 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< * @returns The list of tokens addresses for the current chain */ #getTokenAddresses(chainId: Hex): Hex[] { - const { allTokens, allDetectedTokens } = this.config; - const tokens = allTokens[chainId]?.[this.config.selectedAddress] || []; + const { allTokens, allDetectedTokens, selectedAccountId } = this.config; + const internalAccount = this.getInternalAccount(selectedAccountId); + const tokens = allTokens[chainId]?.[internalAccount.address] || []; const detectedTokens = - allDetectedTokens[chainId]?.[this.config.selectedAddress] || []; + allDetectedTokens[chainId]?.[internalAccount.address] || []; return [ ...new Set( diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 970310d99b..442ff88374 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -1,4 +1,5 @@ import { Contract } from '@ethersproject/contracts'; +import type { AccountsController } from '@metamask/accounts-controller'; import type { ApprovalStateChange } from '@metamask/approval-controller'; import { ApprovalController, @@ -13,18 +14,18 @@ import { convertHexToDecimal, InfuraNetworkType, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientConfiguration, NetworkClientId, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; -import { getDefaultPreferencesState } from '@metamask/preferences-controller'; import nock from 'nock'; import * as sinon from 'sinon'; import { v1 as uuidV1 } from 'uuid'; import { FakeProvider } from '../../../tests/fake-provider'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { ExtractAvailableAction, ExtractAvailableEvent, @@ -58,6 +59,10 @@ const uuidV1Mock = jest.mocked(uuidV1); const ERC20StandardMock = jest.mocked(ERC20Standard); const ERC1155StandardMock = jest.mocked(ERC1155Standard); +const defaultMockInternalAccount = createMockInternalAccount({ + address: '0x1', +}); + describe('TokensController', () => { beforeEach(() => { uuidV1Mock.mockReturnValue('9b1deb4d-3b7d-4bad-9bdd-2b0d7b3dcb6d'); @@ -266,32 +271,34 @@ describe('TokensController', () => { it('should add token by selected address', async () => { await withController( - async ({ controller, triggerPreferencesStateChange }) => { + async ({ + controller, + triggerSelectedAccountChange, + getAccountHandler, + }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); const secondAddress = '0x321'; - - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, + const secondAccount = createMockInternalAccount({ + address: secondAddress, }); + + getAccountHandler.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x01', symbol: 'bar', decimals: 2, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondAddress, - }); + triggerSelectedAccountChange(secondAccount); expect(controller.state.tokens).toHaveLength(0); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x01', decimals: 2, @@ -408,25 +415,32 @@ describe('TokensController', () => { it('should remove token by selected address', async () => { await withController( - async ({ controller, triggerPreferencesStateChange }) => { + async ({ + controller, + triggerSelectedAccountChange, + getAccountHandler, + }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); const secondAddress = '0x321'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, + const secondAccount = createMockInternalAccount({ + address: secondAddress, }); + + getAccountHandler.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x02', symbol: 'baz', decimals: 2, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondAddress, - }); + getAccountHandler.mockReturnValue(secondAccount); + triggerSelectedAccountChange(secondAccount); await controller.addToken({ address: '0x01', symbol: 'bar', @@ -436,10 +450,7 @@ describe('TokensController', () => { controller.ignoreTokens(['0x01']); expect(controller.state.tokens).toHaveLength(0); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x02', decimals: 2, @@ -522,14 +533,16 @@ describe('TokensController', () => { await withController( async ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, + getAccountHandler, }) => { const selectedAddress = '0x0001'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, }); + getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -569,14 +582,16 @@ describe('TokensController', () => { await withController( async ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, + getAccountHandler, }) => { const selectedAddress = '0x0001'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, }); + getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -606,15 +621,20 @@ describe('TokensController', () => { await withController( async ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, + getAccountHandler, }) => { const selectedAddress1 = '0x0001'; + const selectedAccount1 = createMockInternalAccount({ + address: selectedAddress1, + }); const selectedAddress2 = '0x0002'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: selectedAddress1, + const selectedAccount2 = createMockInternalAccount({ + address: selectedAddress2, }); + getAccountHandler.mockReturnValue(selectedAccount1); + triggerSelectedAccountChange(selectedAccount1); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -638,10 +658,8 @@ describe('TokensController', () => { controller.ignoreTokens(['0x02']); expect(controller.state.ignoredTokens).toStrictEqual(['0x02']); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: selectedAddress2, - }); + getAccountHandler.mockReturnValue(selectedAccount2); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.ignoredTokens).toHaveLength(0); await controller.addToken({ @@ -780,7 +798,8 @@ describe('TokensController', () => { describe('addToken method', () => { it('should add isERC721 = true when token is an NFT and is in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); const contractAddresses = Object.keys(contractMaps); const erc721ContractAddresses = contractAddresses.filter( (contractAddress) => contractMaps[contractAddress].erc721 === true, @@ -802,7 +821,8 @@ describe('TokensController', () => { }); it('should add isERC721 = true when the token is an NFT but not in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: true }), ); @@ -830,7 +850,8 @@ describe('TokensController', () => { }); it('should add isERC721 = false to token object already in state when token is not an NFT and in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); const contractAddresses = Object.keys(contractMaps); const erc20ContractAddresses = contractAddresses.filter( (contractAddress) => contractMaps[contractAddress].erc20 === true, @@ -852,7 +873,8 @@ describe('TokensController', () => { }); it('should add isERC721 = false when the token is not an NFT and not in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -880,21 +902,26 @@ describe('TokensController', () => { }); it('should throw error if switching networks while adding token', async () => { - await withController(async ({ controller, changeNetwork }) => { - const dummyTokenAddress = - '0x514910771AF9Ca656af840dff83E8264EcF986CA'; + await withController( + async ({ controller, changeNetwork, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); + const dummyTokenAddress = + '0x514910771AF9Ca656af840dff83E8264EcF986CA'; - const addTokenPromise = controller.addToken({ - address: dummyTokenAddress, - symbol: 'LINK', - decimals: 18, - }); - changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); + const addTokenPromise = controller.addToken({ + address: dummyTokenAddress, + symbol: 'LINK', + decimals: 18, + }); + changeNetwork({ + selectedNetworkClientId: InfuraNetworkType.goerli, + }); - await expect(addTokenPromise).rejects.toThrow( - 'TokensController Error: Switched networks while adding token', - ); - }); + await expect(addTokenPromise).rejects.toThrow( + 'TokensController Error: Switched networks while adding token', + ); + }, + ); }); }); @@ -971,7 +998,8 @@ describe('TokensController', () => { async ({ controller, changeNetwork, - triggerPreferencesStateChange, + triggerSelectedAccountChange, + getAccountHandler, }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), @@ -981,13 +1009,13 @@ describe('TokensController', () => { const CONFIGURED_CHAIN = ChainId.sepolia; const CONFIGURED_NETWORK_CLIENT_ID = InfuraNetworkType.sepolia; const CONFIGURED_ADDRESS = '0xConfiguredAddress'; + const configuredAccount = createMockInternalAccount({ + address: CONFIGURED_ADDRESS, + }); changeNetwork({ selectedNetworkClientId: CONFIGURED_NETWORK_CLIENT_ID, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: CONFIGURED_ADDRESS, - }); + triggerSelectedAccountChange(configuredAccount); // A different chain + address const OTHER_CHAIN = '0xOtherChainId'; @@ -1011,6 +1039,8 @@ describe('TokensController', () => { detectedTokenOtherAccount, ] = generateTokens(3); + getAccountHandler.mockReturnValue(configuredAccount); + // Run twice to ensure idempotency for (let i = 0; i < 2; i++) { // Add and detect some tokens on the configured chain + account @@ -1570,7 +1600,6 @@ describe('TokensController', () => { buildMockEthersERC721Contract({ supportsInterface: false }), ); uuidV1Mock.mockReturnValue(requestId); - await controller.watchAsset({ asset, type: 'ERC20' }); expect(controller.state.tokens).toHaveLength(1); @@ -1721,7 +1750,6 @@ describe('TokensController', () => { buildMockEthersERC721Contract({ supportsInterface: false }), ); uuidV1Mock.mockReturnValue(requestId); - await expect( controller.watchAsset({ asset, type: 'ERC20' }), ).rejects.toThrow(errorMessage); @@ -1844,14 +1872,20 @@ describe('TokensController', () => { describe('when PreferencesController:stateChange is published', () => { it('should update tokens list when set address changes', async () => { await withController( - async ({ controller, triggerPreferencesStateChange }) => { + async ({ + controller, + triggerSelectedAccountChange, + getAccountHandler, + }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x1', + const selectedAccount = createMockInternalAccount({ address: '0x1' }); + const selectedAccount2 = createMockInternalAccount({ + address: '0x2', }); + getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); await controller.addToken({ address: '0x01', symbol: 'A', @@ -1862,10 +1896,8 @@ describe('TokensController', () => { symbol: 'B', decimals: 5, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x2', - }); + getAccountHandler.mockReturnValue(selectedAccount2); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([]); await controller.addToken({ @@ -1873,10 +1905,7 @@ describe('TokensController', () => { symbol: 'C', decimals: 6, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x1', - }); + triggerSelectedAccountChange(selectedAccount); expect(controller.state.tokens).toStrictEqual([ { address: '0x01', @@ -1900,10 +1929,7 @@ describe('TokensController', () => { }, ]); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x2', - }); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([ { address: '0x03', @@ -2010,6 +2036,9 @@ describe('TokensController', () => { describe('Clearing nested lists', () => { it('should clear nest allTokens under chain ID and selected address when an added token is ignored', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2025,10 +2054,11 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller }) => { + async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(selectedAccount); await controller.addTokens(dummyTokens); controller.ignoreTokens([tokenAddress]); @@ -2041,6 +2071,9 @@ describe('TokensController', () => { it('should clear nest allIgnoredTokens under chain ID and selected address when an ignored token is re-added', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2056,10 +2089,11 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller }) => { + async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(selectedAccount); await controller.addTokens(dummyTokens); controller.ignoreTokens([tokenAddress]); await controller.addTokens(dummyTokens); @@ -2073,6 +2107,9 @@ describe('TokensController', () => { it('should clear nest allDetectedTokens under chain ID and selected address when an detected token is added to tokens list', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2088,10 +2125,11 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller }) => { + async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(selectedAccount); await controller.addDetectedTokens(dummyTokens); await controller.addTokens(dummyTokens); @@ -2165,7 +2203,7 @@ type WithControllerCallback = ({ changeNetwork, messenger, approvalController, - triggerPreferencesStateChange, + triggerSelectedAccountChange, }: { controller: TokensController; changeNetwork: (networkControllerState: { @@ -2173,7 +2211,11 @@ type WithControllerCallback = ({ }) => void; messenger: UnrestrictedMessenger; approvalController: ApprovalController; - triggerPreferencesStateChange: (state: PreferencesState) => void; + triggerSelectedAccountChange: (internalAccount: InternalAccount) => void; + getAccountHandler: jest.Mock< + ReturnType, + Parameters + >; }) => Promise | ReturnValue; type WithControllerArgs = @@ -2227,16 +2269,17 @@ async function withController( allowedActions: [ 'ApprovalController:addRequest', 'NetworkController:getNetworkClientById', + 'AccountsController:getAccount', ], allowedEvents: [ 'NetworkController:networkDidChange', - 'PreferencesController:stateChange', + 'AccountsController:selectedEvmAccountChange', 'TokenListController:stateChange', ], }); const controller = new TokensController({ chainId: ChainId.mainnet, - selectedAddress: '0x1', + selectedAccountId: defaultMockInternalAccount.id, // The tests assume that this is set, but they shouldn't make that // assumption. But we have to do this due to a bug in TokensController // where the provider can possibly be `undefined` if `networkClientId` is @@ -2246,10 +2289,20 @@ async function withController( ...options, }); - const triggerPreferencesStateChange = (state: PreferencesState) => { - messenger.publish('PreferencesController:stateChange', state, []); + const triggerSelectedAccountChange = (internalAccount: InternalAccount) => { + messenger.publish( + 'AccountsController:selectedEvmAccountChange', + internalAccount, + ); }; + const getAccountHandler = jest.fn(); + + messenger.registerActionHandler( + `AccountsController:getAccount`, + getAccountHandler.mockReturnValue(defaultMockInternalAccount), + ); + const changeNetwork = ({ selectedNetworkClientId, }: { @@ -2274,7 +2327,8 @@ async function withController( changeNetwork, messenger, approvalController, - triggerPreferencesStateChange, + triggerSelectedAccountChange, + getAccountHandler, }); } diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index ce7cb493de..4b843d2906 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -1,5 +1,9 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; +import type { + AccountsControllerGetAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import type { RestrictedControllerMessenger, @@ -19,6 +23,7 @@ import { isValidHexAddress, safelyExecute, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import { abiERC721 } from '@metamask/metamask-eth-abis'; import type { NetworkClientId, @@ -27,10 +32,6 @@ import type { NetworkState, Provider, } from '@metamask/network-controller'; -import type { - PreferencesControllerStateChangeEvent, - PreferencesState, -} from '@metamask/preferences-controller'; import { rpcErrors } from '@metamask/rpc-errors'; import type { Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; @@ -136,7 +137,8 @@ export type TokensControllerAddDetectedTokensAction = { */ export type AllowedActions = | AddApprovalRequest - | NetworkControllerGetNetworkClientByIdAction; + | NetworkControllerGetNetworkClientByIdAction + | AccountsControllerGetAccountAction; export type TokensControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -147,8 +149,8 @@ export type TokensControllerEvents = TokensControllerStateChangeEvent; export type AllowedEvents = | NetworkControllerNetworkDidChangeEvent - | PreferencesControllerStateChangeEvent - | TokenListStateChange; + | TokenListStateChange + | AccountsControllerSelectedEvmAccountChangeEvent; /** * The messenger of the {@link TokensController}. @@ -184,7 +186,7 @@ export class TokensController extends BaseController< #chainId: Hex; - #selectedAddress: string; + #selectedAccountId: string; #provider: Provider | undefined; @@ -194,20 +196,20 @@ export class TokensController extends BaseController< * Tokens controller options * @param options - Constructor options. * @param options.chainId - The chain ID of the current network. - * @param options.selectedAddress - Vault selected address + * @param options.selectedAccountId - Vault selected account id * @param options.provider - Network provider. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messenger. */ constructor({ chainId: initialChainId, - selectedAddress, + selectedAccountId, provider, state, messenger, }: { chainId: Hex; - selectedAddress: string; + selectedAccountId: string; provider: Provider | undefined; state?: Partial; messenger: TokensControllerMessenger; @@ -226,7 +228,7 @@ export class TokensController extends BaseController< this.#provider = provider; - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = selectedAccountId; this.#abortController = new AbortController(); @@ -236,8 +238,8 @@ export class TokensController extends BaseController< ); this.messagingSystem.subscribe( - 'PreferencesController:stateChange', - this.#onPreferenceControllerStateChange.bind(this), + 'AccountsController:selectedEvmAccountChange', + this.#onSelectedAccountChange.bind(this), ); this.messagingSystem.subscribe( @@ -273,29 +275,32 @@ export class TokensController extends BaseController< this.#abortController.abort(); this.#abortController = new AbortController(); this.#chainId = chainId; + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); this.update((state) => { - state.tokens = allTokens[chainId]?.[this.#selectedAddress] || []; + state.tokens = allTokens[chainId]?.[selectedAccount?.address || ''] || []; state.ignoredTokens = - allIgnoredTokens[chainId]?.[this.#selectedAddress] || []; + allIgnoredTokens[chainId]?.[selectedAccount?.address || ''] || []; state.detectedTokens = - allDetectedTokens[chainId]?.[this.#selectedAddress] || []; + allDetectedTokens[chainId]?.[selectedAccount?.address || ''] || []; }); } /** - * Handles the state change of the preference controller. - * @param preferencesState - The new state of the preference controller. - * @param preferencesState.selectedAddress - The current selected address of the preference controller. + * Handles the selected account change in the accounts controller. + * @param selectedAccount - The new selected account */ - #onPreferenceControllerStateChange({ selectedAddress }: PreferencesState) { + #onSelectedAccountChange(selectedAccount: InternalAccount) { const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = selectedAccount.id; this.update((state) => { - state.tokens = allTokens[this.#chainId]?.[selectedAddress] ?? []; + state.tokens = allTokens[this.#chainId]?.[selectedAccount.address] ?? []; state.ignoredTokens = - allIgnoredTokens[this.#chainId]?.[selectedAddress] ?? []; + allIgnoredTokens[this.#chainId]?.[selectedAccount.address] ?? []; state.detectedTokens = - allDetectedTokens[this.#chainId]?.[selectedAddress] ?? []; + allDetectedTokens[this.#chainId]?.[selectedAccount.address] ?? []; }); } @@ -357,7 +362,6 @@ export class TokensController extends BaseController< networkClientId?: NetworkClientId; }): Promise { const chainId = this.#chainId; - const selectedAddress = this.#selectedAddress; const releaseLock = await this.#mutex.acquire(); const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; let currentChainId = chainId; @@ -368,8 +372,15 @@ export class TokensController extends BaseController< ).configuration.chainId; } - const accountAddress = interactingAddress || selectedAddress; - const isInteractingWithWalletAccount = accountAddress === selectedAddress; + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const accountAddress = interactingAddress || internalAccount?.address || ''; + const isInteractingWithWalletAccount = + accountAddress === internalAccount?.address; try { address = toChecksumHexAddress(address); @@ -578,10 +589,15 @@ export class TokensController extends BaseController< ) { const releaseLock = await this.#mutex.acquire(); - // Get existing tokens for the chain + account + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + const chainId = detectionDetails?.chainId ?? this.#chainId; + // Previously selectedAddress could be an empty string. This is to preserve the behaviour const accountAddress = - detectionDetails?.selectedAddress ?? this.#selectedAddress; + detectionDetails?.selectedAddress ?? internalAccount?.address ?? ''; const { allTokens, allDetectedTokens, allIgnoredTokens } = this.state; let newTokens = [...(allTokens?.[chainId]?.[accountAddress] ?? [])]; @@ -648,9 +664,17 @@ export class TokensController extends BaseController< // We may be detecting tokens on a different chain/account pair than are currently configured. // Re-point `tokens` and `detectedTokens` to keep them referencing the current chain/account. - newTokens = newAllTokens?.[this.#chainId]?.[this.#selectedAddress] || []; + const currentInternalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const currentAddress = currentInternalAccount?.address || ''; + + newTokens = newAllTokens?.[this.#chainId]?.[currentAddress] || []; newDetectedTokens = - newAllDetectedTokens?.[this.#chainId]?.[this.#selectedAddress] || []; + newAllDetectedTokens?.[this.#chainId]?.[currentAddress] || []; this.update((state) => { state.tokens = newTokens; @@ -806,6 +830,12 @@ export class TokensController extends BaseController< throw rpcErrors.invalidParams(`Invalid address "${asset.address}"`); } + // Validate if account is an evm account + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + // Validate contract if (await this.#detectIsERC721(asset.address, networkClientId)) { @@ -896,7 +926,8 @@ export class TokensController extends BaseController< id: this.#generateRandomId(), time: Date.now(), type, - interactingAddress: interactingAddress || this.#selectedAddress, + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + interactingAddress: interactingAddress || selectedAccount?.address || '', }; await this.#requestApproval(suggestedAssetMeta); @@ -940,8 +971,14 @@ export class TokensController extends BaseController< interactingChainId, } = params; const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const userAddressToAddTokens = + interactingAddress ?? selectedInternalAccount?.address ?? ''; - const userAddressToAddTokens = interactingAddress ?? this.#selectedAddress; const chainIdToAddTokens = interactingChainId ?? this.#chainId; let newAllTokens = allTokens; From c8ac59024ac260c5fc94c1ac983481c97ea0f26a Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 6 Jun 2024 00:09:33 +0800 Subject: [PATCH 02/22] fix: add more tests --- .../src/TokenDetectionController.test.ts | 80 ++++++++++++ .../src/TokensController.test.ts | 116 ++++++++++++++++++ 2 files changed, 196 insertions(+) diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index a21c05ff46..249f5d869b 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -2112,6 +2112,86 @@ describe('TokenDetectionController', () => { }, ); }); + + it('should not trigger `TokensController:addDetectedTokens` action when selectedAccount is not found', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + + const mockTrackMetaMetricsEvent = jest.fn(); + + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + trackMetaMetricsEvent: mockTrackMetaMetricsEvent, + selectedAccountId: '', + }, + }, + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + // @ts-expect-error forcing an undefined value + mockGetAccount(undefined); + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokensChainsCache: { + '0x1': { + timestamp: 0, + data: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, + }, + }, + }, + }, + }); + + await controller.detectTokens({ + networkClientId: NetworkType.mainnet, + }); + + expect(callActionSpy).toHaveBeenLastCalledWith( + 'TokensController:addDetectedTokens', + [ + { + address: '0x514910771AF9Ca656af840dff83E8264EcF986CA', + aggregators: [ + 'Paraswap', + 'PMM', + 'AirswapLight', + '0x', + 'Bancor', + 'CoinGecko', + 'Zapper', + 'Kleros', + 'Zerion', + 'CMC', + '1inch', + ], + decimals: 18, + image: + 'https://static.cx.metamask.io/api/v1/tokenIcons/1/0x514910771af9ca656af840dff83e8264ecf986ca.png', + isERC721: false, + name: 'Chainlink', + symbol: 'LINK', + }, + ], + { chainId: '0x1', selectedAddress: '' }, + ); + }, + ); + }); }); }); diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 442ff88374..5561793cea 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -2196,6 +2196,122 @@ describe('TokensController', () => { }); }); }); + + describe('when selectedAccountId is not set or account not found', () => { + describe('detectTokens', () => { + it('should update the token states to empty arrays if the selectedAccountId account is undefined', async () => { + await withController(async ({ controller, changeNetwork }) => { + ContractMock.mockReturnValue( + buildMockEthersERC721Contract({ supportsInterface: false }), + ); + + // getAccountHandler.mockReturnValue(undefined); + changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); + + expect(controller.state.tokens).toStrictEqual([]); + expect(controller.state.ignoredTokens).toStrictEqual([]); + expect(controller.state.detectedTokens).toStrictEqual([]); + }); + }); + + it('should update the token states to empty arrays if the selectedAccountId is not set', async () => { + await withController( + { options: { selectedAccountId: '' } }, + async ({ controller, changeNetwork }) => { + ContractMock.mockReturnValue( + buildMockEthersERC721Contract({ supportsInterface: false }), + ); + + changeNetwork({ + selectedNetworkClientId: InfuraNetworkType.sepolia, + }); + + expect(controller.state.tokens).toStrictEqual([]); + expect(controller.state.ignoredTokens).toStrictEqual([]); + expect(controller.state.detectedTokens).toStrictEqual([]); + }, + ); + }); + }); + + describe('addToken', () => { + it('should handle undefined selected account', async () => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(undefined); + const contractAddresses = Object.keys(contractMaps); + const erc721ContractAddresses = contractAddresses.filter( + (contractAddress) => contractMaps[contractAddress].erc721 === true, + ); + const address = erc721ContractAddresses[0]; + const { symbol, decimals } = contractMaps[address]; + + await controller.addToken({ address, symbol, decimals }); + + expect(controller.state.tokens).toStrictEqual([]); + }); + }); + }); + + describe('addDetectedTokens', () => { + it('should handle undefined selected account', async () => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(undefined); + await controller.addDetectedTokens([ + { + address: '0x01', + symbol: 'barA', + decimals: 2, + aggregators: [], + }, + ]); + console.log(controller.state.allDetectedTokens[ChainId.mainnet]); + expect(controller.state.detectedTokens[0]).toStrictEqual({ + address: '0x01', + decimals: 2, + image: undefined, + symbol: 'barA', + aggregators: [], + isERC721: undefined, + name: undefined, + }); + }); + }); + }); + + describe('watchAsset', () => { + it('should handle undefined selected account', async () => { + await withController( + async ({ controller, approvalController, getAccountHandler }) => { + const requestId = '12345'; + const addAndShowApprovalRequestSpy = jest + .spyOn(approvalController, 'addAndShowApprovalRequest') + .mockResolvedValue(undefined); + const asset = buildToken(); + ContractMock.mockReturnValue( + buildMockEthersERC721Contract({ supportsInterface: false }), + ); + uuidV1Mock.mockReturnValue(requestId); + getAccountHandler.mockReturnValue(undefined); + await controller.watchAsset({ asset, type: 'ERC20' }); + + expect(controller.state.tokens).toHaveLength(0); + expect(controller.state.tokens).toStrictEqual([]); + expect(addAndShowApprovalRequestSpy).toHaveBeenCalledTimes(1); + expect(addAndShowApprovalRequestSpy).toHaveBeenCalledWith({ + id: requestId, + origin: ORIGIN_METAMASK, + type: ApprovalType.WatchAsset, + requestData: { + id: requestId, + interactingAddress: '', // this is the default value if account is not found + asset, + }, + }); + }, + ); + }); + }); + }); }); type WithControllerCallback = ({ From 87b6a32ce748283a5f3f2dcf66e6a4d0aa33d406 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 6 Jun 2024 00:09:41 +0800 Subject: [PATCH 03/22] fix: lower branch coverage --- packages/assets-controllers/jest.config.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/assets-controllers/jest.config.js b/packages/assets-controllers/jest.config.js index c5034d8960..b72d599a24 100644 --- a/packages/assets-controllers/jest.config.js +++ b/packages/assets-controllers/jest.config.js @@ -17,7 +17,7 @@ module.exports = merge(baseConfig, { // An object that configures minimum threshold enforcement for coverage results coverageThreshold: { global: { - branches: 90.35, + branches: 90.11, functions: 96.74, lines: 97.34, statements: 97.36, From 22ee7826ef0fe60bda388c3de5f5e02508650d8f Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Wed, 12 Jun 2024 14:16:00 +0800 Subject: [PATCH 04/22] refactor: token detection test --- .../src/TokenDetectionController.test.ts | 462 ++++++++++-------- .../src/TokenDetectionController.ts | 26 +- 2 files changed, 256 insertions(+), 232 deletions(-) diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 249f5d869b..309190bbcf 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -207,7 +207,10 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, - options: { selectedAccountId: defaultSelectedAccount.id }, + options: {}, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, }, async ({ controller }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); @@ -226,12 +229,12 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, - options: { - selectedAccountId: defaultSelectedAccount.id, + options: {}, + mocks: { + getSelectedAccount: defaultSelectedAccount, }, }, - async ({ controller, mockGetAccount, triggerKeyringUnlock }) => { - mockGetAccount(defaultSelectedAccount); + async ({ controller, triggerKeyringUnlock }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); await controller.start(); @@ -266,12 +269,12 @@ describe('TokenDetectionController', () => { it('should poll and detect tokens on interval while on supported networks', async () => { await withController( { - options: { - selectedAccountId: defaultSelectedAccount.id, + options: {}, + mocks: { + getSelectedAccount: defaultSelectedAccount, }, }, - async ({ controller, mockGetAccount }) => { - mockGetAccount(defaultSelectedAccount); + async ({ controller }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); controller.setIntervalLength(10); @@ -292,11 +295,12 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: defaultSelectedAccount.id, + }, + mocks: { + getSelectedAccount: defaultSelectedAccount, }, }, - async ({ controller, mockGetAccount, mockNetworkState }) => { - mockGetAccount(defaultSelectedAccount); + async ({ controller, mockNetworkState }) => { mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: NetworkType.goerli, @@ -319,16 +323,13 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, - async ({ - controller, - mockGetAccount, - mockTokenListGetState, - callActionSpy, - }) => { - mockGetAccount(selectedAccount); + async ({ controller, mockTokenListGetState, callActionSpy }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -374,18 +375,19 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ controller, - mockGetAccount, mockTokenListGetState, mockNetworkState, mockGetNetworkClientById, callActionSpy, }) => { - mockGetAccount(selectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: 'polygon', @@ -445,16 +447,13 @@ describe('TokenDetectionController', () => { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, interval, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, - async ({ - controller, - mockGetAccount, - mockTokenListGetState, - callActionSpy, - }) => { - mockGetAccount(selectedAccount); + async ({ controller, mockTokenListGetState, callActionSpy }) => { const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -512,17 +511,18 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ controller, - mockGetAccount, mockTokensGetState, mockTokenListGetState, callActionSpy, }) => { - mockGetAccount(selectedAccount); mockTokensGetState({ ...getDefaultTokensState(), ignoredTokens: [sampleTokenA.address], @@ -564,16 +564,12 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: defaultSelectedAccount.id, + }, + mocks: { + getSelectedAccount: defaultSelectedAccount, }, }, - async ({ - controller, - mockGetAccount, - mockTokenListGetState, - callActionSpy, - }) => { - mockGetAccount(defaultSelectedAccount); + async ({ controller, mockTokenListGetState, callActionSpy }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -630,7 +626,9 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ @@ -639,7 +637,6 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { - mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -688,7 +685,9 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ @@ -744,7 +743,9 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, isKeyringUnlocked: false, }, @@ -803,7 +804,9 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ @@ -871,7 +874,9 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ @@ -881,7 +886,6 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { - mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -934,7 +938,9 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, async ({ @@ -1003,7 +1009,9 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, + }, + mocks: { + getSelectedAccount: firstSelectedAccount, }, }, async ({ @@ -1056,16 +1064,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1093,121 +1102,124 @@ describe('TokenDetectionController', () => { }, ); }); + }); - describe('when keyring is locked', () => { - it('should not detect new tokens after switching between accounts', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - const firstSelectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000001', - }); - const secondSelectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000002', - }); - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, - }, - isKeyringUnlocked: false, + describe('when keyring is locked', () => { + it('should not detect new tokens after switching between accounts', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, }, - async ({ - mockGetAccount, - mockTokenListGetState, - triggerPreferencesStateChange, - triggerSelectedAccountChange, - callActionSpy, - }) => { - mockGetAccount(firstSelectedAccount); - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokenList: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, + mocks: { + getSelectedAccount: firstSelectedAccount, + getAccount: firstSelectedAccount, + }, + isKeyringUnlocked: false, + }, + async ({ + mockGetAccount, + mockTokenListGetState, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + callActionSpy, + }) => { + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokenList: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, }, - }); + }, + }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - useTokenDetection: true, - }); - mockGetAccount(secondSelectedAccount); - triggerSelectedAccountChange(secondSelectedAccount); - await advanceTime({ clock, duration: 1 }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: true, + }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); + await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).not.toHaveBeenCalledWith( - 'TokensController:addDetectedTokens', - ); - }, - ); - }); + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); + }, + ); + }); - it('should not detect new tokens after enabling token detection', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - const selectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000001', - }); - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, - }, - isKeyringUnlocked: false, + it('should not detect new tokens after enabling token detection', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, }, - async ({ - mockGetAccount, - mockTokenListGetState, - triggerPreferencesStateChange, - callActionSpy, - }) => { - mockGetAccount(selectedAccount); - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokenList: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, + isKeyringUnlocked: false, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + callActionSpy, + }) => { + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokenList: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, }, - }); + }, + }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - useTokenDetection: false, - }); - await advanceTime({ clock, duration: 1 }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: false, + }); + await advanceTime({ clock, duration: 1 }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - useTokenDetection: true, - }); - await advanceTime({ clock, duration: 1 }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: true, + }); + await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).not.toHaveBeenCalledWith( - 'TokensController:addDetectedTokens', - ); - }, - ); - }); + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); + }, + ); }); }); @@ -1227,7 +1239,10 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: firstSelectedAccount.id, + }, + mocks: { + getAccount: firstSelectedAccount, + getSelectedAccount: firstSelectedAccount, }, }, async ({ @@ -1237,7 +1252,6 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { - mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1280,16 +1294,17 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1349,16 +1364,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1409,16 +1425,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1464,16 +1481,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1515,17 +1533,18 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1569,16 +1588,17 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { - mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1632,16 +1652,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { - mockGetAccount(selectedAccount); const tokenList = { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1692,16 +1713,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { - mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: {}, @@ -1731,17 +1753,18 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { - mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1783,16 +1806,17 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ - mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { - mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1843,7 +1867,10 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ controller, mockTokenListGetState }) => { @@ -1911,7 +1938,10 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ @@ -1919,9 +1949,7 @@ describe('TokenDetectionController', () => { mockNetworkState, triggerPreferencesStateChange, callActionSpy, - mockGetAccount, }) => { - mockGetAccount(selectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: NetworkType.goerli, @@ -1959,16 +1987,17 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, async ({ controller, - mockGetAccount, triggerPreferencesStateChange, callActionSpy, }) => { - mockGetAccount(selectedAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), useTokenDetection: false, @@ -2008,16 +2037,13 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, - async ({ - controller, - mockGetAccount, - mockTokenListGetState, - callActionSpy, - }) => { - mockGetAccount(selectedAccount); + async ({ controller, mockTokenListGetState, callActionSpy }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -2070,11 +2096,13 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, }, - async ({ controller, mockGetAccount, mockTokenListGetState }) => { - mockGetAccount(selectedAccount); + async ({ controller, mockTokenListGetState }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -2126,7 +2154,6 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - selectedAccountId: '', }, }, async ({ @@ -2255,6 +2282,10 @@ type WithControllerOptions = { options?: Partial[0]>; isKeyringUnlocked?: boolean; messenger?: ControllerMessenger; + mocks?: { + getAccount?: InternalAccount; + getSelectedAccount?: InternalAccount; + }; }; type WithControllerArgs = @@ -2274,22 +2305,25 @@ async function withController( ...args: WithControllerArgs ): Promise { const [{ ...rest }, fn] = args.length === 2 ? args : [{}, args[0]]; - const { options, isKeyringUnlocked, messenger } = rest; + const { options, isKeyringUnlocked, messenger, mocks } = rest; const controllerMessenger = messenger ?? new ControllerMessenger(); const mockGetAccount = jest.fn(); controllerMessenger.registerActionHandler( 'AccountsController:getAccount', - mockGetAccount, + mockGetAccount.mockReturnValue( + mocks?.getAccount ?? createMockInternalAccount({ address: '0x1' }), + ), ); const mockGetSelectedAccount = jest.fn(); controllerMessenger.registerActionHandler( 'AccountsController:getSelectedAccount', - mockGetSelectedAccount.mockReturnValue({ - address: '0x1', - } as InternalAccount), + mockGetSelectedAccount.mockReturnValue( + mocks?.getSelectedAccount ?? + createMockInternalAccount({ address: '0x1' }), + ), ); const mockKeyringState = jest.fn(); controllerMessenger.registerActionHandler( diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index a4f914fb86..b2b53b7bab 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -188,19 +188,16 @@ export class TokenDetectionController extends StaticIntervalPollingController< * @param options.messenger - The controller messaging system. * @param options.disabled - If set to true, all network requests are blocked. * @param options.interval - Polling interval used to fetch new token rates - * @param options.selectedAccountId - Vault selected address * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. * @param options.trackMetaMetricsEvent - Sets options for MetaMetrics event tracking. */ constructor({ - selectedAccountId, interval = DEFAULT_INTERVAL, disabled = true, getBalancesInSingleCall, trackMetaMetricsEvent, messenger, }: { - selectedAccountId?: string; interval?: number; disabled?: boolean; getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; @@ -225,9 +222,9 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.#disabled = disabled; this.setIntervalLength(interval); - this.#selectedAccountId = - selectedAccountId ?? - this.messagingSystem.call('AccountsController:getSelectedAccount').id; + this.#selectedAccountId = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ).id; const { chainId, networkClientId } = this.#getCorrectChainIdAndNetworkClientId(); @@ -289,7 +286,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< if (isDetectionChangedFromPreferences) { await this.#restartTokenDetection({ - selectedAccountId: selectedAccount.id, + selectedAddress: selectedAccount.address, }); } }, @@ -303,7 +300,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< if (didSelectedAccountIdChanged) { this.#selectedAccountId = internalAccount.id; await this.#restartTokenDetection({ - selectedAccountId: this.#selectedAccountId, + selectedAddress: internalAccount.address, }); } }, @@ -437,23 +434,16 @@ export class TokenDetectionController extends StaticIntervalPollingController< * in case of address change or user session initialization. * * @param options - Options for restart token detection. - * @param options.selectedAccountId - the id of the InternalAccount against which to detect for token balances + * @param options.selectedAddress - the selectedAddress against which to detect for token balances * @param options.networkClientId - The ID of the network client to use. */ async #restartTokenDetection({ - selectedAccountId, + selectedAddress, networkClientId, }: { - selectedAccountId?: string; + selectedAddress?: string; networkClientId?: NetworkClientId; } = {}): Promise { - const internalAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - selectedAccountId ?? this.#selectedAccountId, - ); - - const selectedAddress = internalAccount?.address || ''; - await this.detectTokens({ networkClientId, selectedAddress, From f8d1303547d8e90ef1c2de07c9d530771db09295 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Wed, 12 Jun 2024 14:16:19 +0800 Subject: [PATCH 05/22] refactor: token balances test --- .../src/TokenBalancesController.test.ts | 332 +++++++++--------- 1 file changed, 170 insertions(+), 162 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 5ac19788a4..49390c64fe 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,5 +1,6 @@ import { ControllerMessenger } from '@metamask/base-controller'; import { toHex } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import BN from 'bn.js'; import { flushPromises } from '../../../tests/helpers'; @@ -37,14 +38,58 @@ function getMessenger( }); } -describe('TokenBalancesController', () => { - let controllerMessenger: ControllerMessenger; - let messenger: TokenBalancesControllerMessenger; +const setupController = ({ + config, + mock, +}: { + config?: Partial[0]>; + mock: { + getBalanceOf?: BN; + selectedAccount: InternalAccount; + }; +}): { + controller: TokenBalancesController; + messenger: TokenBalancesControllerMessenger; + mockSelectedAccount: jest.Mock; + mockGetERC20BalanceOf: jest.Mock; + triggerTokensStateChange: (state: TokensControllerState) => Promise; +} => { + const controllerMessenger = new ControllerMessenger< + AllowedActions, + AllowedEvents + >(); + const messenger = getMessenger(controllerMessenger); + + const mockSelectedAccount = jest.fn().mockReturnValue(mock.selectedAccount); + const mockGetERC20BalanceOf = jest.fn().mockReturnValue(mock.getBalanceOf); + + controllerMessenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + mockSelectedAccount, + ); + + const controller = new TokenBalancesController({ + getERC20BalanceOf: mockGetERC20BalanceOf, + messenger, + ...config, + }); + + const triggerTokensStateChange = async (state: TokensControllerState) => { + controllerMessenger.publish('TokensController:stateChange', state, []); + }; + return { + controller, + messenger, + mockSelectedAccount, + mockGetERC20BalanceOf, + triggerTokensStateChange, + }; +}; + +describe('TokenBalancesController', () => { beforeEach(() => { jest.useFakeTimers(); - controllerMessenger = new ControllerMessenger(); - messenger = getMessenger(controllerMessenger); }); afterEach(() => { @@ -52,27 +97,16 @@ describe('TokenBalancesController', () => { }); it('should set default state', () => { - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - getERC20BalanceOf: jest.fn(), - messenger, + const { controller } = setupController({ + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + }, }); expect(controller.state).toStrictEqual({ contractBalances: {} }); }); it('should poll and update balances in the right interval', async () => { - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); const updateBalancesSpy = jest.spyOn( TokenBalancesController.prototype, 'updateBalances', @@ -81,7 +115,7 @@ describe('TokenBalancesController', () => { new TokenBalancesController({ interval: 10, getERC20BalanceOf: jest.fn(), - messenger, + messenger: getMessenger(new ControllerMessenger()), }); await flushPromises(); @@ -95,18 +129,16 @@ describe('TokenBalancesController', () => { it('should update balances if enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - disabled: false, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + getBalanceOf: new BN(1), + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + }, }); await controller.updateBalances(); @@ -118,18 +150,16 @@ describe('TokenBalancesController', () => { it('should not update balances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - disabled: true, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); await controller.updateBalances(); @@ -139,18 +169,16 @@ describe('TokenBalancesController', () => { it('should update balances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - disabled: true, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); await controller.updateBalances(); @@ -167,18 +195,16 @@ describe('TokenBalancesController', () => { it('should not update balances if controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - disabled: false, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); await controller.updateBalances(); @@ -197,22 +223,17 @@ describe('TokenBalancesController', () => { it('should update balances if tokens change and controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - disabled: true, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; await controller.updateBalances(); @@ -237,22 +258,17 @@ describe('TokenBalancesController', () => { it('should not update balances if tokens change and controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - disabled: false, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - interval: 10, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; await controller.updateBalances(); @@ -278,16 +294,14 @@ describe('TokenBalancesController', () => { }); it('should clear previous interval', async () => { - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - interval: 1337, - getERC20BalanceOf: jest.fn(), - messenger, + const { controller } = setupController({ + config: { + interval: 1337, + }, + mock: { + selectedAccount: createMockInternalAccount({ address: '0x1234' }), + getBalanceOf: new BN(1), + }, }); const mockClearTimeout = jest.spyOn(global, 'clearTimeout'); @@ -310,19 +324,17 @@ describe('TokenBalancesController', () => { aggregators: [], }, ]; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue( - createMockInternalAccount({ address: selectedAddress }), - ), - ); - const controller = new TokenBalancesController({ - interval: 1337, - tokens, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller } = setupController({ + config: { + interval: 1337, + tokens, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address: selectedAddress, + }), + getBalanceOf: new BN(1), + }, }); expect(controller.state.contractBalances).toStrictEqual({}); @@ -337,9 +349,6 @@ describe('TokenBalancesController', () => { it('should handle `getERC20BalanceOf` error case', async () => { const errorMsg = 'Failed to get balance'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const getERC20BalanceOfStub = jest - .fn() - .mockReturnValue(Promise.reject(new Error(errorMsg))); const tokens: Token[] = [ { address, @@ -349,17 +358,21 @@ describe('TokenBalancesController', () => { }, ]; - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest.fn().mockReturnValue(createMockInternalAccount({ address })), - ); - const controller = new TokenBalancesController({ - interval: 1337, - tokens, - getERC20BalanceOf: getERC20BalanceOfStub, - messenger, + const { controller, mockGetERC20BalanceOf } = setupController({ + config: { + interval: 1337, + tokens, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address, + }), + }, }); + // @ts-expect-error Testing error case + mockGetERC20BalanceOf.mockReturnValueOnce(new Error(errorMsg)); + expect(controller.state.contractBalances).toStrictEqual({}); await controller.updateBalances(); @@ -367,8 +380,7 @@ describe('TokenBalancesController', () => { expect(tokens[0].hasBalanceError).toBe(true); expect(controller.state.contractBalances[address]).toBe(toHex(0)); - getERC20BalanceOfStub.mockReturnValue(new BN(1)); - + mockGetERC20BalanceOf.mockReturnValueOnce(new BN(1)); await controller.updateBalances(); expect(tokens[0].hasBalanceError).toBe(false); @@ -377,20 +389,18 @@ describe('TokenBalancesController', () => { }); it('should update balances when tokens change', async () => { - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - getERC20BalanceOf: jest.fn(), - interval: 1337, - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + interval: 1337, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address: '0x1234', + }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); await triggerTokensStateChange({ @@ -408,20 +418,18 @@ describe('TokenBalancesController', () => { }); it('should update token balances when detected tokens are added', async () => { - controllerMessenger.registerActionHandler( - 'AccountsController:getSelectedAccount', - jest - .fn() - .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), - ); - const controller = new TokenBalancesController({ - interval: 1337, - getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger, + const { controller, triggerTokensStateChange } = setupController({ + config: { + interval: 1337, + }, + mock: { + selectedAccount: createMockInternalAccount({ + address: '0x1234', + }), + getBalanceOf: new BN(1), + }, }); - const triggerTokensStateChange = async (state: TokensControllerState) => { - controllerMessenger.publish('TokensController:stateChange', state, []); - }; + expect(controller.state.contractBalances).toStrictEqual({}); await triggerTokensStateChange({ From 580038136aecc88aa16a501aa238cf924488fc64 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Wed, 12 Jun 2024 23:14:29 +0800 Subject: [PATCH 06/22] refactor: token rates tests --- .../src/TokenRatesController.test.ts | 495 ++++++++---------- .../src/TokenRatesController.ts | 91 +++- 2 files changed, 279 insertions(+), 307 deletions(-) diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 5cf53a0bae..73a71504e8 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1,3 +1,4 @@ +import { ControllerMessenger } from '@metamask/base-controller'; import { ChainId, InfuraNetworkType, @@ -35,6 +36,8 @@ import type { TokenRatesConfig, Token, TokenRatesState, + AllowedActions, + AllowedEvents, } from './TokenRatesController'; import type { TokensControllerState } from './TokensController'; @@ -43,6 +46,62 @@ const defaultMockInternalAccount = createMockInternalAccount({ }); const mockTokenAddress = '0x0000000000000000000000000000000000000010'; +/** + * Builds a new ControllerMessenger instance for TokenRatesController. + * @returns A new ControllerMessenger instance. + */ +const buildMessenger = (): ControllerMessenger< + AllowedActions, + AllowedEvents +> => { + return new ControllerMessenger(); +}; + +const buildTokenRatesControllerMessenger = ({ + messenger = buildMessenger(), + mocks: { + getSelectedAccount = defaultMockInternalAccount, + getAccount = defaultMockInternalAccount, + } = {}, +} = {}) => { + const tokenRatesControllerMessenger = messenger.getRestricted({ + name: 'TokenRatesController', + allowedActions: [ + 'AccountsController:getAccount', + 'AccountsController:getSelectedAccount', + ], + allowedEvents: [ + 'AccountsController:selectedEvmAccountChange', + 'TokensController:stateChange', + 'NetworkController:stateChange', + ], + }); + + const mockGetAccount = jest.fn().mockReturnValue(getAccount); + const mockGetSelectedAccount = jest.fn().mockReturnValue(getSelectedAccount); + + messenger.registerActionHandler( + 'AccountsController:getAccount', + mockGetAccount, + ); + + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + mockGetSelectedAccount, + ); + + const triggerSelectedAccountChange = (account: InternalAccount) => { + messenger.publish('AccountsController:selectedEvmAccountChange', account); + }; + + return { + messenger: tokenRatesControllerMessenger, + triggerSelectedAccountChange, + mockGetAccount, + mockGetSelectedAccount, + }; +}; + describe('TokenRatesController', () => { afterEach(() => { jest.restoreAllMocks(); @@ -60,13 +119,12 @@ describe('TokenRatesController', () => { }); it('should set default state', () => { + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, getNetworkClientById: jest.fn(), - getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -77,13 +135,12 @@ describe('TokenRatesController', () => { }); it('should initialize with the default config', () => { + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, getNetworkClientById: jest.fn(), - getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -102,14 +159,13 @@ describe('TokenRatesController', () => { it('should not poll by default', async () => { const fetchSpy = jest.spyOn(globalThis, 'fetch'); + const { messenger } = buildTokenRatesControllerMessenger(); new TokenRatesController({ + messenger, interval: 100, getNetworkClientById: jest.fn(), - getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -137,15 +193,10 @@ describe('TokenRatesController', () => { const chainId = '0xC'; const selectedAccount = defaultMockInternalAccount; const tokenAddresses = ['0xE1', '0xE2']; - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { @@ -162,6 +213,10 @@ describe('TokenRatesController', () => { }, allDetectedTokens: {}, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -196,15 +251,10 @@ describe('TokenRatesController', () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); const tokenAddresses = ['0xE1', '0xE2']; - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, @@ -221,6 +271,10 @@ describe('TokenRatesController', () => { }, }, }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -254,9 +308,6 @@ describe('TokenRatesController', () => { it('should not update exchange rates if both the "all tokens" or "all detected tokens" are exactly the same', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); const tokensState = { allTokens: { [chainId]: { @@ -276,10 +327,12 @@ describe('TokenRatesController', () => { { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: tokensState, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -315,14 +368,14 @@ describe('TokenRatesController', () => { { options: { chainId, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), }, config: { allTokens: tokens, allDetectedTokens: {}, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, }, }, async ({ controller, controllerEvents }) => { @@ -346,15 +399,14 @@ describe('TokenRatesController', () => { it('should not update exchange rates if a new token is added to "all detected tokens" but is already present in "all tokens"', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, }, config: { allTokens: { @@ -415,15 +467,10 @@ describe('TokenRatesController', () => { it('should not update exchange rates if a new token is added to "all tokens" but is already present in "all detected tokens"', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, @@ -440,6 +487,10 @@ describe('TokenRatesController', () => { }, }, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -484,15 +535,10 @@ describe('TokenRatesController', () => { it('should not update exchange rates if none of the addresses in "all tokens" or "all detected tokens" change, even if other parts of the token change', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, @@ -509,6 +555,10 @@ describe('TokenRatesController', () => { }, }, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -544,15 +594,10 @@ describe('TokenRatesController', () => { const selectedAccount = createMockInternalAccount({ address: '0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, @@ -569,6 +614,10 @@ describe('TokenRatesController', () => { }, }, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -602,15 +651,10 @@ describe('TokenRatesController', () => { it('should not update exchange rates if any of the addresses in "all tokens" or "all detected tokens" merely change order', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, @@ -633,6 +677,10 @@ describe('TokenRatesController', () => { }, }, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -674,16 +722,11 @@ describe('TokenRatesController', () => { it('should not update exchange rates when any of the addresses in the "all tokens" collection change', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); const tokenAddresses = ['0xE1', '0xE2']; await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { @@ -700,6 +743,10 @@ describe('TokenRatesController', () => { }, allDetectedTokens: {}, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -731,16 +778,11 @@ describe('TokenRatesController', () => { it('should not update exchange rates when any of the addresses in the "all detected tokens" collection change', async () => { const chainId = '0xC'; const selectedAccount = createMockInternalAccount({ address: '0xA' }); - const mockGetInternalAccount = jest - .fn() - .mockReturnValue(selectedAccount); const tokenAddresses = ['0xE1', '0xE2']; await withController( { options: { chainId, - selectedAccountId: selectedAccount.id, - getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, @@ -757,6 +799,10 @@ describe('TokenRatesController', () => { }, }, }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, }, async ({ controller, controllerEvents }) => { const updateExchangeRatesSpy = jest @@ -812,14 +858,14 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -851,14 +897,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -890,14 +935,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ interval: 100, - getInternalAccount: jest.fn(), + messenger, getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -927,14 +971,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -964,14 +1007,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1005,14 +1047,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1043,14 +1084,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1081,14 +1121,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1117,14 +1156,13 @@ describe('TokenRatesController', () => { .mockImplementation((listener) => { networkStateChangeListener = listener; }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController({ + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1159,24 +1197,15 @@ describe('TokenRatesController', () => { address: alternateSelectedAddress, }); - let selectedAccountChangeListener: ( - interalAccount: InternalAccount, - ) => Promise; - const onSelectedAccountChange = jest - .fn() - .mockImplementation((listener) => { - selectedAccountChangeListener = listener; - }); - + const { messenger, triggerSelectedAccountChange } = + buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval: 100, getNetworkClientById: jest.fn(), - getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1197,32 +1226,21 @@ describe('TokenRatesController', () => { .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await selectedAccountChangeListener!(alternativeAccount); + triggerSelectedAccountChange(alternativeAccount); expect(updateExchangeRatesSpy).toHaveBeenCalled(); }); it('should not update exchange rates when preferences state changes without selected address changing', async () => { - // TODO: Replace `any` with type - - let selectedAccountChangeListener: ( - interalAccount: InternalAccount, - ) => Promise; - const onSelectedAccountChange = jest - .fn() - .mockImplementation((listener) => { - selectedAccountChangeListener = listener; - }); + const { messenger, triggerSelectedAccountChange } = + buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1243,8 +1261,7 @@ describe('TokenRatesController', () => { .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await selectedAccountChangeListener!(defaultMockInternalAccount); + triggerSelectedAccountChange(defaultMockInternalAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); @@ -1257,24 +1274,15 @@ describe('TokenRatesController', () => { const alternateAccount = createMockInternalAccount({ address: alternateSelectedAddress, }); - let selectedAccountChangeListener: ( - interalAccount: InternalAccount, - ) => Promise; - const onSelectedAccountChange = jest - .fn() - .mockImplementation((listener) => { - selectedAccountChangeListener = listener; - }); - + const { messenger, triggerSelectedAccountChange } = + buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval: 100, - getInternalAccount: jest.fn(), getNetworkClientById: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1294,8 +1302,7 @@ describe('TokenRatesController', () => { .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await selectedAccountChangeListener!(alternateAccount); + triggerSelectedAccountChange(alternateAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); @@ -1318,17 +1325,14 @@ describe('TokenRatesController', () => { const interval = 100; const tokenPricesService = buildMockTokenPricesService(); jest.spyOn(tokenPricesService, 'fetchTokenPrices'); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval, getNetworkClientById: jest.fn(), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService, @@ -1365,17 +1369,14 @@ describe('TokenRatesController', () => { const interval = 100; const tokenPricesService = buildMockTokenPricesService(); jest.spyOn(tokenPricesService, 'fetchTokenPrices'); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval, getNetworkClientById: jest.fn(), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService, @@ -1422,13 +1423,13 @@ describe('TokenRatesController', () => { const interval = 100; const tokenPricesService = buildMockTokenPricesService(); jest.spyOn(tokenPricesService, 'fetchTokenPrices'); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval, chainId: '0x2', ticker: 'ticker', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1437,9 +1438,6 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { @@ -1478,12 +1476,12 @@ describe('TokenRatesController', () => { return currency === 'ETH'; }, }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, chainId: '0x2', ticker: 'ticker', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1492,9 +1490,6 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { @@ -1586,12 +1581,12 @@ describe('TokenRatesController', () => { return currency !== 'LOL'; }, }); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, chainId: '0x2', ticker: 'ticker', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1600,9 +1595,6 @@ describe('TokenRatesController', () => { ticker: 'LOL', }, }), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { @@ -1693,12 +1685,12 @@ describe('TokenRatesController', () => { ); const tokenPricesService = buildMockTokenPricesService(); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, chainId: '0x2', ticker: 'ETH', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1707,9 +1699,6 @@ describe('TokenRatesController', () => { ticker: 'LOL', }, }), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { @@ -1752,13 +1741,13 @@ describe('TokenRatesController', () => { const interval = 100; const tokenPricesService = buildMockTokenPricesService(); jest.spyOn(tokenPricesService, 'fetchTokenPrices'); + const { messenger } = buildTokenRatesControllerMessenger(); const controller = new TokenRatesController( { + messenger, interval, chainId: '0x2', ticker: 'ticker', - selectedAccountId: defaultMockInternalAccount.id, - onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1767,9 +1756,6 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { @@ -1815,14 +1801,8 @@ describe('TokenRatesController', () => { it('does not update state when disabled', async () => { await withController( { - options: { - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, config: { disabled: true, - selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -1855,65 +1835,53 @@ describe('TokenRatesController', () => { }); it('does not update state if there are no tokens for the given chain and address', async () => { - await withController( - { - options: { - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, - }, - }, - async ({ controller, controllerEvents }) => { - const tokenAddress = '0x0000000000000000000000000000000000000001'; - const differentAccount = '0x1000000000000000000000000000000000000000'; + await withController(async ({ controller, controllerEvents }) => { + const tokenAddress = '0x0000000000000000000000000000000000000001'; + const differentAccount = '0x1000000000000000000000000000000000000000'; - await callUpdateExchangeRatesMethod({ - allTokens: { - // These tokens are for the right chain but wrong account - [ChainId.mainnet]: { - [differentAccount]: [ - { - address: tokenAddress, - decimals: 18, - symbol: 'TST', - aggregators: [], - }, - ], - }, - // These tokens are for the right account but wrong chain - [toHex(2)]: { - [defaultMockInternalAccount.address]: [ - { - address: tokenAddress, - decimals: 18, - symbol: 'TST', - aggregators: [], - }, - ], - }, + await callUpdateExchangeRatesMethod({ + allTokens: { + // These tokens are for the right chain but wrong account + [ChainId.mainnet]: { + [differentAccount]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], }, - chainId: toHex(1), - controller, - controllerEvents, - method, - nativeCurrency: 'ETH', - selectedNetworkClientId: InfuraNetworkType.mainnet, - }); - - expect(controller.state).toStrictEqual({ - marketData: { - '0x1': { - '0x0000000000000000000000000000000000000000': { - currency: 'ETH', + // These tokens are for the right account but wrong chain + [toHex(2)]: { + [defaultMockInternalAccount.address]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], }, + ], + }, + }, + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, + }); + + expect(controller.state).toStrictEqual({ + marketData: { + '0x1': { + '0x0000000000000000000000000000000000000000': { + currency: 'ETH', }, }, - }); - }, - ); + }, + }); + }); }); it('does not update state if the price update fails', async () => { @@ -1926,12 +1894,6 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -1986,12 +1948,6 @@ describe('TokenRatesController', () => { options: { ticker, tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -2051,12 +2007,6 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -2133,11 +2083,7 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), }, - config: { selectedAccountId: defaultMockInternalAccount.id }, }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ @@ -2232,12 +2178,6 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2328,12 +2268,6 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2407,12 +2341,6 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2484,12 +2412,6 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, - getInternalAccount: jest - .fn() - .mockReturnValue(defaultMockInternalAccount), - }, - config: { - selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -2582,6 +2504,10 @@ type PartialConstructorParameters = { NetworkClientId, NetworkClientConfiguration >; + mocks?: { + getSelectedAccount: InternalAccount; + getAccount: InternalAccount; + }; }; type WithControllerArgs = @@ -2606,14 +2532,29 @@ async function withController( config = {}, state = {}, mockNetworkClientConfigurationsByNetworkClientId = {}, + mocks = {} as { + getSelectedAccount: InternalAccount; + getAccount: InternalAccount; + }, }, testFunction, ] = args.length === 2 ? args : [{}, args[0]]; + const { messenger, triggerSelectedAccountChange } = + buildTokenRatesControllerMessenger({ + mocks: { + getSelectedAccount: + mocks?.getSelectedAccount ?? defaultMockInternalAccount, + getAccount: mocks?.getAccount ?? defaultMockInternalAccount, + }, + }); + // explit cast used here because we know the `on____` functions are always // set in the constructor. const controllerEvents = {} as ControllerEvents; + controllerEvents.seletedAccountChange = triggerSelectedAccountChange; + const getNetworkClientById = buildMockGetNetworkClientById( mockNetworkClientConfigurationsByNetworkClientId, ); @@ -2621,19 +2562,15 @@ async function withController( const controllerOptions: ConstructorParameters< typeof TokenRatesController >[0] = { + messenger, chainId: toHex(1), getNetworkClientById, onNetworkStateChange: (listener) => { controllerEvents.networkStateChange = listener; }, - onSelectedAccountChange: (listener) => { - controllerEvents.seletedAccountChange = listener; - }, onTokensStateChange: (listener) => { controllerEvents.tokensStateChange = listener; }, - getInternalAccount: jest.fn(), - selectedAccountId: defaultMockInternalAccount.id, ticker: NetworksTicker.mainnet, tokenPricesService: buildMockTokenPricesService(), ...options, diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index b72f3d3906..62f3595df0 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -1,4 +1,13 @@ -import type { BaseConfig, BaseState } from '@metamask/base-controller'; +import type { + AccountsControllerGetAccountAction, + AccountsControllerGetSelectedAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; +import type { + BaseConfig, + BaseState, + RestrictedControllerMessenger, +} from '@metamask/base-controller'; import { safelyExecute, toChecksumHexAddress, @@ -9,6 +18,7 @@ import { type InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkController, + NetworkControllerStateChangeEvent, NetworkState, } from '@metamask/network-controller'; import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; @@ -19,7 +29,10 @@ import { reduceInBatchesSerially, TOKEN_PRICES_BATCH_SIZE } from './assetsUtil'; import { fetchExchangeRate as fetchNativeCurrencyExchangeRate } from './crypto-compare-service'; import type { AbstractTokenPricesService } from './token-prices-service/abstract-token-prices-service'; import { ZERO_ADDRESS } from './token-prices-service/codefi-v2'; -import type { TokensControllerState } from './TokensController'; +import type { + TokensControllerState, + TokensControllerStateChangeEvent, +} from './TokensController'; /** * @type Token @@ -152,6 +165,23 @@ async function getCurrencyConversionRate({ } } +export type AllowedActions = + | AccountsControllerGetAccountAction + | AccountsControllerGetSelectedAccountAction; + +export type AllowedEvents = + | AccountsControllerSelectedEvmAccountChangeEvent + | NetworkControllerStateChangeEvent + | TokensControllerStateChangeEvent; + +export type TokenRatesControllerMessenger = RestrictedControllerMessenger< + 'TokenRatesController', + AllowedActions, + AllowedEvents, + AllowedActions['type'], + AllowedEvents['type'] +>; + /** * Controller that passively polls on a set interval for token-to-fiat exchange rates * for tokens stored in the TokensController @@ -175,7 +205,7 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< private readonly getNetworkClientById: NetworkController['getNetworkClientById']; - private readonly getInternalAccount: (accountId: string) => InternalAccount; + private readonly messagingSystem: TokenRatesControllerMessenger; /** * Creates a TokenRatesController instance. @@ -183,12 +213,10 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< * @param options - The controller options. * @param options.interval - The polling interval in ms * @param options.threshold - The duration in ms before metadata fetched from CoinGecko is considered stale + * @param options.messenger - The messaging system used to communicate between controllers. * @param options.getNetworkClientById - Gets the network client with the given id from the NetworkController. * @param options.chainId - The chain ID of the current network. * @param options.ticker - The ticker for the current network. - * @param options.getInternalAccount - A callback to get an InternalAccount by id. - * @param options.selectedAccountId - The current selected address. - * @param options.onSelectedAccountChange - Allows subscribing to changes of selected account. * @param options.onTokensStateChange - Allows subscribing to token controller state changes. * @param options.onNetworkStateChange - Allows subscribing to network state changes. * @param options.tokenPricesService - An object in charge of retrieving token prices. @@ -199,26 +227,21 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< { interval = 3 * 60 * 1000, threshold = 6 * 60 * 60 * 1000, + messenger, getNetworkClientById, chainId: initialChainId, ticker: initialTicker, - selectedAccountId, - getInternalAccount, - onSelectedAccountChange, onTokensStateChange, onNetworkStateChange, tokenPricesService, }: { interval?: number; threshold?: number; + messenger: TokenRatesControllerMessenger; getNetworkClientById: NetworkController['getNetworkClientById']; chainId: Hex; ticker: string; - selectedAccountId: string; - getInternalAccount: (accountId: string) => InternalAccount; - onSelectedAccountChange: ( - listener: (internalAccount: InternalAccount) => void, - ) => void; + onTokensStateChange: ( listener: (tokensState: TokensControllerState) => void, ) => void; @@ -231,13 +254,18 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< state?: Partial, ) { super(config, state); + + this.messagingSystem = messenger; + this.defaultConfig = { interval, threshold, disabled: false, nativeCurrency: initialTicker, chainId: initialChainId, - selectedAccountId, + selectedAccountId: this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ).id, allTokens: {}, // TODO: initialize these correctly, maybe as part of BaseControllerV2 migration allDetectedTokens: {}, }; @@ -248,22 +276,12 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< this.initialize(); this.setIntervalLength(interval); this.getNetworkClientById = getNetworkClientById; - this.getInternalAccount = getInternalAccount; this.#tokenPricesService = tokenPricesService; if (config?.disabled) { this.configure({ disabled: true }, false, false); } - onSelectedAccountChange(async (internalAccount) => { - if (this.config.selectedAccountId !== internalAccount.id) { - this.configure({ selectedAccountId: internalAccount.id }); - if (this.#pollState === PollState.Active) { - await this.updateExchangeRates(); - } - } - }); - onTokensStateChange(async ({ allTokens, allDetectedTokens }) => { const previousTokenAddresses = this.#getTokenAddresses( this.config.chainId, @@ -295,6 +313,11 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< } } }); + + this.messagingSystem.subscribe( + 'AccountsController:selectedEvmAccountChange', + (internalAccount) => this.#onSelectedAccountChange(internalAccount), + ); } /** @@ -305,10 +328,13 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< */ #getTokenAddresses(chainId: Hex): Hex[] { const { allTokens, allDetectedTokens, selectedAccountId } = this.config; - const internalAccount = this.getInternalAccount(selectedAccountId); - const tokens = allTokens[chainId]?.[internalAccount.address] || []; + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + selectedAccountId, + ); + const tokens = allTokens[chainId]?.[internalAccount?.address ?? ''] || []; const detectedTokens = - allDetectedTokens[chainId]?.[internalAccount.address] || []; + allDetectedTokens[chainId]?.[internalAccount?.address ?? ''] || []; return [ ...new Set( @@ -626,6 +652,15 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< return updatedContractExchangeRates; } + + async #onSelectedAccountChange(newInternalAccount: InternalAccount) { + if (this.config.selectedAccountId !== newInternalAccount.id) { + this.configure({ selectedAccountId: newInternalAccount.id }); + if (this.#pollState === PollState.Active) { + await this.updateExchangeRates(); + } + } + } } export default TokenRatesController; From d38c696f95bbfeee67fb6073830b8933687a321f Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 13 Jun 2024 00:48:45 +0800 Subject: [PATCH 07/22] fix: test --- .../src/TokenRatesController.test.ts | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 10d6122f6c..1d2dcf27df 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1,3 +1,4 @@ +import { createMockInternalAccount } from '@metamask/accounts-controller/src/tests/mocks'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import { ControllerMessenger } from '@metamask/base-controller'; import { @@ -7,16 +8,13 @@ import { toChecksumHexAddress, toHex, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkState, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; import type { NetworkClientConfiguration } from '@metamask/network-controller/src/types'; -import { - getDefaultPreferencesState, - type PreferencesState, -} from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { add0x } from '@metamask/utils'; import assert from 'assert'; @@ -45,6 +43,9 @@ import { getDefaultTokensState } from './TokensController'; import type { TokensControllerState } from './TokensController'; const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; +const defaultSelectedAccount = createMockInternalAccount({ + address: defaultSelectedAddress, +}); const mockTokenAddress = '0x0000000000000000000000000000000000000010'; const defaultSelectedNetworkClientId = 'AAAA-BBBB-CCCC-DDDD'; @@ -69,11 +70,14 @@ function buildTokenRatesControllerMessenger( 'NetworkController:getNetworkClientById', 'NetworkController:getState', 'PreferencesController:getState', + 'AccountsController:getAccount', + 'AccountsController:getSelectedAccount', ], allowedEvents: [ 'PreferencesController:stateChange', 'TokensController:stateChange', 'NetworkController:stateChange', + 'AccountsController:selectedEvmAccountChange', ], }); } @@ -997,6 +1001,9 @@ describe('TokenRatesController', () => { it('should update exchange rates when selected address changes', async () => { const alternateSelectedAddress = '0x0000000000000000000000000000000000000002'; + const alternateSelectedAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); await withController( { options: { @@ -1023,69 +1030,26 @@ describe('TokenRatesController', () => { }, }, }, - async ({ controller, triggerPreferencesStateChange }) => { + async ({ controller, triggerSelectedAccountChange }) => { await controller.start(); const updateExchangeRatesSpy = jest .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: alternateSelectedAddress, - }); + triggerSelectedAccountChange(alternateSelectedAccount); expect(updateExchangeRatesSpy).toHaveBeenCalledTimes(1); }, ); }); - - it('should not update exchange rates when preferences state changes without selected address changing', async () => { - await withController( - { - options: { - interval: 100, - }, - mockTokensControllerState: { - allTokens: { - '0x1': { - [defaultSelectedAddress]: [ - { - address: '0x02', - decimals: 0, - symbol: '', - aggregators: [], - }, - { - address: '0x03', - decimals: 0, - symbol: '', - aggregators: [], - }, - ], - }, - }, - }, - }, - async ({ controller, triggerPreferencesStateChange }) => { - await controller.start(); - const updateExchangeRatesSpy = jest - .spyOn(controller, 'updateExchangeRates') - .mockResolvedValue(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: defaultSelectedAddress, - openSeaEnabled: false, - }); - - expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); - }, - ); - }); }); describe('when polling is inactive', () => { - it('should not update exchange rates when selected address changes', async () => { + it('should not update exchange rates when selected account hanges', async () => { const alternateSelectedAddress = '0x0000000000000000000000000000000000000002'; + const alternateSelectedAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); await withController( { options: { @@ -1112,14 +1076,11 @@ describe('TokenRatesController', () => { }, }, }, - async ({ controller, triggerPreferencesStateChange }) => { + async ({ controller, triggerSelectedAccountChange }) => { const updateExchangeRatesSpy = jest .spyOn(controller, 'updateExchangeRates') .mockResolvedValue(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: alternateSelectedAddress, - }); + triggerSelectedAccountChange(alternateSelectedAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }, @@ -2315,12 +2276,12 @@ describe('TokenRatesController', () => { */ type WithControllerCallback = ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, triggerTokensStateChange, triggerNetworkStateChange, }: { controller: TokenRatesController; - triggerPreferencesStateChange: (state: PreferencesState) => void; + triggerSelectedAccountChange: (state: InternalAccount) => void; triggerTokensStateChange: (state: TokensControllerState) => void; triggerNetworkStateChange: (state: NetworkState) => void; }) => Promise | ReturnValue; @@ -2389,13 +2350,16 @@ async function withController( }), ); - const mockPreferencesState = jest.fn(); + const mockGetSelectedAccount = jest.fn(); controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - mockPreferencesState.mockReturnValue({ - ...getDefaultPreferencesState(), - selectedAddress: defaultSelectedAddress, - }), + 'AccountsController:getSelectedAccount', + mockGetSelectedAccount.mockReturnValue(defaultSelectedAccount), + ); + + const mockGetAccount = jest.fn(); + controllerMessenger.registerActionHandler( + 'AccountsController:getAccount', + mockGetAccount.mockReturnValue(defaultSelectedAccount), ); const controller = new TokenRatesController({ @@ -2406,13 +2370,13 @@ async function withController( try { return await fn({ controller, - triggerPreferencesStateChange: (state: PreferencesState) => { + triggerSelectedAccountChange: (account: InternalAccount) => { controllerMessenger.publish( - 'PreferencesController:stateChange', - state, - [], + 'AccountsController:selectedEvmAccountChange', + account, ); }, + triggerTokensStateChange: (state: TokensControllerState) => { controllerMessenger.publish('TokensController:stateChange', state, []); }, From 03aac220f4ec6ec3e6d625cf8712dfb75aece7f2 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 13 Jun 2024 10:16:27 +0800 Subject: [PATCH 08/22] refactor: TokensController to not require selectedAccountId in constructor --- .../src/TokensController.test.ts | 315 ++++++++++-------- .../src/TokensController.ts | 71 ++-- 2 files changed, 226 insertions(+), 160 deletions(-) diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index fb4e74c46f..76e352139f 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -40,12 +40,17 @@ import { TOKEN_END_POINT_API } from './token-service'; import type { Token } from './TokenRatesController'; import { TokensController } from './TokensController'; import type { + AllowedActions, + AllowedEvents, TokensControllerMessenger, TokensControllerState, } from './TokensController'; jest.mock('@ethersproject/contracts'); -jest.mock('uuid'); +jest.mock('uuid', () => ({ + ...jest.requireActual('uuid'), + v1: jest.fn(), +})); jest.mock('./Standards/ERC20Standard'); jest.mock('./Standards/NftStandards/ERC1155/ERC1155Standard'); @@ -270,7 +275,21 @@ describe('TokensController', () => { }); it('should add token by selected address', async () => { + const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); + const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); await withController( + { + mocks: { + getAccount: firstAccount, + getSelectedAccount: firstAccount, + }, + }, async ({ controller, triggerSelectedAccountChange, @@ -279,22 +298,14 @@ describe('TokensController', () => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - const firstAddress = '0x123'; - const firstAccount = createMockInternalAccount({ - address: firstAddress, - }); - const secondAddress = '0x321'; - const secondAccount = createMockInternalAccount({ - address: secondAddress, - }); - getAccountHandler.mockReturnValue(firstAccount); triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x01', symbol: 'bar', decimals: 2, }); + getAccountHandler.mockReturnValue(secondAccount); triggerSelectedAccountChange(secondAccount); expect(controller.state.tokens).toHaveLength(0); @@ -414,7 +425,21 @@ describe('TokensController', () => { }); it('should remove token by selected address', async () => { + const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); + const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); await withController( + { + mocks: { + getAccount: firstAccount, + getSelectedAccount: firstAccount, + }, + }, async ({ controller, triggerSelectedAccountChange, @@ -423,16 +448,7 @@ describe('TokensController', () => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - const firstAddress = '0x123'; - const firstAccount = createMockInternalAccount({ - address: firstAddress, - }); - const secondAddress = '0x321'; - const secondAccount = createMockInternalAccount({ - address: secondAddress, - }); - getAccountHandler.mockReturnValue(firstAccount); triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x02', @@ -450,6 +466,7 @@ describe('TokensController', () => { controller.ignoreTokens(['0x01']); expect(controller.state.tokens).toHaveLength(0); + getAccountHandler.mockReturnValue(firstAccount); triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x02', @@ -530,18 +547,18 @@ describe('TokensController', () => { }); it('should remove a token from the ignoredTokens/allIgnoredTokens lists if re-added as part of a bulk addTokens add', async () => { + const selectedAddress = '0x0001'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( - async ({ - controller, - triggerSelectedAccountChange, - changeNetwork, - getAccountHandler, - }) => { - const selectedAddress = '0x0001'; - const selectedAccount = createMockInternalAccount({ - address: selectedAddress, - }); - getAccountHandler.mockReturnValue(selectedAccount); + { + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange, changeNetwork }) => { triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ @@ -579,18 +596,18 @@ describe('TokensController', () => { }); it('should be able to clear the ignoredTokens list', async () => { + const selectedAddress = '0x0001'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( - async ({ - controller, - triggerSelectedAccountChange, - changeNetwork, - getAccountHandler, - }) => { - const selectedAddress = '0x0001'; - const selectedAccount = createMockInternalAccount({ - address: selectedAddress, - }); - getAccountHandler.mockReturnValue(selectedAccount); + { + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ controller, triggerSelectedAccountChange, changeNetwork }) => { triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ @@ -618,22 +635,27 @@ describe('TokensController', () => { }); it('should ignore tokens by [chainID][accountAddress]', async () => { + const selectedAddress1 = '0x0001'; + const selectedAccount1 = createMockInternalAccount({ + address: selectedAddress1, + }); + const selectedAddress2 = '0x0002'; + const selectedAccount2 = createMockInternalAccount({ + address: selectedAddress2, + }); await withController( + { + mocks: { + getSelectedAccount: selectedAccount1, + getAccount: selectedAccount1, + }, + }, async ({ controller, triggerSelectedAccountChange, changeNetwork, getAccountHandler, }) => { - const selectedAddress1 = '0x0001'; - const selectedAccount1 = createMockInternalAccount({ - address: selectedAddress1, - }); - const selectedAddress2 = '0x0002'; - const selectedAccount2 = createMockInternalAccount({ - address: selectedAddress2, - }); - getAccountHandler.mockReturnValue(selectedAccount1); triggerSelectedAccountChange(selectedAccount1); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ @@ -798,8 +820,7 @@ describe('TokensController', () => { describe('addToken method', () => { it('should add isERC721 = true when token is an NFT and is in our contract-metadata repo', async () => { - await withController(async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(defaultMockInternalAccount); + await withController(async ({ controller }) => { const contractAddresses = Object.keys(contractMaps); const erc721ContractAddresses = contractAddresses.filter( (contractAddress) => contractMaps[contractAddress].erc721 === true, @@ -821,8 +842,7 @@ describe('TokensController', () => { }); it('should add isERC721 = true when the token is an NFT but not in our contract-metadata repo', async () => { - await withController(async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(defaultMockInternalAccount); + await withController(async ({ controller }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: true }), ); @@ -850,8 +870,7 @@ describe('TokensController', () => { }); it('should add isERC721 = false to token object already in state when token is not an NFT and in our contract-metadata repo', async () => { - await withController(async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(defaultMockInternalAccount); + await withController(async ({ controller }) => { const contractAddresses = Object.keys(contractMaps); const erc20ContractAddresses = contractAddresses.filter( (contractAddress) => contractMaps[contractAddress].erc20 === true, @@ -873,8 +892,7 @@ describe('TokensController', () => { }); it('should add isERC721 = false when the token is not an NFT and not in our contract-metadata repo', async () => { - await withController(async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(defaultMockInternalAccount); + await withController(async ({ controller }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -902,26 +920,23 @@ describe('TokensController', () => { }); it('should throw error if switching networks while adding token', async () => { - await withController( - async ({ controller, changeNetwork, getAccountHandler }) => { - getAccountHandler.mockReturnValue(defaultMockInternalAccount); - const dummyTokenAddress = - '0x514910771AF9Ca656af840dff83E8264EcF986CA'; + await withController(async ({ controller, changeNetwork }) => { + const dummyTokenAddress = + '0x514910771AF9Ca656af840dff83E8264EcF986CA'; - const addTokenPromise = controller.addToken({ - address: dummyTokenAddress, - symbol: 'LINK', - decimals: 18, - }); - changeNetwork({ - selectedNetworkClientId: InfuraNetworkType.goerli, - }); + const addTokenPromise = controller.addToken({ + address: dummyTokenAddress, + symbol: 'LINK', + decimals: 18, + }); + changeNetwork({ + selectedNetworkClientId: InfuraNetworkType.goerli, + }); - await expect(addTokenPromise).rejects.toThrow( - 'TokensController Error: Switched networks while adding token', - ); - }, - ); + await expect(addTokenPromise).rejects.toThrow( + 'TokensController Error: Switched networks while adding token', + ); + }); }); }); @@ -996,13 +1011,17 @@ describe('TokensController', () => { }); it('should add tokens to the correct chainId/selectedAddress on which they were detected even if its not the currently configured chainId/selectedAddress', async () => { + const CONFIGURED_ADDRESS = '0xConfiguredAddress'; + const configuredAccount = createMockInternalAccount({ + address: CONFIGURED_ADDRESS, + }); await withController( - async ({ - controller, - changeNetwork, - triggerSelectedAccountChange, - getAccountHandler, - }) => { + { + mocks: { + getAccount: configuredAccount, + }, + }, + async ({ controller, changeNetwork, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -1010,10 +1029,7 @@ describe('TokensController', () => { // The currently configured chain + address const CONFIGURED_CHAIN = ChainId.sepolia; const CONFIGURED_NETWORK_CLIENT_ID = InfuraNetworkType.sepolia; - const CONFIGURED_ADDRESS = '0xConfiguredAddress'; - const configuredAccount = createMockInternalAccount({ - address: CONFIGURED_ADDRESS, - }); + changeNetwork({ selectedNetworkClientId: CONFIGURED_NETWORK_CLIENT_ID, }); @@ -1041,8 +1057,6 @@ describe('TokensController', () => { detectedTokenOtherAccount, ] = generateTokens(3); - getAccountHandler.mockReturnValue(configuredAccount); - // Run twice to ensure idempotency for (let i = 0; i < 2; i++) { // Add and detect some tokens on the configured chain + account @@ -1873,7 +1887,17 @@ describe('TokensController', () => { describe('when PreferencesController:stateChange is published', () => { it('should update tokens list when set address changes', async () => { + const selectedAccount = createMockInternalAccount({ address: '0x1' }); + const selectedAccount2 = createMockInternalAccount({ + address: '0x2', + }); await withController( + { + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, + }, async ({ controller, triggerSelectedAccountChange, @@ -1882,11 +1906,7 @@ describe('TokensController', () => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - const selectedAccount = createMockInternalAccount({ address: '0x1' }); - const selectedAccount2 = createMockInternalAccount({ - address: '0x2', - }); - getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); await controller.addToken({ address: '0x01', @@ -1898,6 +1918,7 @@ describe('TokensController', () => { symbol: 'B', decimals: 5, }); + getAccountHandler.mockReturnValue(selectedAccount2); triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([]); @@ -1907,6 +1928,7 @@ describe('TokensController', () => { symbol: 'C', decimals: 6, }); + getAccountHandler.mockReturnValue(selectedAccount); triggerSelectedAccountChange(selectedAccount); expect(controller.state.tokens).toStrictEqual([ { @@ -2056,11 +2078,12 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, - async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(selectedAccount); + async ({ controller }) => { await controller.addTokens(dummyTokens); controller.ignoreTokens([tokenAddress]); @@ -2091,11 +2114,12 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, - async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(selectedAccount); + async ({ controller }) => { await controller.addTokens(dummyTokens); controller.ignoreTokens([tokenAddress]); await controller.addTokens(dummyTokens); @@ -2127,11 +2151,12 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAccountId: selectedAccount.id, + }, + mocks: { + getSelectedAccount: selectedAccount, }, }, - async ({ controller, getAccountHandler }) => { - getAccountHandler.mockReturnValue(selectedAccount); + async ({ controller }) => { await controller.addDetectedTokens(dummyTokens); await controller.addTokens(dummyTokens); @@ -2216,24 +2241,24 @@ describe('TokensController', () => { }); }); - it('should update the token states to empty arrays if the selectedAccountId is not set', async () => { - await withController( - { options: { selectedAccountId: '' } }, - async ({ controller, changeNetwork }) => { - ContractMock.mockReturnValue( - buildMockEthersERC721Contract({ supportsInterface: false }), - ); - - changeNetwork({ - selectedNetworkClientId: InfuraNetworkType.sepolia, - }); - - expect(controller.state.tokens).toStrictEqual([]); - expect(controller.state.ignoredTokens).toStrictEqual([]); - expect(controller.state.detectedTokens).toStrictEqual([]); - }, - ); - }); + // it('should update the token states to empty arrays if the selectedAccountId is not set', async () => { + // await withController( + // { options: { selectedAccountId: '' } }, + // async ({ controller, changeNetwork }) => { + // ContractMock.mockReturnValue( + // buildMockEthersERC721Contract({ supportsInterface: false }), + // ); + + // changeNetwork({ + // selectedNetworkClientId: InfuraNetworkType.sepolia, + // }); + + // expect(controller.state.tokens).toStrictEqual([]); + // expect(controller.state.ignoredTokens).toStrictEqual([]); + // expect(controller.state.detectedTokens).toStrictEqual([]); + // }, + // ); + // }); }); describe('addToken', () => { @@ -2255,7 +2280,7 @@ describe('TokensController', () => { }); describe('addDetectedTokens', () => { - it('should handle undefined selected account', async () => { + it('handles an undefined selected account', async () => { await withController(async ({ controller, getAccountHandler }) => { getAccountHandler.mockReturnValue(undefined); await controller.addDetectedTokens([ @@ -2330,12 +2355,17 @@ type WithControllerCallback = ({ messenger: UnrestrictedMessenger; approvalController: ApprovalController; triggerSelectedAccountChange: (internalAccount: InternalAccount) => void; - getAccountHandler: jest.Mock< - ReturnType, - Parameters + getAccountHandler: jest.Mock>; + getSelectedAccountHandler: jest.Mock< + ReturnType >; }) => Promise | ReturnValue; +type WithControllerMockArgs = { + getAccount?: InternalAccount; + getSelectedAccount?: InternalAccount; +}; + type WithControllerArgs = | [WithControllerCallback] | [ @@ -2345,6 +2375,7 @@ type WithControllerArgs = NetworkClientId, NetworkClientConfiguration >; + mocks?: WithControllerMockArgs; }, WithControllerCallback, ]; @@ -2359,17 +2390,22 @@ type WithControllerArgs = * @param args.mockNetworkClientConfigurationsByNetworkClientId - Used to construct * mock versions of network clients and ultimately mock the * `NetworkController:getNetworkClientById` action. + * @param args.mocks - Move values for actions to be mocked. * @returns A collection of test controllers and mocks. */ async function withController( ...args: WithControllerArgs ): Promise { const [ - { options = {}, mockNetworkClientConfigurationsByNetworkClientId = {} }, + { + options = {}, + mockNetworkClientConfigurationsByNetworkClientId = {}, + mocks = {} as WithControllerMockArgs, + }, fn, ] = args.length === 2 ? args : [{}, args[0]]; - const messenger: UnrestrictedMessenger = new ControllerMessenger(); + const messenger = new ControllerMessenger(); const approvalControllerMessenger = messenger.getRestricted({ name: 'ApprovalController', @@ -2388,6 +2424,7 @@ async function withController( 'ApprovalController:addRequest', 'NetworkController:getNetworkClientById', 'AccountsController:getAccount', + 'AccountsController:getSelectedAccount', ], allowedEvents: [ 'NetworkController:networkDidChange', @@ -2395,9 +2432,25 @@ async function withController( 'TokenListController:stateChange', ], }); + + const getAccountHandler = jest.fn(); + messenger.registerActionHandler( + 'AccountsController:getAccount', + getAccountHandler.mockReturnValue( + mocks?.getAccount ?? defaultMockInternalAccount, + ), + ); + + const getSelectedAccountHandler = jest.fn(); + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + getSelectedAccountHandler.mockReturnValue( + mocks?.getSelectedAccount ?? defaultMockInternalAccount, + ), + ); + const controller = new TokensController({ chainId: ChainId.mainnet, - selectedAccountId: defaultMockInternalAccount.id, // The tests assume that this is set, but they shouldn't make that // assumption. But we have to do this due to a bug in TokensController // where the provider can possibly be `undefined` if `networkClientId` is @@ -2414,13 +2467,6 @@ async function withController( ); }; - const getAccountHandler = jest.fn(); - - messenger.registerActionHandler( - `AccountsController:getAccount`, - getAccountHandler.mockReturnValue(defaultMockInternalAccount), - ); - const changeNetwork = ({ selectedNetworkClientId, }: { @@ -2447,6 +2493,7 @@ async function withController( approvalController, triggerSelectedAccountChange, getAccountHandler, + getSelectedAccountHandler, }); } diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index e55829caa4..aa97aa1eb6 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -2,6 +2,7 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; import type { AccountsControllerGetAccountAction, + AccountsControllerGetSelectedAccountAction, AccountsControllerSelectedEvmAccountChangeEvent, } from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; @@ -138,7 +139,8 @@ export type TokensControllerAddDetectedTokensAction = { export type AllowedActions = | AddApprovalRequest | NetworkControllerGetNetworkClientByIdAction - | AccountsControllerGetAccountAction; + | AccountsControllerGetAccountAction + | AccountsControllerGetSelectedAccountAction; export type TokensControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -196,20 +198,17 @@ export class TokensController extends BaseController< * Tokens controller options * @param options - Constructor options. * @param options.chainId - The chain ID of the current network. - * @param options.selectedAccountId - Vault selected account id * @param options.provider - Network provider. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messenger. */ constructor({ chainId: initialChainId, - selectedAccountId, provider, state, messenger, }: { chainId: Hex; - selectedAccountId: string; provider: Provider | undefined; state?: Partial; messenger: TokensControllerMessenger; @@ -228,7 +227,7 @@ export class TokensController extends BaseController< this.#provider = provider; - this.#selectedAccountId = selectedAccountId; + this.#selectedAccountId = this.#getSelectedAccount().id; this.#abortController = new AbortController(); @@ -372,15 +371,10 @@ export class TokensController extends BaseController< ).configuration.chainId; } - const internalAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); - - // Previously selectedAddress could be an empty string. This is to preserve the behaviour - const accountAddress = interactingAddress || internalAccount?.address || ''; + const accountAddress = + this.#getAddressOrSelectedAddress(interactingAddress); const isInteractingWithWalletAccount = - accountAddress === internalAccount?.address; + this.#isInterctingWithWallet(accountAddress); try { address = toChecksumHexAddress(address); @@ -830,11 +824,8 @@ export class TokensController extends BaseController< throw rpcErrors.invalidParams(`Invalid address "${asset.address}"`); } - // Validate if account is an evm account - const selectedAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); + const selectedAddress = + this.#getAddressOrSelectedAddress(interactingAddress); // Validate contract @@ -936,8 +927,7 @@ export class TokensController extends BaseController< id: this.#generateRandomId(), time: Date.now(), type, - // Previously selectedAddress could be an empty string. This is to preserve the behaviour - interactingAddress: interactingAddress || selectedAccount?.address || '', + interactingAddress: selectedAddress, }; await this.#requestApproval(suggestedAssetMeta); @@ -981,13 +971,11 @@ export class TokensController extends BaseController< interactingChainId, } = params; const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - const selectedInternalAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); - // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const userAddressToAddTokens = - interactingAddress ?? selectedInternalAccount?.address ?? ''; + this.#getAddressOrSelectedAddress(interactingAddress); + + console.log('userAddressToAddTokens', userAddressToAddTokens); const chainIdToAddTokens = interactingChainId ?? this.#chainId; @@ -1050,6 +1038,29 @@ export class TokensController extends BaseController< return { newAllTokens, newAllIgnoredTokens, newAllDetectedTokens }; } + #getAddressOrSelectedAddress(address: string | undefined): string { + if (address) { + return address; + } + + // If the address is not defined (or empty), we fallback to the currently selected account's address + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + return selectedAccount?.address || ''; + } + + #isInterctingWithWallet(address: string) { + // If the address is not defined (or empty), we fallback to the currently selected account's address + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + return selectedAccount?.address === address; + } + /** * Removes all tokens from the ignored list. */ @@ -1081,6 +1092,14 @@ export class TokensController extends BaseController< true, ); } + + #getSelectedAccount() { + const account = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ); + + return account; + } } export default TokensController; From 5f3f15bf60369ca8aa875c6c5f333b38bddbde35 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 13 Jun 2024 13:42:45 +0800 Subject: [PATCH 09/22] fix: lint --- packages/assets-controllers/src/TokenRatesController.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index e84f67be42..b9e3e72858 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -14,6 +14,7 @@ import { FALL_BACK_VS_CURRENCY, toHex, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkControllerGetNetworkClientByIdAction, @@ -37,7 +38,6 @@ import type { TokensControllerStateChangeEvent, TokensControllerState, } from './TokensController'; -import { InternalAccount } from '@metamask/keyring-api'; /** * @type Token From b0f09ea323b2d26fe814ef49677acbf90314e376 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Fri, 14 Jun 2024 00:30:09 +0800 Subject: [PATCH 10/22] fix: test names --- .../src/TokenDetectionController.test.ts | 2 +- .../src/TokensController.test.ts | 25 +++---------------- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 5cc0125221..2e23ceb5fd 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -2149,7 +2149,7 @@ describe('TokenDetectionController', () => { ); }); - it('should not trigger `TokensController:addDetectedTokens` action when selectedAccount is not found', async () => { + it('does not trigger `TokensController:addDetectedTokens` action when selectedAccount is not found', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 76e352139f..352ad744c4 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -2226,7 +2226,7 @@ describe('TokensController', () => { describe('when selectedAccountId is not set or account not found', () => { describe('detectTokens', () => { - it('should update the token states to empty arrays if the selectedAccountId account is undefined', async () => { + it('updates the token states to empty arrays if the selectedAccountId account is undefined', async () => { await withController(async ({ controller, changeNetwork }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), @@ -2240,29 +2240,10 @@ describe('TokensController', () => { expect(controller.state.detectedTokens).toStrictEqual([]); }); }); - - // it('should update the token states to empty arrays if the selectedAccountId is not set', async () => { - // await withController( - // { options: { selectedAccountId: '' } }, - // async ({ controller, changeNetwork }) => { - // ContractMock.mockReturnValue( - // buildMockEthersERC721Contract({ supportsInterface: false }), - // ); - - // changeNetwork({ - // selectedNetworkClientId: InfuraNetworkType.sepolia, - // }); - - // expect(controller.state.tokens).toStrictEqual([]); - // expect(controller.state.ignoredTokens).toStrictEqual([]); - // expect(controller.state.detectedTokens).toStrictEqual([]); - // }, - // ); - // }); }); describe('addToken', () => { - it('should handle undefined selected account', async () => { + it('handles undefined selected account', async () => { await withController(async ({ controller, getAccountHandler }) => { getAccountHandler.mockReturnValue(undefined); const contractAddresses = Object.keys(contractMaps); @@ -2306,7 +2287,7 @@ describe('TokensController', () => { }); describe('watchAsset', () => { - it('should handle undefined selected account', async () => { + it('handles undefined selected account', async () => { await withController( async ({ controller, approvalController, getAccountHandler }) => { const requestId = '12345'; From 3539272d12df831f3e0460a838842734125e88da Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Fri, 14 Jun 2024 00:32:45 +0800 Subject: [PATCH 11/22] revert: coverage --- packages/assets-controllers/jest.config.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/assets-controllers/jest.config.js b/packages/assets-controllers/jest.config.js index b72d599a24..142e65bf76 100644 --- a/packages/assets-controllers/jest.config.js +++ b/packages/assets-controllers/jest.config.js @@ -17,7 +17,8 @@ module.exports = merge(baseConfig, { // An object that configures minimum threshold enforcement for coverage results coverageThreshold: { global: { - branches: 90.11, + branches: 90.35, + functions: 96.74, lines: 97.34, statements: 97.36, From f92a0a0443f65d92dfeb5b0c2d6494a9fb15f2e3 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Fri, 14 Jun 2024 08:37:58 +0800 Subject: [PATCH 12/22] fix: lint --- packages/assets-controllers/src/TokensController.test.ts | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 352ad744c4..9c355c9cc4 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -1,5 +1,4 @@ import { Contract } from '@ethersproject/contracts'; -import type { AccountsController } from '@metamask/accounts-controller'; import type { ApprovalStateChange } from '@metamask/approval-controller'; import { ApprovalController, @@ -2336,10 +2335,8 @@ type WithControllerCallback = ({ messenger: UnrestrictedMessenger; approvalController: ApprovalController; triggerSelectedAccountChange: (internalAccount: InternalAccount) => void; - getAccountHandler: jest.Mock>; - getSelectedAccountHandler: jest.Mock< - ReturnType - >; + getAccountHandler: jest.Mock; + getSelectedAccountHandler: jest.Mock; }) => Promise | ReturnValue; type WithControllerMockArgs = { From fd652a3aa050425675f9295c675f372e1ad64ef7 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Sat, 15 Jun 2024 00:33:43 +0800 Subject: [PATCH 13/22] refactor: create new helper --- .../src/TokensController.ts | 40 ++++++------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index aa97aa1eb6..dbf59cfa95 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -375,7 +375,6 @@ export class TokensController extends BaseController< this.#getAddressOrSelectedAddress(interactingAddress); const isInteractingWithWalletAccount = this.#isInterctingWithWallet(accountAddress); - try { address = toChecksumHexAddress(address); const tokens = allTokens[currentChainId]?.[accountAddress] || []; @@ -583,15 +582,10 @@ export class TokensController extends BaseController< ) { const releaseLock = await this.#mutex.acquire(); - const internalAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); - const chainId = detectionDetails?.chainId ?? this.#chainId; // Previously selectedAddress could be an empty string. This is to preserve the behaviour const accountAddress = - detectionDetails?.selectedAddress ?? internalAccount?.address ?? ''; + detectionDetails?.selectedAddress ?? this.#getSelectedAddress(); const { allTokens, allDetectedTokens, allIgnoredTokens } = this.state; let newTokens = [...(allTokens?.[chainId]?.[accountAddress] ?? [])]; @@ -658,17 +652,11 @@ export class TokensController extends BaseController< // We may be detecting tokens on a different chain/account pair than are currently configured. // Re-point `tokens` and `detectedTokens` to keep them referencing the current chain/account. - const currentInternalAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); - - // Previously selectedAddress could be an empty string. This is to preserve the behaviour - const currentAddress = currentInternalAccount?.address || ''; + const selectedAddress = this.#getSelectedAddress(); - newTokens = newAllTokens?.[this.#chainId]?.[currentAddress] || []; + newTokens = newAllTokens?.[this.#chainId]?.[selectedAddress] || []; newDetectedTokens = - newAllDetectedTokens?.[this.#chainId]?.[currentAddress] || []; + newAllDetectedTokens?.[this.#chainId]?.[selectedAddress] || []; this.update((state) => { state.tokens = newTokens; @@ -975,8 +963,6 @@ export class TokensController extends BaseController< const userAddressToAddTokens = this.#getAddressOrSelectedAddress(interactingAddress); - console.log('userAddressToAddTokens', userAddressToAddTokens); - const chainIdToAddTokens = interactingChainId ?? this.#chainId; let newAllTokens = allTokens; @@ -1043,12 +1029,7 @@ export class TokensController extends BaseController< return address; } - // If the address is not defined (or empty), we fallback to the currently selected account's address - const selectedAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); - return selectedAccount?.address || ''; + return this.#getSelectedAddress(); } #isInterctingWithWallet(address: string) { @@ -1094,11 +1075,16 @@ export class TokensController extends BaseController< } #getSelectedAccount() { + return this.messagingSystem.call('AccountsController:getSelectedAccount'); + } + + #getSelectedAddress() { + // If the address is not defined (or empty), we fallback to the currently selected account's address const account = this.messagingSystem.call( - 'AccountsController:getSelectedAccount', + 'AccountsController:getAccount', + this.#selectedAccountId, ); - - return account; + return account?.address || ''; } } From e3d389bb21254e6af19208600f3bc5ddf6d31f98 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Tue, 18 Jun 2024 18:41:14 +0800 Subject: [PATCH 14/22] fix: remove empty line --- packages/assets-controllers/jest.config.js | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/assets-controllers/jest.config.js b/packages/assets-controllers/jest.config.js index 142e65bf76..c5034d8960 100644 --- a/packages/assets-controllers/jest.config.js +++ b/packages/assets-controllers/jest.config.js @@ -18,7 +18,6 @@ module.exports = merge(baseConfig, { coverageThreshold: { global: { branches: 90.35, - functions: 96.74, lines: 97.34, statements: 97.36, From 63c2b82bdcb9e7ce13e0f05460138893141d893b Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Tue, 18 Jun 2024 18:43:34 +0800 Subject: [PATCH 15/22] refactor: variable names and helper method --- .../src/TokenDetectionController.ts | 20 ++-- .../src/TokenRatesController.test.ts | 2 - .../src/TokenRatesController.ts | 18 ++-- .../src/TokensController.ts | 4 +- tests/mocks.txt | 94 +++++++++++++++++++ 5 files changed, 113 insertions(+), 25 deletions(-) create mode 100644 tests/mocks.txt diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 317a23d8ca..38a0d24e42 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -230,9 +230,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.#disabled = disabled; this.setIntervalLength(interval); - this.#selectedAccountId = this.messagingSystem.call( - 'AccountsController:getSelectedAccount', - ).id; + this.#selectedAccountId = this.#getSelectedAccount().id; const { chainId, networkClientId } = this.#getCorrectChainIdAndNetworkClientId(); @@ -310,13 +308,13 @@ export class TokenDetectionController extends StaticIntervalPollingController< 'AccountsController:selectedEvmAccountChange', // TODO: Either fix this lint violation or explain why it's necessary to ignore. // eslint-disable-next-line @typescript-eslint/no-misused-promises - async (internalAccount) => { + async (selectedAccount) => { const isSelectedAccountIdChanged = - this.#selectedAccountId !== internalAccount.id; + this.#selectedAccountId !== selectedAccount.id; if (isSelectedAccountIdChanged) { - this.#selectedAccountId = internalAccount.id; + this.#selectedAccountId = selectedAccount.id; await this.#restartTokenDetection({ - selectedAddress: internalAccount.address, + selectedAddress: selectedAccount.address, }); } }, @@ -490,13 +488,13 @@ export class TokenDetectionController extends StaticIntervalPollingController< return; } - const selectedInternalAccount = this.messagingSystem.call( + const selectedAccount = this.messagingSystem.call( 'AccountsController:getAccount', this.#selectedAccountId, ); const addressAgainstWhichToDetect = - selectedAddress ?? selectedInternalAccount?.address ?? ''; + selectedAddress ?? selectedAccount?.address ?? ''; const { chainId, networkClientId: selectedNetworkClientId } = this.#getCorrectChainIdAndNetworkClientId(networkClientId); const chainIdAgainstWhichToDetect = chainId; @@ -640,6 +638,10 @@ export class TokenDetectionController extends StaticIntervalPollingController< } }); } + + #getSelectedAccount() { + return this.messagingSystem.call('AccountsController:getSelectedAccount'); + } } export default TokenDetectionController; diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 1d2dcf27df..052f9436ff 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -69,12 +69,10 @@ function buildTokenRatesControllerMessenger( 'TokensController:getState', 'NetworkController:getNetworkClientById', 'NetworkController:getState', - 'PreferencesController:getState', 'AccountsController:getAccount', 'AccountsController:getSelectedAccount', ], allowedEvents: [ - 'PreferencesController:stateChange', 'TokensController:stateChange', 'NetworkController:stateChange', 'AccountsController:selectedEvmAccountChange', diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index b9e3e72858..91107835eb 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -22,10 +22,6 @@ import type { NetworkControllerStateChangeEvent, } from '@metamask/network-controller'; import { StaticIntervalPollingController } from '@metamask/polling-controller'; -import type { - PreferencesControllerGetStateAction, - PreferencesControllerStateChangeEvent, -} from '@metamask/preferences-controller'; import { createDeferredPromise, type Hex } from '@metamask/utils'; import { isEqual } from 'lodash'; @@ -109,7 +105,6 @@ export type AllowedActions = | TokensControllerGetStateAction | NetworkControllerGetNetworkClientByIdAction | NetworkControllerGetStateAction - | PreferencesControllerGetStateAction | AccountsControllerGetAccountAction | AccountsControllerGetSelectedAccountAction; @@ -117,7 +112,6 @@ export type AllowedActions = * The external events available to the {@link TokenRatesController}. */ export type AllowedEvents = - | PreferencesControllerStateChangeEvent | TokensControllerStateChangeEvent | NetworkControllerStateChangeEvent | AccountsControllerSelectedEvmAccountChangeEvent; @@ -367,9 +361,9 @@ export class TokenRatesController extends StaticIntervalPollingController< 'AccountsController:selectedEvmAccountChange', // TODO: Either fix this lint violation or explain why it's necessary to ignore. // eslint-disable-next-line @typescript-eslint/no-misused-promises - async (newInternalAccount) => { - if (this.#selectedAccountId !== newInternalAccount.id) { - this.#selectedAccountId = newInternalAccount.id; + async (selectedAccount) => { + if (this.#selectedAccountId !== selectedAccount.id) { + this.#selectedAccountId = selectedAccount.id; if (this.#pollState === PollState.Active) { await this.updateExchangeRates(); } @@ -385,14 +379,14 @@ export class TokenRatesController extends StaticIntervalPollingController< * @returns The list of tokens addresses for the current chain */ #getTokenAddresses(chainId: Hex): Hex[] { - const internalAccount = this.messagingSystem.call( + const selectedAccount = this.messagingSystem.call( 'AccountsController:getAccount', this.#selectedAccountId, ); const tokens = - this.#allTokens[chainId]?.[internalAccount?.address ?? ''] || []; + this.#allTokens[chainId]?.[selectedAccount?.address ?? ''] || []; const detectedTokens = - this.#allDetectedTokens[chainId]?.[internalAccount?.address ?? ''] || []; + this.#allDetectedTokens[chainId]?.[selectedAccount?.address ?? ''] || []; return [ ...new Set( diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index dbf59cfa95..3bdd049823 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -374,7 +374,7 @@ export class TokensController extends BaseController< const accountAddress = this.#getAddressOrSelectedAddress(interactingAddress); const isInteractingWithWalletAccount = - this.#isInterctingWithWallet(accountAddress); + this.#isInteractingWithWallet(accountAddress); try { address = toChecksumHexAddress(address); const tokens = allTokens[currentChainId]?.[accountAddress] || []; @@ -1032,7 +1032,7 @@ export class TokensController extends BaseController< return this.#getSelectedAddress(); } - #isInterctingWithWallet(address: string) { + #isInteractingWithWallet(address: string) { // If the address is not defined (or empty), we fallback to the currently selected account's address const selectedAccount = this.messagingSystem.call( 'AccountsController:getAccount', diff --git a/tests/mocks.txt b/tests/mocks.txt new file mode 100644 index 0000000000..f6a573d185 --- /dev/null +++ b/tests/mocks.txt @@ -0,0 +1,94 @@ +// import { v4 } from 'uuid'; + +// // Duplicate code here to avoid using `@metamask/keyring-api` and `@metamask/keyring-controller` as dependencies +// export enum KeyringTypes { +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// simple = 'Simple Key Pair', +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// hd = 'HD Key Tree', +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// qr = 'QR Hardware Wallet Device', +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// trezor = 'Trezor Hardware', +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// ledger = 'Ledger Hardware', +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// lattice = 'Lattice Hardware', +// // TODO: Either fix this lint violation or explain why it's necessary to ignore. +// // eslint-disable-next-line @typescript-eslint/naming-convention +// snap = 'Snap Keyring', +// } + +// export const createMockInternalAccount = ({ +// id = v4(), +// address = '0x2990079bcdee240329a520d2444386fc119da21a', +// type = , +// name = 'Account 1', +// keyringType = KeyringTypes.hd, +// snap, +// importTime = Date.now(), +// lastSelected = Date.now(), +// }: { +// id?: string; +// address?: string; +// type?: ''; +// name?: string; +// keyringType?: KeyringTypes; +// snap?: { +// id: string; +// enabled: boolean; +// name: string; +// }; +// importTime?: number; +// lastSelected?: number; +// } = {}): InternalAccount => { +// let methods; + +// switch (type) { +// case EthAccountType.Eoa: +// methods = [ +// EthMethod.PersonalSign, +// EthMethod.Sign, +// EthMethod.SignTransaction, +// EthMethod.SignTypedDataV1, +// EthMethod.SignTypedDataV3, +// EthMethod.SignTypedDataV4, +// ]; +// break; +// case EthAccountType.Erc4337: +// methods = [ +// EthMethod.PatchUserOperation, +// EthMethod.PrepareUserOperation, +// EthMethod.SignUserOperation, +// ]; +// break; +// case BtcAccountType.P2wpkh: +// methods = [BtcMethod.SendMany]; +// break; +// default: +// throw new Error(`Unknown account type: ${type as string}`); +// } + +// return { +// id, +// address, +// options: {}, +// methods, +// type, +// metadata: { +// name, +// keyring: { type: keyringType }, +// importTime, +// lastSelected, +// snap, +// }, +// }; +// }; + +console.log(1); From 8882ab2f7fc2d0f4d80976cd8fd8e665b975ffff Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Tue, 18 Jun 2024 21:13:47 +0800 Subject: [PATCH 16/22] refactor: use helper --- .../src/TokensController.ts | 20 ++-- tests/mocks.txt | 94 ------------------- 2 files changed, 6 insertions(+), 108 deletions(-) delete mode 100644 tests/mocks.txt diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index 3bdd049823..a63f12bdd8 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -274,16 +274,12 @@ export class TokensController extends BaseController< this.#abortController.abort(); this.#abortController = new AbortController(); this.#chainId = chainId; - const selectedAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); + const selectedAddress = this.#getSelectedAddress(); this.update((state) => { - state.tokens = allTokens[chainId]?.[selectedAccount?.address || ''] || []; - state.ignoredTokens = - allIgnoredTokens[chainId]?.[selectedAccount?.address || ''] || []; + state.tokens = allTokens[chainId]?.[selectedAddress] || []; + state.ignoredTokens = allIgnoredTokens[chainId]?.[selectedAddress] || []; state.detectedTokens = - allDetectedTokens[chainId]?.[selectedAccount?.address || ''] || []; + allDetectedTokens[chainId]?.[selectedAddress] || []; }); } @@ -1033,13 +1029,9 @@ export class TokensController extends BaseController< } #isInteractingWithWallet(address: string) { - // If the address is not defined (or empty), we fallback to the currently selected account's address - const selectedAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); + const selectedAddress = this.#getSelectedAddress(); - return selectedAccount?.address === address; + return selectedAddress === address; } /** diff --git a/tests/mocks.txt b/tests/mocks.txt deleted file mode 100644 index f6a573d185..0000000000 --- a/tests/mocks.txt +++ /dev/null @@ -1,94 +0,0 @@ -// import { v4 } from 'uuid'; - -// // Duplicate code here to avoid using `@metamask/keyring-api` and `@metamask/keyring-controller` as dependencies -// export enum KeyringTypes { -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// simple = 'Simple Key Pair', -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// hd = 'HD Key Tree', -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// qr = 'QR Hardware Wallet Device', -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// trezor = 'Trezor Hardware', -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// ledger = 'Ledger Hardware', -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// lattice = 'Lattice Hardware', -// // TODO: Either fix this lint violation or explain why it's necessary to ignore. -// // eslint-disable-next-line @typescript-eslint/naming-convention -// snap = 'Snap Keyring', -// } - -// export const createMockInternalAccount = ({ -// id = v4(), -// address = '0x2990079bcdee240329a520d2444386fc119da21a', -// type = , -// name = 'Account 1', -// keyringType = KeyringTypes.hd, -// snap, -// importTime = Date.now(), -// lastSelected = Date.now(), -// }: { -// id?: string; -// address?: string; -// type?: ''; -// name?: string; -// keyringType?: KeyringTypes; -// snap?: { -// id: string; -// enabled: boolean; -// name: string; -// }; -// importTime?: number; -// lastSelected?: number; -// } = {}): InternalAccount => { -// let methods; - -// switch (type) { -// case EthAccountType.Eoa: -// methods = [ -// EthMethod.PersonalSign, -// EthMethod.Sign, -// EthMethod.SignTransaction, -// EthMethod.SignTypedDataV1, -// EthMethod.SignTypedDataV3, -// EthMethod.SignTypedDataV4, -// ]; -// break; -// case EthAccountType.Erc4337: -// methods = [ -// EthMethod.PatchUserOperation, -// EthMethod.PrepareUserOperation, -// EthMethod.SignUserOperation, -// ]; -// break; -// case BtcAccountType.P2wpkh: -// methods = [BtcMethod.SendMany]; -// break; -// default: -// throw new Error(`Unknown account type: ${type as string}`); -// } - -// return { -// id, -// address, -// options: {}, -// methods, -// type, -// metadata: { -// name, -// keyring: { type: keyringType }, -// importTime, -// lastSelected, -// snap, -// }, -// }; -// }; - -console.log(1); From 8ab89f224e43de6738cd2081ddc197032aefaa5c Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Tue, 18 Jun 2024 22:52:15 +0800 Subject: [PATCH 17/22] fix: test --- .../src/TokensController.test.ts | 29 +++++++++++++++---- .../src/TokensController.ts | 2 +- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 9c355c9cc4..f91760a615 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -2230,8 +2230,6 @@ describe('TokensController', () => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - - // getAccountHandler.mockReturnValue(undefined); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); expect(controller.state.tokens).toStrictEqual([]); @@ -2254,7 +2252,18 @@ describe('TokensController', () => { await controller.addToken({ address, symbol, decimals }); - expect(controller.state.tokens).toStrictEqual([]); + expect(controller.state.tokens).toStrictEqual([ + { + address: '0x9C8fF314C9Bc7F6e59A9d9225Fb22946427eDC03', + aggregators: [], + decimals: 0, + image: + 'https://static.cx.metamask.io/api/v1/tokenIcons/1/0x9c8ff314c9bc7f6e59a9d9225fb22946427edc03.png', + isERC721: true, + name: undefined, + symbol: 'NOUN', + }, + ]); }); }); }); @@ -2301,8 +2310,18 @@ describe('TokensController', () => { getAccountHandler.mockReturnValue(undefined); await controller.watchAsset({ asset, type: 'ERC20' }); - expect(controller.state.tokens).toHaveLength(0); - expect(controller.state.tokens).toStrictEqual([]); + expect(controller.state.tokens).toHaveLength(1); + expect(controller.state.tokens).toStrictEqual([ + { + address: '0x000000000000000000000000000000000000dEaD', + aggregators: [], + decimals: 12, + image: 'image', + isERC721: false, + name: undefined, + symbol: 'TOKEN', + }, + ]); expect(addAndShowApprovalRequestSpy).toHaveBeenCalledTimes(1); expect(addAndShowApprovalRequestSpy).toHaveBeenCalledWith({ id: requestId, diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index a63f12bdd8..9345f8b334 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -1028,7 +1028,7 @@ export class TokensController extends BaseController< return this.#getSelectedAddress(); } - #isInteractingWithWallet(address: string) { + #isInteractingWithWallet(address: string | undefined) { const selectedAddress = this.#getSelectedAddress(); return selectedAddress === address; From 12814e718382cf4da5cf44205e211d59e86d91bf Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 20 Jun 2024 16:40:38 +0800 Subject: [PATCH 18/22] fix: test name --- packages/assets-controllers/src/TokenRatesController.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 34e4e4ed3b..3f8404ae36 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1037,7 +1037,7 @@ describe('TokenRatesController', () => { }); describe('when polling is inactive', () => { - it('should not update exchange rates when selected account hanges', async () => { + it('does not update exchange rates when selected account changes', async () => { const alternateSelectedAddress = '0x0000000000000000000000000000000000000002'; const alternateSelectedAccount = createMockInternalAccount({ From 1c401338a4596d5f5e0afde53871c80a185cce48 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 20 Jun 2024 16:40:51 +0800 Subject: [PATCH 19/22] feat: add #getSelectedAddress helper --- .../src/TokenDetectionController.ts | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 38a0d24e42..64572b9a43 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -288,9 +288,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< // TODO: Either fix this lint violation or explain why it's necessary to ignore. // eslint-disable-next-line @typescript-eslint/no-misused-promises async ({ useTokenDetection }) => { - const selectedAccount = this.messagingSystem.call( - 'AccountsController:getSelectedAccount', - ); + const selectedAccount = this.#getSelectedAccount(); const isDetectionChangedFromPreferences = this.#isDetectionEnabledFromPreferences !== useTokenDetection; @@ -488,13 +486,8 @@ export class TokenDetectionController extends StaticIntervalPollingController< return; } - const selectedAccount = this.messagingSystem.call( - 'AccountsController:getAccount', - this.#selectedAccountId, - ); - const addressAgainstWhichToDetect = - selectedAddress ?? selectedAccount?.address ?? ''; + selectedAddress ?? this.#getSelectedAddress(); const { chainId, networkClientId: selectedNetworkClientId } = this.#getCorrectChainIdAndNetworkClientId(networkClientId); const chainIdAgainstWhichToDetect = chainId; @@ -642,6 +635,15 @@ export class TokenDetectionController extends StaticIntervalPollingController< #getSelectedAccount() { return this.messagingSystem.call('AccountsController:getSelectedAccount'); } + + #getSelectedAddress() { + // If the address is not defined (or empty), we fallback to the currently selected account's address + const account = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + return account?.address || ''; + } } export default TokenDetectionController; From 58d8a42f9749c0e6552a9c1ae4f7cb8b01aac6f3 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 20 Jun 2024 16:44:56 +0800 Subject: [PATCH 20/22] refactor: create const to be used in getTOkenAddresses --- packages/assets-controllers/src/TokenRatesController.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index b4ed2dc340..534ca176bf 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -383,10 +383,10 @@ export class TokenRatesController extends StaticIntervalPollingController< 'AccountsController:getAccount', this.#selectedAccountId, ); - const tokens = - this.#allTokens[chainId]?.[selectedAccount?.address ?? ''] || []; + const selectedAddress = selectedAccount?.address ?? ''; + const tokens = this.#allTokens[chainId]?.[selectedAddress] || []; const detectedTokens = - this.#allDetectedTokens[chainId]?.[selectedAccount?.address ?? ''] || []; + this.#allDetectedTokens[chainId]?.[selectedAddress] || []; return [ ...new Set( From c7805b95369d986c6fc1bfbc888f0c6a27043f80 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 20 Jun 2024 17:26:07 +0800 Subject: [PATCH 21/22] fix: reduce getAccountHandler calls --- .../src/TokensController.test.ts | 32 +++---------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index f91760a615..5ad74d66f7 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -289,11 +289,7 @@ describe('TokensController', () => { getSelectedAccount: firstAccount, }, }, - async ({ - controller, - triggerSelectedAccountChange, - getAccountHandler, - }) => { + async ({ controller, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -304,7 +300,6 @@ describe('TokensController', () => { symbol: 'bar', decimals: 2, }); - getAccountHandler.mockReturnValue(secondAccount); triggerSelectedAccountChange(secondAccount); expect(controller.state.tokens).toHaveLength(0); @@ -439,11 +434,7 @@ describe('TokensController', () => { getSelectedAccount: firstAccount, }, }, - async ({ - controller, - triggerSelectedAccountChange, - getAccountHandler, - }) => { + async ({ controller, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -454,7 +445,6 @@ describe('TokensController', () => { symbol: 'baz', decimals: 2, }); - getAccountHandler.mockReturnValue(secondAccount); triggerSelectedAccountChange(secondAccount); await controller.addToken({ address: '0x01', @@ -465,7 +455,6 @@ describe('TokensController', () => { controller.ignoreTokens(['0x01']); expect(controller.state.tokens).toHaveLength(0); - getAccountHandler.mockReturnValue(firstAccount); triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x02', @@ -649,12 +638,7 @@ describe('TokensController', () => { getAccount: selectedAccount1, }, }, - async ({ - controller, - triggerSelectedAccountChange, - changeNetwork, - getAccountHandler, - }) => { + async ({ controller, triggerSelectedAccountChange, changeNetwork }) => { triggerSelectedAccountChange(selectedAccount1); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ @@ -679,7 +663,6 @@ describe('TokensController', () => { controller.ignoreTokens(['0x02']); expect(controller.state.ignoredTokens).toStrictEqual(['0x02']); - getAccountHandler.mockReturnValue(selectedAccount2); triggerSelectedAccountChange(selectedAccount2); expect(controller.state.ignoredTokens).toHaveLength(0); @@ -1897,11 +1880,7 @@ describe('TokensController', () => { getSelectedAccount: selectedAccount, }, }, - async ({ - controller, - triggerSelectedAccountChange, - getAccountHandler, - }) => { + async ({ controller, triggerSelectedAccountChange }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -1918,7 +1897,6 @@ describe('TokensController', () => { decimals: 5, }); - getAccountHandler.mockReturnValue(selectedAccount2); triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([]); @@ -1927,7 +1905,6 @@ describe('TokensController', () => { symbol: 'C', decimals: 6, }); - getAccountHandler.mockReturnValue(selectedAccount); triggerSelectedAccountChange(selectedAccount); expect(controller.state.tokens).toStrictEqual([ { @@ -2458,6 +2435,7 @@ async function withController( }); const triggerSelectedAccountChange = (internalAccount: InternalAccount) => { + getAccountHandler.mockReturnValue(internalAccount); messenger.publish( 'AccountsController:selectedEvmAccountChange', internalAccount, From 692d1c270abd380e23224e2b003185bd2f508454 Mon Sep 17 00:00:00 2001 From: Monte Lai Date: Thu, 20 Jun 2024 18:47:04 +0800 Subject: [PATCH 22/22] fix: refactor test vars --- .../src/TokensController.test.ts | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 5ad74d66f7..38692b6404 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -1896,7 +1896,6 @@ describe('TokensController', () => { symbol: 'B', decimals: 5, }); - triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([]); @@ -2231,14 +2230,14 @@ describe('TokensController', () => { expect(controller.state.tokens).toStrictEqual([ { - address: '0x9C8fF314C9Bc7F6e59A9d9225Fb22946427eDC03', + address, aggregators: [], - decimals: 0, + decimals, image: 'https://static.cx.metamask.io/api/v1/tokenIcons/1/0x9c8ff314c9bc7f6e59a9d9225fb22946427edc03.png', isERC721: true, name: undefined, - symbol: 'NOUN', + symbol, }, ]); }); @@ -2249,21 +2248,16 @@ describe('TokensController', () => { it('handles an undefined selected account', async () => { await withController(async ({ controller, getAccountHandler }) => { getAccountHandler.mockReturnValue(undefined); - await controller.addDetectedTokens([ - { - address: '0x01', - symbol: 'barA', - decimals: 2, - aggregators: [], - }, - ]); - console.log(controller.state.allDetectedTokens[ChainId.mainnet]); - expect(controller.state.detectedTokens[0]).toStrictEqual({ + const mockToken = { address: '0x01', - decimals: 2, - image: undefined, symbol: 'barA', + decimals: 2, aggregators: [], + }; + await controller.addDetectedTokens([mockToken]); + expect(controller.state.detectedTokens[0]).toStrictEqual({ + ...mockToken, + image: undefined, isERC721: undefined, name: undefined, });