Skip to content

Commit cbdce64

Browse files
mjolivercopybara-github
authored andcommitted
Refactor Windows service helpers for testability and improve RestartServiceWithVerify.
Introduced mockable variables for Windows service management functions and time.Sleep in helpers_windows.go. Updated all service-related functions to use these new variables. Improved RestartServiceWithVerify to more accurately handle service state transitions by querying the actual service state immediately after restart and looping until the service is in the svc.Running state. Added TestRestartServiceWithVerify in helpers_windows_test.go to test various service restart scenarios, including pending states and timeouts. PiperOrigin-RevId: 903171985
1 parent 43fb0ae commit cbdce64

2 files changed

Lines changed: 176 additions & 51 deletions

File tree

go/helpers/helpers_windows.go

Lines changed: 64 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,19 @@ var (
3939
prodSetWindowPos = moduser32.NewProc("SetWindowPos")
4040
prodGetConsoleWindow = modkernel32.NewProc("GetConsoleWindow")
4141
prodGetParent = moduser32.NewProc("GetParent")
42-
)
4342

44-
var (
45-
// Test helpers
46-
fnProcessList = winapi.ProcessList
43+
// for testing
44+
mgrConnect = mgr.Connect
45+
mgrDisconnect = func(m *mgr.Mgr) error { return m.Disconnect() }
46+
mgrOpenService = func(m *mgr.Mgr, name string) (*mgr.Service, error) { return m.OpenService(name) }
47+
svcClose = func(s *mgr.Service) error { return s.Close() }
48+
svcConfig = func(s *mgr.Service) (mgr.Config, error) { return s.Config() }
49+
svcQuery = func(s *mgr.Service) (svc.Status, error) { return s.Query() }
50+
svcUpdateConfig = func(s *mgr.Service, c mgr.Config) error { return s.UpdateConfig(c) }
51+
svcStart = func(s *mgr.Service) error { return s.Start() }
52+
svcControl = func(s *mgr.Service, c svc.Cmd) (svc.Status, error) { return s.Control(c) }
53+
timeSleep = time.Sleep
54+
fnProcessList = winapi.ProcessList
4755
)
4856

