diff --git a/pkg/family/solana/mcms_pda.go b/pkg/family/solana/mcms_pda.go new file mode 100644 index 0000000..c90c64d --- /dev/null +++ b/pkg/family/solana/mcms_pda.go @@ -0,0 +1,57 @@ +package solana + +import ( + "github.com/gagliardetto/solana-go" +) + +const ( + pdaPrefixMultisigSigner = "multisig_signer" + pdaPrefixMultisigConfig = "multisig_config" + pdaPrefixRootMetadata = "root_metadata" + pdaPrefixExpiringRootAndOpCount = "expiring_root_and_op_count" + pdaPrefixTimelockConfig = "timelock_config" + pdaPrefixTimelockSigner = "timelock_signer" +) + +// GetMCMSignerPDA returns the PDA for the MCMS signer +func GetMCMSignerPDA(programID solana.PublicKey, seed PDASeed) solana.PublicKey { + seeds := [][]byte{[]byte(pdaPrefixMultisigSigner), seed[:]} + return getPDA(programID, seeds) +} + +// GetMCMConfigPDA returns the PDA for the MCMS config +func GetMCMConfigPDA(programID solana.PublicKey, seed PDASeed) solana.PublicKey { + seeds := [][]byte{[]byte(pdaPrefixMultisigConfig), seed[:]} + return getPDA(programID, seeds) +} + +// GetMCMRootMetadataPDA returns the PDA for the MCMS root metadata +func GetMCMRootMetadataPDA(programID solana.PublicKey, seed PDASeed) solana.PublicKey { + seeds := [][]byte{[]byte(pdaPrefixRootMetadata), seed[:]} + return getPDA(programID, seeds) +} + +// GetMCMExpiringRootAndOpCountPDA returns the PDA for the MCMS expiring root and op count +func GetMCMExpiringRootAndOpCountPDA(programID solana.PublicKey, seed PDASeed) solana.PublicKey { + seeds := [][]byte{[]byte(pdaPrefixExpiringRootAndOpCount), seed[:]} + return getPDA(programID, seeds) +} + +// GetTimelockConfigPDA returns the PDA for the Timelock config +func GetTimelockConfigPDA(programID solana.PublicKey, seed PDASeed) solana.PublicKey { + seeds := [][]byte{[]byte(pdaPrefixTimelockConfig), seed[:]} + return getPDA(programID, seeds) +} + +// GetTimelockSignerPDA returns the PDA for the Timelock signer +func GetTimelockSignerPDA(programID solana.PublicKey, seed PDASeed) solana.PublicKey { + seeds := [][]byte{[]byte(pdaPrefixTimelockSigner), seed[:]} + return getPDA(programID, seeds) +} + +// getPDA returns the PDA for the given program ID and seeds +func getPDA(programID solana.PublicKey, seeds [][]byte) solana.PublicKey { + // todo(ggoh): add error handling + pda, _, _ := solana.FindProgramAddress(seeds, programID) + return pda +} diff --git a/pkg/family/solana/mcms_pda_test.go b/pkg/family/solana/mcms_pda_test.go new file mode 100644 index 0000000..6959b92 --- /dev/null +++ b/pkg/family/solana/mcms_pda_test.go @@ -0,0 +1,76 @@ +package solana + +import ( + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/stretchr/testify/require" +) + +func TestMCMSPDA(t *testing.T) { + t.Parallel() + + programID := solana.MustPublicKeyFromBase58("11111111111111111111111111111111") + + tests := []struct { + name string + prefix string + fn func(programID solana.PublicKey, seed PDASeed) solana.PublicKey + }{ + {name: "GetMCMSignerPDA", prefix: pdaPrefixMultisigSigner, fn: GetMCMSignerPDA}, + {name: "GetMCMConfigPDA", prefix: pdaPrefixMultisigConfig, fn: GetMCMConfigPDA}, + {name: "GetMCMRootMetadataPDA", prefix: pdaPrefixRootMetadata, fn: GetMCMRootMetadataPDA}, + {name: "GetMCMExpiringRootAndOpCountPDA", prefix: pdaPrefixExpiringRootAndOpCount, fn: GetMCMExpiringRootAndOpCountPDA}, + {name: "GetTimelockConfigPDA", prefix: pdaPrefixTimelockConfig, fn: GetTimelockConfigPDA}, + {name: "GetTimelockSignerPDA", prefix: pdaPrefixTimelockSigner, fn: GetTimelockSignerPDA}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + seed := testPDASeed(t) + seeds := [][]byte{[]byte(tt.prefix), seed[:]} + want := mustFindPDA(t, seeds, programID) + got := tt.fn(programID, seed) + require.Equal(t, want, got) + }) + } +} + +func TestPDAGeneratorsUseDistinctSeeds(t *testing.T) { + t.Parallel() + programID := solana.MustPublicKeyFromBase58("11111111111111111111111111111111") + id := testPDASeed(t) + + signer := GetMCMSignerPDA(programID, id) + cfg := GetMCMConfigPDA(programID, id) + meta := GetMCMRootMetadataPDA(programID, id) + exp := GetMCMExpiringRootAndOpCountPDA(programID, id) + tlCfg := GetTimelockConfigPDA(programID, id) + tlSigner := GetTimelockSignerPDA(programID, id) + + keys := []solana.PublicKey{signer, cfg, meta, exp, tlCfg, tlSigner} + for i := range keys { + for j := i + 1; j < len(keys); j++ { + require.NotEqualf(t, keys[i], keys[j], "PDA at %d equals PDA at %d", i, j) + } + } +} + +func mustFindPDA(t *testing.T, seeds [][]byte, programID solana.PublicKey) solana.PublicKey { + t.Helper() + pda, _, err := solana.FindProgramAddress(seeds, programID) + require.NoError(t, err) + + return pda +} + +func testPDASeed(t *testing.T) PDASeed { + t.Helper() + var s PDASeed + for i := range s { + s[i] = byte(i + 1) + } + + return s +}