diff --git a/initializer/initializer.go b/initializer/initializer.go index 37ce4c15..15c41acd 100644 --- a/initializer/initializer.go +++ b/initializer/initializer.go @@ -157,9 +157,20 @@ var ( metricsWorkPool, readWorkPool *workpool.WorkPool ) -func Initialize(logger lager.Logger, config ExecutorConfig, cellID, zone string, - rootFSes map[string]string, metronClient loggingclient.IngressClient, - clock clock.Clock) (executor.Client, *containermetrics.StatsReporter, grouper.Members, error) { +func Initialize( + logger lager.Logger, + config ExecutorConfig, + cellID string, + zone string, + rootFSes map[string]string, + metronClient loggingclient.IngressClient, + clock clock.Clock, +) ( + executor.Client, + *containermetrics.StatsReporter, + grouper.Members, + error, +) { var gardenHealthcheckRootFS string for _, rootFSPath := range rootFSes { @@ -218,15 +229,18 @@ func Initialize(logger lager.Logger, config ExecutorConfig, cellID, zone string, return nil, nil, grouper.Members{}, err } - downloader := cacheddownloader.NewDownloader(10*time.Minute, int(math.MaxInt8), assetTLSConfig) + downloader := cacheddownloader.NewDownloader(10*time.Minute, math.MaxInt8, assetTLSConfig) uploader := uploader.New(logger, 10*time.Minute, assetTLSConfig) cache := cacheddownloader.NewCache(config.CachePath, int64(config.MaxCacheSizeInBytes)) - cachedDownloader := cacheddownloader.New( + cachedDownloader, err := cacheddownloader.New( downloader, cache, cacheddownloader.TarTransform, ) + if err != nil { + return nil, nil, grouper.Members{}, err + } err = cachedDownloader.RecoverState(logger.Session("downloader")) if err != nil { diff --git a/initializer/initializer_test.go b/initializer/initializer_test.go index 435234f3..67b5c816 100644 --- a/initializer/initializer_test.go +++ b/initializer/initializer_test.go @@ -1,12 +1,14 @@ package initializer_test import ( + "code.cloudfoundry.org/cacheddownloader" "crypto/tls" "crypto/x509" "encoding/asn1" "encoding/pem" "errors" "fmt" + "math" "net/http" "os" "path/filepath" @@ -24,7 +26,7 @@ import ( "code.cloudfoundry.org/executor/initializer/configuration" "code.cloudfoundry.org/executor/initializer/fakes" "code.cloudfoundry.org/garden" - loggregator "code.cloudfoundry.org/go-loggregator/v9" + "code.cloudfoundry.org/go-loggregator/v9" "code.cloudfoundry.org/lager/v3" "code.cloudfoundry.org/lager/v3/lagertest" . "github.com/onsi/ginkgo/v2" @@ -45,7 +47,7 @@ var _ = Describe("Initializer", func() { logger lager.Logger fakeMetronClient *mfakes.FakeIngressClient metricMap map[string]time.Duration - m sync.RWMutex + mutex sync.RWMutex ) BeforeEach(func() { @@ -105,7 +107,7 @@ var _ = Describe("Initializer", func() { fakeMetronClient = new(mfakes.FakeIngressClient) - m = sync.RWMutex{} + mutex = sync.RWMutex{} }) AfterEach(func() { @@ -114,8 +116,8 @@ var _ = Describe("Initializer", func() { }) getMetrics := func() map[string]time.Duration { - m.Lock() - defer m.Unlock() + mutex.Lock() + defer mutex.Unlock() m := make(map[string]time.Duration, len(metricMap)) for k, v := range metricMap { m[k] = v @@ -135,9 +137,9 @@ var _ = Describe("Initializer", func() { metricMap = make(map[string]time.Duration) fakeMetronClient.SendDurationStub = func(name string, time time.Duration, opts ...loggregator.EmitGaugeOption) error { - m.Lock() + mutex.Lock() metricMap[name] = time - m.Unlock() + mutex.Unlock() return nil } @@ -698,4 +700,40 @@ var _ = Describe("Initializer", func() { }) }) }) + + Describe("CachedDownloader", func() { + Context("when cacheddownloader.New receives a malformed cache.CachedPath", func() { + It("returns an error", func() { + logger := lagertest.NewTestLogger("executor") + fakeCertPoolRetriever := &fakes.FakeCertPoolRetriever{} + config.PathToTLSCert = "fixtures/downloader/client.crt" + config.PathToTLSKey = "fixtures/downloader/client.key" + config.PathToTLSCACert = "fixtures/downloader/ca.crt" + config.CachePath = "" + + fakeCertPoolRetriever.SystemCertsReturns(x509.NewCertPool(), nil) + certBytes, err := os.ReadFile(config.PathToTLSCACert) + Expect(err).NotTo(HaveOccurred()) + block, _ := pem.Decode(certBytes) + _, err = x509.ParseCertificate(block.Bytes) + Expect(err).NotTo(HaveOccurred()) + + tlsConfig, err := initializer.TLSConfigFromConfig(logger, fakeCertPoolRetriever, config) + Expect(err).To(Succeed()) + Expect(tlsConfig).NotTo(BeNil()) + + newDownloader := cacheddownloader.NewDownloader(10*time.Minute, math.MaxInt8, tlsConfig) + newCache := cacheddownloader.NewCache(config.CachePath, int64(config.MaxCacheSizeInBytes)) + + newCachedDownloader, err := cacheddownloader.New( + newDownloader, + newCache, + cacheddownloader.TarTransform, + ) + + Expect(newCachedDownloader).To(BeNil()) + Expect(err.Error()).To(ContainSubstring("could not create cache path")) + }) + }) + }) })