4957
const (
@@ -57,41 +65,41 @@ const (
5765

5866
// GetServiceState interrogates local system services and returns their status and configuration.
5967
func GetServiceState(name string) (svc.Status, mgr.Config, error) {
60-
m, err := mgr.Connect()
68+
m, err := mgrConnect()
6169
if err != nil {
6270
return svc.Status{}, mgr.Config{}, err
6371
}
64-
defer m.Disconnect()
65-
s, err := m.OpenService(name)
72+
defer mgrDisconnect(m)
73+
s, err := mgrOpenService(m, name)
6674
if err != nil {
6775
return svc.Status{}, mgr.Config{}, fmt.Errorf("could not access service: %v", err)
6876
}
69-
defer s.Close()
77+
defer svcClose(s)
7078

71-
config, err := s.Config()
79+
config, err := svcConfig(s)
7280
if err != nil {
7381
return svc.Status{}, mgr.Config{}, err
7482
}
75-
status, err := s.Query()
83+
status, err := svcQuery(s)
7684
return status, config, err
7785
}
7886

7987
// ChangeService can change a services type or/and startup behaviour
8088
// https://docs.microsoft.com/en-us/dotnet/api/system.serviceprocess.servicestartmode
8189
// https://docs.microsoft.com/en-us/dotnet/api/system.serviceprocess.servicetype
8290
func ChangeService(name string, c mgr.Config) error {
83-
m, err := mgr.Connect()
91+
m, err := mgrConnect()
8492
if err != nil {
8593
return err
8694
}
87-
defer m.Disconnect()
88-
s, err := m.OpenService(name)
95+
defer mgrDisconnect(m)
96+
s, err := mgrOpenService(m, name)
8997
if err != nil {
9098
return fmt.Errorf("could not access service: %v", err)
9199
}
92-
defer s.Close()
100+
defer svcClose(s)
93101

94-
return s.UpdateConfig(c)
102+
return svcUpdateConfig(s, c)
95103
}
96104

97105
const (
@@ -112,25 +120,25 @@ func GetSysEnv(key string) (string, error) {
112120

113121
// RestartService attempts to restart local system services.
114122
func RestartService(name string) error {
115-
m, err := mgr.Connect()
123+
m, err := mgrConnect()
116124
if err != nil {
117125
return err
118126
}
119-
defer m.Disconnect()
120-
s, err := m.OpenService(name)
127+
defer mgrDisconnect(m)
128+
s, err := mgrOpenService(m, name)
121129
if err != nil {
122130
return err
123131
}
124-
defer s.Close()
132+
defer svcClose(s)
125133

126-
if err := stopService(s); err != nil {
134+
if err := stopService(s, name); err != nil {
127135
return err
128136
}
129137

130-
return s.Start()
138+
return svcStart(s)
131139
}
132140

133-
// RestartServiceWithVerify attempts to restart local system services and verifies the service is running with a 60 second timeout.
141+
// RestartServiceWithVerify attempts to restart local system services and verifies the service is running with a timeout.
134142
func RestartServiceWithVerify(name string, retryCount ...int) error {
135143
retryAttempts := 12
136144
if len(retryCount) > 0 {
@@ -139,24 +147,29 @@ func RestartServiceWithVerify(name string, retryCount ...int) error {
139147
if err := RestartService(name); err != nil {
140148
return err
141149
}
142-
status := svc.Status{
143-
State: svc.StartPending, // Assume the service is starting
150+
151+
// Check the actual state immediately, rather than faking it
152+
status, _, err := GetServiceState(name)
153+
if err != nil {
154+
return err
144155
}
145-
for retry := 0; status.State == svc.StartPending; retry++ {
156+
157+
// Loop as long as the service is NOT running
158+
for retry := 0; status.State != svc.Running; retry++ {
159+
if retry == retryAttempts {
160+
return fmt.Errorf("timed out waiting for service %q to start", name)
161+
}
162+
146163
deck.Infof("Waiting for service %q to start, sleeping for 5 seconds", name)
147-
time.Sleep(5 * time.Second)
148-
var err error
164+
timeSleep(5 * time.Second)
165+
149166
status, _, err = GetServiceState(name)
150167
if err != nil {
151168
return err
152169
}
153-
if retry == retryAttempts {
154-
return fmt.Errorf("timed out waiting for service %q to start", name)
155-
}
156-
}
157-
if status.State != svc.Running {
158-
return fmt.Errorf("service %q is not running after restart, current state: %v", name, status.State)
159170
}
171+
172+
// If the loop exits normally, we know status.State == svc.Running
160173
return nil
161174
}
162175

@@ -181,18 +194,18 @@ func SetSysEnv(key, value string) error {
181194

182195
// StartService attempts to start local system services.
183196
func StartService(name string) error {
184-
m, err := mgr.Connect()
197+
m, err := mgrConnect()
185198
if err != nil {
186199
return err
187200
}
188-
defer m.Disconnect()
189-
s, err := m.OpenService(name)
201+
defer mgrDisconnect(m)
202+
s, err := mgrOpenService(m, name)
190203
if err != nil {
191204
return err
192205
}
193-
defer s.Close()
206+
defer svcClose(s)
194207

195-
return s.Start()
208+
return svcStart(s)
196209
}
197210

198211
// StartServiceWithVerify attempts to start local system services and verifies
@@ -210,7 +223,7 @@ func StartServiceWithVerify(name string, retryCount ...int) error {
210223
}
211224
for retry := 0; status.State == svc.StartPending; retry++ {
212225
deck.Infof("Waiting for service %q to start, sleeping for 5 seconds", name)
213-
time.Sleep(5 * time.Second)
226+
timeSleep(5 * time.Second)
214227
var err error
215228
status, _, err = GetServiceState(name)
216229
if err != nil {
@@ -226,28 +239,28 @@ func StartServiceWithVerify(name string, retryCount ...int) error {
226239
return nil
227240
}
228241

229-
func stopService(s *mgr.Service) error {
242+
func stopService(s *mgr.Service, name string) error {
230243
// although s.Control returns stat, if the service is already stopped it returns an error
231-
stat, err := s.Query()
244+
stat, err := svcQuery(s)
232245
if err != nil {
233246
return err
234247
}
235248
if stat.State == svc.Stopped {
236249
return nil
237250
}
238-
stat, err = s.Control(svc.Stop)
251+
stat, err = svcControl(s, svc.Stop)
239252
if err != nil {
240253
return err
241254
}
242255
retry := 0
243256
for stat.State != svc.Stopped {
244-
deck.Infof("Waiting for service %q to stop.", s.Name)
245-
time.Sleep(5 * time.Second)
257+
deck.Infof("Waiting for service %q to stop.", name)
258+
timeSleep(5 * time.Second)
246259
retry++
247260
if retry > 12 {
248-
return fmt.Errorf("timed out waiting for service %q to stop", s.Name)
261+
return fmt.Errorf("timed out waiting for service %q to stop", name)
249262
}
250-
stat, err = s.Query()
263+
stat, err = svcQuery(s)
251264
if err != nil {
252265
return err
253266
}
@@ -257,18 +270,18 @@ func stopService(s *mgr.Service) error {
257270

258271
// StopService attempts to stop local system services.
259272
func StopService(name string) error {
260-
m, err := mgr.Connect()
273+
m, err := mgrConnect()
261274
if err != nil {
262275
return err
263276
}
264-
defer m.Disconnect()
265-
s, err := m.OpenService(name)
277+
defer mgrDisconnect(m)
278+
s, err := mgrOpenService(m, name)
266279
if err != nil {
267280
return err
268281
}
269-
defer s.Close()
282+
defer svcClose(s)
270283

271-
return stopService(s)
284+
return stopService(s, name)
272285
}
273286

274287
// WaitForProcessExit waits for a process to stop (no longer appear in the process list).

go/helpers/helpers_windows_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,121 @@ import (
2424
"testing"
2525
"time"
2626

27+
"golang.org/x/sys/windows/svc/mgr"
28+
"golang.org/x/sys/windows/svc"
2729
so "github.com/iamacarpet/go-win64api/shared"
2830
)
2931

32+
type serviceState int
33+
34+
const (
35+
stateStopped serviceState = iota
36+
stateStartPending
37+
stateRunning
38+
stateStopPending
39+
)
40+
41+
type fakeService struct {
42+
state []svc.State
43+
t *testing.T
44+
i int
45+
}
46+
47+
func (s *fakeService) next() svc.State {
48+
if s.i >= len(s.state) {
49+
s.t.Fatalf("ran out of service states...")
50+
}
51+
st := s.state[s.i]
52+
s.i++
53+
return st
54+
}
55+
56+
func TestRestartServiceWithVerify(t *testing.T) {
57+
tests := []struct {
58+
name string
59+
states []svc.State // sequence of states returned by Query
60+
startState svc.State // state returned by Control(Stop)
61+
wantErr bool
62+
}{
63+
{
64+
"GoodService",
65+
[]svc.State{svc.Running, svc.Stopped, svc.Running},
66+
svc.StopPending,
67+
false,
68+
},
69+
{
70+
"PendingService",
71+
[]svc.State{svc.Running, svc.Stopped, svc.StartPending, svc.StartPending, svc.Running},
72+
svc.StopPending,
73+
false,
74+
},
75+
{
76+
"TimeoutService",
77+
[]svc.State{svc.Running, svc.Stopped, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending, svc.StartPending},
78+
svc.StopPending,
79+
true,
80+
},
81+
{
82+
"AlreadyStoppedService",
83+
[]svc.State{svc.Stopped, svc.Running},
84+
svc.Stopped,
85+
false,
86+
},
87+
}
88+
89+
for _, tt := range tests {
90+
t.Run(tt.name, func(t *testing.T) {
91+
fs := &fakeService{state: tt.states, t: t}
92+
oldMgrConnect := mgrConnect
93+
oldMgrDisconnect := mgrDisconnect
94+
oldMgrOpenService := mgrOpenService
95+
oldSvcClose := svcClose
96+
oldSvcConfig := svcConfig
97+
oldSvcQuery := svcQuery
98+
oldSvcUpdateConfig := svcUpdateConfig
99+
oldSvcStart := svcStart
100+
oldSvcControl := svcControl
101+
oldTimeSleep := timeSleep
102+
defer func() {
103+
mgrConnect = oldMgrConnect
104+
mgrDisconnect = oldMgrDisconnect
105+
mgrOpenService = oldMgrOpenService
106+
svcClose = oldSvcClose
107+
svcConfig = oldSvcConfig
108+
svcQuery = oldSvcQuery
109+
svcUpdateConfig = oldSvcUpdateConfig
110+
svcStart = oldSvcStart
111+
svcControl = oldSvcControl
112+
timeSleep = oldTimeSleep
113+
}()
114+
mgrConnect = func() (*mgr.Mgr, error) { return nil, nil }
115+
mgrDisconnect = func(*mgr.Mgr) error { return nil }
116+
mgrOpenService = func(*mgr.Mgr, string) (*mgr.Service, error) { return nil, nil }
117+
svcClose = func(*mgr.Service) error { return nil }
118+
svcConfig = func(*mgr.Service) (mgr.Config, error) { return mgr.Config{}, nil }
119+
svcQuery = func(*mgr.Service) (svc.Status, error) {
120+
return svc.Status{State: fs.next()}, nil
121+
}
122+
svcStart = func(*mgr.Service) error { return nil }
123+
svcControl = func(s *mgr.Service, c svc.Cmd) (svc.Status, error) {
124+
if c == svc.Stop {
125+
return svc.Status{State: tt.startState}, nil
126+
}
127+
return svc.Status{}, fmt.Errorf("unexpected control code: %v", c)
128+
}
129+
timeSleep = func(time.Duration) {}
130+
131+
err := RestartServiceWithVerify(tt.name, 12)
132+
if err != nil && !tt.wantErr {
133+
t.Errorf("RestartServiceWithVerify(%q) returned error: %v, want nil", tt.name, err)
134+
}
135+
if err == nil && tt.wantErr {
136+
t.Errorf("RestartServiceWithVerify(%q) returned nil, want error", tt.name)
137+
}
138+
})
139+
}
140+
}
141+
30142
func TestWaitForProcessExit(t *testing.T) {
31143
tests := []struct {
32144
match string

0 commit comments

Comments
 (0)