Skip to content

Commit 147cf17

Browse files
chore: implement version checking for channel handshake application callbacks (#6175)
* add SupportedVersions array for different ics20 versions, add version checking on channel handshake application callbacks * add tests * update pr review * pr review * last few pr review nits * linter * add version counter proposing * fix missing app versino * update code + tests to return our proposed version if counterparty version is invalid * remove if statement * address review comments: return ics20-2 if counterparty version is not supported --------- Co-authored-by: Carlos Rodriguez <carlos@interchain.io>
1 parent ca056cf commit 147cf17

File tree

3 files changed

+64
-46
lines changed

3 files changed

+64
-46
lines changed

modules/apps/transfer/ibc_module.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"fmt"
66
"math"
7+
"slices"
78
"strings"
89

910
errorsmod "cosmossdk.io/errors"
@@ -87,12 +88,13 @@ func (im IBCModule) OnChanOpenInit(
8788
return "", err
8889
}
8990

91+
// default to latest supported version
9092
if strings.TrimSpace(version) == "" {
9193
version = types.Version
9294
}
9395

94-
if version != types.Version {
95-
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, version)
96+
if !slices.Contains(types.SupportedVersions, version) {
97+
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected one of %s, got %s", types.SupportedVersions, version)
9698
}
9799

98100
// Claim channel capability passed back by IBC module
@@ -118,16 +120,16 @@ func (im IBCModule) OnChanOpenTry(
118120
return "", err
119121
}
120122

121-
if counterpartyVersion != types.Version {
122-
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "invalid counterparty version: expected %s, got %s", types.Version, counterpartyVersion)
123+
if !slices.Contains(types.SupportedVersions, counterpartyVersion) {
124+
return types.Version, nil
123125
}
124126

125127
// OpenTry must claim the channelCapability that IBC passes into the callback
126128
if err := im.keeper.ClaimCapability(ctx, chanCap, host.ChannelCapabilityPath(portID, channelID)); err != nil {
127129
return "", err
128130
}
129131

130-
return types.Version, nil
132+
return counterpartyVersion, nil
131133
}
132134

133135
// OnChanOpenAck implements the IBCModule interface
@@ -138,9 +140,10 @@ func (IBCModule) OnChanOpenAck(
138140
_ string,
139141
counterpartyVersion string,
140142
) error {
141-
if counterpartyVersion != types.Version {
142-
return errorsmod.Wrapf(types.ErrInvalidVersion, "invalid counterparty version: expected %s, got %s", types.Version, counterpartyVersion)
143+
if !slices.Contains(types.SupportedVersions, counterpartyVersion) {
144+
return errorsmod.Wrapf(types.ErrInvalidVersion, "invalid counterparty version: expected one of %s, got %s", types.SupportedVersions, counterpartyVersion)
143145
}
146+
144147
return nil
145148
}
146149

@@ -338,8 +341,8 @@ func (im IBCModule) OnChanUpgradeInit(ctx sdk.Context, portID, channelID string,
338341
return "", err
339342
}
340343

341-
if proposedVersion != types.Version {
342-
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, proposedVersion)
344+
if !slices.Contains(types.SupportedVersions, proposedVersion) {
345+
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "invalid counterparty version: expected one of %s, got %s", types.SupportedVersions, proposedVersion)
343346
}
344347

345348
return proposedVersion, nil
@@ -351,17 +354,17 @@ func (im IBCModule) OnChanUpgradeTry(ctx sdk.Context, portID, channelID string,
351354
return "", err
352355
}
353356

354-
if counterpartyVersion != types.Version {
355-
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, counterpartyVersion)
357+
if !slices.Contains(types.SupportedVersions, counterpartyVersion) {
358+
return types.Version, nil
356359
}
357360

358361
return counterpartyVersion, nil
359362
}
360363

361364
// OnChanUpgradeAck implements the IBCModule interface
362365
func (IBCModule) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, counterpartyVersion string) error {
363-
if counterpartyVersion != types.Version {
364-
return errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, counterpartyVersion)
366+
if !slices.Contains(types.SupportedVersions, counterpartyVersion) {
367+
return errorsmod.Wrapf(types.ErrInvalidVersion, "invalid counterparty version: expected one of %s, got %s", types.SupportedVersions, counterpartyVersion)
365368
}
366369

367370
return nil

modules/apps/transfer/ibc_module_test.go

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,50 +26,56 @@ func (suite *TransferTestSuite) TestOnChanOpenInit() {
2626
)
2727

2828
testCases := []struct {
29-
name string
30-
malleate func()
31-
expPass bool
29+
name string
30+
malleate func()
31+
expPass bool
32+
expVersion string
3233
}{
3334
{
34-
"success", func() {}, true,
35+
"success", func() {}, true, types.Version,
3536
},
3637
{
3738
// connection hops is not used in the transfer application callback,
3839
// it is already validated in the core OnChanUpgradeInit.
3940
"success: invalid connection hops", func() {
4041
path.EndpointA.ConnectionID = "invalid-connection-id"
41-
}, true,
42+
}, true, types.Version,
4243
},
4344
{
44-
"empty version string", func() {
45+
"success: empty version string", func() {
4546
channel.Version = ""
46-
}, true,
47+
}, true, types.Version,
48+
},
49+
{
50+
"success: ics20-1 legacy", func() {
51+
channel.Version = types.Version1
52+
}, true, types.Version1,
4753
},
4854
{
4955
"max channels reached", func() {
5056
path.EndpointA.ChannelID = channeltypes.FormatChannelIdentifier(math.MaxUint32 + 1)
51-
}, false,
57+
}, false, "",
5258
},
5359
{
5460
"invalid order - ORDERED", func() {
5561
channel.Ordering = channeltypes.ORDERED
56-
}, false,
62+
}, false, "",
5763
},
5864
{
5965
"invalid port ID", func() {
6066
path.EndpointA.ChannelConfig.PortID = ibctesting.MockPort
61-
}, false,
67+
}, false, "",
6268
},
6369
{
6470
"invalid version", func() {
6571
channel.Version = "version" //nolint:goconst
66-
}, false,
72+
}, false, "",
6773
},
6874
{
6975
"capability already claimed", func() {
7076
err := suite.chainA.GetSimApp().ScopedTransferKeeper.ClaimCapability(suite.chainA.GetContext(), chanCap, host.ChannelCapabilityPath(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
7177
suite.Require().NoError(err)
72-
}, false,
78+
}, false, "",
7379
},
7480
}
7581

