diff --git a/modules/apps/transfer/ibc_module.go b/modules/apps/transfer/ibc_module.go index c6445645c23..95d4681b9b9 100644 --- a/modules/apps/transfer/ibc_module.go +++ b/modules/apps/transfer/ibc_module.go @@ -184,7 +184,7 @@ func (im IBCModule) OnRecvPacket( events.EmitOnRecvPacketEvent(ctx, data, ack, ackErr) }() - data, ackErr = types.UnmarshalPacketData(packet.GetData(), channelVersion) + data, ackErr = types.UnmarshalPacketData(packet.GetData(), channelVersion, "") if ackErr != nil { ack = channeltypes.NewErrorAcknowledgement(ackErr) im.keeper.Logger(ctx).Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence)) @@ -221,7 +221,7 @@ func (im IBCModule) OnAcknowledgementPacket( return errorsmod.Wrapf(ibcerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet acknowledgement: %v", err) } - data, err := types.UnmarshalPacketData(packet.GetData(), channelVersion) + data, err := types.UnmarshalPacketData(packet.GetData(), channelVersion, "") if err != nil { return err } @@ -242,7 +242,7 @@ func (im IBCModule) OnTimeoutPacket( packet channeltypes.Packet, relayer sdk.AccAddress, ) error { - data, err := types.UnmarshalPacketData(packet.GetData(), channelVersion) + data, err := types.UnmarshalPacketData(packet.GetData(), channelVersion, "") if err != nil { return err } @@ -305,6 +305,6 @@ func (im IBCModule) UnmarshalPacketData(ctx context.Context, portID string, chan return types.FungibleTokenPacketDataV2{}, "", errorsmod.Wrapf(ibcerrors.ErrNotFound, "app version not found for port %s and channel %s", portID, channelID) } - ftpd, err := types.UnmarshalPacketData(bz, ics20Version) + ftpd, err := types.UnmarshalPacketData(bz, ics20Version, "") return ftpd, ics20Version, err } diff --git a/modules/apps/transfer/types/packet.go b/modules/apps/transfer/types/packet.go index 98525dac769..968e996c08f 100644 --- a/modules/apps/transfer/types/packet.go +++ b/modules/apps/transfer/types/packet.go @@ -23,6 +23,11 @@ var ( _ ibcexported.PacketDataProvider = (*FungibleTokenPacketDataV2)(nil) ) +const ( + EncodingJSON = "application/json" + EncodingProtobuf = "application/x-protobuf" +) + // NewFungibleTokenPacketData constructs a new FungibleTokenPacketData instance func NewFungibleTokenPacketData( denom string, amount string, @@ -208,36 +213,78 @@ func (ftpd FungibleTokenPacketDataV2) HasForwarding() bool { } // UnmarshalPacketData attempts to unmarshal the provided packet data bytes into a FungibleTokenPacketDataV2. -// The version of ics20 should be provided and should be either ics20-1 or ics20-2. -func UnmarshalPacketData(bz []byte, ics20Version string) (FungibleTokenPacketDataV2, error) { - // TODO: in transfer ibc module V2, we need to respect he encoding value passed via the payload, some hard coded assumptions about - // encoding exist here based on the ics20 version passed in. +func UnmarshalPacketData(bz []byte, ics20Version string, encoding string) (FungibleTokenPacketDataV2, error) { + const failedUnmarshalingErrorMsg = "cannot unmarshal %s transfer packet data: %s" + + // Depending on the ics20 version, we use a different default encoding (json for V1, proto for V2) + // and we have a different type to unmarshal the data into. + var data proto.Message switch ics20Version { case V1: - var datav1 FungibleTokenPacketData - if err := json.Unmarshal(bz, &datav1); err != nil { - return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot unmarshal ICS20-V1 transfer packet data: %s", err.Error()) + if encoding == "" { + encoding = EncodingJSON } - - return PacketDataV1ToV2(datav1) + data = &FungibleTokenPacketData{} case V2: - var datav2 FungibleTokenPacketDataV2 - if err := unknownproto.RejectUnknownFieldsStrict(bz, &datav2, unknownproto.DefaultAnyResolver{}); err != nil { - return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot unmarshal ICS20-V2 transfer packet data: %s", err.Error()) + if encoding == "" { + encoding = EncodingProtobuf } + data = &FungibleTokenPacketDataV2{} + default: + return FungibleTokenPacketDataV2{}, errorsmod.Wrap(ErrInvalidVersion, ics20Version) + } - if err := proto.Unmarshal(bz, &datav2); err != nil { - return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot unmarshal ICS20-V2 transfer packet data: %s", err.Error()) + errorMsgVersion := "ICS20-V2" + if ics20Version == V1 { + errorMsgVersion = "ICS20-V1" + } + + // Here we perform the unmarshaling based on the specified encoding. + // The functions act on the generic "data" variable which is of type proto.Message (an interface). + // The underlying type is either FungibleTokenPacketData or FungibleTokenPacketDataV2, based on the value + // of "ics20Version". + switch encoding { + case EncodingJSON: + if err := json.Unmarshal(bz, &data); err != nil { + return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, failedUnmarshalingErrorMsg, errorMsgVersion, err.Error()) + } + case EncodingProtobuf: + if err := unknownproto.RejectUnknownFieldsStrict(bz, data, unknownproto.DefaultAnyResolver{}); err != nil { + return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, failedUnmarshalingErrorMsg, errorMsgVersion, err.Error()) } - if err := datav2.ValidateBasic(); err != nil { - return FungibleTokenPacketDataV2{}, err + if err := proto.Unmarshal(bz, data); err != nil { + return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, failedUnmarshalingErrorMsg, errorMsgVersion, err.Error()) } - return datav2, nil default: - return FungibleTokenPacketDataV2{}, errorsmod.Wrap(ErrInvalidVersion, ics20Version) + return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "invalid encoding provided, must be either empty or one of [%q, %q], got %s", EncodingJSON, EncodingProtobuf, encoding) + } + + // When the unmarshaling is done, we want to retrieve the underlying data type based on the value of ics20Version + // If it's v1, we convert the data to FungibleTokenPacketData and then call the conversion function to construct + // the v2 type. + if ics20Version == V1 { + datav1, ok := data.(*FungibleTokenPacketData) + if !ok { + // We should never get here, as we manually constructed the type at the beginning of the file + return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot convert proto message into FungibleTokenPacketData") + } + // The call to ValidateBasic for V1 is done inside PacketDataV1toV2. + return PacketDataV1ToV2(*datav1) + } + + // If it's v2, we convert the data to FungibleTokenPacketDataV2, validate it and return it. + datav2, ok := data.(*FungibleTokenPacketDataV2) + if !ok { + // We should never get here, as we manually constructed the type at the beginning of the file + return FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot convert proto message into FungibleTokenPacketDataV2") + } + + if err := datav2.ValidateBasic(); err != nil { + return FungibleTokenPacketDataV2{}, err } + return *datav2, nil } // PacketDataV1ToV2 converts a v1 packet data to a v2 packet data. The packet data is validated diff --git a/modules/apps/transfer/types/packet_test.go b/modules/apps/transfer/types/packet_test.go index 254cec03522..db166b2b8d8 100644 --- a/modules/apps/transfer/types/packet_test.go +++ b/modules/apps/transfer/types/packet_test.go @@ -765,7 +765,7 @@ func TestUnmarshalPacketData(t *testing.T) { tc.malleate() - packetData, err := types.UnmarshalPacketData(packetDataBz, version) + packetData, err := types.UnmarshalPacketData(packetDataBz, version, "") expPass := tc.expError == nil if expPass { @@ -819,7 +819,7 @@ func TestV2ForwardsCompatibilityFails(t *testing.T) { tc.malleate() - packetData, err := types.UnmarshalPacketData(packetDataBz, types.V2) + packetData, err := types.UnmarshalPacketData(packetDataBz, types.V2, types.EncodingProtobuf) expPass := tc.expError == nil if expPass { diff --git a/modules/apps/transfer/v2/ibc_module.go b/modules/apps/transfer/v2/ibc_module.go index fcbedc6ef03..19aa3db5292 100644 --- a/modules/apps/transfer/v2/ibc_module.go +++ b/modules/apps/transfer/v2/ibc_module.go @@ -41,7 +41,7 @@ func (im *IBCModule) OnSendPacket(goCtx context.Context, sourceChannel string, d return errorsmod.Wrapf(ibcerrors.ErrUnauthorized, "%s is not allowed to send funds", signer) } - data, err := transfertypes.UnmarshalPacketData(payload.Value, payload.Version) + data, err := transfertypes.UnmarshalPacketData(payload.Value, payload.Version, payload.Encoding) if err != nil { return err } @@ -66,7 +66,7 @@ func (im *IBCModule) OnRecvPacket(ctx context.Context, sourceChannel string, des events.EmitOnRecvPacketEvent(ctx, data, ack, ackErr) }() - data, ackErr = transfertypes.UnmarshalPacketData(payload.Value, payload.Version) + data, ackErr = transfertypes.UnmarshalPacketData(payload.Value, payload.Version, payload.Encoding) if ackErr != nil { ack = channeltypes.NewErrorAcknowledgement(ackErr) im.keeper.Logger(ctx).Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), sequence)) @@ -99,7 +99,7 @@ func (im *IBCModule) OnRecvPacket(ctx context.Context, sourceChannel string, des } func (im *IBCModule) OnTimeoutPacket(ctx context.Context, sourceChannel string, destinationChannel string, sequence uint64, payload types.Payload, relayer sdk.AccAddress) error { - data, err := transfertypes.UnmarshalPacketData(payload.Value, payload.Version) + data, err := transfertypes.UnmarshalPacketData(payload.Value, payload.Version, payload.Encoding) if err != nil { return err } @@ -119,7 +119,7 @@ func (im *IBCModule) OnAcknowledgementPacket(ctx context.Context, sourceChannel return errorsmod.Wrapf(ibcerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet acknowledgement: %v", err) } - data, err := transfertypes.UnmarshalPacketData(payload.Value, payload.Version) + data, err := transfertypes.UnmarshalPacketData(payload.Value, payload.Version, payload.Encoding) if err != nil { return err } diff --git a/modules/apps/transfer/v2/keeper/msg_server_test.go b/modules/apps/transfer/v2/keeper/msg_server_test.go index 9dcaba67e89..1e4c270c4ec 100644 --- a/modules/apps/transfer/v2/keeper/msg_server_test.go +++ b/modules/apps/transfer/v2/keeper/msg_server_test.go @@ -68,10 +68,7 @@ func (suite *KeeperTestSuite) TestMsgSendPacketTransfer() { bz := suite.chainA.Codec.MustMarshal(&ftpd) timestamp := suite.chainA.GetTimeoutTimestampSecs() - // TODO: note, encoding field currently not respected in the implementation. encoding is determined by the version. - // ics20-v1 == json - // ics20-v2 == proto - payload = channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, "json", bz) + payload = channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, transfertypes.EncodingProtobuf, bz) tc.malleate() @@ -166,7 +163,7 @@ func (suite *KeeperTestSuite) TestMsgRecvPacketTransfer() { bz := suite.chainA.Codec.MustMarshal(&ftpd) timestamp := suite.chainA.GetTimeoutTimestampSecs() - payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, "json", bz) + payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, transfertypes.EncodingProtobuf, bz) var err error packet, err = path.EndpointA.MsgSendPacket(timestamp, payload) suite.Require().NoError(err) @@ -275,7 +272,7 @@ func (suite *KeeperTestSuite) TestMsgAckPacketTransfer() { bz := suite.chainA.Codec.MustMarshal(&ftpd) timestamp := suite.chainA.GetTimeoutTimestampSecs() - payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, "json", bz) + payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, transfertypes.EncodingProtobuf, bz) var err error packet, err = path.EndpointA.MsgSendPacket(timestamp, payload) @@ -376,7 +373,7 @@ func (suite *KeeperTestSuite) TestMsgTimeoutPacketTransfer() { bz := suite.chainA.Codec.MustMarshal(&ftpd) timeoutTimestamp = uint64(suite.chainA.GetContext().BlockTime().Unix()) + uint64(time.Hour.Seconds()) - payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, "json", bz) + payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, transfertypes.EncodingProtobuf, bz) var err error packet, err = path.EndpointA.MsgSendPacket(timeoutTimestamp, payload) @@ -497,7 +494,7 @@ func (suite *KeeperTestSuite) TestV2RetainsFungibility() { bz := suite.chainB.Codec.MustMarshal(&ftpd) timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Unix()) + uint64(time.Hour.Seconds()) - payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, "json", bz) + payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, transfertypes.EncodingProtobuf, bz) var err error packetV2, err = pathv2.EndpointA.MsgSendPacket(timeoutTimestamp, payload) @@ -536,7 +533,7 @@ func (suite *KeeperTestSuite) TestV2RetainsFungibility() { bz := suite.chainC.Codec.MustMarshal(&ftpd) timeoutTimestamp := uint64(suite.chainC.GetContext().BlockTime().Unix()) + uint64(time.Hour.Seconds()) - payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, "json", bz) + payload := channeltypesv2.NewPayload(transfertypes.ModuleName, transfertypes.ModuleName, transfertypes.V2, transfertypes.EncodingProtobuf, bz) var err error packetV2, err = pathv2.EndpointB.MsgSendPacket(timeoutTimestamp, payload) diff --git a/modules/core/04-channel/v2/types/msgs_test.go b/modules/core/04-channel/v2/types/msgs_test.go index fd6f5ceaa0a..35165f0d72e 100644 --- a/modules/core/04-channel/v2/types/msgs_test.go +++ b/modules/core/04-channel/v2/types/msgs_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" + transfertypes "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" clienttypes "github.com/cosmos/ibc-go/v9/modules/core/02-client/types" "github.com/cosmos/ibc-go/v9/modules/core/04-channel/v2/types" commitmenttypes "github.com/cosmos/ibc-go/v9/modules/core/23-commitment/types" @@ -205,7 +206,7 @@ func (s *TypesTestSuite) TestMsgSendPacketValidateBasic() { msg = types.NewMsgSendPacket( ibctesting.FirstChannelID, s.chainA.GetTimeoutTimestamp(), s.chainA.SenderAccount.GetAddress().String(), - types.Payload{SourcePort: ibctesting.MockPort, DestinationPort: ibctesting.MockPort, Version: "ics20-1", Encoding: "json", Value: ibctesting.MockPacketData}, + types.Payload{SourcePort: ibctesting.MockPort, DestinationPort: ibctesting.MockPort, Version: "ics20-1", Encoding: transfertypes.EncodingJSON, Value: ibctesting.MockPacketData}, ) tc.malleate() diff --git a/modules/core/04-channel/v2/types/packet_test.go b/modules/core/04-channel/v2/types/packet_test.go index 0a8b079b0a5..4952e03671d 100644 --- a/modules/core/04-channel/v2/types/packet_test.go +++ b/modules/core/04-channel/v2/types/packet_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + transfertypes "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" "github.com/cosmos/ibc-go/v9/modules/core/04-channel/v2/types" host "github.com/cosmos/ibc-go/v9/modules/core/24-host" ibctesting "github.com/cosmos/ibc-go/v9/testing" @@ -109,7 +110,7 @@ func TestValidateBasic(t *testing.T) { SourcePort: ibctesting.MockPort, DestinationPort: ibctesting.MockPort, Version: "ics20-v2", - Encoding: "json", + Encoding: transfertypes.EncodingProtobuf, Value: mock.MockPacketData, }) diff --git a/testing/mock/v2/mock.go b/testing/mock/v2/mock.go index 41582217070..3d2a0bd8201 100644 --- a/testing/mock/v2/mock.go +++ b/testing/mock/v2/mock.go @@ -1,6 +1,7 @@ package mock import ( + transfertypes "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" channeltypesv2 "github.com/cosmos/ibc-go/v9/modules/core/04-channel/v2/types" mockv1 "github.com/cosmos/ibc-go/v9/testing/mock" ) @@ -24,7 +25,7 @@ func NewMockPayload(sourcePort, destPort string) channeltypesv2.Payload { return channeltypesv2.Payload{ SourcePort: sourcePort, DestinationPort: destPort, - Encoding: "json", + Encoding: transfertypes.EncodingProtobuf, Value: mockv1.MockPacketData, Version: mockv1.Version, }