Skip to content

Commit 4dc9469

Browse files
feat(stf/router): support backwards compat type URL in router (#21177)
1 parent 90fd632 commit 4dc9469

File tree

3 files changed

+160
-18
lines changed

3 files changed

+160
-18
lines changed

server/v2/stf/stf.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ var Identity = []byte("stf")
2626
type STF[T transaction.Tx] struct {
2727
logger log.Logger
2828

29-
msgRouter Router
30-
queryRouter Router
29+
msgRouter coreRouterImpl
30+
queryRouter coreRouterImpl
3131

3232
doPreBlock func(ctx context.Context, txs []T) error
3333
doBeginBlock func(ctx context.Context) error
@@ -584,8 +584,8 @@ func newExecutionContext(
584584
sender transaction.Identity,
585585
state store.WriterMap,
586586
execMode transaction.ExecMode,
587-
msgRouter Router,
588-
queryRouter Router,
587+
msgRouter coreRouterImpl,
588+
queryRouter coreRouterImpl,
589589
) *executionContext {
590590
meter := makeGasMeterFn(gas.NoGasLimit)
591591
meteredState := makeGasMeteredStoreFn(meter, state)

server/v2/stf/stf_router.go

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"reflect"
8+
"strings"
89

910
gogoproto "github.com/cosmos/gogoproto/proto"
1011

@@ -61,7 +62,7 @@ func (b *MsgRouterBuilder) HandlerExists(msgType string) bool {
6162
return ok
6263
}
6364

64-
func (b *MsgRouterBuilder) Build() (Router, error) {
65+
func (b *MsgRouterBuilder) Build() (coreRouterImpl, error) {
6566
handlers := make(map[string]appmodulev2.Handler)
6667

6768
globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error {
@@ -93,7 +94,7 @@ func (b *MsgRouterBuilder) Build() (Router, error) {
9394
handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler)
9495
}
9596

96-
return Router{
97+
return coreRouterImpl{
9798
handlers: handlers,
9899
}, nil
99100
}
@@ -139,39 +140,73 @@ func msgTypeURL(msg gogoproto.Message) string {
139140
return gogoproto.MessageName(msg)
140141
}
141142

142-
var _ router.Service = (*Router)(nil)
143+
var _ router.Service = (*coreRouterImpl)(nil)
143144

144-
// Router implements the STF router for msg and query handlers.
145-
type Router struct {
145+
// coreRouterImpl implements the STF router for msg and query handlers.
146+
type coreRouterImpl struct {
146147
handlers map[string]appmodulev2.Handler
147148
}
148149

149-
func (r Router) CanInvoke(_ context.Context, typeURL string) error {
150+
func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error {
151+
// trimming prefixes is a backwards compatibility strategy that we use
152+
// for baseapp components that did routing through type URL rather
153+
// than protobuf message names.
154+
typeURL = strings.TrimPrefix(typeURL, "/")
150155
_, exists := r.handlers[typeURL]
151156
if !exists {
152157
return fmt.Errorf("%w: %s", ErrNoHandler, typeURL)
153158
}
154159
return nil
155160
}
156161

157-
func (r Router) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
162+
func (r coreRouterImpl) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
158163
handlerResp, err := r.InvokeUntyped(ctx, req)
159164
if err != nil {
160165
return err
161166
}
162-
merge(handlerResp, resp)
163-
return nil
164-
}
165-
166-
func merge(src, dst gogoproto.Message) {
167-
reflect.Indirect(reflect.ValueOf(dst)).Set(reflect.Indirect(reflect.ValueOf(src)))
167+
return merge(handlerResp, resp)
168168
}
169169

170-
func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
170+
func (r coreRouterImpl) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
171171
typeName := msgTypeURL(req)
172172
handler, exists := r.handlers[typeName]
173173
if !exists {
174174
return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName)
175175
}
176176
return handler(ctx, req)
177177
}
178+
179+
// merge merges together two protobuf messages by setting the pointer
180+
// to src in dst. Used internally.
181+
func merge(src, dst gogoproto.Message) error {
182+
if src == nil {
183+
return fmt.Errorf("source message is nil")
184+
}
185+
if dst == nil {
186+
return fmt.Errorf("destination message is nil")
187+
}
188+
189+
srcVal := reflect.ValueOf(src)
190+
dstVal := reflect.ValueOf(dst)
191+
192+
if srcVal.Kind() == reflect.Interface {
193+
srcVal = srcVal.Elem()
194+
}
195+
if dstVal.Kind() == reflect.Interface {
196+
dstVal = dstVal.Elem()
197+
}
198+
199+
if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr {
200+
return fmt.Errorf("both source and destination must be pointers")
201+
}
202+
203+
srcElem := srcVal.Elem()
204+
dstElem := dstVal.Elem()
205+
206+
if !srcElem.Type().AssignableTo(dstElem.Type()) {
207+
return fmt.Errorf("incompatible types: cannot merge %v into %v", srcElem.Type(), dstElem.Type())
208+
}
209+
210+
dstElem.Set(srcElem)
211+
return nil
212+
}

server/v2/stf/stf_router_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package stf
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
gogoproto "github.com/cosmos/gogoproto/proto"
8+
gogotypes "github.com/cosmos/gogoproto/types"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"cosmossdk.io/core/appmodule/v2"
13+
)
14+
15+
func TestRouter(t *testing.T) {
16+
expectedMsg := &gogotypes.BoolValue{Value: true}
17+
expectedMsgName := gogoproto.MessageName(expectedMsg)
18+
19+
expectedResp := &gogotypes.StringValue{Value: "test"}
20+
21+
router := coreRouterImpl{handlers: map[string]appmodule.Handler{
22+
gogoproto.MessageName(expectedMsg): func(ctx context.Context, gotMsg appmodule.Message) (msgResp appmodule.Message, err error) {
23+
require.Equal(t, expectedMsg, gotMsg)
24+
return expectedResp, nil
25+
},
26+
}}
27+
28+
t.Run("can invoke message by name", func(t *testing.T) {
29+
err := router.CanInvoke(context.Background(), expectedMsgName)
30+
require.NoError(t, err, "must be invokable")
31+
})
32+
33+
t.Run("can invoke message by type URL", func(t *testing.T) {
34+
err := router.CanInvoke(context.Background(), "/"+expectedMsgName)
35+
require.NoError(t, err)
36+
})
37+
38+
t.Run("cannot invoke unknown message", func(t *testing.T) {
39+
err := router.CanInvoke(context.Background(), "not exist")
40+
require.Error(t, err)
41+
})
42+
43+
t.Run("invoke untyped", func(t *testing.T) {
44+
gotResp, err := router.InvokeUntyped(context.Background(), expectedMsg)
45+
require.NoError(t, err)
46+
require.Equal(t, expectedResp, gotResp)
47+
})
48+
49+
t.Run("invoked typed", func(t *testing.T) {
50+
gotResp := new(gogotypes.StringValue)
51+
err := router.InvokeTyped(context.Background(), expectedMsg, gotResp)
52+
require.NoError(t, err)
53+
require.Equal(t, expectedResp, gotResp)
54+
})
55+
}
56+
57+
func TestMerge(t *testing.T) {
58+
tests := []struct {
59+
name string
60+
src gogoproto.Message
61+
dst gogoproto.Message
62+
expected gogoproto.Message
63+
wantErr bool
64+
}{
65+
{
66+
name: "success",
67+
src: &gogotypes.BoolValue{Value: true},
68+
dst: &gogotypes.BoolValue{},
69+
expected: &gogotypes.BoolValue{Value: true},
70+
wantErr: false,
71+
},
72+
{
73+
name: "nil src",
74+
src: nil,
75+
dst: &gogotypes.StringValue{},
76+
expected: &gogotypes.StringValue{},
77+
wantErr: true,
78+
},
79+
{
80+
name: "nil dst",
81+
src: &gogotypes.StringValue{Value: "hello"},
82+
dst: nil,
83+
expected: nil,
84+
wantErr: true,
85+
},
86+
{
87+
name: "incompatible types",
88+
src: &gogotypes.StringValue{Value: "hello"},
89+
dst: &gogotypes.BoolValue{},
90+
expected: &gogotypes.BoolValue{},
91+
wantErr: true,
92+
},
93+
}
94+
95+
for _, tt := range tests {
96+
t.Run(tt.name, func(t *testing.T) {
97+
err := merge(tt.src, tt.dst)
98+
99+
if tt.wantErr {
100+
assert.Error(t, err)
101+
} else {
102+
assert.NoError(t, err)
103+
assert.Equal(t, tt.expected, tt.dst)
104+
}
105+
})
106+
}
107+
}

0 commit comments

Comments
 (0)