diff --git a/go.mod b/go.mod
index 6bfa287..7475498 100644
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,7 @@ go 1.24.0
require (
github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c
- github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f
+ github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34
github.com/Azure/azure-sdk-for-go v63.2.0+incompatible
github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1
github.com/go-kit/kit v0.12.0
@@ -35,6 +35,7 @@ require (
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
golang.org/x/crypto v0.45.0 // indirect
+ golang.org/x/sys v0.15.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index ae334f2..914bce2 100644
--- a/go.sum
+++ b/go.sum
@@ -2,6 +2,8 @@ github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c h
github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c/go.mod h1:sNC6lMTUkXwjrQ+nttr6GXhDfvSGT7t3UDq30BEYzu8=
github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f h1:ddsUz/suc9txCMz/xWOslqNMvzhbWFMTflUrbcMNoSw=
github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY=
+github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 h1:7bEC4DJC4w0gx7SBy7M7Q2qi6ckmHcnnlFJzo+X/gi4=
+github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY=
github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI=
github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
@@ -93,6 +95,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
diff --git a/main/cmds.go b/main/cmds.go
index 354f815..5bd856f 100644
--- a/main/cmds.go
+++ b/main/cmds.go
@@ -12,7 +12,9 @@ import (
"strconv"
"time"
- utils "github.com/Azure/azure-extension-platform/pkg/utils"
+ "github.com/Azure/azure-extension-platform/pkg/utils"
+ vmextension "github.com/Azure/azure-extension-platform/vmextension"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/Azure/custom-script-extension-linux/pkg/seqnum"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
@@ -22,7 +24,7 @@ const (
maxScriptSize = 256 * 1024
)
-type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, err error)
+type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, ewc *vmextension.ErrorWithClarification)
type preFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error
type cmd struct {
@@ -55,14 +57,14 @@ var (
}
)
-func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
+func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
ctx.Log("event", "noop")
return "", nil
}
-func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
+func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
if err := os.MkdirAll(dataDir, 0755); err != nil {
- return "", errors.Wrap(err, "failed to create data dir")
+ return "", vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrap(err, "failed to create data dir"))
}
// If the file mrseq does not exists it is for two possible reasons.
@@ -77,12 +79,12 @@ func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error)
return "", nil
}
-func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
+func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
{ // a new context scope with path
ctx = ctx.With("path", dataDir)
ctx.Log("event", "removing data dir", "path", dataDir)
if err := os.RemoveAll(dataDir); err != nil {
- return "", errors.Wrap(err, "failed to delete data directory")
+ return "", vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToDeleteDataDir, errors.Wrap(err, "failed to delete data directory"))
}
ctx.Log("event", "removed data dir")
}
@@ -110,16 +112,18 @@ func min(a, b int) int {
return b
}
-func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
+func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
// parse the extension handler settings (not available prior to 'enable')
- cfg, err := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
- if err != nil {
- return "", errors.Wrap(err, "failed to get configuration")
+ cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
+ if ewc != nil {
+ ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration")
+ return "", ewc
}
dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum))
- if err := downloadFiles(ctx, dir, cfg); err != nil {
- return "", errors.Wrap(err, "processing file downloads failed")
+ if ewc := downloadFiles(ctx, dir, cfg); ewc != nil {
+ ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed")
+ return "", ewc
}
// execute the command, save its error
@@ -175,12 +179,12 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b
// downloadFiles downloads the files specified in cfg into dir (creates if does
// not exist) and takes storage credentials specified in cfg into account.
-func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
+func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextension.ErrorWithClarification {
// - prepare the output directory for files and the command output
// - create the directory if missing
ctx.Log("event", "creating output directory", "path", dir)
if err := os.MkdirAll(dir, 0700); err != nil {
- return errors.Wrap(err, "failed to prepare output directory")
+ return vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unableToCreateDownloadDirectory, errors.Wrap(err, "failed to prepare output directory"))
}
ctx.Log("event", "created output directory")
@@ -200,9 +204,9 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
for i, f := range cfg.fileUrls() {
ctx := ctx.With("file", i)
ctx.Log("event", "download start")
- if err := downloadAndProcessURL(ctx, f, dir, &cfg); err != nil {
- ctx.Log("event", "download failed", "error", err)
- return errors.Wrapf(err, "failed to download file[%d]", i)
+ if ewc := downloadAndProcessURL(ctx, f, dir, &cfg); ewc != nil {
+ ctx.Log("event", "download failed", "error", ewc.Err)
+ return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download file[%d]", i))
}
ctx.Log("event", "download complete", "output", dir)
}
@@ -210,11 +214,12 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
}
// runCmd runs the command (extracted from cfg) in the given dir (assumed to exist).
-func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (err error) {
+func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.ErrorWithClarification) {
ctx.Log("event", "executing command", "output", dir)
var cmd string
var scenario string
var scenarioInfo string
+ var err error
// So many ways to execute a command!
if cfg.publicSettings.CommandToExecute != "" {
@@ -228,27 +233,27 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (err error) {
} else if cfg.publicSettings.Script != "" {
ctx.Log("event", "executing public script", "output", dir)
if cmd, scenarioInfo, err = writeTempScript(cfg.publicSettings.Script, dir, cfg.publicSettings.SkipDos2Unix); err != nil {
- return
+ return nil
}
scenario = fmt.Sprintf("public-script;%s", scenarioInfo)
} else if cfg.protectedSettings.Script != "" {
ctx.Log("event", "executing protected script", "output", dir)
if cmd, scenarioInfo, err = writeTempScript(cfg.protectedSettings.Script, dir, cfg.publicSettings.SkipDos2Unix); err != nil {
- return
+ return nil
}
scenario = fmt.Sprintf("protected-script;%s", scenarioInfo)
}
begin := time.Now()
- err = ExecCmdInDir(cmd, dir)
+ ewc = ExecCmdInDir(cmd, dir)
elapsed := time.Now().Sub(begin)
- isSuccess := err == nil
+ isSuccess := ewc == nil
telemetry("scenario", scenario, isSuccess, elapsed)
- if err != nil {
+ if ewc != nil {
ctx.Log("event", "failed to execute command", "error", err, "output", dir)
- return errors.Wrap(err, "failed to execute command")
+ return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrap(ewc.Err, "failed to execute command"))
}
ctx.Log("event", "executed command", "output", dir)
return nil
diff --git a/main/cmds_test.go b/main/cmds_test.go
index d6294df..8437e6f 100644
--- a/main/cmds_test.go
+++ b/main/cmds_test.go
@@ -7,6 +7,7 @@ import (
"path/filepath"
"testing"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
@@ -84,7 +85,7 @@ func Test_runCmd_success(t *testing.T) {
require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "date"},
- }), "command should run successfully")
+ }).Err, "command should run successfully")
// check stdout stderr files
_, err = os.Stat(filepath.Join(dir, "stdout"))
@@ -98,11 +99,12 @@ func Test_runCmd_fail(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(dir)
- err = runCmd(log.NewNopLogger(), dir, handlerSettings{
+ ewc := runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "non-existing-cmd"},
})
- require.NotNil(t, err, "command terminated with exit status")
- require.Contains(t, err.Error(), "failed to execute command")
+ require.Equal(t, errorutil.CommandExecution_failureExitCode, ewc.ErrorCode)
+ require.NotNil(t, ewc.Err, "command terminated with exit status")
+ require.Contains(t, ewc.Err.Error(), "failed to execute command")
}
func Test_downloadFiles(t *testing.T) {
@@ -113,7 +115,7 @@ func Test_downloadFiles(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
- err = downloadFiles(log.NewContext(log.NewNopLogger()),
+ ewc := downloadFiles(log.NewContext(log.NewNopLogger()),
dir,
handlerSettings{
publicSettings: publicSettings{
@@ -123,7 +125,7 @@ func Test_downloadFiles(t *testing.T) {
srv.URL + "/bytes/1000",
}},
})
- require.Nil(t, err)
+ require.Nil(t, ewc)
// check the files
f := []string{"10", "100", "1000"}
diff --git a/main/exec.go b/main/exec.go
index ed367ca..d9328df 100644
--- a/main/exec.go
+++ b/main/exec.go
@@ -8,6 +8,8 @@ import (
"path/filepath"
"syscall"
+ vmextension "github.com/Azure/azure-extension-platform/vmextension"
+ errorutil "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/pkg/errors"
)
@@ -16,7 +18,7 @@ import (
//
// On error, an exit code may be returned if it is an exit code error.
// Given stdout and stderr will be closed upon returning.
-func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
+func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, *vmextension.ErrorWithClarification) {
defer stdout.Close()
defer stderr.Close()
@@ -30,10 +32,14 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
if ok {
if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
code := status.ExitStatus()
- return code, fmt.Errorf("command terminated with exit status=%d", code)
+ return code, vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failureExitCode, fmt.Errorf("command terminated with exit status=%d", code))
}
}
- return 0, errors.Wrapf(err, "failed to execute command")
+ if err == nil {
+ return 0, nil
+ }
+
+ return 0, vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failedUnknownError, errors.Wrapf(err, "failed to execute command"))
}
// ExecCmdInDir executes the given command in given directory and saves output
@@ -42,20 +48,20 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
//
// Ideally, we execute commands only once per sequence number in custom-script-extension,
// and save their output under /var/lib/waagent/
/download//*.
-func ExecCmdInDir(cmd, workdir string) error {
+func ExecCmdInDir(cmd, workdir string) *vmextension.ErrorWithClarification {
outFn, errFn := logPaths(workdir)
outF, err := os.OpenFile(outFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
- return errors.Wrapf(err, "failed to open stdout file")
+ return vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToOpenStdOut, errors.Wrapf(err, "failed to open stdout file"))
}
errF, err := os.OpenFile(errFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
- return errors.Wrapf(err, "failed to open stderr file")
+ return vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToOpenStdErr, errors.Wrapf(err, "failed to open stderr file"))
}
- _, err = Exec(cmd, workdir, outF, errF)
- return err
+ _, ewc := Exec(cmd, workdir, outF, errF)
+ return ewc
}
// logPaths returns stdout and stderr file paths for the specified output
diff --git a/main/exec_test.go b/main/exec_test.go
index e01b34c..5eabc88 100644
--- a/main/exec_test.go
+++ b/main/exec_test.go
@@ -8,13 +8,14 @@ import (
"path/filepath"
"testing"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/stretchr/testify/require"
)
func TestExec_success(t *testing.T) {
v := new(mockFile)
ec, err := Exec("date", "/", v, v)
- require.Nil(t, err, "err: %v -- out: %s", err, v.b.Bytes())
+ require.Nil(t, err.Err, "err: %v -- out: %s", err.Err, v.b.Bytes())
require.EqualValues(t, 0, ec)
}
@@ -24,7 +25,7 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")
_, err := Exec("/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2", "/", o, e)
- require.Nil(t, err, "err: %v -- stderr: %s", err, e.b.Bytes())
+ require.Nil(t, err, "err: %v -- stderr: %s", err.Err, e.b.Bytes())
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
@@ -33,15 +34,17 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {
func TestExec_failure_exitError(t *testing.T) {
ec, err := Exec("exit 12", "/", new(mockFile), new(mockFile))
- require.NotNil(t, err)
- require.EqualError(t, err, "command terminated with exit status=12") // error is customized
+ require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
+ require.NotNil(t, err.Err)
+ require.EqualError(t, err.Err, "command terminated with exit status=12") // error is customized
require.EqualValues(t, 12, ec)
}
func TestExec_failure_genericError(t *testing.T) {
_, err := Exec("date", "/non-existing-path", new(mockFile), new(mockFile))
- require.NotNil(t, err)
- require.Contains(t, err.Error(), "failed to execute command:") // error is wrapped
+ require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
+ require.NotNil(t, err.Err)
+ require.Contains(t, err.Err.Error(), "failed to execute command:") // error is wrapped
}
func TestExec_failure_fdClosed(t *testing.T) {
@@ -49,8 +52,9 @@ func TestExec_failure_fdClosed(t *testing.T) {
require.Nil(t, out.Close())
_, err := Exec("date", "/", out, out)
- require.NotNil(t, err)
- require.Contains(t, err.Error(), "file closed") // error is wrapped
+ require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
+ require.NotNil(t, err.Err)
+ require.Contains(t, err.Err.Error(), "file closed") // error is wrapped
}
func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
@@ -59,7 +63,8 @@ func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")
_, err := Exec(`/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2; exit 12`, "/", o, e)
- require.NotNil(t, err)
+ require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
+ require.NotNil(t, err.Err)
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
@@ -71,8 +76,8 @@ func TestExecCmdInDir(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(dir)
- err = ExecCmdInDir("/bin/echo 'Hello world'", dir)
- require.Nil(t, err)
+ ewc := ExecCmdInDir("/bin/echo 'Hello world'", dir)
+ require.Nil(t, ewc)
require.True(t, fileExists(t, filepath.Join(dir, "stdout")), "stdout file should be created")
require.True(t, fileExists(t, filepath.Join(dir, "stderr")), "stderr file should be created")
@@ -87,7 +92,9 @@ func TestExecCmdInDir(t *testing.T) {
func TestExecCmdInDir_cantOpenError(t *testing.T) {
err := ExecCmdInDir("/bin/echo 'Hello world'", "/non-existing-dir")
- require.Contains(t, err.Error(), "failed to open stdout file")
+ require.Equal(t, err.ErrorCode, errorutil.Os_FailedToOpenStdErr)
+ require.NotNil(t, err.Err)
+ require.Contains(t, err.Err.Error(), "failed to open stdout file")
}
func TestExecCmdInDir_truncates(t *testing.T) {
diff --git a/main/files.go b/main/files.go
index 8e57090..dc7ce6e 100644
--- a/main/files.go
+++ b/main/files.go
@@ -9,48 +9,56 @@ import (
"os"
+ "github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/blobutil"
"github.com/Azure/custom-script-extension-linux/pkg/download"
"github.com/Azure/custom-script-extension-linux/pkg/preprocess"
"github.com/Azure/custom-script-extension-linux/pkg/urlutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
+
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
)
// downloadAndProcessURL downloads using the specified downloader and saves it to the
// specified existing directory, which must be the path to the saved file. Then
// it post-processes file based on heuristics.
-func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) error {
+func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) *vmextension.ErrorWithClarification {
fn, err := urlToFileName(url)
if err != nil {
- return err
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, err)
}
if !urlutil.IsValidUrl(url) {
- return fmt.Errorf("[REDACTED] is not a valid url")
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, fmt.Errorf("[REDACTED] is not a valid url"))
}
- dl, err := getDownloaders(url, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.ManagedIdentity)
- if err != nil {
- return err
+ dl, ewc := getDownloaders(url, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.ManagedIdentity)
+ if ewc != nil {
+ return ewc
}
fp := filepath.Join(downloadDir, fn)
const mode = 0500 // we assume users download scripts to execute
- if _, err := download.SaveTo(ctx, dl, fp, mode); err != nil {
- return err
+ if _, ewc := download.SaveTo(ctx, dl, fp, mode); ewc != nil {
+ return ewc
}
if cfg.SkipDos2Unix == false {
err = postProcessFile(fp)
}
- return errors.Wrapf(err, "failed to post-process '%s'", fn)
+
+ if err != nil {
+ return vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrapf(err, "failed to post-process '%s'", fn))
+ }
+
+ return nil
}
// getDownloader returns a downloader for the given URL based on whether the
// storage credentials are empty or not.
func getDownloaders(fileURL string, storageAccountName, storageAccountKey string, managedIdentity *clientOrObjectId) (
- []download.Downloader, error) {
+ []download.Downloader, *vmextension.ErrorWithClarification) {
if storageAccountName == "" || storageAccountKey == "" {
// storage account name and key cannot be specified with managed identity, handler settings validation won't allow that
// handler settings validation will also not allow storageAccountName XOR storageAccountKey == 1
@@ -67,7 +75,7 @@ func getDownloaders(fileURL string, storageAccountName, storageAccountKey string
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId)
default:
- return nil, fmt.Errorf("unexpected combination of ClientId and ObjectId found")
+ return nil, vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_clientIdObjectIdBothSpecified, fmt.Errorf("unexpected combination of ClientId and ObjectId found"))
}
return []download.Downloader{
// try downloading without MSI token first, but attempt with MSI if the download fails
@@ -83,11 +91,11 @@ func getDownloaders(fileURL string, storageAccountName, storageAccountKey string
// this preserves old behavior
blob, err := blobutil.ParseBlobURL(fileURL)
if err != nil {
- return nil, err
+ return nil, vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_invalidFileUris, err)
}
return []download.Downloader{download.NewBlobDownload(
- storageAccountName, storageAccountKey,
- blob)}, nil
+ storageAccountName, storageAccountKey, blob)},
+ nil
}
}
diff --git a/main/files_test.go b/main/files_test.go
index 72e60a6..dd1bb05 100644
--- a/main/files_test.go
+++ b/main/files_test.go
@@ -124,8 +124,8 @@ func Test_downloadAndProcessURL(t *testing.T) {
defer os.RemoveAll(tmpDir)
cfg := handlerSettings{publicSettings{}, protectedSettings{StorageAccountName: "", StorageAccountKey: ""}}
- err = downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg)
- require.Nil(t, err)
+ ewc := downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg)
+ require.Nil(t, ewc)
fp := filepath.Join(tmpDir, "256")
fi, err := os.Stat(fp)
diff --git a/main/handlersettings.go b/main/handlersettings.go
index a15dee5..430ee9d 100644
--- a/main/handlersettings.go
+++ b/main/handlersettings.go
@@ -5,6 +5,8 @@ import (
"fmt"
"path/filepath"
+ "github.com/Azure/azure-extension-platform/vmextension"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)
@@ -13,6 +15,7 @@ var (
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
+ errFileUrisTooMany = errors.New("'fileUris' were specified both in public and protected settings; it must be specified only once")
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
errCmdMissing = errors.New("'commandToExecute' is not specified")
errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'")
@@ -48,34 +51,38 @@ func (s *handlerSettings) fileUrls() []string {
// validate makes logical validation on the handlerSettings which already passed
// the schema validation.
-func (h handlerSettings) validate() error {
+func (h handlerSettings) validate() *vmextension.ErrorWithClarification {
if h.commandToExecute() == "" && h.script() == "" {
- return errCmdMissing
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_commandToExecuteAndScriptNotSpecified, errCmdMissing)
}
if h.publicSettings.CommandToExecute != "" && h.protectedSettings.CommandToExecute != "" {
- return errCmdTooMany
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_commandToExecuteSpecifiedInTwoPlaces, errCmdTooMany)
}
if h.publicSettings.Script != "" && h.protectedSettings.Script != "" {
- return errScriptTooMany
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_scriptSpecifiedInTwoPlaces, errScriptTooMany)
+ }
+
+ if (h.publicSettings.FileURLs != nil && len(h.publicSettings.FileURLs) > 0) && (h.protectedSettings.FileURLs != nil && len(h.protectedSettings.FileURLs) > 0) {
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_fileUrisSpecifiedInTwoPlaces, errFileUrisTooMany)
}
if h.commandToExecute() != "" && h.script() != "" {
- return errCmdAndScript
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_commandToExecuteAndScriptBothSpecified, errCmdAndScript)
}
if (h.protectedSettings.StorageAccountName != "") !=
(h.protectedSettings.StorageAccountKey != "") {
- return errStoragePartialCredentials
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_incompleteStorageCreds, errStoragePartialCredentials)
}
if (h.protectedSettings.StorageAccountKey != "" || h.protectedSettings.StorageAccountName != "") && h.protectedSettings.ManagedIdentity != nil {
- return errUsingBothKeyAndMsi
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_storageCredsAndMIBothSpecified, errUsingBothKeyAndMsi)
}
if h.protectedSettings.ManagedIdentity != nil {
if h.protectedSettings.ManagedIdentity.ClientId != "" && h.protectedSettings.ManagedIdentity.ObjectId != "" {
- return errUsingBothClientIdAndObjectId
+ return vmextension.NewErrorWithClarificationPtr(errorutil.CustomerInput_clientIdObjectIdBothSpecified, errUsingBothClientIdAndObjectId)
}
}
@@ -113,29 +120,30 @@ func (self *clientOrObjectId) isEmpty() bool {
// parseAndValidateSettings reads configuration from configFolder, decrypts it,
// runs JSON-schema and logical validation on it and returns it back.
-func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ error) {
+func parseAndValidateSettings(ctx *log.Context, configFolder string, seqNum int) (h handlerSettings, _ *vmextension.ErrorWithClarification) {
ctx.Log("event", "reading configuration")
pubJSON, protJSON, err := readSettings(configFolder, seqNum)
if err != nil {
- return h, err
+ return h, vmextension.NewErrorWithClarificationPtr(errorutil.Internal_badConfig, err)
}
ctx.Log("event", "read configuration")
ctx.Log("event", "validating json schema")
if err := validateSettingsSchema(pubJSON, protJSON); err != nil {
- return h, errors.Wrap(err, "json validation error")
+ return h, vmextension.NewErrorWithClarificationPtr(errorutil.Internal_badConfig, errors.Wrap(err, "json validation error"))
}
ctx.Log("event", "json schema valid")
ctx.Log("event", "parsing configuration json")
if err := UnmarshalHandlerSettings(pubJSON, protJSON, &h.publicSettings, &h.protectedSettings); err != nil {
- return h, errors.Wrap(err, "json parsing error")
+ return h, vmextension.NewErrorWithClarificationPtr(errorutil.Internal_badConfig, errors.Wrap(err, "json parsing error"))
}
ctx.Log("event", "parsed configuration json")
ctx.Log("event", "validating configuration logically")
- if err := h.validate(); err != nil {
- return h, errors.Wrap(err, "invalid configuration")
+ if ewc := h.validate(); err != nil {
+ ewc.Err = errors.Wrap(ewc.Err, "invalid configuration")
+ return h, ewc
}
ctx.Log("event", "validated configuration")
return h, nil
diff --git a/main/handlersettings_test.go b/main/handlersettings_test.go
index 3addd72..53f149a 100644
--- a/main/handlersettings_test.go
+++ b/main/handlersettings_test.go
@@ -3,38 +3,39 @@ package main
import (
"encoding/json"
"testing"
+
+ "github.com/stretchr/testify/require"
)
-import "github.com/stretchr/testify/require"
func Test_handlerSettingsValidate(t *testing.T) {
// commandToExecute not specified
require.Equal(t, errCmdMissing, handlerSettings{
publicSettings{},
protectedSettings{},
- }.validate())
+ }.validate().Err)
// commandToExecute specified twice
require.Equal(t, errCmdTooMany, handlerSettings{
publicSettings{CommandToExecute: "foo"},
protectedSettings{CommandToExecute: "foo"},
- }.validate())
+ }.validate().Err)
// script specified twice
require.Equal(t, errScriptTooMany, handlerSettings{
publicSettings{Script: "foo"},
protectedSettings{Script: "foo"},
- }.validate())
+ }.validate().Err)
// commandToExecute and script both specified
require.Equal(t, errCmdAndScript, handlerSettings{
publicSettings{CommandToExecute: "foo"},
protectedSettings{Script: "foo"},
- }.validate())
+ }.validate().Err)
require.Equal(t, errCmdAndScript, handlerSettings{
publicSettings{Script: "foo"},
protectedSettings{CommandToExecute: "foo"},
- }.validate())
+ }.validate().Err)
// storageAccount name specified; but not key
require.Equal(t, errStoragePartialCredentials, handlerSettings{
@@ -42,7 +43,7 @@ func Test_handlerSettingsValidate(t *testing.T) {
CommandToExecute: "date",
StorageAccountName: "foo",
StorageAccountKey: ""},
- }.validate())
+ }.validate().Err)
// storageAccount key specified; but not name
require.Equal(t, errStoragePartialCredentials, handlerSettings{
@@ -50,7 +51,7 @@ func Test_handlerSettingsValidate(t *testing.T) {
CommandToExecute: "date",
StorageAccountName: "",
StorageAccountKey: "foo"},
- }.validate())
+ }.validate().Err)
}
func Test_commandToExecutePrivateIfNotPublic(t *testing.T) {
@@ -90,20 +91,22 @@ func Test_skipDos2UnixDefaultsToFalse(t *testing.T) {
}
func Test_managedIdentityVerification(t *testing.T) {
- require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{
+ err := handlerSettings{publicSettings{}, protectedSettings{
CommandToExecute: "echo hi",
FileURLs: []string{"file1", "file2"},
ManagedIdentity: &clientOrObjectId{
ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
- }}.validate(), "validation failed for settings with MSI")
+ }}.validate()
+ require.Nil(t, err, "validation failed for settings with MSI")
- require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{
+ err = handlerSettings{publicSettings{}, protectedSettings{
CommandToExecute: "echo hi",
ManagedIdentity: &clientOrObjectId{
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
- }}.validate(), "validation failed for settings with MSI")
+ }}.validate()
+ require.Nil(t, err, "validation failed for settings with MSI")
require.Equal(t, errUsingBothKeyAndMsi,
handlerSettings{publicSettings{},
@@ -114,7 +117,7 @@ func Test_managedIdentityVerification(t *testing.T) {
ManagedIdentity: &clientOrObjectId{
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
- }}.validate(), "validation didn't fail for settings with both MSI and storage account")
+ }}.validate().Err, "validation didn't fail for settings with both MSI and storage account")
require.Equal(t, errUsingBothClientIdAndObjectId,
handlerSettings{publicSettings{},
@@ -124,7 +127,7 @@ func Test_managedIdentityVerification(t *testing.T) {
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
- }}.validate(), "validation didn't fail for settings with both MSI and storage account")
+ }}.validate().Err, "validation didn't fail for settings with both MSI and storage account")
}
func Test_toJSON_empty(t *testing.T) {
@@ -148,7 +151,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) {
require.NoError(t, err, "error while deserializing json")
require.Nil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to be nil")
h := handlerSettings{publicSettings{}, *protSettings}
- require.NoError(t, h.validate(), "settings should be valid")
+ require.Nil(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { }}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
@@ -159,7 +162,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) {
require.Equal(t, protSettings.ManagedIdentity.ClientId, "")
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "")
h = handlerSettings{publicSettings{}, *protSettings}
- require.NoError(t, h.validate(), "settings should be valid")
+ require.Nil(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
@@ -170,7 +173,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) {
require.Equal(t, protSettings.ManagedIdentity.ClientId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "")
h = handlerSettings{publicSettings{}, *protSettings}
- require.NoError(t, h.validate(), "settings should be valid")
+ require.Nil(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { "objectId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
@@ -181,7 +184,7 @@ func Test_toJSONUmarshallForManagedIdentity(t *testing.T) {
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
require.Equal(t, protSettings.ManagedIdentity.ClientId, "")
h = handlerSettings{publicSettings{}, *protSettings}
- require.NoError(t, h.validate(), "settings should be valid")
+ require.Nil(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232", "objectId": "41b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
diff --git a/main/main.go b/main/main.go
index f104200..9c5b839 100644
--- a/main/main.go
+++ b/main/main.go
@@ -7,6 +7,7 @@ import (
"strings"
"github.com/go-kit/kit/log"
+ "github.com/pkg/errors"
)
var (
@@ -78,10 +79,11 @@ func main() {
}
// execute the subcommand
reportStatus(ctx, hEnv, seqNum, StatusTransitioning, cmd, "")
- msg, err := cmd.f(ctx, hEnv, seqNum)
- if err != nil {
- ctx.Log("event", "failed to handle", "error", err)
- reportStatus(ctx, hEnv, seqNum, StatusError, cmd, err.Error()+msg)
+ msg, ewc := cmd.f(ctx, hEnv, seqNum)
+ if ewc.Err != nil {
+ ctx.Log("event", "failed to handle", "error", ewc.Error())
+ ewc.Err = errors.Wrap(ewc.Err, ewc.Error()+msg)
+ reportErrorStatus(ctx, hEnv, seqNum, StatusError, cmd, ewc)
os.Exit(cmd.failExitCode)
}
reportStatus(ctx, hEnv, seqNum, StatusSuccess, cmd, msg)
diff --git a/main/status.go b/main/status.go
index 196feb4..7f9fa03 100644
--- a/main/status.go
+++ b/main/status.go
@@ -8,6 +8,8 @@ import (
"path/filepath"
"time"
+ status "github.com/Azure/azure-extension-platform/pkg/status"
+ vmextension "github.com/Azure/azure-extension-platform/vmextension"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)
@@ -102,6 +104,31 @@ func reportStatus(ctx *log.Context, hEnv HandlerEnvironment, seqNum int, t Type,
return nil
}
+// reportErrorStatus saves the error(s) that occurred during the operation
+// to the status file for the extension handler with clarification messages and codes,
+// if the given cmd requires reporting status.
+//
+// If an error occurs reporting the status, it will be logged and returned.
+func reportErrorStatus(ctx *log.Context, hEnv HandlerEnvironment, seqNum int, t Type, c cmd, ewc *vmextension.ErrorWithClarification) error {
+ if !c.shouldReportStatus {
+ ctx.Log("status", "not reported for operation (by design)")
+ return nil
+ }
+ var err error
+ if ewc == nil {
+ s := NewStatus(t, c.name, statusMsg(c, t, ewc.Err.Error()))
+ err = s.Save(hEnv.HandlerEnvironment.StatusFolder, seqNum)
+ } else {
+ s := status.NewError(c.name, status.ErrorClarification{Code: ewc.ErrorCode, Message: ewc.Error()})
+ err = s.Save(hEnv.HandlerEnvironment.StatusFolder, uint(seqNum))
+ }
+ if err != nil {
+ ctx.Log("event", "failed to save handler status", "error", err)
+ return errors.Wrap(err, "failed to save handler status")
+ }
+ return nil
+}
+
// readStatus loads current status file in StatusReport
func readStatus(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (Type, error) {
fileName := fmt.Sprintf("%d.status", seqNum)
diff --git a/main/status_test.go b/main/status_test.go
index 770180d..333539e 100644
--- a/main/status_test.go
+++ b/main/status_test.go
@@ -1,11 +1,14 @@
package main
import (
+ "fmt"
"io/ioutil"
"os"
"path/filepath"
"testing"
+ vmextension "github.com/Azure/azure-extension-platform/vmextension"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
)
@@ -46,6 +49,23 @@ func Test_reportStatus_fileExists(t *testing.T) {
require.NotEqual(t, 0, len(b), ".status file not empty")
}
+func Test_reportErrorStatus_fileExists(t *testing.T) {
+ tmpDir, err := ioutil.TempDir("", "")
+ require.Nil(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ fakeEnv := HandlerEnvironment{}
+ fakeEnv.HandlerEnvironment.StatusFolder = tmpDir
+ ewc := vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failureExitCode, fmt.Errorf("command failed with exit code = 1"))
+
+ require.Nil(t, reportErrorStatus(log.NewContext(log.NewNopLogger()), fakeEnv, 1, StatusError, cmdEnable, ewc))
+
+ path := filepath.Join(tmpDir, "1.status")
+ b, err := ioutil.ReadFile(path)
+ require.Nil(t, err, ".status file exists")
+ require.NotEqual(t, 0, len(b), ".status file not empty")
+}
+
func Test_reportStatus_checksIfShouldBeReported(t *testing.T) {
for _, c := range cmds {
tmpDir, err := ioutil.TempDir("", "status-"+c.name)
diff --git a/pkg/download/blobwithmsitoken_test.go b/pkg/download/blobwithmsitoken_test.go
index ec33db0..7134cda 100644
--- a/pkg/download/blobwithmsitoken_test.go
+++ b/pkg/download/blobwithmsitoken_test.go
@@ -46,8 +46,8 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) {
err := json.Unmarshal([]byte(msiJson), &msi)
return msi, err
}}
- _, stream, err := Download(testctx, &downloader)
- require.NoError(t, err, "File download failed")
+ _, stream, ewc := Download(testctx, &downloader)
+ require.NoError(t, ewc.Err, "File download failed")
defer stream.Close()
bytes, err := ioutil.ReadAll(stream)
diff --git a/pkg/download/downloader.go b/pkg/download/downloader.go
index 9c89748..bb7589c 100644
--- a/pkg/download/downloader.go
+++ b/pkg/download/downloader.go
@@ -5,8 +5,11 @@ import (
"io"
"net"
"net/http"
+ "net/url"
"time"
+ "github.com/Azure/azure-extension-platform/vmextension"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/Azure/custom-script-extension-linux/pkg/urlutil"
"github.com/go-kit/kit/log"
@@ -20,8 +23,10 @@ type Downloader interface {
}
const (
- MsiDownload404ErrorString = "please ensure that the blob location in the fileUri setting exists, and the specified Managed Identity has read permissions to the storage blob"
- MsiDownload403ErrorString = "please ensure that the specified Managed Identity has read permissions to the storage blob"
+ MsiDownload404ErrorString = "please ensure that the blob location in the fileUri setting exists, and the specified Managed Identity has read permissions to the storage blob"
+ MsiDownload403ErrorString = "please ensure that the specified Managed Identity has read permissions to the storage blob"
+ MsiDownloadGenericErrorString = "unable to download the MSI. This may be due to firewall rules or a networking error"
+ MsiDownload500ErrorString = "the IMDS service returned a00 upon requesting the MSI"
)
var (
@@ -44,10 +49,10 @@ var (
// Download retrieves a response body and checks the response status code to see
// if it is 200 OK and then returns the response body. It issues a new request
// every time called. It is caller's responsibility to close the response body.
-func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) {
+func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, *vmextension.ErrorWithClarification) {
req, err := d.GetRequest()
if err != nil {
- return -1, nil, errors.Wrapf(err, "failed to create http request")
+ return -1, nil, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_genericError, errors.Wrapf(err, "failed to create http request"))
}
requestID := req.Header.Get(xMsClientRequestIdHeaderName)
if len(requestID) > 0 {
@@ -55,8 +60,12 @@ func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) {
}
resp, err := httpClient.Do(req)
if err != nil {
+ if (err.(*url.Error)).Timeout() {
+ err = urlutil.RemoveUrlFromErr(err)
+ return -1, nil, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_exceededTimeout, errors.Wrapf(err, "http request timed out"))
+ }
err = urlutil.RemoveUrlFromErr(err)
- return -1, nil, errors.Wrapf(err, "http request failed")
+ return -1, nil, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unknownError, errors.Wrapf(err, "http request failed"))
}
if resp.StatusCode == http.StatusOK {
@@ -64,14 +73,23 @@ func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) {
}
errString := ""
+ errClarificationCode := 0
requestId := resp.Header.Get(xMsServiceRequestIdHeaderName)
switch d.(type) {
case *blobWithMsiToken:
switch resp.StatusCode {
case http.StatusNotFound:
errString = MsiDownload404ErrorString
+ errClarificationCode = errorutil.Msi_notFound
case http.StatusForbidden:
errString = MsiDownload403ErrorString
+ errClarificationCode = errorutil.Msi_doesNotHaveRightPermissions
+ case http.StatusInternalServerError:
+ errString = MsiDownload500ErrorString
+ errClarificationCode = errorutil.Imds_internalMsiError
+ default:
+ errString = MsiDownloadGenericErrorString
+ errClarificationCode = errorutil.Msi_GenericRetrievalError
}
break
default:
@@ -81,28 +99,38 @@ func Download(ctx *log.Context, d Downloader) (int, io.ReadCloser, error) {
errString = fmt.Sprintf("CustomScript failed to download the file from %s because access was denied. Please fix the blob permissions and try again, the response code and message returned were: %q",
hostname,
resp.Status)
+ errClarificationCode = errorutil.FileDownload_accessDenied
case http.StatusNotFound:
errString = fmt.Sprintf("CustomScript failed to download the file from %s because it does not exist. Please create the blob and try again, the response code and message returned were: %q",
hostname,
resp.Status)
+ errClarificationCode = errorutil.FileDownload_doesNotExist
case http.StatusBadRequest:
errString = fmt.Sprintf("CustomScript failed to download the file from %s because parts of the request were incorrectly formatted, missing, and/or invalid. The response code and message returned were: %q",
hostname,
resp.Status)
+ errClarificationCode = errorutil.FileDownload_badRequest
case http.StatusInternalServerError:
errString = fmt.Sprintf("CustomScript failed to download the file from %s due to an issue with storage. The response code and message returned were: %q",
hostname,
resp.Status)
+ errClarificationCode = errorutil.Storage_internalServerError
default:
errString = fmt.Sprintf("CustomScript failed to download the file from %s because the server returned a response code and message of %q Please verify the machine has network connectivity.",
hostname,
resp.Status)
+ errClarificationCode = errorutil.FileDownload_networkingError
}
}
if len(requestId) > 0 {
errString += fmt.Sprintf(" (Service request ID: %s)", requestId)
}
- return resp.StatusCode, nil, fmt.Errorf("%s", errString)
+
+ if errClarificationCode == 0 {
+ return resp.StatusCode, nil, nil
+ }
+
+ return resp.StatusCode, nil, vmextension.NewErrorWithClarificationPtr(errClarificationCode, errors.New(errString))
}
diff --git a/pkg/download/downloader_test.go b/pkg/download/downloader_test.go
index 7c2bf2d..78ceda3 100644
--- a/pkg/download/downloader_test.go
+++ b/pkg/download/downloader_test.go
@@ -3,7 +3,7 @@ package download_test
import (
"errors"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"net/http/httptest"
"strings"
@@ -13,6 +13,7 @@ import (
"github.com/go-kit/kit/log"
"github.com/Azure/custom-script-extension-linux/pkg/download"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/stretchr/testify/require"
)
@@ -29,15 +30,17 @@ func (b *badDownloader) GetRequest() (*http.Request, error) {
}
func TestDownload_wrapsGetRequestError(t *testing.T) {
- _, _, err := download.Download(testctx, new(badDownloader))
- require.NotNil(t, err)
- require.EqualError(t, err, "failed to create http request: expected error")
+ _, _, ewc := download.Download(testctx, new(badDownloader))
+ require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_genericError)
+ require.NotNil(t, ewc.Err)
+ require.EqualError(t, ewc.Err, "failed to create http request: expected error")
}
func TestDownload_wrapsHTTPError(t *testing.T) {
- _, _, err := download.Download(testctx, download.NewURLDownload("bad url"))
- require.NotNil(t, err)
- require.Contains(t, err.Error(), "http request failed:")
+ _, _, ewc := download.Download(testctx, download.NewURLDownload("bad url"))
+ require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_unknownError)
+ require.NotNil(t, ewc.Err)
+ require.Contains(t, ewc.Err.Error(), "http request failed:")
}
// This test is only to make sure that formatting of error messages for specific codes is correct
@@ -53,20 +56,25 @@ func TestDownload_wrapsCommonErrorCodes(t *testing.T) {
http.StatusBadRequest,
http.StatusUnauthorized,
} {
- respCode, _, err := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
- require.NotNil(t, err, "not failed for code:%d", code)
+ respCode, _, ewc := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
+ require.NotNil(t, ewc.Err, "not failed for code:%d", code)
require.Equal(t, code, respCode)
switch respCode {
case http.StatusNotFound:
- require.Contains(t, err.Error(), "because it does not exist")
+ require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_doesNotExist)
+ require.Contains(t, ewc.Err.Error(), "because it does not exist")
case http.StatusForbidden:
- require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
+ require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_networkingError)
+ require.Contains(t, ewc.Err.Error(), "Please verify the machine has network connectivity")
case http.StatusInternalServerError:
- require.Contains(t, err.Error(), "due to an issue with storage")
+ require.Equal(t, ewc.ErrorCode, errorutil.Storage_internalServerError)
+ require.Contains(t, ewc.Err.Error(), "due to an issue with storage")
case http.StatusBadRequest:
- require.Contains(t, err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid")
+ require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_badRequest)
+ require.Contains(t, ewc.Err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid")
case http.StatusUnauthorized:
- require.Contains(t, err.Error(), "because access was denied")
+ require.Equal(t, ewc.ErrorCode, errorutil.FileDownload_accessDenied)
+ require.Contains(t, ewc.Err.Error(), "because access was denied")
}
}
}
@@ -75,8 +83,8 @@ func TestDownload_statusOKSucceeds(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
- _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/status/200"))
- require.Nil(t, err)
+ _, body, ewc := download.Download(testctx, download.NewURLDownload(srv.URL+"/status/200"))
+ require.Nil(t, ewc)
defer body.Close()
require.NotNil(t, body)
}
@@ -90,27 +98,43 @@ func TestDowload_msiDownloaderErrorMessage(t *testing.T) {
msiDownloader404 := download.NewBlobWithMsiDownload(srv.URL+"/status/404", mockMsiProvider)
- returnCode, body, err := download.Download(testctx, msiDownloader404)
- require.True(t, strings.Contains(err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message")
+ returnCode, body, ewc := download.Download(testctx, msiDownloader404)
+ require.Equal(t, ewc.ErrorCode, errorutil.Msi_notFound)
+ require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message")
require.Nil(t, body, "body is not nil for failed download")
require.Equal(t, 404, returnCode, "return code was not 404")
msiDownloader403 := download.NewBlobWithMsiDownload(srv.URL+"/status/403", mockMsiProvider)
- returnCode, body, err = download.Download(testctx, msiDownloader403)
- require.True(t, strings.Contains(err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message")
+ returnCode, body, ewc = download.Download(testctx, msiDownloader403)
+ require.Equal(t, ewc.ErrorCode, errorutil.Msi_doesNotHaveRightPermissions)
+ require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message")
require.Nil(t, body, "body is not nil for failed download")
require.Equal(t, 403, returnCode, "return code was not 403")
+ msiDownloade500 := download.NewBlobWithMsiDownload(srv.URL+"/status/500", mockMsiProvider)
+ returnCode, body, ewc = download.Download(testctx, msiDownloade500)
+ require.Equal(t, ewc.ErrorCode, errorutil.Imds_internalMsiError)
+ require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownload500ErrorString), "error string doesn't contain the correct message")
+ require.Nil(t, body, "body is not nil for failed download")
+ require.Equal(t, 500, returnCode, "return code was not 500")
+
+ msiDownloader400 := download.NewBlobWithMsiDownload(srv.URL+"/status/400", mockMsiProvider)
+ returnCode, body, ewc = download.Download(testctx, msiDownloader400)
+ require.Equal(t, ewc.ErrorCode, errorutil.Msi_GenericRetrievalError)
+ require.True(t, strings.Contains(ewc.Err.Error(), download.MsiDownloadGenericErrorString), "error string doesn't contain the correct message")
+ require.Nil(t, body, "body is not nil for failed download")
+ require.Equal(t, 400, returnCode, "return code was not 400")
+
}
func TestDownload_retrievesBody(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
- _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536"))
- require.Nil(t, err)
+ _, body, ewc := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536"))
+ require.Nil(t, ewc)
defer body.Close()
- b, err := ioutil.ReadAll(body)
+ b, err := io.ReadAll(body)
require.Nil(t, err)
require.EqualValues(t, 65536, len(b))
}
@@ -119,7 +143,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
- _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/get"))
- require.Nil(t, err)
+ _, body, ewc := download.Download(testctx, download.NewURLDownload(srv.URL+"/get"))
+ require.Nil(t, ewc)
require.Nil(t, body.Close(), "body should close fine")
}
diff --git a/pkg/download/retry.go b/pkg/download/retry.go
index 6999119..515b36e 100644
--- a/pkg/download/retry.go
+++ b/pkg/download/retry.go
@@ -8,7 +8,10 @@ import (
"os"
"time"
+ "github.com/Azure/azure-extension-platform/vmextension"
"github.com/go-kit/kit/log"
+
+ errorutil "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
)
// SleepFunc pauses the execution for at least duration d.
@@ -33,17 +36,19 @@ const (
// closed on failures). If the retries do not succeed, the last error is returned.
//
// It sleeps in exponentially increasing durations between retries.
-func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf SleepFunc) (int64, error) {
+func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf SleepFunc) (int64, *vmextension.ErrorWithClarification) {
var lastErr error
+ var lastErrCode int
for _, d := range downloaders {
for n := 0; n < expRetryN; n++ {
ctx := ctx.With("retry", n)
// reset the last error before each retry
lastErr = nil
+ lastErrCode = 0
start := time.Now()
- status, out, err := Download(ctx, d)
- if err == nil {
+ status, out, ewc := Download(ctx, d)
+ if ewc == nil {
// server returned status code 200 OK
// we have a response body, copy it to the file
nBytes, innerErr := io.CopyBuffer(f, out, make([]byte, writeBufSize))
@@ -62,11 +67,13 @@ func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf Slee
// clear out the contents of the file so as to not leave a partial file
f.Truncate(0)
// cache the inner error
+ lastErrCode = errorutil.FileDownload_genericError
lastErr = innerErr
}
} else {
// cache the outer error
- lastErr = err
+ lastErr = ewc.Err
+ lastErrCode = ewc.ErrorCode
}
// we are here because either server returned a non-200 status code
@@ -94,7 +101,12 @@ func WithRetries(ctx *log.Context, f *os.File, downloaders []Downloader, sf Slee
}
}
}
- return 0, lastErr
+
+ if lastErr == nil {
+ return 0, nil
+ }
+
+ return 0, vmextension.NewErrorWithClarificationPtr(lastErrCode, lastErr)
}
func isTransientHttpStatusCode(statusCode int) bool {
diff --git a/pkg/download/retry_test.go b/pkg/download/retry_test.go
index aeb6b69..267ca1d 100644
--- a/pkg/download/retry_test.go
+++ b/pkg/download/retry_test.go
@@ -2,7 +2,7 @@ package download_test
import (
"fmt"
- "io/ioutil"
+ // "io/ioutil"
"net/http"
"net/http/httptest"
"os"
@@ -65,7 +65,7 @@ func TestWithRetries_failing_validateNumberOfCalls(t *testing.T) {
sr := new(sleepRecorder)
n, err := download.WithRetries(nopLog(), file, []download.Downloader{bd}, sr.Sleep)
- require.Contains(t, err.Error(), "expected error", "error is preserved")
+ require.Contains(t, err.Err.Error(), "expected error", "error is preserved")
require.EqualValues(t, 0, n, "downloaded number of bytes should be zero")
require.EqualValues(t, 7, bd.calls, "calls exactly expRetryN times")
}
@@ -82,8 +82,8 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) {
sr := new(sleepRecorder)
n, err := download.WithRetries(nopLog(), file, []download.Downloader{d}, sr.Sleep)
- require.Contains(t, err.Error(), "429 Too Many Requests")
- require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
+ require.Contains(t, err.Err.Error(), "429 Too Many Requests")
+ require.Contains(t, err.Err.Error(), "Please verify the machine has network connectivity")
require.EqualValues(t, 0, n, "downloaded number of bytes should be zero")
require.Equal(t, sleepSchedule, []time.Duration(*sr))
}
@@ -168,8 +168,8 @@ func TestRetriesWith_LargeFileThatTimesOutWhileDownloading(t *testing.T) {
largeFileDownloader := mockDownloader{0, srv.URL + "/bytes/" + fmt.Sprintf("%d", size)}
sr := new(sleepRecorder)
- n, err := download.WithRetries(nopLog(), file, []download.Downloader{&largeFileDownloader}, sr.Sleep)
- require.NotNil(t, err, "download with retries should fail because of server timeout")
+ n, ewc := download.WithRetries(nopLog(), file, []download.Downloader{&largeFileDownloader}, sr.Sleep)
+ require.NotNil(t, ewc.Err, "download with retries should fail because of server timeout")
require.EqualValues(t, 0, n, "downloaded number of bytes should be zero")
fi, err := file.Stat()
@@ -178,8 +178,8 @@ func TestRetriesWith_LargeFileThatTimesOutWhileDownloading(t *testing.T) {
}
func CreateTestFile(t *testing.T) (string, *os.File) {
- dir, err := ioutil.TempDir("", "")
- require.Nil(t, err)
+ dir := os.TempDir()
+ // require.Nil(t, err)
path := filepath.Join(dir, "test-file")
diff --git a/pkg/download/save.go b/pkg/download/save.go
index 7273138..c74a57f 100644
--- a/pkg/download/save.go
+++ b/pkg/download/save.go
@@ -3,6 +3,8 @@ package download
import (
"os"
+ "github.com/Azure/azure-extension-platform/vmextension"
+ "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)
@@ -11,16 +13,18 @@ import (
// given file. Directory of dst is not created by this function. If a file at
// dst exists, it will be truncated. If a new file is created, mode is used to
// set the permission bits. Written number of bytes are returned on success.
-func SaveTo(ctx *log.Context, d []Downloader, dst string, mode os.FileMode) (int64, error) {
+func SaveTo(ctx *log.Context, d []Downloader, dst string, mode os.FileMode) (int64, *vmextension.ErrorWithClarification) {
f, err := os.OpenFile(dst, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, mode)
if err != nil {
- return 0, errors.Wrap(err, "failed to open file for writing")
+ return 0, vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unknownError, errors.Wrap(err, "failed to open file for writing"))
+
}
defer f.Close()
- n, err := WithRetries(ctx, f, d, ActualSleep)
- if err != nil {
- return n, errors.Wrapf(err, "failed to download response and write to file: %s", dst)
+ n, ewc := WithRetries(ctx, f, d, ActualSleep)
+ if ewc != nil {
+ return n, vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download response and write to file: %s", dst))
+
}
return n, nil
diff --git a/pkg/download/save_test.go b/pkg/download/save_test.go
index 3bbfb9f..3b16de3 100644
--- a/pkg/download/save_test.go
+++ b/pkg/download/save_test.go
@@ -33,8 +33,8 @@ func TestSave(t *testing.T) {
d := download.NewURLDownload(srv.URL + "/bytes/65536")
path := filepath.Join(dir, "test-file")
- n, err := download.SaveTo(nopLog(), []download.Downloader{d}, path, 0600)
- require.Nil(t, err)
+ n, ewc := download.SaveTo(nopLog(), []download.Downloader{d}, path, 0600)
+ require.Nil(t, ewc.Err)
require.EqualValues(t, 65536, n)
fi, err := os.Stat(path)
@@ -52,10 +52,10 @@ func TestSave_truncates(t *testing.T) {
defer os.RemoveAll(dir)
path := filepath.Join(dir, "test-file")
- _, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/65536")}, path, 0600)
- require.Nil(t, err)
- _, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/128")}, path, 0777)
- require.Nil(t, err)
+ _, ewc := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/65536")}, path, 0600)
+ require.Nil(t, ewc.Err)
+ _, ewc = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/128")}, path, 0777)
+ require.Nil(t, ewc.Err)
fi, err := os.Stat(path)
require.Nil(t, err)
@@ -74,8 +74,8 @@ func TestSave_largeFile(t *testing.T) {
size := 1024 * 1024 * 128 // 128 mb
path := filepath.Join(dir, "large-file")
- n, err := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/" + fmt.Sprintf("%d", size))}, path, 0600)
- require.Nil(t, err)
+ n, ewc := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/" + fmt.Sprintf("%d", size))}, path, 0600)
+ require.Nil(t, ewc.Err)
require.EqualValues(t, size, n)
fi, err := os.Stat(path)
diff --git a/pkg/errorutil/errorclarificationcodes.go b/pkg/errorutil/errorclarificationcodes.go
new file mode 100644
index 0000000..e823017
--- /dev/null
+++ b/pkg/errorutil/errorclarificationcodes.go
@@ -0,0 +1,56 @@
+package errorutil
+
+import (
+ "math"
+)
+
+const (
+ // System errors
+ FileDownload_badRequest int = -41
+ FileDownload_unknownError int = -40
+
+ Imds_internalMsiError int = -30
+
+ Internal_badConfig int = -21
+ Internal_couldNotFindCertificate int = -20
+
+ Os_FailedToDeleteDataDir int = -50
+ Os_FailedToOpenStdOut int = -51
+ Os_FailedToOpenStdErr int = -52
+
+ Storage_internalServerError int = -1
+ SystemError int = 0 // CRP interprets anything > 0 as user errors
+
+ // User errors
+ CommandExecution_failedUnknownError int = 1
+ CommandExecution_failureExitCode int = 2
+ CommandExecution_interruptedByVmShutdown int = 3
+
+ CustomerInput_commandToExecuteSpecifiedInTwoPlaces int = 20
+ CustomerInput_fileUrisSpecifiedInTwoPlaces int = 22
+ CustomerInput_commandToExecuteAndScriptNotSpecified int = 23
+ CustomerInput_fileUriContainsNull int = 24
+ CustomerInput_invalidFileUris int = 25
+ CustomerInput_storageCredsAndMIBothSpecified int = 26
+ CustomerInput_clientIdObjectIdBothSpecified int = 27
+ CustomerInput_scriptSpecifiedInTwoPlaces int = 28
+ CustomerInput_commandToExecuteAndScriptBothSpecified int = 29
+ CustomerInput_incompleteStorageCreds int = 30
+
+ FileDownload_unableToCreateDownloadDirectory int = 50
+ FileDownload_sasExpired int = 51
+ FileDownload_accessDenied int = 52
+ FileDownload_doesNotExist int = 53
+ FileDownload_networkingError int = 54
+ FileDownload_genericError int = 55
+ FileDownload_exceededTimeout int = 56
+
+ Msi_notFound int = 70
+ Msi_doesNotHaveRightPermissions int = 71
+ Msi_GenericRetrievalError int = 72
+
+ // No Error - used as a placeholder value
+ // when representing an "empty" ErrorWithClarification
+ // or when the error can be treated without the clarification
+ NoError int = math.MaxInt
+)