diff --git a/.taskcluster.yml b/.taskcluster.yml index 1ef60ad2..102f36bf 100644 --- a/.taskcluster.yml +++ b/.taskcluster.yml @@ -68,6 +68,9 @@ tasks: - generic-worker:cache:generic-worker-checkout payload: maxRunTime: 3600 + env: + # once bug 1508383 is landed, this won't be required (the worker will supply it) + TASKCLUSTER_ROOT_URL: https://taskcluster.net command: - set CGO_ENABLED=0 - set GOPATH=%CD%\gopath1.10.3 @@ -150,6 +153,9 @@ tasks: - generic-worker:cache:generic-worker-checkout payload: maxRunTime: 3600 + env: + # once bug 1508383 is landed, this won't be required (the worker will supply it) + TASKCLUSTER_ROOT_URL: https://taskcluster.net command: - set CGO_ENABLED=0 - set GOPATH=%CD%\gopath1.10.3 @@ -221,6 +227,9 @@ tasks: - generic-worker:cache:generic-worker-checkout payload: maxRunTime: 3600 + env: + # once bug 1508383 is landed, this won't be required (the worker will supply it) + TASKCLUSTER_ROOT_URL: https://taskcluster.net command: - set CGO_ENABLED=0 - set GOPATH=%CD%\gopath1.10.3 @@ -303,6 +312,9 @@ tasks: - generic-worker:cache:generic-worker-checkout payload: maxRunTime: 3600 + env: + # once bug 1508383 is landed, this won't be required (the worker will supply it) + TASKCLUSTER_ROOT_URL: https://taskcluster.net command: - - /bin/bash - -vxec @@ -445,6 +457,9 @@ tasks: taskclusterProxy: true maxRunTime: 3600 image: golang + env: + # once bug 1508383 is landed, this won't be required (the worker will supply it) + TASKCLUSTER_ROOT_URL: https://taskcluster.net command: - /bin/bash - -vxec diff --git a/aws.go b/aws.go index 83b2c296..c64316db 100644 --- a/aws.go +++ b/aws.go @@ -53,7 +53,7 @@ func queryMetaData(url string) (string, error) { return string(content), err } -// taken from https://github.com/taskcluster/aws-provisioner/blob/5a01a94141c38447968ec75232fd86a86cca366a/src/worker-type.js#L601-L615 +// taken from https://github.com/taskcluster/aws-provisioner/blob/5a2bc7c57b20df00f9c4357e0daeb7967e6f5ee8/lib/worker-type.js#L607-L624 type UserData struct { Data interface{} `json:"data"` Capacity int `json:"capacity"` @@ -67,6 +67,7 @@ type UserData struct { LaunchSpecGenerated time.Time `json:"launchSpecGenerated"` LastModified time.Time `json:"lastModified"` ProvisionerBaseURL string `json:"provisionerBaseUrl"` + TaskclusterRootURL string `json:"taskclusterRootUrl"` SecurityToken string `json:"securityToken"` } @@ -209,9 +210,11 @@ func updateConfigWithAmazonSettings(c *gwconfig.Config) error { if removeErr != nil { return removeErr } + c.AccessToken = secToken.Credentials.AccessToken - c.ClientID = secToken.Credentials.ClientID c.Certificate = secToken.Credentials.Certificate + c.ClientID = secToken.Credentials.ClientID + c.RootURL = userData.TaskclusterRootURL c.WorkerGroup = userData.Region c.WorkerType = userData.WorkerType diff --git a/aws_helper_test.go b/aws_helper_test.go index 5ea6ba7d..92fa8fd4 100644 --- a/aws_helper_test.go +++ b/aws_helper_test.go @@ -93,6 +93,7 @@ func (m *MockAWSProvisionedEnvironment) Setup(t *testing.T) func() { "instanceType": "p3.teenyweeny", "spotBid": 3.5, "price": 3.02, + "taskclusterRootUrl": "http://localhost:13243", "launchSpecGenerated": time.Now(), "lastModified": time.Now().Add(time.Minute * -30), "provisionerBaseUrl": "http://localhost:13243/provisioner", diff --git a/docs/features.md b/docs/features.md index 4e039bf2..abab0bb4 100644 --- a/docs/features.md +++ b/docs/features.md @@ -59,53 +59,37 @@ References: #### Since: generic-worker 10.6.0 The taskcluster proxy provides an easy and safe way to make authenticated -taskcluster requests within the scope(s) of a particular task. +taskcluster requests within the scope(s) of a particular task. The proxy +accepts un-authenticated requests and attaches credentials to them +corresponding to `task.scopes` as well as scopes to upload artifacts. -For example lets say we have a task like this: +The proxy's rootUrl is available to tasks in the environment variable +`TASKCLUSTER_PROXY_URL`. It can be used with a client like this: ```js -{ - "scopes": ["a", "b"], - "payload": { - "features": { - "taskclusterProxy": true - } - } -} +var taskcluster = require('taskcluster-client'); +var queue = new taskcluster.Queue({ + rootUrl: process.env.TASKCLUSTER_PROXY_URL, +}); +queue.createTask(..); ``` -A web service will execute (typically on port 80) of the local machine for the -duration of the task, with which you can proxy unauthenticated requests to -various taskcluster services. The proxy will inject the Authorization http -header for you and proxy the request to the target service, granting the -request the scopes of the task (in this case ["a", "b"]). +This request would require that `task.scopes` contain the appropriate +`queue:create-task:..` scope for the `createTask` API call. -| Target Destination | Proxy Address | -|------------------------------------------------|------------------------------------------| -| https://queue.taskcluster.net/ | http://localhost/queue/ | -| https://index.taskcluster.net/ | http://localhost/index/ | -| https://aws-provisioner.taskcluster.net/ | http://localhost/aws-provisioner/ | -| https://secrets.taskcluster.net/ | http://localhost/secrets/ | -| https://auth.taskcluster.net/ | http://localhost/auth/ | -| https://hooks.taskcluster.net/ | http://localhost/hooks/ | -| https://purge-cache.taskcluster.net/ | http://localhost/purge-cache/ | +*NOTE*: as a special case, the scopes required to call +`queue.createArtifact(, , ..)` are automatically included, +regardless of `task.scopes`. -For example (using curl) inside a task container. +The proxy is easy to use within a shell command, too: ```sh -cat secret | curl --header 'Content-Type: application/json' --request PUT --data @- http://localhost/secrets/v1/secret/ +curl $TASKCLUSTER_PROXY_URL/api/secrets/v1/secret/my-top-secret-secret +# ..or +cat secret | curl --header 'Content-Type: application/json' --request PUT --data @- $TASKCLUSTER_PROXY_URL/api/secrets/v1/secret/my-top-secret-secret ``` -You can also use the `baseUrl` parameter in the taskcluster-client - -```js -var taskcluster = require('taskcluster-client'); -var queue = new taskcluster.Queue({ - baseUrl: 'http://localhost/queue' - }); - -queue.createTask(...); -``` +These invocations would require `secrets:get:my-top-secret-secret` or `secrets:put:my-top-secret-secret`, respectively, in `task.scopes`. References: diff --git a/gwconfig/config.go b/gwconfig/config.go index c246a333..d8a29301 100644 --- a/gwconfig/config.go +++ b/gwconfig/config.go @@ -43,6 +43,7 @@ type ( QueueBaseURL string `json:"queueBaseURL"` Region string `json:"region"` RequiredDiskSpaceMegabytes uint `json:"requiredDiskSpaceMegabytes"` + RootURL string `json:"rootURL"` RunAfterUserCreation string `json:"runAfterUserCreation"` RunTasksAsCurrentUser bool `json:"runTasksAsCurrentUser"` SentryProject string `json:"sentryProject"` @@ -100,6 +101,7 @@ func (c *Config) Validate() error { {value: c.LiveLogSecret, name: "livelogSecret", disallowed: ""}, {value: c.ProvisionerID, name: "provisionerId", disallowed: ""}, {value: c.PublicIP, name: "publicIP", disallowed: net.IP(nil)}, + {value: c.RootURL, name: "rootURL", disallowed: ""}, {value: c.SigningKeyLocation, name: "signingKeyLocation", disallowed: ""}, {value: c.Subdomain, name: "subdomain", disallowed: ""}, {value: c.TasksDir, name: "tasksDir", disallowed: ""}, diff --git a/helper_test.go b/helper_test.go index 52313664..d54dd61d 100644 --- a/helper_test.go +++ b/helper_test.go @@ -129,6 +129,7 @@ func setup(t *testing.T) (teardown func()) { PurgeCacheBaseURL: tcpurgecache.DefaultBaseURL, QueueBaseURL: tcqueue.DefaultBaseURL, Region: "test-worker-group", + RootURL: os.Getenv("TASKCLUSTER_ROOT_URL"), // should be enough for tests, and travis-ci.org CI environments don't // have a lot of free disk RequiredDiskSpaceMegabytes: 16, @@ -172,8 +173,10 @@ func setup(t *testing.T) (teardown func()) { func NewQueue(t *testing.T) *tcqueue.Queue { // check we have all the env vars we need to run this test - if os.Getenv("TASKCLUSTER_CLIENT_ID") == "" || os.Getenv("TASKCLUSTER_ACCESS_TOKEN") == "" { - t.Skip("Skipping test since TASKCLUSTER_CLIENT_ID and/or TASKCLUSTER_ACCESS_TOKEN env vars not set") + if os.Getenv("TASKCLUSTER_CLIENT_ID") == "" || + os.Getenv("TASKCLUSTER_ACCESS_TOKEN") == "" || + os.Getenv("TASKCLUSTER_ROOT_URL") == "" { + t.Skip("Skipping test since TASKCLUSTER_{CLIENT_ID,ACCESS_TOKEN,ROOT_URL} env vars not set") } return tcqueue.NewFromEnv() } diff --git a/main.go b/main.go index a459f975..716f6e80 100644 --- a/main.go +++ b/main.go @@ -170,6 +170,9 @@ and reports back results to the queue. for serving live logs; see https://github.com/taskcluster/livelog and https://github.com/taskcluster/stateless-dns-server + rootURL The root URL of the Taskcluster deploment to which + clientId and accessToken grant access. For example, + 'https://taskcluster.net'. signingKeyLocation The PGP signing key for signing artifacts with. workerId A name to uniquely identify your worker. workerType This should match a worker_type managed by the @@ -487,6 +490,7 @@ func loadConfig(filename string, queryUserData bool) (*gwconfig.Config, error) { PurgeCacheBaseURL: tcpurgecache.DefaultBaseURL, QueueBaseURL: tcqueue.DefaultBaseURL, RequiredDiskSpaceMegabytes: 10240, + RootURL: "", RunAfterUserCreation: "", RunTasksAsCurrentUser: runtime.GOOS != "windows", SentryProject: "", diff --git a/main_test.go b/main_test.go index b512c001..f5e676b2 100644 --- a/main_test.go +++ b/main_test.go @@ -92,8 +92,8 @@ func TestAbortAfterMaxRunTime(t *testing.T) { func TestIdleWithoutCrash(t *testing.T) { defer setup(t)() - if config.ClientID == "" || config.AccessToken == "" { - t.Skip("Skipping test since TASKCLUSTER_CLIENT_ID and/or TASKCLUSTER_ACCESS_TOKEN env vars not set") + if config.ClientID == "" || config.AccessToken == "" || config.RootURL == "" { + t.Skip("Skipping test since TASKCLUSTER_{CLIENT_ID,ACCESS_TOKEN,ROOT_URL} env vars not set") } start := time.Now() config.IdleTimeoutSecs = 7 diff --git a/plat_all-unix-style.go b/plat_all-unix-style.go index be0abc97..3e826c0d 100644 --- a/plat_all-unix-style.go +++ b/plat_all-unix-style.go @@ -104,6 +104,15 @@ func install(arguments map[string]interface{}) (err error) { return nil } +// Set an environment variable in each command. This can be called from a feature's +// NewTaskFeature method to set variables for the task. +func (task *TaskRun) setVariable(variable string, value string) error { + for i := range task.Commands { + task.Commands[i].Cmd.Env = append(task.Commands[i].Cmd.Env, fmt.Sprintf("%s=%s", variable, value)) + } + return nil +} + func (task *TaskRun) EnvVars() []string { workerEnv := os.Environ() taskEnv := map[string]string{} @@ -121,6 +130,8 @@ func (task *TaskRun) EnvVars() []string { taskEnv[k] = v } taskEnv["TASK_ID"] = task.TaskID + taskEnv["TASKCLUSTER_ROOT_URL"] = config.RootURL + for i, j := range taskEnv { taskEnvArray = append(taskEnvArray, i+"="+j) } diff --git a/plat_windows.go b/plat_windows.go index 25839b70..72906666 100644 --- a/plat_windows.go +++ b/plat_windows.go @@ -311,6 +311,7 @@ func (task *TaskRun) prepareCommand(index int) *CommandExecutionError { contents += "set " + envVar + "=" + envValue + "\r\n" } contents += "set TASK_ID=" + task.TaskID + "\r\n" + contents += "set TASKCLUSTER_ROOT_URL=" + config.RootURL + "\r\n" contents += "cd \"" + taskContext.TaskDir + "\"" + "\r\n" // Otherwise get the env from the previous command @@ -396,6 +397,20 @@ func (task *TaskRun) prepareCommand(index int) *CommandExecutionError { return nil } +// Set an environment variable in each command. This can be called from a feature's +// NewTaskFeature method to set variables for the task. +func (task *TaskRun) setVariable(variable string, value string) error { + for i := range task.Commands { + newEnv := []string{fmt.Sprintf("%s=%s", variable, value)} + combined, err := win32.MergeEnvLists(&task.Commands[i].Cmd.Env, &newEnv) + if err != nil { + return err + } + task.Commands[i].Cmd.Env = *combined + } + return nil +} + // Only return critical errors func purgeOldTasks() error { if config.CleanUpTaskDirs { diff --git a/process/process_windows.go b/process/process_windows.go index ea9a52dc..648db644 100644 --- a/process/process_windows.go +++ b/process/process_windows.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "log" + "os" "os/exec" "strconv" "sync" @@ -53,15 +54,22 @@ func (r *Result) Crashed() bool { } func NewCommand(commandLine []string, workingDirectory string, env []string, accessToken syscall.Token) (*Command, error) { + var err error + var combined *[]string if accessToken != 0 { - environment, err := win32.CreateEnvironment(&env, accessToken) - if err != nil { - return nil, err - } - env = *environment + // in task-user mode, we must merge env with the task user's environment + combined, err = win32.CreateEnvironment(&env, accessToken) + } else { + // in current-user mode, we merge env with the *current* environment + parentEnv := os.Environ() + combined, err = win32.MergeEnvLists(&parentEnv, &env) + + } + if err != nil { + return nil, err } cmd := exec.Command(commandLine[0], commandLine[1:]...) - cmd.Env = env + cmd.Env = *combined cmd.Dir = workingDirectory isWindows8OrGreater := win32.IsWindows8OrGreater() creationFlags := uint32(win32.CREATE_NEW_PROCESS_GROUP | win32.CREATE_NEW_CONSOLE) diff --git a/taskcluster_proxy.go b/taskcluster_proxy.go index b4d6ac09..8510af25 100644 --- a/taskcluster_proxy.go +++ b/taskcluster_proxy.go @@ -53,15 +53,27 @@ func (l *TaskclusterProxyTask) RequiredScopes() scopes.Required { } func (l *TaskclusterProxyTask) Start() *CommandExecutionError { + // Set TASKCLUSTER_PROXY_URL in the task environment + err := l.task.setVariable("TASKCLUSTER_PROXY_URL", + fmt.Sprintf("http://localhost:%d", config.TaskclusterProxyPort)) + if err != nil { + return MalformedPayloadError(err) + } + + // include all scopes from task.scopes, as well as the scope to create artifacts on + // this task (which cannot be represented in task.scopes) + scopes := append(l.task.Definition.Scopes, + fmt.Sprintf("queue:create-artifact:%s/%d", l.task.TaskID, l.task.RunID)) taskclusterProxy, err := tcproxy.New( config.TaskclusterProxyExecutable, config.TaskclusterProxyPort, + config.RootURL, &tcclient.Credentials{ - AccessToken: l.task.TaskClaimResponse.Credentials.AccessToken, - Certificate: l.task.TaskClaimResponse.Credentials.Certificate, - ClientID: l.task.TaskClaimResponse.Credentials.ClientID, + AccessToken: l.task.TaskClaimResponse.Credentials.AccessToken, + Certificate: l.task.TaskClaimResponse.Credentials.Certificate, + ClientID: l.task.TaskClaimResponse.Credentials.ClientID, + AuthorizedScopes: scopes, }, - l.task.TaskID, ) if err != nil { return executionError(internalError, errored, fmt.Errorf("Could not start taskcluster proxy: %s", err)) diff --git a/taskcluster_proxy_test.go b/taskcluster_proxy_test.go index 71b533ab..840c4cd6 100644 --- a/taskcluster_proxy_test.go +++ b/taskcluster_proxy_test.go @@ -7,16 +7,24 @@ import ( ) func TestTaskclusterProxy(t *testing.T) { + if os.Getenv("TASKCLUSTER_CLIENT_ID") == "" || + os.Getenv("TASKCLUSTER_ACCESS_TOKEN") == "" || + os.Getenv("TASKCLUSTER_ROOT_URL") == "" { + t.Skip("Skipping test since TASKCLUSTER_{CLIENT_ID,ACCESS_TOKEN,ROOT_URL} env vars not set") + } + defer setup(t)() payload := GenericWorkerPayload{ Command: append( append( goEnv(), + // long enough to reclaim and get new credentials sleep(12)..., ), goRun( "curlget.go", - fmt.Sprintf("http://localhost:%v/queue/v1/task/KTBKfEgxR5GdfIIREQIvFQ/runs/0/artifacts/SampleArtifacts/_/X.txt", config.TaskclusterProxyPort), + // note that curlget.go supports substituting the proxy URL from its runtime environment + fmt.Sprintf("TASKCLUSTER_PROXY_URL/queue/v1/task/KTBKfEgxR5GdfIIREQIvFQ/runs/0/artifacts/SampleArtifacts/_/X.txt"), )..., ), MaxRunTime: 60, diff --git a/tcproxy/tcproxy.go b/tcproxy/tcproxy.go index 34568f84..8595f8e4 100644 --- a/tcproxy/tcproxy.go +++ b/tcproxy/tcproxy.go @@ -25,9 +25,10 @@ type TaskclusterProxy struct { // New starts a tcproxy OS process using the executable specified, and returns // a *TaskclusterProxy. -func New(taskclusterProxyExecutable string, httpPort uint16, creds *tcclient.Credentials, taskID string) (*TaskclusterProxy, error) { +func New(taskclusterProxyExecutable string, httpPort uint16, rootURL string, creds *tcclient.Credentials) (*TaskclusterProxy, error) { args := []string{ "--port", strconv.Itoa(int(httpPort)), + "--root-url", rootURL, "--client-id", creds.ClientID, "--access-token", creds.AccessToken, "--ip-address", "127.0.0.1", @@ -35,9 +36,6 @@ func New(taskclusterProxyExecutable string, httpPort uint16, creds *tcclient.Cre if creds.Certificate != "" { args = append(args, "--certificate", creds.Certificate) } - if taskID != "" { - args = append(args, "--task-id", taskID) - } for _, scope := range creds.AuthorizedScopes { args = append(args, scope) } diff --git a/tcproxy/tcproxy_test.go b/tcproxy/tcproxy_test.go index 58448a26..c359b0bf 100644 --- a/tcproxy/tcproxy_test.go +++ b/tcproxy/tcproxy_test.go @@ -10,7 +10,13 @@ import ( tcclient "github.com/taskcluster/taskcluster-client-go" ) -func TestTaskclusterProxy(t *testing.T) { +func TestTcProxy(t *testing.T) { + if os.Getenv("TASKCLUSTER_CLIENT_ID") == "" || + os.Getenv("TASKCLUSTER_ACCESS_TOKEN") == "" || + os.Getenv("TASKCLUSTER_ROOT_URL") == "" { + t.Skip("Skipping test since TASKCLUSTER_{CLIENT_ID,ACCESS_TOKEN,ROOT_URL} env vars not set") + } + var executable string switch runtime.GOOS { case "windows": @@ -24,7 +30,7 @@ func TestTaskclusterProxy(t *testing.T) { Certificate: os.Getenv("TASKCLUSTER_CERTIFICATE"), AuthorizedScopes: []string{"queue:get-artifact:SampleArtifacts/_/X.txt"}, } - ll, err := New(executable, 34569, creds, "") + ll, err := New(executable, 34569, os.Getenv("TASKCLUSTER_ROOT_URL"), creds) // Do defer before checking err since err could be a different error and // process may have already started up. defer func() { diff --git a/testdata/config/bool-as-string.json b/testdata/config/bool-as-string.json index 476e8270..a79692ee 100644 --- a/testdata/config/bool-as-string.json +++ b/testdata/config/bool-as-string.json @@ -2,6 +2,7 @@ "livelogSecret" : "this-is-a-secret", "clientId" : "test-client", "workerId" : "myworkerid", + "rootURL" : "https://tc-tests.example.com", "accessToken" : "V7w5mcc3Q3mQHp3ns0C7dA", "workerGroup" : "abcde", "workerType" : "some-worker-type", diff --git a/testdata/config/invalid-ip.json b/testdata/config/invalid-ip.json index 5db198db..0391a506 100644 --- a/testdata/config/invalid-ip.json +++ b/testdata/config/invalid-ip.json @@ -2,6 +2,7 @@ "livelogSecret" : "this-is-a-secret", "clientId" : "test-client", "workerId" : "myworkerid", + "rootURL" : "https://tc-tests.example.com", "accessToken" : "V7w5mcc3Q3mQHp3ns0C7dA", "workerGroup" : "abcde", "workerType" : "some-worker-type", diff --git a/testdata/config/invalid-json.json b/testdata/config/invalid-json.json index 0a009327..9f89ce6a 100644 --- a/testdata/config/invalid-json.json +++ b/testdata/config/invalid-json.json @@ -1,6 +1,7 @@ { "livelogSecret" : "this-is-a-secret", "clientId" : "test-client", + "rootURL" : "https://tc-tests.example.com", "workerId" : "THERE IS A MISSING COMMA AT THE END OF THIS LINE!!!" "accessToken" : "V0C7dA", "workerGroup" : "abcde", diff --git a/testdata/config/noip.json b/testdata/config/noip.json index 70e41e7c..03a26b84 100644 --- a/testdata/config/noip.json +++ b/testdata/config/noip.json @@ -2,6 +2,7 @@ "livelogSecret" : "this-is-a-secret", "clientId" : "test-client", "workerId" : "myworkerid", + "rootURL" : "https://tc-tests.example.com", "accessToken" : "V7w5mcc3Q3mQHp3ns0C7dA", "workerGroup" : "abcde", "workerType" : "some-worker-type" diff --git a/testdata/config/valid.json b/testdata/config/valid.json index 1e4fb6f9..030ba663 100644 --- a/testdata/config/valid.json +++ b/testdata/config/valid.json @@ -2,6 +2,7 @@ "livelogSecret" : "this-is-a-secret", "clientId" : "test-client", "workerId" : "myworkerid", + "rootURL" : "https://tc-tests.example.com", "accessToken" : "V7w5mcc3Q3mQHp3ns0C7dA", "workerGroup" : "abcde", "workerType" : "some-worker-type", diff --git a/testdata/config/worker-type-metadata.json b/testdata/config/worker-type-metadata.json index 3bc5d19f..03883a0d 100644 --- a/testdata/config/worker-type-metadata.json +++ b/testdata/config/worker-type-metadata.json @@ -2,6 +2,7 @@ "livelogSecret" : "this-is-a-secret", "clientId" : "test-client", "workerId" : "myworkerid", + "rootURL" : "https://tc-tests.example.com", "accessToken" : "V7w5mcc3Q3mQHp3ns0C7dA", "workerGroup" : "abcde", "workerType" : "some-worker-type", diff --git a/testdata/curlget.go b/testdata/curlget.go index 38f57919..abf4b366 100644 --- a/testdata/curlget.go +++ b/testdata/curlget.go @@ -5,13 +5,16 @@ import ( "log" "net/http" "os" + "strings" ) func main() { if len(os.Args) != 2 { - log.Fatal("Usage: go run curlget.go ") + log.Fatal("Usage: go run curlget.go \n will have the current $TASKCLUSTER_PROXY_URL substituted for the string TASKCLUSTER_PROXY_URL") } - res, err := http.Get(os.Args[1]) + url := os.Args[1] + url = strings.Replace(url, "TASKCLUSTER_PROXY_URL", os.Getenv("TASKCLUSTER_PROXY_URL"), -1) + res, err := http.Get(url) if err != nil { log.Fatalf("%v", err) } diff --git a/win32/merge.go b/win32/merge.go new file mode 100644 index 00000000..0e393cc2 --- /dev/null +++ b/win32/merge.go @@ -0,0 +1,68 @@ +package win32 + +import ( + "fmt" + "sort" + "strings" + "unicode/utf8" +) + +type envSetting struct { + name string + value string +} + +func MergeEnvLists(envLists ...*[]string) (*[]string, error) { + mergedEnvMap := map[string]envSetting{} + for _, envList := range envLists { + if envList == nil { + continue + } + for _, env := range *envList { + if utf8.RuneCountInString(env) > 32767 { + return nil, fmt.Errorf("Env setting is more than 32767 runes: %v", env) + } + spl := strings.SplitN(env, "=", 2) + if len(spl) != 2 { + return nil, fmt.Errorf("Could not interpret string %q as `key=value`", env) + } + newVarName := spl[0] + newVarValue := spl[1] + // if env var already exists, use case of existing name, to simulate behaviour of + // setting an existing env var with a different case + // e.g. + // set aVar=3 + // set AVAR=4 + // results in + // aVar=4 + canonicalVarName := strings.ToLower(newVarName) + if existingVarName := mergedEnvMap[canonicalVarName].name; existingVarName != "" { + newVarName = existingVarName + } + mergedEnvMap[canonicalVarName] = envSetting{ + name: newVarName, + value: newVarValue, + } + } + } + canonicalVarNames := make([]string, len(mergedEnvMap)) + i := 0 + for k := range mergedEnvMap { + canonicalVarNames[i] = k + i++ + } + // All strings in the environment block must be sorted alphabetically by + // name. The sort is case-insensitive, Unicode order, without regard to + // locale. + // + // See https://msdn.microsoft.com/en-us/library/windows/desktop/ms682009(v=vs.85).aspx + sort.Strings(canonicalVarNames) + // Finally piece back together into an environment block + mergedEnv := make([]string, len(mergedEnvMap)) + i = 0 + for _, canonicalVarName := range canonicalVarNames { + mergedEnv[i] = mergedEnvMap[canonicalVarName].name + "=" + mergedEnvMap[canonicalVarName].value + i++ + } + return &mergedEnv, nil +} diff --git a/win32/merge_test.go b/win32/merge_test.go new file mode 100644 index 00000000..8cfbb6f0 --- /dev/null +++ b/win32/merge_test.go @@ -0,0 +1,28 @@ +package win32_test + +import ( + "log" + "testing" + + "github.com/taskcluster/generic-worker/win32" +) + +func TestMergeNilListsFirstNil(t *testing.T) { + res, err := win32.MergeEnvLists(nil, &[]string{"FOO=bar"}) + if err != nil { + log.Fatalf("Hit error: %v", err) + } + if res == nil || len(*res) != 1 || (*res)[0] != "FOO=bar" { + t.Fatalf("Did not merge correctly; got %#v", res) + } +} + +func TestMergeNilListsSecondNil(t *testing.T) { + res, err := win32.MergeEnvLists(&[]string{"FOO=bar"}, nil) + if err != nil { + log.Fatalf("Hit error: %v", err) + } + if res == nil || len(*res) != 1 || (*res)[0] != "FOO=bar" { + t.Fatalf("Did not merge correctly; got %#v", res) + } +} diff --git a/win32/win32_windows.go b/win32/win32_windows.go index 83cab137..3fef6380 100644 --- a/win32/win32_windows.go +++ b/win32/win32_windows.go @@ -10,8 +10,6 @@ import ( "log" "os" "runtime" - "sort" - "strings" "syscall" "time" "unicode/utf8" @@ -293,66 +291,6 @@ func CreateEnvironment(env *[]string, hUser syscall.Token) (mergedEnv *[]string, return } -type envSetting struct { - name string - value string -} - -func MergeEnvLists(envLists ...*[]string) (*[]string, error) { - mergedEnvMap := map[string]envSetting{} - for _, envList := range envLists { - if envList == nil { - continue - } - for _, env := range *envList { - if utf8.RuneCountInString(env) > 32767 { - return nil, fmt.Errorf("Env setting is more than 32767 runes: %v", env) - } - spl := strings.SplitN(env, "=", 2) - if len(spl) != 2 { - return nil, fmt.Errorf("Could not interpret string %q as `key=value`", env) - } - newVarName := spl[0] - newVarValue := spl[1] - // if env var already exists, use case of existing name, to simulate behaviour of - // setting an existing env var with a different case - // e.g. - // set aVar=3 - // set AVAR=4 - // results in - // aVar=4 - canonicalVarName := strings.ToLower(newVarName) - if existingVarName := mergedEnvMap[canonicalVarName].name; existingVarName != "" { - newVarName = existingVarName - } - mergedEnvMap[canonicalVarName] = envSetting{ - name: newVarName, - value: newVarValue, - } - } - } - canonicalVarNames := make([]string, len(mergedEnvMap)) - i := 0 - for k := range mergedEnvMap { - canonicalVarNames[i] = k - i++ - } - // All strings in the environment block must be sorted alphabetically by - // name. The sort is case-insensitive, Unicode order, without regard to - // locale. - // - // See https://msdn.microsoft.com/en-us/library/windows/desktop/ms682009(v=vs.85).aspx - sort.Strings(canonicalVarNames) - // Finally piece back together into an environment block - mergedEnv := make([]string, len(mergedEnvMap)) - i = 0 - for _, canonicalVarName := range canonicalVarNames { - mergedEnv[i] = mergedEnvMap[canonicalVarName].name + "=" + mergedEnvMap[canonicalVarName].value - i++ - } - return &mergedEnv, nil -} - // https://msdn.microsoft.com/en-us/library/windows/desktop/bb762188(v=vs.85).aspx func SHGetKnownFolderPath(rfid *syscall.GUID, dwFlags uint32, hToken syscall.Token, pszPath *uintptr) (err error) { r0, _, _ := procSHGetKnownFolderPath.Call(