Skip to content

Commit 8eae033

Browse files
authored
Modify UnmarshalPacketData interface to allow additional args (#6341)
* api(port)!: Allow passing of context, port and channel identifier to unmarshal packet data interface as disussed. This allows us to grab the app version in transfer and unmarshal the packet based on that instead of a hacky unmarshal v2 then v1 and whatever happens. * lint: as we do * callbacks: fix signature of UnmarshalPacketData as per changes, make refactors to hopefully simplify signatures. * chore: lint and remove some todos. * review: address feedback.
1 parent 43877df commit 8eae033

File tree

16 files changed

+110
-59
lines changed

16 files changed

+110
-59
lines changed

modules/apps/27-interchain-accounts/controller/ibc_middleware.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func (im IBCMiddleware) GetAppVersion(ctx sdk.Context, portID, channelID string)
340340
// UnmarshalPacketData attempts to unmarshal the provided packet data bytes
341341
// into an InterchainAccountPacketData. This function implements the optional
342342
// PacketDataUnmarshaler interface required for ADR 008 support.
343-
func (IBCMiddleware) UnmarshalPacketData(bz []byte) (interface{}, error) {
343+
func (IBCMiddleware) UnmarshalPacketData(_ sdk.Context, _, _ string, bz []byte) (interface{}, error) {
344344
var data icatypes.InterchainAccountPacketData
345345
err := data.UnmarshalJSON(bz)
346346
if err != nil {

modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,13 +1276,15 @@ func (suite *InterchainAccountsTestSuite) TestPacketDataUnmarshalerInterface() {
12761276
Memo: "",
12771277
}
12781278

1279-
packetData, err := controller.IBCMiddleware{}.UnmarshalPacketData(expPacketData.GetBytes())
1279+
// Context, port identifier and channel identifier are unused for controller.
1280+
packetData, err := controller.IBCMiddleware{}.UnmarshalPacketData(suite.chainA.GetContext(), "", "", expPacketData.GetBytes())
12801281
suite.Require().NoError(err)
12811282
suite.Require().Equal(expPacketData, packetData)
12821283

12831284
// test invalid packet data
12841285
invalidPacketData := []byte("invalid packet data")
1285-
packetData, err = controller.IBCMiddleware{}.UnmarshalPacketData(invalidPacketData)
1286+
// Context, port identifier and channel identifier are not used for controller.
1287+
packetData, err = controller.IBCMiddleware{}.UnmarshalPacketData(suite.chainA.GetContext(), "", "", invalidPacketData)
12861288
suite.Require().Error(err)
12871289
suite.Require().Nil(packetData)
12881290
}

modules/apps/27-interchain-accounts/host/ibc_module.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func (IBCModule) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, pr
183183
// UnmarshalPacketData attempts to unmarshal the provided packet data bytes
184184
// into an InterchainAccountPacketData. This function implements the optional
185185
// PacketDataUnmarshaler interface required for ADR 008 support.
186-
func (IBCModule) UnmarshalPacketData(bz []byte) (interface{}, error) {
186+
func (IBCModule) UnmarshalPacketData(_ sdk.Context, _, _ string, bz []byte) (interface{}, error) {
187187
var data icatypes.InterchainAccountPacketData
188188
err := data.UnmarshalJSON(bz)
189189
if err != nil {

modules/apps/27-interchain-accounts/host/ibc_module_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -883,13 +883,15 @@ func (suite *InterchainAccountsTestSuite) TestPacketDataUnmarshalerInterface() {
883883
Memo: "",
884884
}
885885

886-
packetData, err := icahost.IBCModule{}.UnmarshalPacketData(expPacketData.GetBytes())
886+
// Context, port identifier and channel identifier are unused for host.
887+
packetData, err := icahost.IBCModule{}.UnmarshalPacketData(suite.chainA.GetContext(), "", "", expPacketData.GetBytes())
887888
suite.Require().NoError(err)
888889
suite.Require().Equal(expPacketData, packetData)
889890

890891
// test invalid packet data
891892
invalidPacketData := []byte("invalid packet data")
892-
packetData, err = icahost.IBCModule{}.UnmarshalPacketData(invalidPacketData)
893+
// Context, port identifier and channel identifier are unused for host.
894+
packetData, err = icahost.IBCModule{}.UnmarshalPacketData(suite.chainA.GetContext(), "", "", invalidPacketData)
893895
suite.Require().Error(err)
894896
suite.Require().Nil(packetData)
895897
}

modules/apps/29-fee/ibc_middleware.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,11 @@ func (im IBCMiddleware) GetAppVersion(ctx sdk.Context, portID, channelID string)
472472
// UnmarshalPacketData attempts to use the underlying app to unmarshal the packet data.
473473
// If the underlying app does not support the PacketDataUnmarshaler interface, an error is returned.
474474
// This function implements the optional PacketDataUnmarshaler interface required for ADR 008 support.
475-
func (im IBCMiddleware) UnmarshalPacketData(bz []byte) (interface{}, error) {
475+
func (im IBCMiddleware) UnmarshalPacketData(ctx sdk.Context, portID, channelID string, bz []byte) (interface{}, error) {
476476
unmarshaler, ok := im.app.(porttypes.PacketDataUnmarshaler)
477477
if !ok {
478478
return nil, errorsmod.Wrapf(types.ErrUnsupportedAction, "underlying app does not implement %T", (*porttypes.PacketDataUnmarshaler)(nil))
479479
}
480480

481-
return unmarshaler.UnmarshalPacketData(bz)
481+
return unmarshaler.UnmarshalPacketData(ctx, portID, channelID, bz)
482482
}

modules/apps/29-fee/ibc_middleware_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,7 +1573,8 @@ func (suite *FeeTestSuite) TestPacketDataUnmarshalerInterface() {
15731573
feeModule, ok := cbs.(porttypes.PacketDataUnmarshaler)
15741574
suite.Require().True(ok)
15751575

1576-
packetData, err := feeModule.UnmarshalPacketData(ibcmock.MockPacketData)
1576+
// Context, port identifier, channel identifier are not used in current wiring of fee.
1577+
packetData, err := feeModule.UnmarshalPacketData(suite.chainA.GetContext(), "", "", ibcmock.MockPacketData)
15771578
suite.Require().NoError(err)
15781579
suite.Require().Equal(ibcmock.MockPacketData, packetData)
15791580
}
@@ -1582,7 +1583,8 @@ func (suite *FeeTestSuite) TestPacketDataUnmarshalerInterfaceError() {
15821583
// test the case when the underlying application cannot be casted to a PacketDataUnmarshaler
15831584
mockFeeMiddleware := ibcfee.NewIBCMiddleware(nil, feekeeper.Keeper{})
15841585

1585-
_, err := mockFeeMiddleware.UnmarshalPacketData(ibcmock.MockPacketData)
1586+
// Context, port identifier, channel identifier are not used in mockFeeMiddleware.
1587+
_, err := mockFeeMiddleware.UnmarshalPacketData(suite.chainA.GetContext(), "", "", ibcmock.MockPacketData)
15861588
expError := errorsmod.Wrapf(types.ErrUnsupportedAction, "underlying app does not implement %T", (*porttypes.PacketDataUnmarshaler)(nil))
15871589
suite.Require().ErrorIs(err, expError)
15881590
}

modules/apps/callbacks/callbacks_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"cosmossdk.io/log"
1313
sdkmath "cosmossdk.io/math"
14+
storetypes "cosmossdk.io/store/types"
1415

1516
simtestutil "github.com/cosmos/cosmos-sdk/testutil/sims"
1617
sdk "github.com/cosmos/cosmos-sdk/types"
@@ -24,6 +25,7 @@ import (
2425
icatypes "github.com/cosmos/ibc-go/v8/modules/apps/27-interchain-accounts/types"
2526
feetypes "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types"
2627
transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
28+
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
2729
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
2830
porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types"
2931
ibctesting "github.com/cosmos/ibc-go/v8/testing"
@@ -291,17 +293,24 @@ func (s *CallbacksTestSuite) AssertHasExecutedExpectedCallbackWithFee(
291293

292294
// GetExpectedEvent returns the expected event for a callback.
293295
func GetExpectedEvent(
294-
packetDataUnmarshaler porttypes.PacketDataUnmarshaler, remainingGas uint64, data []byte, srcPortID,
296+
ctx sdk.Context, packetDataUnmarshaler porttypes.PacketDataUnmarshaler, remainingGas uint64, data []byte, srcPortID,
295297
eventPortID, eventChannelID string, seq uint64, callbackType types.CallbackType, expError error,
296298
) (abci.Event, bool) {
297299
var (
298300
callbackData types.CallbackData
299301
err error
300302
)
303+
304+
// Set up gas meter with remainingGas.
305+
gasMeter := storetypes.NewGasMeter(remainingGas)
306+
ctx = ctx.WithGasMeter(gasMeter)
307+
308+
// Mock packet.
309+
packet := channeltypes.NewPacket(data, 0, srcPortID, "", "", "", clienttypes.ZeroHeight(), 0)
301310
if callbackType == types.CallbackTypeReceivePacket {
302-
callbackData, err = types.GetDestCallbackData(packetDataUnmarshaler, data, srcPortID, remainingGas, maxCallbackGas)
311+
callbackData, err = types.GetDestCallbackData(ctx, packetDataUnmarshaler, packet, maxCallbackGas)
303312
} else {
304-
callbackData, err = types.GetSourceCallbackData(packetDataUnmarshaler, data, srcPortID, remainingGas, maxCallbackGas)
313+
callbackData, err = types.GetSourceCallbackData(ctx, packetDataUnmarshaler, packet, maxCallbackGas)
305314
}
306315
if err != nil {
307316
return abci.Event{}, false

modules/apps/callbacks/ibc_middleware.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ func (im IBCMiddleware) SendPacket(
9999
return 0, err
100100
}
101101

102-
callbackData, err := types.GetSourceCallbackData(im.app, data, sourcePort, ctx.GasMeter().GasRemaining(), im.maxCallbackGas)
102+
// packet is created withouth destination information present, GetSourceCallbackData does not use these.
103+
packet := channeltypes.NewPacket(data, seq, sourcePort, sourceChannel, "", "", timeoutHeight, timeoutTimestamp)
104+
105+
callbackData, err := types.GetSourceCallbackData(ctx, im.app, packet, im.maxCallbackGas)
103106
// SendPacket is not blocked if the packet does not opt-in to callbacks
104107
if err != nil {
105108
return seq, nil
@@ -138,7 +141,7 @@ func (im IBCMiddleware) OnAcknowledgementPacket(
138141
}
139142

140143
callbackData, err := types.GetSourceCallbackData(
141-
im.app, packet.GetData(), packet.GetSourcePort(), ctx.GasMeter().GasRemaining(), im.maxCallbackGas,
144+
ctx, im.app, packet, im.maxCallbackGas,
142145
)
143146
// OnAcknowledgementPacket is not blocked if the packet does not opt-in to callbacks
144147
if err != nil {
@@ -172,7 +175,7 @@ func (im IBCMiddleware) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Pac
172175
}
173176

174177
callbackData, err := types.GetSourceCallbackData(
175-
im.app, packet.GetData(), packet.GetSourcePort(), ctx.GasMeter().GasRemaining(), im.maxCallbackGas,
178+
ctx, im.app, packet, im.maxCallbackGas,
176179
)
177180
// OnTimeoutPacket is not blocked if the packet does not opt-in to callbacks
178181
if err != nil {
@@ -208,7 +211,7 @@ func (im IBCMiddleware) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet
208211
}
209212

210213
callbackData, err := types.GetDestCallbackData(
211-
im.app, packet.GetData(), packet.GetSourcePort(), ctx.GasMeter().GasRemaining(), im.maxCallbackGas,
214+
ctx, im.app, packet, im.maxCallbackGas,
212215
)
213216
// OnRecvPacket is not blocked if the packet does not opt-in to callbacks
214217
if err != nil {
@@ -245,8 +248,13 @@ func (im IBCMiddleware) WriteAcknowledgement(
245248
return err
246249
}
247250

251+
chanPacket, ok := packet.(channeltypes.Packet)
252+
if !ok {
253+
panic(fmt.Errorf("expected type %T, got %T", &channeltypes.Packet{}, packet))
254+
}
255+
248256
callbackData, err := types.GetDestCallbackData(
249-
im.app, packet.GetData(), packet.GetSourcePort(), ctx.GasMeter().GasRemaining(), im.maxCallbackGas,
257+
ctx, im.app, chanPacket, im.maxCallbackGas,
250258
)
251259
// WriteAcknowledgement is not blocked if the packet does not opt-in to callbacks
252260
if err != nil {
@@ -417,6 +425,6 @@ func (im IBCMiddleware) GetAppVersion(ctx sdk.Context, portID, channelID string)
417425

418426
// UnmarshalPacketData defers to the underlying app to unmarshal the packet data.
419427
// This function implements the optional PacketDataUnmarshaler interface.
420-
func (im IBCMiddleware) UnmarshalPacketData(bz []byte) (interface{}, error) {
421-
return im.app.UnmarshalPacketData(bz)
428+
func (im IBCMiddleware) UnmarshalPacketData(ctx sdk.Context, portID, channelID string, bz []byte) (interface{}, error) {
429+
return im.app.UnmarshalPacketData(ctx, portID, channelID, bz)
422430
}

modules/apps/callbacks/ibc_middleware_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func (s *CallbacksTestSuite) TestSendPacket() {
200200
s.Require().Equal(uint64(1), seq)
201201

202202
expEvent, exists := GetExpectedEvent(
203-
transferICS4Wrapper.(porttypes.PacketDataUnmarshaler), gasLimit, packetData.GetBytes(), s.path.EndpointA.ChannelConfig.PortID,
203+
ctx, transferICS4Wrapper.(porttypes.PacketDataUnmarshaler), gasLimit, packetData.GetBytes(), s.path.EndpointA.ChannelConfig.PortID,
204204
s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID, seq, types.CallbackTypeSendPacket, nil,
205205
)
206206
if exists {
@@ -381,7 +381,7 @@ func (s *CallbacksTestSuite) TestOnAcknowledgementPacket() {
381381
s.Require().Equal(uint8(1), sourceStatefulCounter)
382382

383383
expEvent, exists := GetExpectedEvent(
384-
transferStack.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
384+
ctx, transferStack.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
385385
packet.SourcePort, packet.SourceChannel, packet.Sequence, types.CallbackTypeAcknowledgementPacket, nil,
386386
)
387387
s.Require().True(exists)
@@ -543,7 +543,7 @@ func (s *CallbacksTestSuite) TestOnTimeoutPacket() {
543543
s.Require().Equal(uint8(2), sourceStatefulCounter)
544544

545545
expEvent, exists := GetExpectedEvent(
546-
transferStack.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
546+
ctx, transferStack.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
547547
packet.SourcePort, packet.SourceChannel, packet.Sequence, types.CallbackTypeTimeoutPacket, nil,
548548
)
549549
s.Require().True(exists)
@@ -712,7 +712,7 @@ func (s *CallbacksTestSuite) TestOnRecvPacket() {
712712
s.Require().Equal(uint8(1), destStatefulCounter)
713713

714714
expEvent, exists := GetExpectedEvent(
715-
transferStack.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
715+
ctx, transferStack.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
716716
packet.DestinationPort, packet.DestinationChannel, packet.Sequence, types.CallbackTypeReceivePacket, nil,
717717
)
718718
s.Require().True(exists)
@@ -814,7 +814,7 @@ func (s *CallbacksTestSuite) TestWriteAcknowledgement() {
814814
s.Require().NoError(err)
815815

816816
expEvent, exists := GetExpectedEvent(
817-
transferICS4Wrapper.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
817+
ctx, transferICS4Wrapper.(porttypes.PacketDataUnmarshaler), gasLimit, packet.Data, packet.SourcePort,
818818
packet.DestinationPort, packet.DestinationChannel, packet.Sequence, types.CallbackTypeReceivePacket, nil,
819819
)
820820
if exists {
@@ -1003,15 +1003,18 @@ func (s *CallbacksTestSuite) TestUnmarshalPacketData() {
10031003
Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}, "dest_callback": {"address":"%s"}}`, ibctesting.TestAccAddress, ibctesting.TestAccAddress),
10041004
}
10051005

1006+
portID := s.path.EndpointA.ChannelConfig.PortID
1007+
channelID := s.path.EndpointA.ChannelID
1008+
10061009
// Unmarshal ICS20 v1 packet data
10071010
data := expPacketDataICS20V1.GetBytes()
1008-
packetData, err := unmarshalerStack.UnmarshalPacketData(data)
1011+
packetData, err := unmarshalerStack.UnmarshalPacketData(s.chainA.GetContext(), portID, channelID, data)
10091012
s.Require().NoError(err)
10101013
s.Require().Equal(expPacketDataICS20V2, packetData)
10111014

10121015
// Unmarshal ICS20 v1 packet data
10131016
data = expPacketDataICS20V2.GetBytes()
1014-
packetData, err = unmarshalerStack.UnmarshalPacketData(data)
1017+
packetData, err = unmarshalerStack.UnmarshalPacketData(s.chainA.GetContext(), portID, channelID, data)
10151018
s.Require().NoError(err)
10161019
s.Require().Equal(expPacketDataICS20V2, packetData)
10171020
}

modules/apps/callbacks/types/callbacks.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66

77
errorsmod "cosmossdk.io/errors"
88

9+
sdk "github.com/cosmos/cosmos-sdk/types"
10+
11+
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
912
porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types"
1013
ibcexported "github.com/cosmos/ibc-go/v8/modules/core/exported"
1114
)
@@ -67,35 +70,43 @@ type CallbackData struct {
6770

6871
// GetSourceCallbackData parses the packet data and returns the source callback data.
6972
func GetSourceCallbackData(
73+
ctx sdk.Context,
7074
packetDataUnmarshaler porttypes.PacketDataUnmarshaler,
71-
data []byte, srcPortID string, remainingGas uint64, maxGas uint64,
75+
packet channeltypes.Packet,
76+
maxGas uint64,
7277
) (CallbackData, error) {
73-
return getCallbackData(packetDataUnmarshaler, data, srcPortID, remainingGas, maxGas, SourceCallbackKey)
78+
packetData, err := packetDataUnmarshaler.UnmarshalPacketData(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetData())
79+
if err != nil {
80+
return CallbackData{}, errorsmod.Wrap(ErrCannotUnmarshalPacketData, err.Error())
81+
}
82+
83+
return getCallbackData(packetData, packet.GetSourcePort(), ctx.GasMeter().GasRemaining(), maxGas, SourceCallbackKey)
7484
}
7585

7686
// GetDestCallbackData parses the packet data and returns the destination callback data.
7787
func GetDestCallbackData(
88+
ctx sdk.Context,
7889
packetDataUnmarshaler porttypes.PacketDataUnmarshaler,
79-
data []byte, srcPortID string, remainingGas, maxGas uint64,
90+
packet channeltypes.Packet, maxGas uint64,
8091
) (CallbackData, error) {
81-
return getCallbackData(packetDataUnmarshaler, data, srcPortID, remainingGas, maxGas, DestinationCallbackKey)
92+
packetData, err := packetDataUnmarshaler.UnmarshalPacketData(ctx, packet.GetDestPort(), packet.GetDestChannel(), packet.GetData())
93+
if err != nil {
94+
return CallbackData{}, errorsmod.Wrap(ErrCannotUnmarshalPacketData, err.Error())
95+
}
96+
97+
return getCallbackData(packetData, packet.GetSourcePort(), ctx.GasMeter().GasRemaining(), maxGas, DestinationCallbackKey)
8298
}
8399

84100
// getCallbackData parses the packet data and returns the callback data.
85101
// It also checks that the remaining gas is greater than the gas limit specified in the packet data.
86102
// The addressGetter and gasLimitGetter functions are used to retrieve the callback
87103
// address and gas limit from the callback data.
88104
func getCallbackData(
89-
packetDataUnmarshaler porttypes.PacketDataUnmarshaler,
90-
data []byte, srcPortID string, remainingGas,
105+
packetData interface{},
106+
srcPortID string,
107+
remainingGas,
91108
maxGas uint64, callbackKey string,
92109
) (CallbackData, error) {
93-
// unmarshal packet data
94-
packetData, err := packetDataUnmarshaler.UnmarshalPacketData(data)
95-
if err != nil {
96-
return CallbackData{}, errorsmod.Wrap(ErrCannotUnmarshalPacketData, err.Error())
97-
}
98-
99110
packetDataProvider, ok := packetData.(ibcexported.PacketDataProvider)
100111
if !ok {
101112
return CallbackData{}, ErrNotPacketDataProvider

0 commit comments

Comments
 (0)