From df395147dd8059b2aca661ed9f263c5aaef830ef Mon Sep 17 00:00:00 2001 From: alsanmsft Date: Thu, 26 Feb 2026 18:09:32 +0000 Subject: [PATCH 1/5] tmp change --- go.mod | 2 +- go.sum | 2 ++ main/cmds.go | 22 ++++++++++++++++++++++ main/cmds_test.go | 5 ++++- main/main.go | 1 + misc/HandlerManifest.json | 1 + pkg/download/save_test.go | 1 + 7 files changed, 32 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 7475498..97fb948 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect golang.org/x/crypto v0.45.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.38.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 914bce2..8658986 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/main/cmds.go b/main/cmds.go index 5bd856f..04c64b8 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -115,6 +115,25 @@ func min(a, b int) int { func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { // parse the extension handler settings (not available prior to 'enable') cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) + + // Lourdes: tmp code, delete later. Try to read runtime policy file. if you're able to read it, put it in an out folder and log that you read it. YAY! + policyPath := filepath.Join(h.HandlerEnvironment.ConfigFolder, "waagent_runtime_policy.json") + ctx.Log("I am in ENABLE: ", policyPath) + // content, err := os.ReadFile(policyPath) + // if err != nil { + // ewc.Err = errors.Wrap(ewc.Err, "failed to read policy file") + // return "", ewc + // } + // fo, err := os.Create("lourdes_output.txt") + // defer fo.Close(); + // Write string to file + // _, err = fo.WriteString(string(content)) + // if err != nil { + // ewc.Err = errors.Wrap(ewc.Err, "lourdes: something wrong with repeating the policy file in your output") + // return "", ewc + // } + + if ewc != nil { ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration") return "", ewc @@ -216,6 +235,7 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextensi // runCmd runs the command (extracted from cfg) in the given dir (assumed to exist). func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.ErrorWithClarification) { ctx.Log("event", "executing command", "output", dir) + fmt.Println("lourdes debugging-- inside runCmd") var cmd string var scenario string var scenarioInfo string @@ -243,11 +263,13 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.E } scenario = fmt.Sprintf("protected-script;%s", scenarioInfo) } + fmt.Println("lourdes debugging-made it through parsing without a seg fault") begin := time.Now() ewc = ExecCmdInDir(cmd, dir) elapsed := time.Now().Sub(begin) isSuccess := ewc == nil + fmt.Println("lourdes debugging-made it thorugh command a seg fault") telemetry("scenario", scenario, isSuccess, elapsed) diff --git a/main/cmds_test.go b/main/cmds_test.go index 8437e6f..b05d7a7 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "testing" + "fmt" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" @@ -79,14 +80,16 @@ func Test_checkAndSaveSeqNum(t *testing.T) { } func Test_runCmd_success(t *testing.T) { + fmt.Println("Lourdes debugging inside of test_runcmd_success") dir, err := ioutil.TempDir("", "") require.Nil(t, err) defer os.RemoveAll(dir) + fmt.Println("lourdes before nil check") require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "date"}, }).Err, "command should run successfully") - + fmt.Println("lourdes-- i htink the above statement is casuing a null derefence (bad)") // check stdout stderr files _, err = os.Stat(filepath.Join(dir, "stdout")) require.Nil(t, err, "stdout should exist") diff --git a/main/main.go b/main/main.go index 9c5b839..59f966e 100644 --- a/main/main.go +++ b/main/main.go @@ -36,6 +36,7 @@ func main() { ctx := log.NewContext(log.NewSyncLogger(log.NewLogfmtLogger( os.Stdout))).With("time", log.DefaultTimestamp).With("version", VersionString()) + ctx.Log("BRO IM IN HERE IM IN MAIN") // parse command line arguments cmd := parseCmd(os.Args) ctx = ctx.With("operation", strings.ToLower(cmd.name)) diff --git a/misc/HandlerManifest.json b/misc/HandlerManifest.json index 237c77e..325f5cd 100644 --- a/misc/HandlerManifest.json +++ b/misc/HandlerManifest.json @@ -6,6 +6,7 @@ "updateCommand": "bin/custom-script-shim update", "enableCommand": "bin/custom-script-shim enable", "disableCommand": "bin/custom-script-shim disable", + "supportsPolicy": true, "rebootAfterInstall": false, "reportHeartbeat": false, "updateMode": "UpdateWithInstall" diff --git a/pkg/download/save_test.go b/pkg/download/save_test.go index 3b16de3..bcee1df 100644 --- a/pkg/download/save_test.go +++ b/pkg/download/save_test.go @@ -24,6 +24,7 @@ func TestSaveTo_invalidDir(t *testing.T) { } func TestSave(t *testing.T) { + fmt.Println("lourdes-debugging inside of test save") srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() From dae574dc358362834c154c22f2988a1ef2c3f7ef Mon Sep 17 00:00:00 2001 From: alsanmsft Date: Tue, 3 Mar 2026 22:50:06 +0000 Subject: [PATCH 2/5] merging --- go.mod | 2 +- go.sum | 2 ++ main/cmds.go | 46 ++++++++++++++++++++++++++++------------------ main/cmds_test.go | 19 +++++++++++++++---- 4 files changed, 46 insertions(+), 23 deletions(-) diff --git a/go.mod b/go.mod index 97fb948..3396983 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.0 require ( github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c - github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 + github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7 github.com/Azure/azure-sdk-for-go v63.2.0+incompatible github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1 github.com/go-kit/kit v0.12.0 diff --git a/go.sum b/go.sum index 8658986..b8f7946 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f h1: github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 h1:7bEC4DJC4w0gx7SBy7M7Q2qi6ckmHcnnlFJzo+X/gi4= github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7 h1:sckOZUC/OwKKF/bneOiik8bCf4wiOtEBpyN26IUCBMQ= +github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= diff --git a/main/cmds.go b/main/cmds.go index 04c64b8..150638b 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -12,6 +12,8 @@ import ( "strconv" "time" + "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" + "github.com/Azure/azure-extension-platform/pkg/logging" "github.com/Azure/azure-extension-platform/pkg/utils" vmextension "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" @@ -39,6 +41,7 @@ const ( fullName = "Microsoft.Azure.Extensions.CustomScript" maxTailLen = 4 * 1024 // length of max stdout/stderr to be transmitted in .status file maxTelemetryTailLen int = 1800 + policyFileName = "waagent_runtime_policy.json" ) var ( @@ -112,33 +115,40 @@ func min(a, b int) int { return b } +type CSEExtensionPolicySettings struct { + RequireSigning bool `json:"requireSigning"` + FileRootCertCA string `json:"fileRootCertCA,omitempty"` // optional field for customer that want to specify a root cert for script signature verification. This is a path to a cert file on the VM that the extension can use to verify script signatures. The customer is responsible for ensuring the cert is there and updated as needed (e.g. if the cert expires). The customer can choose to use this field or not based on their needs. + AllowedScripts []string `json:"allowedScripts"` +} + +func (cseps CSEExtensionPolicySettings) ValidateFormat() error { + if cseps.RequireSigning && len(cseps.FileRootCertCA) == 0 { + return errors.New("invalid policy settings: if RequireSigning is true, fileRootCertCA must be provided") + } + return nil +} + func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { // parse the extension handler settings (not available prior to 'enable') cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) - // Lourdes: tmp code, delete later. Try to read runtime policy file. if you're able to read it, put it in an out folder and log that you read it. YAY! - policyPath := filepath.Join(h.HandlerEnvironment.ConfigFolder, "waagent_runtime_policy.json") - ctx.Log("I am in ENABLE: ", policyPath) - // content, err := os.ReadFile(policyPath) - // if err != nil { - // ewc.Err = errors.Wrap(ewc.Err, "failed to read policy file") - // return "", ewc - // } - // fo, err := os.Create("lourdes_output.txt") - // defer fo.Close(); - // Write string to file - // _, err = fo.WriteString(string(content)) - // if err != nil { - // ewc.Err = errors.Wrap(ewc.Err, "lourdes: something wrong with repeating the policy file in your output") - // return "", ewc - // } - - if ewc != nil { ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration") return "", ewc } + // Lourdes: tmp code, delete later. Try to read runtime policy file. if you're able to read it, put it in an out folder and log that you read it. YAY! + policyPath := filepath.Join(h.HandlerEnvironment.ConfigFolder, policyFileName) + + ExtensionPolicyManagerPtr := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyPath, &logging.ExtensionLogger{}) + err := ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + if err != nil { + ctx.Log("message", "failed to load extension policy settings", "error", err) + } else { + ctx.Log("message", "successfully loaded extension policy settings", "settings", fmt.Sprintf("%+v", ExtensionPolicyManagerPtr.GetSettings())) + fmt.Println("lourdes debugging-- successfully loaded extension policy settings: \n" + fmt.Sprintf("%+v", ExtensionPolicyManagerPtr.GetSettings())) + } + dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum)) if ewc := downloadFiles(ctx, dir, cfg); ewc != nil { ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed") diff --git a/main/cmds_test.go b/main/cmds_test.go index b05d7a7..92a9910 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -6,7 +6,6 @@ import ( "os" "path/filepath" "testing" - "fmt" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" @@ -80,16 +79,28 @@ func Test_checkAndSaveSeqNum(t *testing.T) { } func Test_runCmd_success(t *testing.T) { - fmt.Println("Lourdes debugging inside of test_runcmd_success") dir, err := ioutil.TempDir("", "") require.Nil(t, err) defer os.RemoveAll(dir) - fmt.Println("lourdes before nil check") + require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ + publicSettings: publicSettings{CommandToExecute: "date"}, + }), "command should run successfully") + // check stdout stderr files + _, err = os.Stat(filepath.Join(dir, "stdout")) + require.Nil(t, err, "stdout should exist") + _, err = os.Stat(filepath.Join(dir, "stderr")) + require.Nil(t, err, "stderr should exist") +} + +func Test_runCmd_success_with_policy(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "date"}, }).Err, "command should run successfully") - fmt.Println("lourdes-- i htink the above statement is casuing a null derefence (bad)") // check stdout stderr files _, err = os.Stat(filepath.Join(dir, "stdout")) require.Nil(t, err, "stdout should exist") From 2db1160940ec5b746365a2fa5390cf89cfdf912b Mon Sep 17 00:00:00 2001 From: alsanmsft Date: Tue, 10 Mar 2026 23:38:09 +0000 Subject: [PATCH 3/5] added allowlist and load policy logic --- go.mod | 2 +- go.sum | 10 + main/cmds.go | 42 ++-- main/cmds_test.go | 306 +++++++++++++++++++++-- main/files.go | 21 +- main/files_test.go | 2 +- pkg/errorutil/errorclarificationcodes.go | 2 + 7 files changed, 353 insertions(+), 32 deletions(-) diff --git a/go.mod b/go.mod index 3396983..04564b7 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.0 require ( github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c - github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7 + github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797 github.com/Azure/azure-sdk-for-go v63.2.0+incompatible github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1 github.com/go-kit/kit v0.12.0 diff --git a/go.sum b/go.sum index 5427d4e..ed19870 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,16 @@ github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 h1: github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7 h1:sckOZUC/OwKKF/bneOiik8bCf4wiOtEBpyN26IUCBMQ= github.com/Azure/azure-extension-platform v0.0.0-20260303193429-96e5f13d68a7/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304003136-336d3e56fa20 h1:d5WpxK2xd2htsET+EQhZa/WjauDV0RLZef15x+JLShI= +github.com/Azure/azure-extension-platform v0.0.0-20260304003136-336d3e56fa20/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304003541-160ebdf80e13 h1:t+jiig+xE6PQjBQOASaMZfFeo8Ln5ClUCxB/SFSJ6Dw= +github.com/Azure/azure-extension-platform v0.0.0-20260304003541-160ebdf80e13/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304004038-0cf99cda38d7 h1:TUZKvOXi2MUEQhLYsTqSIyLqxQy2U08cP9VHSNCF2Bk= +github.com/Azure/azure-extension-platform v0.0.0-20260304004038-0cf99cda38d7/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260304193358-17aecbaff233 h1:XOj8YJeGOdLZiXdSrXbdYFi8dDJnUi7xXlg7CvD26EI= +github.com/Azure/azure-extension-platform v0.0.0-20260304193358-17aecbaff233/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797 h1:9YWYQi8rHmaMm1cdwVhS5FAL/ivPDzvOEyxsmHymc4w= +github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= diff --git a/main/cmds.go b/main/cmds.go index 150638b..7570d53 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -13,7 +13,6 @@ import ( "time" "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" - "github.com/Azure/azure-extension-platform/pkg/logging" "github.com/Azure/azure-extension-platform/pkg/utils" vmextension "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" @@ -137,20 +136,37 @@ func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmexte return "", ewc } - // Lourdes: tmp code, delete later. Try to read runtime policy file. if you're able to read it, put it in an out folder and log that you read it. YAY! + // TODO: Lourdes--what if it doesn't exist? + // If policy file exists, load the policy. + // If the policy is invalid, we log error and exit. + // If the policy file does not exist, proceed as normal. + var ExtensionPolicyManagerPtr *extensionpolicysettings.ExtensionPolicySettingsManager[CSEExtensionPolicySettings] policyPath := filepath.Join(h.HandlerEnvironment.ConfigFolder, policyFileName) - ExtensionPolicyManagerPtr := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyPath, &logging.ExtensionLogger{}) - err := ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() - if err != nil { - ctx.Log("message", "failed to load extension policy settings", "error", err) + if _, err := os.Stat(policyPath); err == nil { + ExtensionPolicyManagerPtr, err = extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyPath) + if err != nil { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "failed to create extension policy settings manager")) + } + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + if err != nil { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "failed to load extension policy settings")) + } else { + settings, err := ExtensionPolicyManagerPtr.GetSettings() + if err != nil { + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "failed to get extension policy settings")) + } + ctx.Log("message", "successfully loaded extension policy settings", "settings", fmt.Sprintf("%+v", settings)) + } + } else if os.IsNotExist(err) { + ctx.Log("message", "extension policy settings file does not exist, proceeding with default extension behavior.", "path", policyPath) + ExtensionPolicyManagerPtr = nil } else { - ctx.Log("message", "successfully loaded extension policy settings", "settings", fmt.Sprintf("%+v", ExtensionPolicyManagerPtr.GetSettings())) - fmt.Println("lourdes debugging-- successfully loaded extension policy settings: \n" + fmt.Sprintf("%+v", ExtensionPolicyManagerPtr.GetSettings())) + return "", vmextension.NewErrorWithClarificationPtr(errorutil.ExtensionPolicySettings_policyLoadFailed, errors.Wrap(err, "error while checking for extension policy settings file. Stat failed with an error other than file not existing")) } dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum)) - if ewc := downloadFiles(ctx, dir, cfg); ewc != nil { + if ewc := downloadFiles(ctx, dir, cfg, ExtensionPolicyManagerPtr); ewc != nil { ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed") return "", ewc } @@ -208,7 +224,8 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b // downloadFiles downloads the files specified in cfg into dir (creates if does // not exist) and takes storage credentials specified in cfg into account. -func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextension.ErrorWithClarification { +// If extension policy settings is provided, they are passed on to downloadAndProcessURL for file validation. +func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings, eps *extensionpolicysettings.ExtensionPolicySettingsManager[CSEExtensionPolicySettings]) *vmextension.ErrorWithClarification { // - prepare the output directory for files and the command output // - create the directory if missing ctx.Log("event", "creating output directory", "path", dir) @@ -233,7 +250,7 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextensi for i, f := range cfg.fileUrls() { ctx := ctx.With("file", i) ctx.Log("event", "download start") - if ewc := downloadAndProcessURL(ctx, f, dir, &cfg); ewc != nil { + if ewc := downloadAndProcessURL(ctx, f, dir, &cfg, eps); ewc != nil { ctx.Log("event", "download failed", "error", ewc.Err) return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download file[%d]", i)) } @@ -245,7 +262,6 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextensi // runCmd runs the command (extracted from cfg) in the given dir (assumed to exist). func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.ErrorWithClarification) { ctx.Log("event", "executing command", "output", dir) - fmt.Println("lourdes debugging-- inside runCmd") var cmd string var scenario string var scenarioInfo string @@ -273,13 +289,11 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.E } scenario = fmt.Sprintf("protected-script;%s", scenarioInfo) } - fmt.Println("lourdes debugging-made it through parsing without a seg fault") begin := time.Now() ewc = ExecCmdInDir(cmd, dir) elapsed := time.Now().Sub(begin) isSuccess := ewc == nil - fmt.Println("lourdes debugging-made it thorugh command a seg fault") telemetry("scenario", scenario, isSuccess, elapsed) diff --git a/main/cmds_test.go b/main/cmds_test.go index e5d5425..06498a8 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -1,12 +1,15 @@ package main import ( + "fmt" "io/ioutil" + "net/http" "net/http/httptest" "os" "path/filepath" "testing" + "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" "github.com/go-kit/kit/log" @@ -78,29 +81,51 @@ func Test_checkAndSaveSeqNum(t *testing.T) { require.True(t, shouldExit) } -func Test_runCmd_success(t *testing.T) { - dir, err := ioutil.TempDir("", "") - require.Nil(t, err) - defer os.RemoveAll(dir) +const policyTestDir = "./testdata" +const policyTestFile = "extensionPolicySettingsTestConfig.json" +const policyTestPath = policyTestDir + "/" + policyTestFile - require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ - publicSettings: publicSettings{CommandToExecute: "date"}, - }), "command should run successfully") +func Test_LoadExtensionPolicySettings_PolicyFileExistsValid(t *testing.T) { + // Set up test parameters + require.Nil(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + // maybe clean up policy test directory too? + require.Nil(t, loadTestPolicy("valid, basic", nil)) + + // Replicate the logic in cmd.go enable() + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + // Check that settings are loaded correctly + require.NoError(t, err) + settings, err := ExtensionPolicyManagerPtr.GetSettings() + require.NoError(t, err, "should be able to get extension policy settings") + require.NotNil(t, settings, "settings should not be nil") + require.Equal(t, false, settings.RequireSigning) + require.Empty(t, settings.AllowedScripts) +} + +func Test_LoadExtensionPolicySettings_PolicyFileMissing(t *testing.T) { + // Replicate the logic in cmd.go enable() + missingPolicyFilePath := filepath.Join(policyTestDir, "missingPolicyFile.json") + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](missingPolicyFilePath) + require.NoError(t, err, "should be able to create extension policy settings manager") + + _, err = os.Stat(missingPolicyFilePath) + require.True(t, os.IsNotExist(err), "policy file should not exist") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.Error(t, err, "should not be able to load extension policy settings") } -func Test_runCmd_success_with_policy(t *testing.T) { +func Test_runCmd_success(t *testing.T) { dir, err := ioutil.TempDir("", "") require.Nil(t, err) defer os.RemoveAll(dir) require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "date"}, - }).Err, "command should run successfully") - // check stdout stderr files - _, err = os.Stat(filepath.Join(dir, "stdout")) - require.Nil(t, err, "stdout should exist") - _, err = os.Stat(filepath.Join(dir, "stderr")) - require.Nil(t, err, "stderr should exist") + }), "command should run successfully") } func Test_runCmd_fail(t *testing.T) { @@ -133,11 +158,208 @@ func Test_downloadFiles(t *testing.T) { srv.URL + "/bytes/100", srv.URL + "/bytes/1000", }}, - }) + }, nil) require.Nil(t, ewc) // check the files f := []string{"10", "100", "1000"} + for _, fn := range f { + fp := filepath.Join(dir, fn) + _, err := os.Stat(fp) + data, err := os.ReadFile(fp) + fmt.Println("File Content:") + fmt.Println(string(data)) + + require.Nil(t, err, "%s is missing from download dir", fp) + } +} + +func Test_downloadFiles_goodAllowlist_SHA256(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + + file1Content := "echo hello" + file2Content := "echo world" + file3Content := "echo !" + + // Create a custom HTTP server with preset content + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // r.URL.Path is our key string to fetch file content below. + // Define preset file contents + files := map[string]string{ + "/file1": file1Content, + "/file2": file2Content, + "/file3": file3Content, + } + + if content, ok := files[r.URL.Path]; ok { // 'content' is value corresponding to the key r.URL.Path. 'ok' is true if content exists. + w.Header().Set("Content-Type", "application/octet-stream") // 'content' is going to be the response body from the test server. + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, content) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + // Compute SHA256 hashes (you can use the extensionpolicysettings.ComputeFileHash function) + file1Hash, _ := extensionpolicysettings.ComputeFileHash(file1Content, extensionpolicysettings.HashTypeSHA256) + file2Hash, _ := extensionpolicysettings.ComputeFileHash(file2Content, extensionpolicysettings.HashTypeSHA256) + file3Hash, _ := extensionpolicysettings.ComputeFileHash(file3Content, extensionpolicysettings.HashTypeSHA256) + + // Create a hash list. + al := []string{file1Hash, file2Hash, file3Hash} + + // Write the policy file (GA behavior) + require.Nil(t, loadTestPolicy("valid, allowlist present", al)) + defer cleanupPolicyFile(policyTestPath) + + // Load the policy into manager (enable() behavior) + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/file1", + srv.URL + "/file2", + srv.URL + "/file3", + }}, + }, ExtensionPolicyManagerPtr) + require.Nil(t, ewc) + + // check the files. All files should have passed. + f := []string{"file1", "file2", "file3"} + for _, fn := range f { + fp := filepath.Join(dir, fn) + _, err := os.Stat(fp) + require.Nil(t, err, "%s is missing from download dir", fp) + } +} + +func Test_downloadFiles_badAllowlist(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + + // Generate hashes of the preset content for the allowlist + file1Content := "echo hello" + file2Content := "echo world" + file3Content := "echo !" + file4Content := "echo bad" + file5Content := "echo also bad" + + // Create a custom HTTP server with preset content + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Define preset file contents + files := map[string]string{ + "/file1": file1Content, + "/file2": file2Content, + "/file3": file3Content, + "/file4": file4Content, + "/file5": file5Content, + } + + if content, ok := files[r.URL.Path]; ok { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, content) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + // Compute SHA256 hashes (you can use the extensionpolicysettings.ComputeFileHash function) + file1Hash, _ := extensionpolicysettings.ComputeFileHash(file1Content, extensionpolicysettings.HashTypeSHA256) + file2Hash, _ := extensionpolicysettings.ComputeFileHash(file2Content, extensionpolicysettings.HashTypeSHA256) + file3Hash, _ := extensionpolicysettings.ComputeFileHash(file3Content, extensionpolicysettings.HashTypeSHA256) + + // Create a hash list. + al := []string{file1Hash, file2Hash, file3Hash} + + // Write the policy file (GA behavior) + require.Nil(t, loadTestPolicy("valid, allowlist present", al)) + defer cleanupPolicyFile(policyTestPath) + + // Load the policy into manager (enable() behavior) + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/file1", + srv.URL + "/file2", + srv.URL + "/file3", + srv.URL + "/file4", // this file is not in the allowlist. Extension should exit gracefully. + }}, + }, ExtensionPolicyManagerPtr) + require.NotNil(t, ewc) + require.Contains(t, ewc.Err.Error(), "Validation of script 'file4' against policy-allowlist failed", "error should indicate that file4 failed validation") + + // Check the files. At most file1, file2, and file3 should be present. file4 and file5 should not both be present, + // but one of them will be present because they are validated after being downloaded. + // The extension should exit gracefully. If downloaded out of order, it's possible that not file1, file2, and file3 are all + // present because the extension will exit immediately after downloading the first file that is not in the allowlist. + + // The check below assumes the files are downloaded in order. + f := []string{"file1", "file2", "file3", "file4", "file5"} + for _, fn := range f { + fp := filepath.Join(dir, fn) + _, err := os.Stat(fp) + if fn == "file5" { + require.NotNil(t, err, "%s should not be downloaded because it's not in the allowlist", fp) + continue + } else { + require.Nil(t, err, "%s is missing from download dir", fp) + } + } +} + +func Test_downloadFiles_emptyAllowlist(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(dir) + + srv := httptest.NewServer(httpbin.GetMux()) + defer srv.Close() + + // Create an EMPTY list. + al := []string{} + + // Write the policy file (GA behavior) + require.Nil(t, loadTestPolicy("valid, allowlist present", al)) + defer cleanupPolicyFile(policyTestPath) + + // Load the policy into manager (enable() behavior) + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.NoError(t, err, "should be able to load extension policy settings") + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/bytes/10", + srv.URL + "/bytes/100", + srv.URL + "/bytes/1000", + }}, + }, ExtensionPolicyManagerPtr) + require.Nil(t, ewc) + + // Check the files. All files should have passed. + f := []string{"10", "100", "1000"} for _, fn := range f { fp := filepath.Join(dir, fn) _, err := os.Stat(fp) @@ -162,3 +384,57 @@ func Test_decodeScriptGzip(t *testing.T) { require.Equal(t, info, "32;3;gzip=1") require.Equal(t, s, "ls\n") } + +// Helper Methods +func writeToFile(filePath, content string) error { + err := os.WriteFile(filePath, []byte(content), 0644) + return err +} + +func cleanupPolicyFile(path string) { + // Do not remove missingPolicyFilePath as it simulates a missing file + if _, err := os.Stat(path); err == nil { + os.Remove(path) + } +} + +func setupPolicyDir(path string) error { + if _, err := os.Stat(path); os.IsNotExist(err) { + err = os.MkdirAll(path, 0750) + return err + } + return nil +} + +func loadTestPolicy(scenario string, list []string) error { + var validPolicyContent string + var allowlistStr string + if list != nil { + allowlistStr = "[" + for i, item := range list { + allowlistStr += `"` + item + `"` + if i < len(list)-1 { + allowlistStr += "," + } + } + allowlistStr += "]" + } + + switch scenario { + case "valid, basic": + validPolicyContent = `{ + "requireSigning": false, + "allowedScripts": [] + }` + case "valid, allowlist present": + // Convert list to JSON array string + + validPolicyContent = `{ + "requireSigning": false, + "allowedScripts": ` + allowlistStr + ` + }` + default: + validPolicyContent = `{}` + } + return writeToFile(filepath.Join(policyTestDir, policyTestFile), validPolicyContent) +} diff --git a/main/files.go b/main/files.go index dc7ce6e..6cd8f72 100644 --- a/main/files.go +++ b/main/files.go @@ -9,6 +9,7 @@ import ( "os" + "github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings" "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/blobutil" "github.com/Azure/custom-script-extension-linux/pkg/download" @@ -23,7 +24,8 @@ import ( // downloadAndProcessURL downloads using the specified downloader and saves it to the // specified existing directory, which must be the path to the saved file. Then // it post-processes file based on heuristics. -func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) *vmextension.ErrorWithClarification { +// If extension policy settings manager is provided, the downloaded file will be validated against the policy. +func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings, eps *extensionpolicysettings.ExtensionPolicySettingsManager[CSEExtensionPolicySettings]) *vmextension.ErrorWithClarification { fn, err := urlToFileName(url) if err != nil { return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, err) @@ -52,6 +54,23 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handl return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrapf(err, "failed to post-process '%s'", fn)) } + if eps != nil { + settings, err := eps.GetSettings() + if err != nil { + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, fmt.Errorf("failed to get extension policy settings: %w", err)) + } + if settings == nil { + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, fmt.Errorf("extension policy settings manager initialized, but settings not properly loaded.")) + } + + if len(settings.AllowedScripts) > 0 { + if err := extensionpolicysettings.ValidateFileHashInAllowlist(fp, settings.AllowedScripts, extensionpolicysettings.HashTypeSHA256); err != nil { + // TO DO: Consider whether to delete the blocked file. + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, fmt.Errorf("Validation of script '%s' against policy-allowlist failed: %w.", fn, err)) + } + } + } + return nil } diff --git a/main/files_test.go b/main/files_test.go index dd1bb05..32051cb 100644 --- a/main/files_test.go +++ b/main/files_test.go @@ -124,7 +124,7 @@ func Test_downloadAndProcessURL(t *testing.T) { defer os.RemoveAll(tmpDir) cfg := handlerSettings{publicSettings{}, protectedSettings{StorageAccountName: "", StorageAccountKey: ""}} - ewc := downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg) + ewc := downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg, nil) require.Nil(t, ewc) fp := filepath.Join(tmpDir, "256") diff --git a/pkg/errorutil/errorclarificationcodes.go b/pkg/errorutil/errorclarificationcodes.go index e823017..6461790 100644 --- a/pkg/errorutil/errorclarificationcodes.go +++ b/pkg/errorutil/errorclarificationcodes.go @@ -49,6 +49,8 @@ const ( Msi_doesNotHaveRightPermissions int = 71 Msi_GenericRetrievalError int = 72 + ExtensionPolicySettings_invalidPolicyFileFormat int = 80 + ExtensionPolicySettings_policyLoadFailed int = 81 // No Error - used as a placeholder value // when representing an "empty" ErrorWithClarification // or when the error can be treated without the clarification From 222915a58f6fb4d126873c81d5eb958d9f287a67 Mon Sep 17 00:00:00 2001 From: alsanmsft Date: Tue, 10 Mar 2026 23:45:51 +0000 Subject: [PATCH 4/5] del personal comments --- main/cmds.go | 1 - pkg/download/save_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/main/cmds.go b/main/cmds.go index 7570d53..72c0b98 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -136,7 +136,6 @@ func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmexte return "", ewc } - // TODO: Lourdes--what if it doesn't exist? // If policy file exists, load the policy. // If the policy is invalid, we log error and exit. // If the policy file does not exist, proceed as normal. diff --git a/pkg/download/save_test.go b/pkg/download/save_test.go index ea9db67..c81542c 100644 --- a/pkg/download/save_test.go +++ b/pkg/download/save_test.go @@ -24,7 +24,6 @@ func TestSaveTo_invalidDir(t *testing.T) { } func TestSave(t *testing.T) { - fmt.Println("lourdes-debugging inside of test save") srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() From d5b6189778b7c81779fabee58ab18be95780e83c Mon Sep 17 00:00:00 2001 From: alsanmsft Date: Thu, 12 Mar 2026 21:39:45 +0000 Subject: [PATCH 5/5] updated azure-extension-platform reference + added UTs --- go.mod | 2 +- go.sum | 2 + main/cmds_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 04564b7..369df37 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.0 require ( github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c - github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797 + github.com/Azure/azure-extension-platform v0.0.0-20260312212104-6da8a253549a github.com/Azure/azure-sdk-for-go v63.2.0+incompatible github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1 github.com/go-kit/kit v0.12.0 diff --git a/go.sum b/go.sum index ed19870..7443846 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/Azure/azure-extension-platform v0.0.0-20260304193358-17aecbaff233 h1: github.com/Azure/azure-extension-platform v0.0.0-20260304193358-17aecbaff233/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797 h1:9YWYQi8rHmaMm1cdwVhS5FAL/ivPDzvOEyxsmHymc4w= github.com/Azure/azure-extension-platform v0.0.0-20260305214320-4828fb38d797/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260312212104-6da8a253549a h1:VbxYR5y5uVJaNKQBzDnFBKoHqEjBmix+27/mPvOTqb4= +github.com/Azure/azure-extension-platform v0.0.0-20260312212104-6da8a253549a/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= diff --git a/main/cmds_test.go b/main/cmds_test.go index 06498a8..25d3adc 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -118,6 +118,40 @@ func Test_LoadExtensionPolicySettings_PolicyFileMissing(t *testing.T) { require.Error(t, err, "should not be able to load extension policy settings") } +func Test_LoadExtensionPolicySettings_InvalidJSON(t *testing.T) { + require.NoError(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + + invalidPolicyContent := `{ + "requireSigning": false, + "allowedScripts": [} + }` + require.NoError(t, writeToFile(policyTestPath, invalidPolicyContent)) + + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.Error(t, err, "invalid JSON policy should fail to load") // lourdes: be specific about the error message. +} + +func Test_LoadExtensionPolicySettings_InvalidFormat_RequireSigningWithoutRootCA(t *testing.T) { + require.NoError(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + + invalidPolicyContent := `{ + "requireSigning": true, + "allowedScripts": [] + }` + require.NoError(t, writeToFile(policyTestPath, invalidPolicyContent)) + + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err, "should be able to create extension policy settings manager") + + err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings() + require.Error(t, err, "policy requiring signing without fileRootCertCA should fail validation") +} + func Test_runCmd_success(t *testing.T) { dir, err := ioutil.TempDir("", "") require.Nil(t, err) @@ -174,6 +208,70 @@ func Test_downloadFiles(t *testing.T) { } } +func Test_downloadFiles_allowlistStopsOnFirstDisallowedFile(t *testing.T) { + dir, err := ioutil.TempDir("", "") + require.NoError(t, err) + defer os.RemoveAll(dir) + + require.NoError(t, setupPolicyDir(policyTestDir)) + defer cleanupPolicyFile(policyTestPath) + + good1Content := "echo good1" + good2Content := "echo good2" + badContent := "echo bad" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + files := map[string]string{ + "/bad": badContent, + "/good1": good1Content, + "/good2": good2Content, + } + if content, ok := files[r.URL.Path]; ok { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, content) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + good1Hash, _ := extensionpolicysettings.ComputeFileHash(good1Content, extensionpolicysettings.HashTypeSHA256) + good2Hash, _ := extensionpolicysettings.ComputeFileHash(good2Content, extensionpolicysettings.HashTypeSHA256) + require.NoError(t, loadTestPolicy("valid, allowlist present", []string{good1Hash, good2Hash})) + + ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[CSEExtensionPolicySettings](policyTestPath) + require.NoError(t, err) + require.NoError(t, ExtensionPolicyManagerPtr.LoadExtensionPolicySettings()) + + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), + dir, + handlerSettings{ + publicSettings: publicSettings{ + FileURLs: []string{ + srv.URL + "/bad", // first file is disallowed + srv.URL + "/good1", // should not be attempted + srv.URL + "/good2", // should not be attempted + }, + }, + }, + ExtensionPolicyManagerPtr, + ) + + require.NotNil(t, ewc, "download should fail on first disallowed file") + require.Contains(t, ewc.Err.Error(), "bad", "error should identify disallowed script") + + // "bad" is downloaded before validation, so it may exist. + _, err = os.Stat(filepath.Join(dir, "bad")) + require.NoError(t, err, "first (disallowed) file should have been downloaded before validation failure") + + _, err = os.Stat(filepath.Join(dir, "good1")) + require.True(t, os.IsNotExist(err), "good1 should not be downloaded after first failure") + + _, err = os.Stat(filepath.Join(dir, "good2")) + require.True(t, os.IsNotExist(err), "good2 should not be downloaded after first failure") +} + func Test_downloadFiles_goodAllowlist_SHA256(t *testing.T) { dir, err := ioutil.TempDir("", "") require.Nil(t, err)