diff --git a/go.mod b/go.mod index f762653..1617ce0 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.0 require ( github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c - github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 + github.com/Azure/azure-extension-platform v0.0.0-20260312212104-6da8a253549a github.com/Azure/azure-sdk-for-go v63.2.0+incompatible github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1 github.com/go-kit/kit v0.12.0 diff --git a/go.sum b/go.sum index 87eb274..a7e566e 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,20 @@ github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c h github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c/go.mod h1:sNC6lMTUkXwjrQ+nttr6GXhDfvSGT7t3UDq30BEYzu8= github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 h1:7bEC4DJC4w0gx7SBy7M7Q2qi6ckmHcnnlFJzo+X/gi4= github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7 h1:sckOZUC/OwKKF/bneOiik8bCf4wiOtEBpyN26IUCBMQ= +github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304003136-336d3e56fa20 h1:d5WpxK2xd2htsET+EQhZa/WjauDV0RLZef15x+JLShI= +github.com/Azure/azure-extension-platform v0.0.0-20260304003136-336d3e56fa20/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304003541-160ebdf80e13 h1:t+jiig+xE6PQjBQOASaMZfFeo8Ln5ClUCxB/SFSJ6Dw= +github.com/Azure/azure-extension-platform v0.0.0-20260304003541-160ebdf80e13/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304004038-0cf99cda38d7 h1:TUZKvOXi2MUEQhLYsTqSIyLqxQy2U08cP9VHSNCF2Bk= +github.com/Azure/azure-extension-platform v0.0.0-20260304004038-0cf99cda38d7/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304193358-17aecbaff233 h1:XOj8YJeGOdLZiXdSrXbdYFi8dDJnUi7xXlg7CvD26EI= +github.com/Azure/azure-extension-platform v0.0.0-20260304193358-17aecbaff233/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797 h1:9YWYQi8rHmaMm1cdwVhS5FAL/ivPDzvOEyxsmHymc4w= +github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260312212104-6da8a253549a h1:VbxYR5y5uVJaNKQBzDnFBKoHqEjBmix+27/mPvOTqb4= +github.com/Azure/azure-extension-platform v0.0.0-20260312212104-6da8a253549a/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= diff --git a/main/cmds.go b/main/cmds.go index b41b01b..72c0b98 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -12,6 +12,7 @@ import ( "strconv" "time" + "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" "github.com/Azure/azure-extension-platform/pkg/utils" vmextension "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" @@ -39,6 +40,7 @@ const ( fullName = "Microsoft.Azure.Extensions.CustomScript" maxTailLen = 4 * 1024 // length of max stdout/stderr to be transmitted in .status file maxTelemetryTailLen int = 1800 + policyFileName = "waagent_runtime_policy.json" ) var ( @@ -112,6 +114,19 @@ func min(a, b int) int { return b } +type CSEExtensionPolicySettings struct { + RequireSigning bool `json:"requireSigning"` + FileRootCertCA string `json:"fileRootCertCA,omitempty"` // optional field for customer that want to specify a root cert for script signature verification. This is a path to a cert file on the VM that the extension can use to verify script signatures. The customer is responsible for ensuring the cert is there and updated as needed (e.g. if the cert expires). The customer can choose to use this field or not based on their needs. + AllowedScripts []string `json:"allowedScripts"` +} + +func (cseps CSEExtensionPolicySettings) ValidateFormat() error { + if cseps.RequireSigning && len(cseps.FileRootCertCA) == 0 { + return errors.New("invalid policy settings: if RequireSigning is true, fileRootCertCA must be provided") + } + return nil +} + func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { // parse the extension handler settings (not available prior to 'enable') cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) @@ -121,8 +136,36 @@ func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmexte return "", ewc } + // If policy file exists, load the policy. + // If the policy is invalid, we log error and exit. + // If the policy file does not exist, proceed as normal. + var ExtensionPolicyManagerPtr *extensionpolicysettings.ExtensionPolicySettingsManager[CSEExtensionPolicySettings] + policyPath := filepath.Join(h.HandlerEnvironment.ConfigFolder, policyFileName) + + if _, err := os.Stat(policyPath); err == nil { + ExtensionPolicyManagerPtr, err = extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyPath) + if err != nil { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "failed to create extension policy settings manager")) + } + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + if err != nil { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "failed to load extension policy settings")) + } else { + settings, err := ExtensionPolicyManagerPtr.GetSettings() + if err != nil { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "failed to get extension policy settings")) + } + ctx.Log("message", "successfully loaded extension policy settings", "settings", fmt.Sprintf("%+v", settings)) + } + } else if os.IsNotExist(err) { + ctx.Log("message", "extension policy settings file does not exist, proceeding with default extension behavior.", "path", policyPath) + ExtensionPolicyManagerPtr = nil + } else { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "error while checking for extension policy settings file. Stat failed with an error other than file not existing")) + } + dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum)) - if ewc := downloadFiles(ctx, dir, cfg); ewc != nil { + if ewc := downloadFiles(ctx, dir, cfg, ExtensionPolicyManagerPtr); ewc != nil { ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed") return "", ewc } @@ -180,7 +223,8 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b // downloadFiles downloads the files specified in cfg into dir (creates if does // not exist) and takes storage credentials specified in cfg into account. -func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextension.ErrorWithClarification { +// If extension policy settings is provided, they are passed on to downloadAndProcessURL for file validation. +func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings, eps *extensionpolicysettings.ExtensionPolicySettingsManager[CSEExtensionPolicySettings]) *vmextension.ErrorWithClarification { // - prepare the output directory for files and the command output // - create the directory if missing ctx.Log("event", "creating output directory", "path", dir) @@ -205,7 +249,7 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextensi for i, f := range cfg.fileUrls() { ctx := ctx.With("file", i) ctx.Log("event", "download start") - if ewc := downloadAndProcessURL(ctx, f, dir, &cfg); ewc != nil { + if ewc := downloadAndProcessURL(ctx, f, dir, &cfg, eps); ewc != nil { ctx.Log("event", "download failed", "error", ewc.Err) return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download file[%d]", i)) } diff --git a/main/cmds_test.go b/main/cmds_test.go index 4c53e27..25d3adc 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -1,12 +1,15 @@ package main import ( + "fmt" "io/ioutil" + "net/http" "net/http/httptest" "os" "path/filepath" "testing" + "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" "github.com/go-kit/kit/log" @@ -78,6 +81,77 @@ func Test_checkAndSaveSeqNum(t *testing.T) { require.True(t, shouldExit) } +const policyTestDir = "./testdata" +const policyTestFile = "extensionPolicySettingsTestConfig.json" +const policyTestPath = policyTestDir + "/" + policyTestFile + +func Test_LoadExtensionPolicySettings_PolicyFileExistsValid(t *testing.T) { + // Set up test parameters + require.Nil(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + // maybe clean up policy test directory too? + require.Nil(t, loadTestPolicy("valid, basic", nil)) + + // Replicate the logic in cmd.go enable() + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + // Check that settings are loaded correctly + require.NoError(t, err) + settings, err := ExtensionPolicyManagerPtr.GetSettings() + require.NoError(t, err, "should be able to get extension policy settings") + require.NotNil(t, settings, "settings should not be nil") + require.Equal(t, false, settings.RequireSigning) + require.Empty(t, settings.AllowedScripts) +} + +func Test_LoadExtensionPolicySettings_PolicyFileMissing(t *testing.T) { + // Replicate the logic in cmd.go enable() + missingPolicyFilePath := filepath.Join(policyTestDir, "missingPolicyFile.json") + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](missingPolicyFilePath) + require.NoError(t, err, "should be able to create extension policy settings manager") + + _, err = os.Stat(missingPolicyFilePath) + require.True(t, os.IsNotExist(err), "policy file should not exist") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.Error(t, err, "should not be able to load extension policy settings") +} + +func Test_LoadExtensionPolicySettings_InvalidJSON(t *testing.T) { + require.NoError(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + + invalidPolicyContent := `{ + "requireSigning": false, + "allowedScripts": [} + }` + require.NoError(t, writeToFile(policyTestPath, invalidPolicyContent)) + + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.Error(t, err, "invalid JSON policy should fail to load") // lourdes: be specific about the error message. +} + +func Test_LoadExtensionPolicySettings_InvalidFormat_RequireSigningWithoutRootCA(t *testing.T) { + require.NoError(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + + invalidPolicyContent := `{ + "requireSigning": true, + "allowedScripts": [] + }` + require.NoError(t, writeToFile(policyTestPath, invalidPolicyContent)) + + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.Error(t, err, "policy requiring signing without fileRootCertCA should fail validation") +} + func Test_runCmd_success(t *testing.T) { dir, err := ioutil.TempDir("", "") require.Nil(t, err) @@ -86,12 +160,6 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "date"}, }), "command should run successfully") - - // check stdout stderr files - _, err = os.Stat(filepath.Join(dir, "stdout")) - require.Nil(t, err, "stdout should exist") - _, err = os.Stat(filepath.Join(dir, "stderr")) - require.Nil(t, err, "stderr should exist") } func Test_runCmd_fail(t *testing.T) { @@ -124,11 +192,272 @@ func Test_downloadFiles(t *testing.T) { srv.URL + "/bytes/100", srv.URL + "/bytes/1000", }}, - }) + }, nil) require.Nil(t, ewc) // check the files f := []string{"10", "100", "1000"} + for _, fn := range f { + fp := filepath.Join(dir, fn) + _, err := os.Stat(fp) + data, err := os.ReadFile(fp) + fmt.Println("File Content:") + fmt.Println(string(data)) + + require.Nil(t, err, "%s is missing from download dir", fp) + } +} + +func Test_downloadFiles_allowlistStopsOnFirstDisallowedFile(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(dir) + + require.NoError(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + + good1Content := "echo good1" + good2Content := "echo good2" + badContent := "echo bad" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + files := map[string]string{ + "/bad": badContent, + "/good1": good1Content, + "/good2": good2Content, + } + if content, ok := files[r.URL.Path]; ok { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, content) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + good1Hash, _ := extensionpolicysettings.ComputeFileHash(good1Content, extensionpolicysettings.HashTypeSHA256) + good2Hash, _ := extensionpolicysettings.ComputeFileHash(good2Content, extensionpolicysettings.HashTypeSHA256) + require.NoError(t, loadTestPolicy("valid, allowlist present", []string{good1Hash, good2Hash})) + + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err) + require.NoError(t, ExtensionPolicyManagerPtr.LoadExtensionPolicySettings()) + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/bad", // first file is disallowed + srv.URL + "/good1", // should not be attempted + srv.URL + "/good2", // should not be attempted + }, + }, + }, + ExtensionPolicyManagerPtr, + ) + + require.NotNil(t, ewc, "download should fail on first disallowed file") + require.Contains(t, ewc.Err.Error(), "bad", "error should identify disallowed script") + + // "bad" is downloaded before validation, so it may exist. + _, err = os.Stat(filepath.Join(dir, "bad")) + require.NoError(t, err, "first (disallowed) file should have been downloaded before validation failure") + + _, err = os.Stat(filepath.Join(dir, "good1")) + require.True(t, os.IsNotExist(err), "good1 should not be downloaded after first failure") + + _, err = os.Stat(filepath.Join(dir, "good2")) + require.True(t, os.IsNotExist(err), "good2 should not be downloaded after first failure") +} + +func Test_downloadFiles_goodAllowlist_SHA256(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + + file1Content := "echo hello" + file2Content := "echo world" + file3Content := "echo !" + + // Create a custom HTTP server with preset content + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // r.URL.Path is our key string to fetch file content below. + // Define preset file contents + files := map[string]string{ + "/file1": file1Content, + "/file2": file2Content, + "/file3": file3Content, + } + + if content, ok := files[r.URL.Path]; ok { // 'content' is value corresponding to the key r.URL.Path. 'ok' is true if content exists. + w.Header().Set("Content-Type", "application/octet-stream") // 'content' is going to be the response body from the test server. + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, content) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + // Compute SHA256 hashes (you can use the extensionpolicysettings.ComputeFileHash function) + file1Hash, _ := extensionpolicysettings.ComputeFileHash(file1Content, extensionpolicysettings.HashTypeSHA256) + file2Hash, _ := extensionpolicysettings.ComputeFileHash(file2Content, extensionpolicysettings.HashTypeSHA256) + file3Hash, _ := extensionpolicysettings.ComputeFileHash(file3Content, extensionpolicysettings.HashTypeSHA256) + + // Create a hash list. + al := []string{file1Hash, file2Hash, file3Hash} + + // Write the policy file (GA behavior) + require.Nil(t, loadTestPolicy("valid, allowlist present", al)) + defer cleanupPolicyFile(policyTestPath) + + // Load the policy into manager (enable() behavior) + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/file1", + srv.URL + "/file2", + srv.URL + "/file3", + }}, + }, ExtensionPolicyManagerPtr) + require.Nil(t, ewc) + + // check the files. All files should have passed. + f := []string{"file1", "file2", "file3"} + for _, fn := range f { + fp := filepath.Join(dir, fn) + _, err := os.Stat(fp) + require.Nil(t, err, "%s is missing from download dir", fp) + } +} + +func Test_downloadFiles_badAllowlist(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + + // Generate hashes of the preset content for the allowlist + file1Content := "echo hello" + file2Content := "echo world" + file3Content := "echo !" + file4Content := "echo bad" + file5Content := "echo also bad" + + // Create a custom HTTP server with preset content + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Define preset file contents + files := map[string]string{ + "/file1": file1Content, + "/file2": file2Content, + "/file3": file3Content, + "/file4": file4Content, + "/file5": file5Content, + } + + if content, ok := files[r.URL.Path]; ok { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, content) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + // Compute SHA256 hashes (you can use the extensionpolicysettings.ComputeFileHash function) + file1Hash, _ := extensionpolicysettings.ComputeFileHash(file1Content, extensionpolicysettings.HashTypeSHA256) + file2Hash, _ := extensionpolicysettings.ComputeFileHash(file2Content, extensionpolicysettings.HashTypeSHA256) + file3Hash, _ := extensionpolicysettings.ComputeFileHash(file3Content, extensionpolicysettings.HashTypeSHA256) + + // Create a hash list. + al := []string{file1Hash, file2Hash, file3Hash} + + // Write the policy file (GA behavior) + require.Nil(t, loadTestPolicy("valid, allowlist present", al)) + defer cleanupPolicyFile(policyTestPath) + + // Load the policy into manager (enable() behavior) + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/file1", + srv.URL + "/file2", + srv.URL + "/file3", + srv.URL + "/file4", // this file is not in the allowlist. Extension should exit gracefully. + }}, + }, ExtensionPolicyManagerPtr) + require.NotNil(t, ewc) + require.Contains(t, ewc.Err.Error(), "Validation of script 'file4' against policy-allowlist failed", "error should indicate that file4 failed validation") + + // Check the files. At most file1, file2, and file3 should be present. file4 and file5 should not both be present, + // but one of them will be present because they are validated after being downloaded. + // The extension should exit gracefully. If downloaded out of order, it's possible that not file1, file2, and file3 are all + // present because the extension will exit immediately after downloading the first file that is not in the allowlist. + + // The check below assumes the files are downloaded in order. + f := []string{"file1", "file2", "file3", "file4", "file5"} + for _, fn := range f { + fp := filepath.Join(dir, fn) + _, err := os.Stat(fp) + if fn == "file5" { + require.NotNil(t, err, "%s should not be downloaded because it's not in the allowlist", fp) + continue + } else { + require.Nil(t, err, "%s is missing from download dir", fp) + } + } +} + +func Test_downloadFiles_emptyAllowlist(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + + srv := httptest.NewServer(httpbin.GetMux()) + defer srv.Close() + + // Create an EMPTY list. + al := []string{} + + // Write the policy file (GA behavior) + require.Nil(t, loadTestPolicy("valid, allowlist present", al)) + defer cleanupPolicyFile(policyTestPath) + + // Load the policy into manager (enable() behavior) + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/bytes/10", + srv.URL + "/bytes/100", + srv.URL + "/bytes/1000", + }}, + }, ExtensionPolicyManagerPtr) + require.Nil(t, ewc) + + // Check the files. All files should have passed. + f := []string{"10", "100", "1000"} for _, fn := range f { fp := filepath.Join(dir, fn) _, err := os.Stat(fp) @@ -153,3 +482,57 @@ func Test_decodeScriptGzip(t *testing.T) { require.Equal(t, info, "32;3;gzip=1") require.Equal(t, s, "ls\n") } + +// Helper Methods +func writeToFile(filePath, content string) error { + err := os.WriteFile(filePath, []byte(content), 0644) + return err +} + +func cleanupPolicyFile(path string) { + // Do not remove missingPolicyFilePath as it simulates a missing file + if _, err := os.Stat(path); err == nil { + os.Remove(path) + } +} + +func setupPolicyDir(path string) error { + if _, err := os.Stat(path); os.IsNotExist(err) { + err = os.MkdirAll(path, 0750) + return err + } + return nil +} + +func loadTestPolicy(scenario string, list []string) error { + var validPolicyContent string + var allowlistStr string + if list != nil { + allowlistStr = "[" + for i, item := range list { + allowlistStr += `"` + item + `"` + if i < len(list)-1 { + allowlistStr += "," + } + } + allowlistStr += "]" + } + + switch scenario { + case "valid, basic": + validPolicyContent = `{ + "requireSigning": false, + "allowedScripts": [] + }` + case "valid, allowlist present": + // Convert list to JSON array string + + validPolicyContent = `{ + "requireSigning": false, + "allowedScripts": ` + allowlistStr + ` + }` + default: + validPolicyContent = `{}` + } + return writeToFile(filepath.Join(policyTestDir, policyTestFile), validPolicyContent) +} diff --git a/main/files.go b/main/files.go index dc7ce6e..6cd8f72 100644 --- a/main/files.go +++ b/main/files.go @@ -9,6 +9,7 @@ import ( "os" + "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/blobutil" "github.com/Azure/custom-script-extension-linux/pkg/download" @@ -23,7 +24,8 @@ import ( // downloadAndProcessURL downloads using the specified downloader and saves it to the // specified existing directory, which must be the path to the saved file. Then // it post-processes file based on heuristics. -func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) *vmextension.ErrorWithClarification { +// If extension policy settings manager is provided, the downloaded file will be validated against the policy. +func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings, eps *extensionpolicysettings.ExtensionPolicySettingsManager[CSEExtensionPolicySettings]) *vmextension.ErrorWithClarification { fn, err := urlToFileName(url) if err != nil { return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, err) @@ -52,6 +54,23 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handl return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrapf(err, "failed to post-process '%s'", fn)) } + if eps != nil { + settings, err := eps.GetSettings() + if err != nil { + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, fmt.Errorf("failed to get extension policy settings: %w", err)) + } + if settings == nil { + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, fmt.Errorf("extension policy settings manager initialized, but settings not properly loaded.")) + } + + if len(settings.AllowedScripts) > 0 { + if err := extensionpolicysettings.ValidateFileHashInAllowlist(fp, settings.AllowedScripts, extensionpolicysettings.HashTypeSHA256); err != nil { + // TO DO: Consider whether to delete the blocked file. + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, fmt.Errorf("Validation of script '%s' against policy-allowlist failed: %w.", fn, err)) + } + } + } + return nil } diff --git a/main/files_test.go b/main/files_test.go index dd1bb05..32051cb 100644 --- a/main/files_test.go +++ b/main/files_test.go @@ -124,7 +124,7 @@ func Test_downloadAndProcessURL(t *testing.T) { defer os.RemoveAll(tmpDir) cfg := handlerSettings{publicSettings{}, protectedSettings{StorageAccountName: "", StorageAccountKey: ""}} - ewc := downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg) + ewc := downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg, nil) require.Nil(t, ewc) fp := filepath.Join(tmpDir, "256") diff --git a/main/main.go b/main/main.go index 9c5b839..59f966e 100644 --- a/main/main.go +++ b/main/main.go @@ -36,6 +36,7 @@ func main() { ctx := log.NewContext(log.NewSyncLogger(log.NewLogfmtLogger( os.Stdout))).With("time", log.DefaultTimestamp).With("version", VersionString()) + ctx.Log("BRO IM IN HERE IM IN MAIN") // parse command line arguments cmd := parseCmd(os.Args) ctx = ctx.With("operation", strings.ToLower(cmd.name)) diff --git a/misc/HandlerManifest.json b/misc/HandlerManifest.json index 237c77e..325f5cd 100644 --- a/misc/HandlerManifest.json +++ b/misc/HandlerManifest.json @@ -6,6 +6,7 @@ "updateCommand": "bin/custom-script-shim update", "enableCommand": "bin/custom-script-shim enable", "disableCommand": "bin/custom-script-shim disable", + "supportsPolicy": true, "rebootAfterInstall": false, "reportHeartbeat": false, "updateMode": "UpdateWithInstall" diff --git a/pkg/errorutil/errorclarificationcodes.go b/pkg/errorutil/errorclarificationcodes.go index e823017..6461790 100644 --- a/pkg/errorutil/errorclarificationcodes.go +++ b/pkg/errorutil/errorclarificationcodes.go @@ -49,6 +49,8 @@ const ( Msi_doesNotHaveRightPermissions int = 71 Msi_GenericRetrievalError int = 72 + ExtensionPolicySettings_invalidPolicyFileFormat int = 80 + ExtensionPolicySettings_policyLoadFailed int = 81 // No Error - used as a placeholder value // when representing an "empty" ErrorWithClarification // or when the error can be treated without the clarification