@@ -104,7 +110,7 @@ func (suite *TransferTestSuite) TestOnChanOpenInit() {
104110

105111
if tc.expPass {
106112
suite.Require().NoError(err)
107-
suite.Require().Equal(types.Version, version)
113+
suite.Require().Equal(tc.expVersion, version)
108114
} else {
109115
suite.Require().Error(err)
110116
suite.Require().Equal(version, "")
@@ -123,38 +129,44 @@ func (suite *TransferTestSuite) TestOnChanOpenTry() {
123129
)
124130

125131
testCases := []struct {
126-
name string
127-
malleate func()
128-
expPass bool
132+
name string
133+
malleate func()
134+
expPass bool
135+
expVersion string
129136
}{
130137
{
131-
"success", func() {}, true,
138+
"success", func() {}, true, types.Version,
139+
},
140+
{
141+
"success: counterparty version is legacy ics20-1", func() {
142+
counterpartyVersion = types.Version1
143+
}, true, types.Version1,
144+
},
145+
{
146+
"success: invalid counterparty version, we use our proposed version", func() {
147+
counterpartyVersion = "version"
148+
}, true, types.Version,
132149
},
133150
{
134151
"max channels reached", func() {
135152
path.EndpointA.ChannelID = channeltypes.FormatChannelIdentifier(math.MaxUint32 + 1)
136-
}, false,
153+
}, false, "",
137154
},
138155
{
139156
"capability already claimed", func() {
140157
err := suite.chainA.GetSimApp().ScopedTransferKeeper.ClaimCapability(suite.chainA.GetContext(), chanCap, host.ChannelCapabilityPath(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
141158
suite.Require().NoError(err)
142-
}, false,
159+
}, false, "",
143160
},
144161
{
145162
"invalid order - ORDERED", func() {
146163
channel.Ordering = channeltypes.ORDERED
147-
}, false,
164+
}, false, "",
148165
},
149166
{
150167
"invalid port ID", func() {
151168
path.EndpointA.ChannelConfig.PortID = ibctesting.MockPort
152-
}, false,
153-
},
154-
{
155-
"invalid counterparty version", func() {
156-
counterpartyVersion = "version"
157-
}, false,
169+
}, false, "",
158170
},
159171
}
160172

@@ -195,7 +207,7 @@ func (suite *TransferTestSuite) TestOnChanOpenTry() {
195207

196208
if tc.expPass {
197209
suite.Require().NoError(err)
198-
suite.Require().Equal(types.Version, version)
210+
suite.Require().Equal(tc.expVersion, version)
199211
} else {
200212
suite.Require().Error(err)
201213
suite.Require().Equal("", version)
@@ -340,18 +352,18 @@ func (suite *TransferTestSuite) TestOnChanUpgradeTry() {
340352
nil,
341353
},
342354
{
343-
"invalid upgrade ordering",
355+
"success: invalid upgrade version from counterparty, we use our proposed version",
344356
func() {
345-
counterpartyUpgrade.Fields.Ordering = channeltypes.ORDERED
357+
counterpartyUpgrade.Fields.Version = ibctesting.InvalidID
346358
},
347-
channeltypes.ErrInvalidChannelOrdering,
359+
nil,
348360
},
349361
{
350-
"invalid upgrade version",
362+
"invalid upgrade ordering",
351363
func() {
352-
counterpartyUpgrade.Fields.Version = ibctesting.InvalidID
364+
counterpartyUpgrade.Fields.Ordering = channeltypes.ORDERED
353365
},
354-
types.ErrInvalidVersion,
366+
channeltypes.ErrInvalidChannelOrdering,
355367
},
356368
}
357369

modules/apps/transfer/types/keys.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ var (
5151
PortKey = []byte{0x01}
5252
// DenomTraceKey defines the key to store the denomination trace info in store
5353
DenomTraceKey = []byte{0x02}
54+
55+
// SupportedVersions defines all versions that are supported by the module
56+
SupportedVersions = []string{Version, Version1}
5457
)
5558

5659
// GetEscrowAddress returns the escrow address for the specified channel.

0 commit comments

Comments
 (0)