diff --git a/go.mod b/go.mod index 6bfa287..7475498 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.0 require ( github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c - github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f + github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 github.com/Azure/azure-sdk-for-go v63.2.0+incompatible github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1 github.com/go-kit/kit v0.12.0 @@ -35,6 +35,7 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect golang.org/x/crypto v0.45.0 // indirect + golang.org/x/sys v0.15.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ae334f2..914bce2 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c h github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c/go.mod h1:sNC6lMTUkXwjrQ+nttr6GXhDfvSGT7t3UDq30BEYzu8= github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f h1:ddsUz/suc9txCMz/xWOslqNMvzhbWFMTflUrbcMNoSw= github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 h1:7bEC4DJC4w0gx7SBy7M7Q2qi6ckmHcnnlFJzo+X/gi4= +github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI= github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= @@ -93,6 +95,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/main/cmds.go b/main/cmds.go index 354f815..5bd856f 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -12,7 +12,9 @@ import ( "strconv" "time" - utils "github.com/Azure/azure-extension-platform/pkg/utils" + "github.com/Azure/azure-extension-platform/pkg/utils" + vmextension "github.com/Azure/azure-extension-platform/vmextension" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/Azure/custom-script-extension-linux/pkg/seqnum" "github.com/go-kit/kit/log" "github.com/pkg/errors" @@ -22,7 +24,7 @@ const ( maxScriptSize = 256 * 1024 ) -type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, err error) +type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, ewc *vmextension.ErrorWithClarification) type preFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error type cmd struct { @@ -55,14 +57,14 @@ var ( } ) -func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { +func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { ctx.Log("event", "noop") return "", nil } -func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { +func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { if err := os.MkdirAll(dataDir, 0755); err != nil { - return "", errors.Wrap(err, "failed to create data dir") + return "", vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrap(err, "failed to create data dir")) } // If the file mrseq does not exists it is for two possible reasons. @@ -77,12 +79,12 @@ func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) return "", nil } -func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { +func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { { // a new context scope with path ctx = ctx.With("path", dataDir) ctx.Log("event", "removing data dir", "path", dataDir) if err := os.RemoveAll(dataDir); err != nil { - return "", errors.Wrap(err, "failed to delete data directory") + return "", vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToDeleteDataDir, errors.Wrap(err, "failed to delete data directory")) } ctx.Log("event", "removed data dir") } @@ -110,16 +112,18 @@ func min(a, b int) int { return b } -func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) { +func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) { // parse the extension handler settings (not available prior to 'enable') - cfg, err := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) - if err != nil { - return "", errors.Wrap(err, "failed to get configuration") + cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum) + if ewc != nil { + ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration") + return "", ewc } dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum)) - if err := downloadFiles(ctx, dir, cfg); err != nil { - return "", errors.Wrap(err, "processing file downloads failed") + if ewc := downloadFiles(ctx, dir, cfg); ewc != nil { + ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed") + return "", ewc } // execute the command, save its error @@ -175,12 +179,12 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b // downloadFiles downloads the files specified in cfg into dir (creates if does // not exist) and takes storage credentials specified in cfg into account. -func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error { +func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextension.ErrorWithClarification { // - prepare the output directory for files and the command output // - create the directory if missing ctx.Log("event", "creating output directory", "path", dir) if err := os.MkdirAll(dir, 0700); err != nil { - return errors.Wrap(err, "failed to prepare output directory") + return vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unableToCreateDownloadDirectory, errors.Wrap(err, "failed to prepare output directory")) } ctx.Log("event", "created output directory") @@ -200,9 +204,9 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error { for i, f := range cfg.fileUrls() { ctx := ctx.With("file", i) ctx.Log("event", "download start") - if err := downloadAndProcessURL(ctx, f, dir, &cfg); err != nil { - ctx.Log("event", "download failed", "error", err) - return errors.Wrapf(err, "failed to download file[%d]", i) + if ewc := downloadAndProcessURL(ctx, f, dir, &cfg); ewc != nil { + ctx.Log("event", "download failed", "error", ewc.Err) + return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download file[%d]", i)) } ctx.Log("event", "download complete", "output", dir) } @@ -210,11 +214,12 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error { } // runCmd runs the command (extracted from cfg) in the given dir (assumed to exist). -func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (err error) { +func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.ErrorWithClarification) { ctx.Log("event", "executing command", "output", dir) var cmd string var scenario string var scenarioInfo string + var err error // So many ways to execute a command! if cfg.publicSettings.CommandToExecute != "" { @@ -228,27 +233,27 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (err error) { } else if cfg.publicSettings.Script != "" { ctx.Log("event", "executing public script", "output", dir) if cmd, scenarioInfo, err = writeTempScript(cfg.publicSettings.Script, dir, cfg.publicSettings.SkipDos2Unix); err != nil { - return + return nil } scenario = fmt.Sprintf("public-script;%s", scenarioInfo) } else if cfg.protectedSettings.Script != "" { ctx.Log("event", "executing protected script", "output", dir) if cmd, scenarioInfo, err = writeTempScript(cfg.protectedSettings.Script, dir, cfg.publicSettings.SkipDos2Unix); err != nil { - return + return nil } scenario = fmt.Sprintf("protected-script;%s", scenarioInfo) } begin := time.Now() - err = ExecCmdInDir(cmd, dir) + ewc = ExecCmdInDir(cmd, dir) elapsed := time.Now().Sub(begin) - isSuccess := err == nil + isSuccess := ewc == nil telemetry("scenario", scenario, isSuccess, elapsed) - if err != nil { + if ewc != nil { ctx.Log("event", "failed to execute command", "error", err, "output", dir) - return errors.Wrap(err, "failed to execute command") + return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrap(ewc.Err, "failed to execute command")) } ctx.Log("event", "executed command", "output", dir) return nil diff --git a/main/cmds_test.go b/main/cmds_test.go index d6294df..8437e6f 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "testing" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" "github.com/go-kit/kit/log" "github.com/stretchr/testify/require" @@ -84,7 +85,7 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "date"}, - }), "command should run successfully") + }).Err, "command should run successfully") // check stdout stderr files _, err = os.Stat(filepath.Join(dir, "stdout")) @@ -98,11 +99,12 @@ func Test_runCmd_fail(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) - err = runCmd(log.NewNopLogger(), dir, handlerSettings{ + ewc := runCmd(log.NewNopLogger(), dir, handlerSettings{ publicSettings: publicSettings{CommandToExecute: "non-existing-cmd"}, }) - require.NotNil(t, err, "command terminated with exit status") - require.Contains(t, err.Error(), "failed to execute command") + require.Equal(t, errorutil.CommandExecution_failureExitCode, ewc.ErrorCode) + require.NotNil(t, ewc.Err, "command terminated with exit status") + require.Contains(t, ewc.Err.Error(), "failed to execute command") } func Test_downloadFiles(t *testing.T) { @@ -113,7 +115,7 @@ func Test_downloadFiles(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - err = downloadFiles(log.NewContext(log.NewNopLogger()), + ewc := downloadFiles(log.NewContext(log.NewNopLogger()), dir, handlerSettings{ publicSettings: publicSettings{ @@ -123,7 +125,7 @@ func Test_downloadFiles(t *testing.T) { srv.URL + "/bytes/1000", }}, }) - require.Nil(t, err) + require.Nil(t, ewc) // check the files f := []string{"10", "100", "1000"} diff --git a/main/exec.go b/main/exec.go index ed367ca..d9328df 100644 --- a/main/exec.go +++ b/main/exec.go @@ -8,6 +8,8 @@ import ( "path/filepath" "syscall" + vmextension "github.com/Azure/azure-extension-platform/vmextension" + errorutil "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/pkg/errors" ) @@ -16,7 +18,7 @@ import ( // // On error, an exit code may be returned if it is an exit code error. // Given stdout and stderr will be closed upon returning. -func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) { +func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, *vmextension.ErrorWithClarification) { defer stdout.Close() defer stderr.Close() @@ -30,10 +32,14 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) { if ok { if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { code := status.ExitStatus() - return code, fmt.Errorf("command terminated with exit status=%d", code) + return code, vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failureExitCode, fmt.Errorf("command terminated with exit status=%d", code)) } } - return 0, errors.Wrapf(err, "failed to execute command") + if err == nil { + return 0, nil + } + + return 0, vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failedUnknownError, errors.Wrapf(err, "failed to execute command")) } // ExecCmdInDir executes the given command in given directory and saves output @@ -42,20 +48,20 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) { // // Ideally, we execute commands only once per sequence number in custom-script-extension, // and save their output under /var/lib/waagent//download//*. -func ExecCmdInDir(cmd, workdir string) error { +func ExecCmdInDir(cmd, workdir string) *vmextension.ErrorWithClarification { outFn, errFn := logPaths(workdir) outF, err := os.OpenFile(outFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { - return errors.Wrapf(err, "failed to open stdout file") + return vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToOpenStdOut, errors.Wrapf(err, "failed to open stdout file")) } errF, err := os.OpenFile(errFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { - return errors.Wrapf(err, "failed to open stderr file") + return vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToOpenStdErr, errors.Wrapf(err, "failed to open stderr file")) } - _, err = Exec(cmd, workdir, outF, errF) - return err + _, ewc := Exec(cmd, workdir, outF, errF) + return ewc } // logPaths returns stdout and stderr file paths for the specified output diff --git a/main/exec_test.go b/main/exec_test.go index e01b34c..5eabc88 100644 --- a/main/exec_test.go +++ b/main/exec_test.go @@ -8,13 +8,14 @@ import ( "path/filepath" "testing" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/stretchr/testify/require" ) func TestExec_success(t *testing.T) { v := new(mockFile) ec, err := Exec("date", "/", v, v) - require.Nil(t, err, "err: %v -- out: %s", err, v.b.Bytes()) + require.Nil(t, err.Err, "err: %v -- out: %s", err.Err, v.b.Bytes()) require.EqualValues(t, 0, ec) } @@ -24,7 +25,7 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) { require.False(t, e.closed, "stderr open") _, err := Exec("/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2", "/", o, e) - require.Nil(t, err, "err: %v -- stderr: %s", err, e.b.Bytes()) + require.Nil(t, err, "err: %v -- stderr: %s", err.Err, e.b.Bytes()) require.Equal(t, "I am stdout!\n", string(o.b.Bytes())) require.Equal(t, "I am stderr!\n", string(e.b.Bytes())) require.True(t, o.closed, "stdout closed") @@ -33,15 +34,17 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) { func TestExec_failure_exitError(t *testing.T) { ec, err := Exec("exit 12", "/", new(mockFile), new(mockFile)) - require.NotNil(t, err) - require.EqualError(t, err, "command terminated with exit status=12") // error is customized + require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode) + require.NotNil(t, err.Err) + require.EqualError(t, err.Err, "command terminated with exit status=12") // error is customized require.EqualValues(t, 12, ec) } func TestExec_failure_genericError(t *testing.T) { _, err := Exec("date", "/non-existing-path", new(mockFile), new(mockFile)) - require.NotNil(t, err) - require.Contains(t, err.Error(), "failed to execute command:") // error is wrapped + require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError) + require.NotNil(t, err.Err) + require.Contains(t, err.Err.Error(), "failed to execute command:") // error is wrapped } func TestExec_failure_fdClosed(t *testing.T) { @@ -49,8 +52,9 @@ func TestExec_failure_fdClosed(t *testing.T) { require.Nil(t, out.Close()) _, err := Exec("date", "/", out, out) - require.NotNil(t, err) - require.Contains(t, err.Error(), "file closed") // error is wrapped + require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError) + require.NotNil(t, err.Err) + require.Contains(t, err.Err.Error(), "file closed") // error is wrapped } func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) { @@ -59,7 +63,8 @@ func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) { require.False(t, e.closed, "stderr open") _, err := Exec(`/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2; exit 12`, "/", o, e) - require.NotNil(t, err) + require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode) + require.NotNil(t, err.Err) require.Equal(t, "I am stdout!\n", string(o.b.Bytes())) require.Equal(t, "I am stderr!\n", string(e.b.Bytes())) require.True(t, o.closed, "stdout closed") @@ -71,8 +76,8 @@ func TestExecCmdInDir(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) - err = ExecCmdInDir("/bin/echo 'Hello world'", dir) - require.Nil(t, err) + ewc := ExecCmdInDir("/bin/echo 'Hello world'", dir) + require.Nil(t, ewc) require.True(t, fileExists(t, filepath.Join(dir, "stdout")), "stdout file should be created") require.True(t, fileExists(t, filepath.Join(dir, "stderr")), "stderr file should be created") @@ -87,7 +92,9 @@ func TestExecCmdInDir(t *testing.T) { func TestExecCmdInDir_cantOpenError(t *testing.T) { err := ExecCmdInDir("/bin/echo 'Hello world'", "/non-existing-dir") - require.Contains(t, err.Error(), "failed to open stdout file") + require.Equal(t, err.ErrorCode, errorutil.Os_FailedToOpenStdErr) + require.NotNil(t, err.Err) + require.Contains(t, err.Err.Error(), "failed to open stdout file") } func TestExecCmdInDir_truncates(t *testing.T) { diff --git a/main/files.go b/main/files.go index 8e57090..dc7ce6e 100644 --- a/main/files.go +++ b/main/files.go @@ -9,48 +9,56 @@ import ( "os" + "github.com/Azure/azure-extension-platform/vmextension" "github.com/Azure/custom-script-extension-linux/pkg/blobutil" "github.com/Azure/custom-script-extension-linux/pkg/download" "github.com/Azure/custom-script-extension-linux/pkg/preprocess" "github.com/Azure/custom-script-extension-linux/pkg/urlutil" "github.com/go-kit/kit/log" "github.com/pkg/errors" + + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" ) // downloadAndProcessURL downloads using the specified downloader and saves it to the // specified existing directory, which must be the path to the saved file. Then // it post-processes file based on heuristics. -func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) error { +func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) *vmextension.ErrorWithClarification { fn, err := urlToFileName(url) if err != nil { - return err + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, err) } if !urlutil.IsValidUrl(url) { - return fmt.Errorf("[REDACTED] is not a valid url") + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, fmt.Errorf("[REDACTED] is not a valid url")) } - dl, err := getDownloaders(url, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.ManagedIdentity) - if err != nil { - return err + dl, ewc := getDownloaders(url, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.ManagedIdentity) + if ewc != nil { + return ewc } fp := filepath.Join(downloadDir, fn) const mode = 0500 // we assume users download scripts to execute - if _, err := download.SaveTo(ctx, dl, fp, mode); err != nil { - return err + if _, ewc := download.SaveTo(ctx, dl, fp, mode); ewc != nil { + return ewc } if cfg.SkipDos2Unix == false { err = postProcessFile(fp) } - return errors.Wrapf(err, "failed to post-process '%s'", fn) + + if err != nil { + return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrapf(err, "failed to post-process '%s'", fn)) + } + + return nil } // getDownloader returns a downloader for the given URL based on whether the // storage credentials are empty or not. func getDownloaders(fileURL string, storageAccountName, storageAccountKey string, managedIdentity *clientOrObjectId) ( - []download.Downloader, error) { + []download.Downloader, *vmextension.ErrorWithClarification) { if storageAccountName == "" || storageAccountKey == "" { // storage account name and key cannot be specified with managed identity, handler settings validation won't allow that // handler settings validation will also not allow storageAccountName XOR storageAccountKey == 1 @@ -67,7 +75,7 @@ func getDownloaders(fileURL string, storageAccountName, storageAccountKey string case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "": msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId) default: - return nil, fmt.Errorf("unexpected combination of ClientId and ObjectId found") + return nil, vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_clientIdObjectIdBothSpecified, fmt.Errorf("unexpected combination of ClientId and ObjectId found")) } return []download.Downloader{ // try downloading without MSI token first, but attempt with MSI if the download fails @@ -83,11 +91,11 @@ func getDownloaders(fileURL string, storageAccountName, storageAccountKey string // this preserves old behavior blob, err := blobutil.ParseBlobURL(fileURL) if err != nil { - return nil, err + return nil, vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, err) } return []download.Downloader{download.NewBlobDownload( - storageAccountName, storageAccountKey, - blob)}, nil + storageAccountName, storageAccountKey, blob)}, + nil } } diff --git a/main/files_test.go b/main/files_test.go index 72e60a6..dd1bb05 100644 --- a/main/files_test.go +++ b/main/files_test.go @@ -124,8 +124,8 @@ func Test_downloadAndProcessURL(t *testing.T) { defer os.RemoveAll(tmpDir) cfg := handlerSettings{publicSettings{}, protectedSettings{StorageAccountName: "", StorageAccountKey: ""}} - err = downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg) - require.Nil(t, err) + ewc := downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg) + require.Nil(t, ewc) fp := filepath.Join(tmpDir, "256") fi, err := os.Stat(fp) diff --git a/main/handlersettings.go b/main/handlersettings.go index a15dee5..430ee9d 100644 --- a/main/handlersettings.go +++ b/main/handlersettings.go @@ -5,6 +5,8 @@ import ( "fmt" "path/filepath" + "github.com/Azure/azure-extension-platform/vmextension" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/go-kit/kit/log" "github.com/pkg/errors" ) @@ -13,6 +15,7 @@ var ( errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified") errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once") errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once") + errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once") errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time") errCmdMissing = errors.New("'commandToExecute' is not specified") errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'") @@ -48,34 +51,38 @@ func (s *handlerSettings) fileUrls() []string { // validate makes logical validation on the handlerSettings which already passed // the schema validation. -func (h handlerSettings) validate() error { +func (h handlerSettings) validate() *vmextension.ErrorWithClarification { if h.commandToExecute() == "" && h.script() == "" { - return errCmdMissing + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_commandToExecuteAndScriptNotSpecified, errCmdMissing) } if h.publicSettings.CommandToExecute != "" && h.protectedSettings.CommandToExecute != "" { - return errCmdTooMany + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_commandToExecuteSpecifiedInTwoPlaces, errCmdTooMany) } if h.publicSettings.Script != "" && h.protectedSettings.Script != "" { - return errScriptTooMany + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_scriptSpecifiedInTwoPlaces, errScriptTooMany) + } + + if (h.publicSettings.FileURLs != nil && len(h.publicSettings.FileURLs) > 0) && (h.protectedSettings.FileURLs != nil && len(h.protectedSettings.FileURLs) > 0) { + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_fileUrisSpecifiedInTwoPlaces, errFileUrisTooMany) } if h.commandToExecute() != "" && h.script() != "" { - return errCmdAndScript + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_commandToExecuteAndScriptBothSpecified, errCmdAndScript) } if (h.protectedSettings.StorageAccountName != "") != (h.protectedSettings.StorageAccountKey != "") { - return errStoragePartialCredentials + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_incompleteStorageCreds, errStoragePartialCredentials) } if (h.protectedSettings.StorageAccountKey != "" || h.protectedSettings.StorageAccountName != "") && h.protectedSettings.ManagedIdentity != nil { - return errUsingBothKeyAndMsi + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_storageCredsAndMIBothSpecified, errUsingBothKeyAndMsi) } if h.protectedSettings.ManagedIdentity != nil { if h.protectedSettings.ManagedIdentity.ClientId != "" && h.protectedSettings.ManagedIdentity.ObjectId != "" { - return errUsingBothClientIdAndObjectId + return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_clientIdObjectIdBothSpecified, errUsingBothClientIdAndObjectId) } } @@ -113,29 +120,30 @@ func (self *clientOrObjectId) isEmpty() bool { // parseAndValidateSettings reads configuration from configFolder, decrypts it, // runs JSON-schema and logical validation on it and returns it back. -func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ error) { +func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ *vmextension.ErrorWithClarification) { ctx.Log("event", "reading configuration") pubJSON, protJSON, err := readSettings(configFolder, seqNum) if err != nil { - return h, err + return h, vmextension.NewErrorWithClarificationPtr(errorutil.Internal_badConfig, err) } ctx.Log("event", "read configuration") ctx.Log("event", "validating json schema") if err := validateSettingsSchema(pubJSON, protJSON); err != nil { - return h, errors.Wrap(err, "json validation error") + return h, vmextension.NewErrorWithClarificationPtr(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error")) } ctx.Log("event", "json schema valid") ctx.Log("event", "parsing configuration json") if err := UnmarshalHandlerSettings(pubJSON, protJSON, &h.publicSettings, &h.protectedSettings); err != nil { - return h, errors.Wrap(err, "json parsing error") + return h, vmextension.NewErrorWithClarificationPtr(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error")) } ctx.Log("event", "parsed configuration json") ctx.Log("event", "validating configuration logically") - if err := h.validate(); err != nil { - return h, errors.Wrap(err, "invalid configuration") + if ewc := h.validate(); err != nil { + ewc.Err = errors.Wrap(ewc.Err, "invalid configuration") + return h, ewc } ctx.Log("event", "validated configuration") return h, nil diff --git a/main/handlersettings_test.go b/main/handlersettings_test.go index 3addd72..53f149a 100644 --- a/main/handlersettings_test.go +++ b/main/handlersettings_test.go @@ -3,38 +3,39 @@ package main import ( "encoding/json" "testing" + + "github.com/stretchr/testify/require" ) -import "github.com/stretchr/testify/require" func Test_handlerSettingsValidate(t *testing.T) { // commandToExecute not specified require.Equal(t, errCmdMissing, handlerSettings{ publicSettings{}, protectedSettings{}, - }.validate()) + }.validate().Err) // commandToExecute specified twice require.Equal(t, errCmdTooMany, handlerSettings{ publicSettings{CommandToExecute: "foo"}, protectedSettings{CommandToExecute: "foo"}, - }.validate()) + }.validate().Err) // script specified twice require.Equal(t, errScriptTooMany, handlerSettings{ publicSettings{Script: "foo"}, protectedSettings{Script: "foo"}, - }.validate()) + }.validate().Err) // commandToExecute and script both specified require.Equal(t, errCmdAndScript, handlerSettings{ publicSettings{CommandToExecute: "foo"}, protectedSettings{Script: "foo"}, - }.validate()) + }.validate().Err) require.Equal(t, errCmdAndScript, handlerSettings{ publicSettings{Script: "foo"}, protectedSettings{CommandToExecute: "foo"}, - }.validate()) + }.validate().Err) // storageAccount name specified; but not key require.Equal(t, errStoragePartialCredentials, handlerSettings{ @@ -42,7 +43,7 @@ func Test_handlerSettingsValidate(t *testing.T) { CommandToExecute: "date", StorageAccountName: "foo", StorageAccountKey: ""}, - }.validate()) + }.validate().Err) // storageAccount key specified; but not name require.Equal(t, errStoragePartialCredentials, handlerSettings{ @@ -50,7 +51,7 @@ func Test_handlerSettingsValidate(t *testing.T) { CommandToExecute: "date", StorageAccountName: "", StorageAccountKey: "foo"}, - }.validate()) + }.validate().Err) } func Test_commandToExecutePrivateIfNotPublic(t *testing.T) { @@ -90,20 +91,22 @@ func Test_skipDos2UnixDefaultsToFalse(t *testing.T) { } func Test_managedIdentityVerification(t *testing.T) { - require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{ + err := handlerSettings{publicSettings{}, protectedSettings{ CommandToExecute: "echo hi", FileURLs: []string{"file1", "file2"}, ManagedIdentity: &clientOrObjectId{ ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232", }, - }}.validate(), "validation failed for settings with MSI") + }}.validate() + require.Nil(t, err, "validation failed for settings with MSI") - require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{ + err = handlerSettings{publicSettings{}, protectedSettings{ CommandToExecute: "echo hi", ManagedIdentity: &clientOrObjectId{ ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232", }, - }}.validate(), "validation failed for settings with MSI") + }}.validate() + require.Nil(t, err, "validation failed for settings with MSI") require.Equal(t, errUsingBothKeyAndMsi, handlerSettings{publicSettings{}, @@ -114,7 +117,7 @@ func Test_managedIdentityVerification(t *testing.T) { ManagedIdentity: &clientOrObjectId{ ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232", }, - }}.validate(), "validation didn't fail for settings with both MSI and storage account") + }}.validate().Err, "validation didn't fail for settings with both MSI and storage account") require.Equal(t, errUsingBothClientIdAndObjectId, handlerSettings{publicSettings{}, @@ -124,7 +127,7 @@ func Test_managedIdentityVerification(t *testing.T) { ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232", ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232", }, - }}.validate(), "validation didn't fail for settings with both MSI and storage account") + }}.validate().Err, "validation didn't fail for settings with both MSI and storage account") } func Test_toJSON_empty(t *testing.T) { @@ -148,7 +151,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) { require.NoError(t, err, "error while deserializing json") require.Nil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to be nil") h := handlerSettings{publicSettings{}, *protSettings} - require.NoError(t, h.validate(), "settings should be valid") + require.Nil(t, h.validate(), "settings should be valid") testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { }}` require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid") @@ -159,7 +162,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) { require.Equal(t, protSettings.ManagedIdentity.ClientId, "") require.Equal(t, protSettings.ManagedIdentity.ObjectId, "") h = handlerSettings{publicSettings{}, *protSettings} - require.NoError(t, h.validate(), "settings should be valid") + require.Nil(t, h.validate(), "settings should be valid") testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}` require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid") @@ -170,7 +173,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) { require.Equal(t, protSettings.ManagedIdentity.ClientId, "31b403aa-c364-4240-a7ff-d85fb6cd7232") require.Equal(t, protSettings.ManagedIdentity.ObjectId, "") h = handlerSettings{publicSettings{}, *protSettings} - require.NoError(t, h.validate(), "settings should be valid") + require.Nil(t, h.validate(), "settings should be valid") testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { "objectId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}` require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid") @@ -181,7 +184,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) { require.Equal(t, protSettings.ManagedIdentity.ObjectId, "31b403aa-c364-4240-a7ff-d85fb6cd7232") require.Equal(t, protSettings.ManagedIdentity.ClientId, "") h = handlerSettings{publicSettings{}, *protSettings} - require.NoError(t, h.validate(), "settings should be valid") + require.Nil(t, h.validate(), "settings should be valid") testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232", "objectId": "41b403aa-c364-4240-a7ff-d85fb6cd7232"}}` require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid") diff --git a/main/main.go b/main/main.go index f104200..9c5b839 100644 --- a/main/main.go +++ b/main/main.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/go-kit/kit/log" + "github.com/pkg/errors" ) var ( @@ -78,10 +79,11 @@ func main() { } // execute the subcommand reportStatus(ctx, hEnv, seqNum, StatusTransitioning, cmd, "") - msg, err := cmd.f(ctx, hEnv, seqNum) - if err != nil { - ctx.Log("event", "failed to handle", "error", err) - reportStatus(ctx, hEnv, seqNum, StatusError, cmd, err.Error()+msg) + msg, ewc := cmd.f(ctx, hEnv, seqNum) + if ewc.Err != nil { + ctx.Log("event", "failed to handle", "error", ewc.Error()) + ewc.Err = errors.Wrap(ewc.Err, ewc.Error()+msg) + reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, ewc) os.Exit(cmd.failExitCode) } reportStatus(ctx, hEnv, seqNum, StatusSuccess, cmd, msg) diff --git a/main/status.go b/main/status.go index 196feb4..7f9fa03 100644 --- a/main/status.go +++ b/main/status.go @@ -8,6 +8,8 @@ import ( "path/filepath" "time" + status "github.com/Azure/azure-extension-platform/pkg/status" + vmextension "github.com/Azure/azure-extension-platform/vmextension" "github.com/go-kit/kit/log" "github.com/pkg/errors" ) @@ -102,6 +104,31 @@ func reportStatus(ctx *log.Context, hEnv HandlerEnvironment, seqNum int, t Type, return nil } +// reportErrorStatus saves the error(s) that occurred during the operation +// to the status file for the extension handler with clarification messages and codes, +// if the given cmd requires reporting status. +// +// If an error occurs reporting the status, it will be logged and returned. +func reportErrorStatus(ctx *log.Context, hEnv HandlerEnvironment, seqNum int, t Type, c cmd, ewc *vmextension.ErrorWithClarification) error { + if !c.shouldReportStatus { + ctx.Log("status", "not reported for operation (by design)") + return nil + } + var err error + if ewc == nil { + s := NewStatus(t, c.name, statusMsg(c, t, ewc.Err.Error())) + err = s.Save(hEnv.HandlerEnvironment.StatusFolder, seqNum) + } else { + s := status.NewError(c.name, status.ErrorClarification{Code: ewc.ErrorCode, Message: ewc.Error()}) + err = s.Save(hEnv.HandlerEnvironment.StatusFolder, uint(seqNum)) + } + if err != nil { + ctx.Log("event", "failed to save handler status", "error", err) + return errors.Wrap(err, "failed to save handler status") + } + return nil +} + // readStatus loads current status file in StatusReport func readStatus(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (Type, error) { fileName := fmt.Sprintf("%d.status", seqNum) diff --git a/main/status_test.go b/main/status_test.go index 770180d..333539e 100644 --- a/main/status_test.go +++ b/main/status_test.go @@ -1,11 +1,14 @@ package main import ( + "fmt" "io/ioutil" "os" "path/filepath" "testing" + vmextension "github.com/Azure/azure-extension-platform/vmextension" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/go-kit/kit/log" "github.com/stretchr/testify/require" ) @@ -46,6 +49,23 @@ func Test_reportStatus_fileExists(t *testing.T) { require.NotEqual(t, 0, len(b), ".status file not empty") } +func Test_reportErrorStatus_fileExists(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "") + require.Nil(t, err) + defer os.RemoveAll(tmpDir) + + fakeEnv := HandlerEnvironment{} + fakeEnv.HandlerEnvironment.StatusFolder = tmpDir + ewc := vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failureExitCode, fmt.Errorf("command failed with exit code = 1")) + + require.Nil(t, reportErrorStatus(log.NewContext(log.NewNopLogger()), fakeEnv, 1, StatusError, cmdEnable, ewc)) + + path := filepath.Join(tmpDir, "1.status") + b, err := ioutil.ReadFile(path) + require.Nil(t, err, ".status file exists") + require.NotEqual(t, 0, len(b), ".status file not empty") +} + func Test_reportStatus_checksIfShouldBeReported(t *testing.T) { for _, c := range cmds { tmpDir, err := ioutil.TempDir("", "status-"+c.name) diff --git a/pkg/download/blobwithmsitoken_test.go b/pkg/download/blobwithmsitoken_test.go index ec33db0..7134cda 100644 --- a/pkg/download/blobwithmsitoken_test.go +++ b/pkg/download/blobwithmsitoken_test.go @@ -46,8 +46,8 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) { err := json.Unmarshal([]byte(msiJson), &msi) return msi, err }} - _, stream, err := Download(testctx, &downloader) - require.NoError(t, err, "File download failed") + _, stream, ewc := Download(testctx, &downloader) + require.NoError(t, ewc.Err, "File download failed") defer stream.Close() bytes, err := ioutil.ReadAll(stream) diff --git a/pkg/download/downloader.go b/pkg/download/downloader.go index 9c89748..bb7589c 100644 --- a/pkg/download/downloader.go +++ b/pkg/download/downloader.go @@ -5,8 +5,11 @@ import ( "io" "net" "net/http" + "net/url" "time" + "github.com/Azure/azure-extension-platform/vmextension" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/Azure/custom-script-extension-linux/pkg/urlutil" "github.com/go-kit/kit/log" @@ -20,8 +23,10 @@ type Downloader interface { } const ( - MsiDownload404ErrorString = "please ensure that the blob location in the fileUri setting exists, and the specified Managed Identity has read permissions to the storage blob" - MsiDownload403ErrorString = "please ensure that the specified Managed Identity has read permissions to the storage blob" + MsiDownload404ErrorString = "please ensure that the blob location in the fileUri setting exists, and the specified Managed Identity has read permissions to the storage blob" + MsiDownload403ErrorString = "please ensure that the specified Managed Identity has read permissions to the storage blob" + MsiDownloadGenericErrorString = "unable to download the MSI. This may be due to firewall rules or a networking error" + MsiDownload500ErrorString = "the IMDS service returned a00 upon requesting the MSI" ) var ( @@ -44,10 +49,10 @@ var ( // Download retrieves a response body and checks the response status code to see // if it is 200 OK and then returns the response body. It issues a new request // every time called. It is caller's responsibility to close the response body. -func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) { +func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, *vmextension.ErrorWithClarification) { req, err := d.GetRequest() if err != nil { - return -1, nil, errors.Wrapf(err, "failed to create http request") + return -1, nil, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_genericError, errors.Wrapf(err, "failed to create http request")) } requestID := req.Header.Get(xMsClientRequestIdHeaderName) if len(requestID) > 0 { @@ -55,8 +60,12 @@ func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) { } resp, err := httpClient.Do(req) if err != nil { + if (err.(*url.Error)).Timeout() { + err = urlutil.RemoveUrlFromErr(err) + return -1, nil, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_exceededTimeout, errors.Wrapf(err, "http request timed out")) + } err = urlutil.RemoveUrlFromErr(err) - return -1, nil, errors.Wrapf(err, "http request failed") + return -1, nil, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unknownError, errors.Wrapf(err, "http request failed")) } if resp.StatusCode == http.StatusOK { @@ -64,14 +73,23 @@ func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) { } errString := "" + errClarificationCode := 0 requestId := resp.Header.Get(xMsServiceRequestIdHeaderName) switch d.(type) { case *blobWithMsiToken: switch resp.StatusCode { case http.StatusNotFound: errString = MsiDownload404ErrorString + errClarificationCode = errorutil.Msi_notFound case http.StatusForbidden: errString = MsiDownload403ErrorString + errClarificationCode = errorutil.Msi_doesNotHaveRightPermissions + case http.StatusInternalServerError: + errString = MsiDownload500ErrorString + errClarificationCode = errorutil.Imds_internalMsiError + default: + errString = MsiDownloadGenericErrorString + errClarificationCode = errorutil.Msi_GenericRetrievalError } break default: @@ -81,28 +99,38 @@ func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) { errString = fmt.Sprintf("CustomScript failed to download the file from %s because access was denied. Please fix the blob permissions and try again, the response code and message returned were: %q", hostname, resp.Status) + errClarificationCode = errorutil.FileDownload_accessDenied case http.StatusNotFound: errString = fmt.Sprintf("CustomScript failed to download the file from %s because it does not exist. Please create the blob and try again, the response code and message returned were: %q", hostname, resp.Status) + errClarificationCode = errorutil.FileDownload_doesNotExist case http.StatusBadRequest: errString = fmt.Sprintf("CustomScript failed to download the file from %s because parts of the request were incorrectly formatted, missing, and/or invalid. The response code and message returned were: %q", hostname, resp.Status) + errClarificationCode = errorutil.FileDownload_badRequest case http.StatusInternalServerError: errString = fmt.Sprintf("CustomScript failed to download the file from %s due to an issue with storage. The response code and message returned were: %q", hostname, resp.Status) + errClarificationCode = errorutil.Storage_internalServerError default: errString = fmt.Sprintf("CustomScript failed to download the file from %s because the server returned a response code and message of %q Please verify the machine has network connectivity.", hostname, resp.Status) + errClarificationCode = errorutil.FileDownload_networkingError } } if len(requestId) > 0 { errString += fmt.Sprintf(" (Service request ID: %s)", requestId) } - return resp.StatusCode, nil, fmt.Errorf("%s", errString) + + if errClarificationCode == 0 { + return resp.StatusCode, nil, nil + } + + return resp.StatusCode, nil, vmextension.NewErrorWithClarificationPtr(errClarificationCode, errors.New(errString)) } diff --git a/pkg/download/downloader_test.go b/pkg/download/downloader_test.go index 7c2bf2d..78ceda3 100644 --- a/pkg/download/downloader_test.go +++ b/pkg/download/downloader_test.go @@ -3,7 +3,7 @@ package download_test import ( "errors" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -13,6 +13,7 @@ import ( "github.com/go-kit/kit/log" "github.com/Azure/custom-script-extension-linux/pkg/download" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/ahmetalpbalkan/go-httpbin" "github.com/stretchr/testify/require" ) @@ -29,15 +30,17 @@ func (b *badDownloader) GetRequest() (*http.Request, error) { } func TestDownload_wrapsGetRequestError(t *testing.T) { - _, _, err := download.Download(testctx, new(badDownloader)) - require.NotNil(t, err) - require.EqualError(t, err, "failed to create http request: expected error") + _, _, ewc := download.Download(testctx, new(badDownloader)) + require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_genericError) + require.NotNil(t, ewc.Err) + require.EqualError(t, ewc.Err, "failed to create http request: expected error") } func TestDownload_wrapsHTTPError(t *testing.T) { - _, _, err := download.Download(testctx, download.NewURLDownload("bad url")) - require.NotNil(t, err) - require.Contains(t, err.Error(), "http request failed:") + _, _, ewc := download.Download(testctx, download.NewURLDownload("bad url")) + require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_unknownError) + require.NotNil(t, ewc.Err) + require.Contains(t, ewc.Err.Error(), "http request failed:") } // This test is only to make sure that formatting of error messages for specific codes is correct @@ -53,20 +56,25 @@ func TestDownload_wrapsCommonErrorCodes(t *testing.T) { http.StatusBadRequest, http.StatusUnauthorized, } { - respCode, _, err := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code))) - require.NotNil(t, err, "not failed for code:%d", code) + respCode, _, ewc := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code))) + require.NotNil(t, ewc.Err, "not failed for code:%d", code) require.Equal(t, code, respCode) switch respCode { case http.StatusNotFound: - require.Contains(t, err.Error(), "because it does not exist") + require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_doesNotExist) + require.Contains(t, ewc.Err.Error(), "because it does not exist") case http.StatusForbidden: - require.Contains(t, err.Error(), "Please verify the machine has network connectivity") + require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_networkingError) + require.Contains(t, ewc.Err.Error(), "Please verify the machine has network connectivity") case http.StatusInternalServerError: - require.Contains(t, err.Error(), "due to an issue with storage") + require.Equal(t, ewc.ErrorCode, errorutil.Storage_internalServerError) + require.Contains(t, ewc.Err.Error(), "due to an issue with storage") case http.StatusBadRequest: - require.Contains(t, err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid") + require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_badRequest) + require.Contains(t, ewc.Err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid") case http.StatusUnauthorized: - require.Contains(t, err.Error(), "because access was denied") + require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_accessDenied) + require.Contains(t, ewc.Err.Error(), "because access was denied") } } } @@ -75,8 +83,8 @@ func TestDownload_statusOKSucceeds(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/status/200")) - require.Nil(t, err) + _, body, ewc := download.Download(testctx, download.NewURLDownload(srv.URL+"/status/200")) + require.Nil(t, ewc) defer body.Close() require.NotNil(t, body) } @@ -90,27 +98,43 @@ func TestDowload_msiDownloaderErrorMessage(t *testing.T) { msiDownloader404 := download.NewBlobWithMsiDownload(srv.URL+"/status/404", mockMsiProvider) - returnCode, body, err := download.Download(testctx, msiDownloader404) - require.True(t, strings.Contains(err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message") + returnCode, body, ewc := download.Download(testctx, msiDownloader404) + require.Equal(t, ewc.ErrorCode, errorutil.Msi_notFound) + require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message") require.Nil(t, body, "body is not nil for failed download") require.Equal(t, 404, returnCode, "return code was not 404") msiDownloader403 := download.NewBlobWithMsiDownload(srv.URL+"/status/403", mockMsiProvider) - returnCode, body, err = download.Download(testctx, msiDownloader403) - require.True(t, strings.Contains(err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message") + returnCode, body, ewc = download.Download(testctx, msiDownloader403) + require.Equal(t, ewc.ErrorCode, errorutil.Msi_doesNotHaveRightPermissions) + require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message") require.Nil(t, body, "body is not nil for failed download") require.Equal(t, 403, returnCode, "return code was not 403") + msiDownloade500 := download.NewBlobWithMsiDownload(srv.URL+"/status/500", mockMsiProvider) + returnCode, body, ewc = download.Download(testctx, msiDownloade500) + require.Equal(t, ewc.ErrorCode, errorutil.Imds_internalMsiError) + require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownload500ErrorString), "error string doesn't contain the correct message") + require.Nil(t, body, "body is not nil for failed download") + require.Equal(t, 500, returnCode, "return code was not 500") + + msiDownloader400 := download.NewBlobWithMsiDownload(srv.URL+"/status/400", mockMsiProvider) + returnCode, body, ewc = download.Download(testctx, msiDownloader400) + require.Equal(t, ewc.ErrorCode, errorutil.Msi_GenericRetrievalError) + require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownloadGenericErrorString), "error string doesn't contain the correct message") + require.Nil(t, body, "body is not nil for failed download") + require.Equal(t, 400, returnCode, "return code was not 400") + } func TestDownload_retrievesBody(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536")) - require.Nil(t, err) + _, body, ewc := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536")) + require.Nil(t, ewc) defer body.Close() - b, err := ioutil.ReadAll(body) + b, err := io.ReadAll(body) require.Nil(t, err) require.EqualValues(t, 65536, len(b)) } @@ -119,7 +143,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/get")) - require.Nil(t, err) + _, body, ewc := download.Download(testctx, download.NewURLDownload(srv.URL+"/get")) + require.Nil(t, ewc) require.Nil(t, body.Close(), "body should close fine") } diff --git a/pkg/download/retry.go b/pkg/download/retry.go index 6999119..515b36e 100644 --- a/pkg/download/retry.go +++ b/pkg/download/retry.go @@ -8,7 +8,10 @@ import ( "os" "time" + "github.com/Azure/azure-extension-platform/vmextension" "github.com/go-kit/kit/log" + + errorutil "github.com/Azure/custom-script-extension-linux/pkg/errorutil" ) // SleepFunc pauses the execution for at least duration d. @@ -33,17 +36,19 @@ const ( // closed on failures). If the retries do not succeed, the last error is returned. // // It sleeps in exponentially increasing durations between retries. -func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf SleepFunc) (int64, error) { +func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf SleepFunc) (int64, *vmextension.ErrorWithClarification) { var lastErr error + var lastErrCode int for _, d := range downloaders { for n := 0; n < expRetryN; n++ { ctx := ctx.With("retry", n) // reset the last error before each retry lastErr = nil + lastErrCode = 0 start := time.Now() - status, out, err := Download(ctx, d) - if err == nil { + status, out, ewc := Download(ctx, d) + if ewc == nil { // server returned status code 200 OK // we have a response body, copy it to the file nBytes, innerErr := io.CopyBuffer(f, out, make([]byte, writeBufSize)) @@ -62,11 +67,13 @@ func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf Slee // clear out the contents of the file so as to not leave a partial file f.Truncate(0) // cache the inner error + lastErrCode = errorutil.FileDownload_genericError lastErr = innerErr } } else { // cache the outer error - lastErr = err + lastErr = ewc.Err + lastErrCode = ewc.ErrorCode } // we are here because either server returned a non-200 status code @@ -94,7 +101,12 @@ func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf Slee } } } - return 0, lastErr + + if lastErr == nil { + return 0, nil + } + + return 0, vmextension.NewErrorWithClarificationPtr(lastErrCode, lastErr) } func isTransientHttpStatusCode(statusCode int) bool { diff --git a/pkg/download/retry_test.go b/pkg/download/retry_test.go index aeb6b69..267ca1d 100644 --- a/pkg/download/retry_test.go +++ b/pkg/download/retry_test.go @@ -2,7 +2,7 @@ package download_test import ( "fmt" - "io/ioutil" + // "io/ioutil" "net/http" "net/http/httptest" "os" @@ -65,7 +65,7 @@ func TestWithRetries_failing_validateNumberOfCalls(t *testing.T) { sr := new(sleepRecorder) n, err := download.WithRetries(nopLog(), file, []download.Downloader{bd}, sr.Sleep) - require.Contains(t, err.Error(), "expected error", "error is preserved") + require.Contains(t, err.Err.Error(), "expected error", "error is preserved") require.EqualValues(t, 0, n, "downloaded number of bytes should be zero") require.EqualValues(t, 7, bd.calls, "calls exactly expRetryN times") } @@ -82,8 +82,8 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) { sr := new(sleepRecorder) n, err := download.WithRetries(nopLog(), file, []download.Downloader{d}, sr.Sleep) - require.Contains(t, err.Error(), "429 Too Many Requests") - require.Contains(t, err.Error(), "Please verify the machine has network connectivity") + require.Contains(t, err.Err.Error(), "429 Too Many Requests") + require.Contains(t, err.Err.Error(), "Please verify the machine has network connectivity") require.EqualValues(t, 0, n, "downloaded number of bytes should be zero") require.Equal(t, sleepSchedule, []time.Duration(*sr)) } @@ -168,8 +168,8 @@ func TestRetriesWith_LargeFileThatTimesOutWhileDownloading(t *testing.T) { largeFileDownloader := mockDownloader{0, srv.URL + "/bytes/" + fmt.Sprintf("%d", size)} sr := new(sleepRecorder) - n, err := download.WithRetries(nopLog(), file, []download.Downloader{&largeFileDownloader}, sr.Sleep) - require.NotNil(t, err, "download with retries should fail because of server timeout") + n, ewc := download.WithRetries(nopLog(), file, []download.Downloader{&largeFileDownloader}, sr.Sleep) + require.NotNil(t, ewc.Err, "download with retries should fail because of server timeout") require.EqualValues(t, 0, n, "downloaded number of bytes should be zero") fi, err := file.Stat() @@ -178,8 +178,8 @@ func TestRetriesWith_LargeFileThatTimesOutWhileDownloading(t *testing.T) { } func CreateTestFile(t *testing.T) (string, *os.File) { - dir, err := ioutil.TempDir("", "") - require.Nil(t, err) + dir := os.TempDir() + // require.Nil(t, err) path := filepath.Join(dir, "test-file") diff --git a/pkg/download/save.go b/pkg/download/save.go index 7273138..c74a57f 100644 --- a/pkg/download/save.go +++ b/pkg/download/save.go @@ -3,6 +3,8 @@ package download import ( "os" + "github.com/Azure/azure-extension-platform/vmextension" + "github.com/Azure/custom-script-extension-linux/pkg/errorutil" "github.com/go-kit/kit/log" "github.com/pkg/errors" ) @@ -11,16 +13,18 @@ import ( // given file. Directory of dst is not created by this function. If a file at // dst exists, it will be truncated. If a new file is created, mode is used to // set the permission bits. Written number of bytes are returned on success. -func SaveTo(ctx *log.Context, d []Downloader, dst string, mode os.FileMode) (int64, error) { +func SaveTo(ctx *log.Context, d []Downloader, dst string, mode os.FileMode) (int64, *vmextension.ErrorWithClarification) { f, err := os.OpenFile(dst, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, mode) if err != nil { - return 0, errors.Wrap(err, "failed to open file for writing") + return 0, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unknownError, errors.Wrap(err, "failed to open file for writing")) + } defer f.Close() - n, err := WithRetries(ctx, f, d, ActualSleep) - if err != nil { - return n, errors.Wrapf(err, "failed to download response and write to file: %s", dst) + n, ewc := WithRetries(ctx, f, d, ActualSleep) + if ewc != nil { + return n, vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download response and write to file: %s", dst)) + } return n, nil diff --git a/pkg/download/save_test.go b/pkg/download/save_test.go index 3bbfb9f..3b16de3 100644 --- a/pkg/download/save_test.go +++ b/pkg/download/save_test.go @@ -33,8 +33,8 @@ func TestSave(t *testing.T) { d := download.NewURLDownload(srv.URL + "/bytes/65536") path := filepath.Join(dir, "test-file") - n, err := download.SaveTo(nopLog(), []download.Downloader{d}, path, 0600) - require.Nil(t, err) + n, ewc := download.SaveTo(nopLog(), []download.Downloader{d}, path, 0600) + require.Nil(t, ewc.Err) require.EqualValues(t, 65536, n) fi, err := os.Stat(path) @@ -52,10 +52,10 @@ func TestSave_truncates(t *testing.T) { defer os.RemoveAll(dir) path := filepath.Join(dir, "test-file") - _, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/65536")}, path, 0600) - require.Nil(t, err) - _, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/128")}, path, 0777) - require.Nil(t, err) + _, ewc := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/65536")}, path, 0600) + require.Nil(t, ewc.Err) + _, ewc = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/128")}, path, 0777) + require.Nil(t, ewc.Err) fi, err := os.Stat(path) require.Nil(t, err) @@ -74,8 +74,8 @@ func TestSave_largeFile(t *testing.T) { size := 1024 * 1024 * 128 // 128 mb path := filepath.Join(dir, "large-file") - n, err := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/" + fmt.Sprintf("%d", size))}, path, 0600) - require.Nil(t, err) + n, ewc := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/" + fmt.Sprintf("%d", size))}, path, 0600) + require.Nil(t, ewc.Err) require.EqualValues(t, size, n) fi, err := os.Stat(path) diff --git a/pkg/errorutil/errorclarificationcodes.go b/pkg/errorutil/errorclarificationcodes.go new file mode 100644 index 0000000..e823017 --- /dev/null +++ b/pkg/errorutil/errorclarificationcodes.go @@ -0,0 +1,56 @@ +package errorutil + +import ( + "math" +) + +const ( + // System errors + FileDownload_badRequest int = -41 + FileDownload_unknownError int = -40 + + Imds_internalMsiError int = -30 + + Internal_badConfig int = -21 + Internal_couldNotFindCertificate int = -20 + + Os_FailedToDeleteDataDir int = -50 + Os_FailedToOpenStdOut int = -51 + Os_FailedToOpenStdErr int = -52 + + Storage_internalServerError int = -1 + SystemError int = 0 // CRP interprets anything > 0 as user errors + + // User errors + CommandExecution_failedUnknownError int = 1 + CommandExecution_failureExitCode int = 2 + CommandExecution_interruptedByVmShutdown int = 3 + + CustomerInput_commandToExecuteSpecifiedInTwoPlaces int = 20 + CustomerInput_fileUrisSpecifiedInTwoPlaces int = 22 + CustomerInput_commandToExecuteAndScriptNotSpecified int = 23 + CustomerInput_fileUriContainsNull int = 24 + CustomerInput_invalidFileUris int = 25 + CustomerInput_storageCredsAndMIBothSpecified int = 26 + CustomerInput_clientIdObjectIdBothSpecified int = 27 + CustomerInput_scriptSpecifiedInTwoPlaces int = 28 + CustomerInput_commandToExecuteAndScriptBothSpecified int = 29 + CustomerInput_incompleteStorageCreds int = 30 + + FileDownload_unableToCreateDownloadDirectory int = 50 + FileDownload_sasExpired int = 51 + FileDownload_accessDenied int = 52 + FileDownload_doesNotExist int = 53 + FileDownload_networkingError int = 54 + FileDownload_genericError int = 55 + FileDownload_exceededTimeout int = 56 + + Msi_notFound int = 70 + Msi_doesNotHaveRightPermissions int = 71 + Msi_GenericRetrievalError int = 72 + + // No Error - used as a placeholder value + // when representing an "empty" ErrorWithClarification + // or when the error can be treated without the clarification + NoError int = math.MaxInt +)