From 0b6e85ae378d0bafc5ca23866da4f6a08ea5e972 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 29 Sep 2025 17:16:08 +0100 Subject: [PATCH] Add a prototype Edge Worker client to the Go SDK As the Edeg API server is currently implemented we need to "pretend" to be a specific Airflow and Edge provider version. These default to the currently released versions, but can be changed via env vars to work elsewhere. This works enough to run tasks, but there might need to be some changes to the Edge API to support non-python clients (for example, working out the versioning strategy to make it long-term supportible and not need the Airflow and Edge Provider version to match 100%). A chunk of the changes here are to make the config and global setup "more well structured" -- so that they are suitable to be easily called from multiple workers (Celery and Go, useful if we don't end up keeping the Celery worker) --- go-sdk/Justfile | 5 +- .../bundle/bundlev1/bundlev1server/server.go | 3 +- go-sdk/celery/cmd/main.go | 11 +- go-sdk/celery/commands/root.go | 57 +- go-sdk/edge/cmd/main.go | 37 + go-sdk/edge/commands/root.go | 45 + go-sdk/edge/commands/run.go | 68 + go-sdk/edge/worker.go | 456 +++++++ go-sdk/go.mod | 2 +- go-sdk/go.sum | 4 +- .../pkg/{bundles/shared => config}/config.go | 93 +- go-sdk/pkg/edgeapi/client.gen.go | 1207 +++++++++++++++++ go-sdk/pkg/edgeapi/client.go | 62 + go-sdk/pkg/edgeapi/oapi-codegen.yml | 34 + go-sdk/pkg/edgeapi/overlay.yml | 29 + go-sdk/pkg/logging/level.go | 23 + go-sdk/pkg/logging/shclog/shclog.go | 9 +- go-sdk/pkg/worker/runner.go | 7 - 18 files changed, 2076 insertions(+), 76 deletions(-) create mode 100644 go-sdk/edge/cmd/main.go create mode 100644 go-sdk/edge/commands/root.go create mode 100644 go-sdk/edge/commands/run.go create mode 100644 go-sdk/edge/worker.go rename go-sdk/pkg/{bundles/shared => config}/config.go (58%) create mode 100644 go-sdk/pkg/edgeapi/client.gen.go create mode 100644 go-sdk/pkg/edgeapi/client.go create mode 100644 go-sdk/pkg/edgeapi/oapi-codegen.yml create mode 100644 go-sdk/pkg/edgeapi/overlay.yml create mode 100644 go-sdk/pkg/logging/level.go diff --git a/go-sdk/Justfile b/go-sdk/Justfile index cbbd3e87c2af5..4f70f70574259 100644 --- a/go-sdk/Justfile +++ b/go-sdk/Justfile @@ -23,12 +23,15 @@ default: @just --list # Build all components -build: build-celery build-examples +build: build-celery build-edge build-examples # Build the worker binary build-celery: go build -o bin/airflow-go-celery ./celery/cmd +build-edge: + go build -o bin/airflow-go-edge ./edge/cmd + # Build all example bundles build-examples: @just example/bundle/build diff --git a/go-sdk/bundle/bundlev1/bundlev1server/server.go b/go-sdk/bundle/bundlev1/bundlev1server/server.go index 42028eb6af4da..67d212ff06976 100644 --- a/go-sdk/bundle/bundlev1/bundlev1server/server.go +++ b/go-sdk/bundle/bundlev1/bundlev1server/server.go @@ -31,6 +31,7 @@ import ( "github.com/apache/airflow/go-sdk/bundle/bundlev1" "github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1server/impl" "github.com/apache/airflow/go-sdk/pkg/bundles/shared" + "github.com/apache/airflow/go-sdk/pkg/config" ) var versionInfo *bool = flag.Bool("bundle-metadata", false, "show the embedded bundle info") @@ -56,7 +57,7 @@ type ServerConfig struct{} // Zero or more options to configure the server may also be passed. There are no options yet, this is to allow // future changes without breaking compatibility func Serve(bundle bundlev1.BundleProvider, opts ...ServeOpt) error { - shared.SetupViper("") + config.SetupViper("") hcLogger := hclog.New(&hclog.LoggerOptions{ Level: hclog.Trace, diff --git a/go-sdk/celery/cmd/main.go b/go-sdk/celery/cmd/main.go index 1ae0ddb0ce589..6eb4deb3dcf74 100644 --- a/go-sdk/celery/cmd/main.go +++ b/go-sdk/celery/cmd/main.go @@ -17,8 +17,17 @@ package main -import "github.com/apache/airflow/go-sdk/celery/commands" +import ( + "fmt" + "os" + + "github.com/apache/airflow/go-sdk/celery/commands" +) func main() { + if os.Getenv("AIRFLOW_BUNDLE_MAGIC_COOKIE") != "" { + fmt.Println("(We're not a bundle plugin)") + os.Exit(0) + } commands.Execute() } diff --git a/go-sdk/celery/commands/root.go b/go-sdk/celery/commands/root.go index 2b723c0c1650c..7882a6723ba29 100644 --- a/go-sdk/celery/commands/root.go +++ b/go-sdk/celery/commands/root.go @@ -19,20 +19,13 @@ package commands import ( "context" - "log/slog" "os" - "github.com/MatusOllah/slogcolor" - "github.com/fatih/color" - cc "github.com/ivanpirog/coloredcobra" "github.com/spf13/cobra" - "github.com/apache/airflow/go-sdk/pkg/bundles/shared" - "github.com/apache/airflow/go-sdk/pkg/logging/shclog" + "github.com/apache/airflow/go-sdk/pkg/config" ) -var cfgFile string - // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ Use: "airflow-go-celery", @@ -42,21 +35,12 @@ var rootCmd = &cobra.Command{ All options (other than ` + "`--config`" + `) can be specified in the config file using the same name as the CLI argument but without the ` + "`--`" + ` prefix.`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - return initializeConfig(cmd) + return config.Configure(cmd) }, } // Execute is the main entrypoint, and runs the Celery broker app and listens for Celery Tasks func Execute() { - cc.Init(&cc.Config{ - RootCmd: rootCmd, - Headings: cc.Bold, - Commands: cc.Yellow + cc.Bold, - Example: cc.Italic, - ExecName: cc.HiMagenta + cc.Bold, - Flags: cc.Green, - FlagsDataType: cc.Italic + cc.White, - }) err := rootCmd.ExecuteContext(context.Background()) if err != nil { os.Exit(1) @@ -64,39 +48,8 @@ func Execute() { } func init() { - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", - "config file (default is $HOME/airflow/go-sdk.yaml)") + config.InitColor(rootCmd) + rootCmd.PersistentFlags(). + String("config", "", "config file (default is $HOME/airflow/go-sdk.yaml)") rootCmd.AddCommand(runCmd) } - -// initConfig reads in config file and ENV variables if set. -func initializeConfig(cmd *cobra.Command) error { - if err := shared.SetupViper(cfgFile); err != nil { - return err - } - // Bind the current command's flags to viper - shared.BindFlagsToViper(cmd) - - logger := makeLogger() - slog.SetDefault(logger) - - return nil -} - -func makeLogger() *slog.Logger { - opts := *slogcolor.DefaultOptions - leveler := &slog.LevelVar{} - leveler.Set(shclog.SlogLevelTrace) - - opts.Level = leveler - opts.LevelTags = map[slog.Level]string{ - shclog.SlogLevelTrace: color.New(color.FgHiGreen).Sprint("TRACE"), - slog.LevelDebug: color.New(color.BgCyan, color.FgHiWhite).Sprint("DEBUG"), - slog.LevelInfo: color.New(color.BgGreen, color.FgHiWhite).Sprint("INFO "), - slog.LevelWarn: color.New(color.BgYellow, color.FgHiWhite).Sprint("WARN "), - slog.LevelError: color.New(color.BgRed, color.FgHiWhite).Sprint("ERROR"), - } - - log := slog.New(slogcolor.NewHandler(os.Stderr, &opts)) - return log -} diff --git a/go-sdk/edge/cmd/main.go b/go-sdk/edge/cmd/main.go new file mode 100644 index 0000000000000..ac50572639146 --- /dev/null +++ b/go-sdk/edge/cmd/main.go @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "context" + "fmt" + "os" + + "github.com/apache/airflow/go-sdk/edge/commands" +) + +func main() { + if os.Getenv("AIRFLOW_BUNDLE_MAGIC_COOKIE") != "" { + fmt.Println("(We're not a bundle plugin)") + os.Exit(0) + } + err := commands.Root.ExecuteContext(context.Background()) + if err != nil { + os.Exit(1) + } +} diff --git a/go-sdk/edge/commands/root.go b/go-sdk/edge/commands/root.go new file mode 100644 index 0000000000000..0a97a4e1e23c1 --- /dev/null +++ b/go-sdk/edge/commands/root.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package commands + +import ( + "github.com/spf13/cobra" + + "github.com/apache/airflow/go-sdk/pkg/config" +) + +// Root represents the base command when called without any subcommands +var Root = &cobra.Command{ + Use: "airflow-go-edge", + Short: "Airflow worker for running Go workloads sent via Edge Worker API.", + Long: `Airflow worker for running Go workloads sent via Edge Worker API. + +All options (other than ` + "`--config`" + `) can be specified in the config file using +the same name as the CLI argument but without the ` + "`--`" + ` prefix.`, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return config.Configure(cmd) + }, + SilenceUsage: true, +} + +func init() { + Root.PersistentFlags(). + String("config", "", "config file (default is $HOME/airflow/go-sdk.yaml)") + Root.AddCommand(runCmd) + config.InitColor(Root) +} diff --git a/go-sdk/edge/commands/run.go b/go-sdk/edge/commands/run.go new file mode 100644 index 0000000000000..7eb247c6a9a3d --- /dev/null +++ b/go-sdk/edge/commands/run.go @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package commands + +import ( + "context" + + "github.com/spf13/cobra" + + "github.com/apache/airflow/go-sdk/edge" +) + +var runCmd = &cobra.Command{ + Use: "run", + Short: "Connect to Edge Executor API and run Airflow workloads", + Long: "Connect to Edge Executor API and run Airflow workloads", + + RunE: func(cmd *cobra.Command, args []string) error { + return edge.Run(context.Background()) + }, +} + +func init() { + flags := runCmd.Flags() + flags.StringP( + "execution-api-url", + "", + "http://localhost:8080/execution/", + "Execution API to connect to", + ) + flags.StringSliceP( + "queues", + "q", + []string{"default"}, + "Comma delimited list of queues to serve, serve all queues if not provided.", + ) + flags.StringP( + "api-url", + "", + "", + "URL endpoint on which the Airflow code edge API is accessible from edge worker.", + ) + flags.StringP( + "hostname", + "H", + "", + "Set the hostname of worker if you have multiple workers on a single machine.", + ) + + runCmd.MarkFlagRequired("api-url") + flags.SetAnnotation("api-url", "viper-mapping", []string{"edge.api_url"}) + flags.SetAnnotation("hostname", "viper-mapping", []string{"edge.hostname"}) +} diff --git a/go-sdk/edge/worker.go b/go-sdk/edge/worker.go new file mode 100644 index 0000000000000..6403f2d413da3 --- /dev/null +++ b/go-sdk/edge/worker.go @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package edge + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "os/signal" + "runtime" + "runtime/debug" + "sync/atomic" + "syscall" + "time" + + "github.com/google/uuid" + "github.com/spf13/cast" + "github.com/spf13/viper" + + "github.com/apache/airflow/go-sdk/bundle/bundlev1" + "github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1client" + "github.com/apache/airflow/go-sdk/pkg/bundles/shared" + "github.com/apache/airflow/go-sdk/pkg/edgeapi" + "github.com/apache/airflow/go-sdk/pkg/logging" + logserver "github.com/apache/airflow/go-sdk/pkg/logging/server" +) + +type worker struct { + *shared.Discovery + + hostname string + client edgeapi.ClientInterface + queues []string + logger *slog.Logger + + maxConcurrency int32 + + freeConcurrency atomic.Int32 + + // Are we currently attempting to drain jobs? + drain bool + + activeWorkloads map[uuid.UUID]bundlev1.ExecuteTaskWorkload + + // We need to send this on most requests, so we keep a copy of it around + sysInfo map[string]edgeapi.WorkerStateBody_Sysinfo_AdditionalProperties +} + +var ( + HeartbeatInterval = 30 * time.Second + DeregisterTimeout = 5 * time.Second +) + +func Run(ctx context.Context) error { + apiURL := viper.GetString("edge.api_url") + + hostname := viper.GetString("edge.hostname") + + if hostname == "" { + var err error + hostname, err = os.Hostname() + if err != nil { + return err + } + } + + w, err := NewWorker(hostname, apiURL, viper.GetString("api_auth.secret_key"), + viper.GetStringSlice("queues"), + ) + if err != nil { + return err + } + + err = w.Register(ctx) + if err != nil { + return err + } + + defer w.deregister(ctx) + + return w.mainLoop(ctx) +} + +func configOrDefault[T cast.Basic](key string, fallback T) T { + x := viper.Get(key) + if x == nil { + return fallback + } + return cast.To[T](x) +} + +func NewWorker( + hostname string, + apiURL string, + apiJWTSecretKey string, + queues []string, +) (*worker, error) { + client, err := edgeapi.NewClient(apiURL, edgeapi.WithEdgeAPIJWTKey([]byte(apiJWTSecretKey))) + if err != nil { + return nil, err + } + + var maxConcurrency int32 = 16 + + var airflowVer, edgeVer, concurrency, freeConcurrency, goVer edgeapi.WorkerStateBody_Sysinfo_AdditionalProperties + airflowVer.FromWorkerStateBodySysinfo0(edgeapi.WorkerStateBodySysinfo0( + configOrDefault("edge.airflow_version", "3.1.0"), + )) + edgeVer.FromWorkerStateBodySysinfo0(edgeapi.WorkerStateBodySysinfo0( + configOrDefault("edge.provider_version", "1.3.1"), + )) + concurrency.FromWorkerStateBodySysinfo1(edgeapi.WorkerStateBodySysinfo1(maxConcurrency)) + freeConcurrency.FromWorkerStateBodySysinfo1(edgeapi.WorkerStateBodySysinfo1(maxConcurrency)) + goVer.FromWorkerStateBodySysinfo0(edgeapi.WorkerStateBodySysinfo0(runtime.Version())) + + sysInfo := map[string]edgeapi.WorkerStateBody_Sysinfo_AdditionalProperties{ + "airflow_version": airflowVer, + "edge_provider_version": edgeVer, + "concurrency": concurrency, + "free_concurrency": freeConcurrency, + "go_version": goVer, + } + w := &worker{ + Discovery: shared.NewDiscovery(viper.GetString("bundles.folder"), nil), + + hostname: hostname, + queues: queues, + client: client, + sysInfo: sysInfo, + logger: slog.Default().With("logger", "edge.worker"), + maxConcurrency: maxConcurrency, + activeWorkloads: map[uuid.UUID]bundlev1.ExecuteTaskWorkload{}, + } + + w.logger.Info("Starting Go Edge worker", "queues", queues) + + w.freeConcurrency.Store(maxConcurrency) + + return w, nil +} + +func (w *worker) Register(ctx context.Context) error { + _, err := w.client.Worker().Register(ctx, w.hostname, &edgeapi.WorkerStateBody{ + State: edgeapi.EdgeWorkerStateStarting, + Sysinfo: w.sysInfo, + Queues: &w.queues, + }) + return err +} + +func (w *worker) deregister(ctx context.Context) { + w.logger.Debug("Deregistering worker") + // Create a new context that isn't cancelled to give us time to report + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), DeregisterTimeout) + defer cancel() + + _, err := w.client.Worker().SetState(ctx, w.hostname, &edgeapi.WorkerStateBody{ + State: edgeapi.EdgeWorkerStateOffline, + Sysinfo: w.sysInfo, + }) + if err != nil { + w.logger.Warn("Unable to report worker shutdown to Edge API server", "err", err) + } +} + +func (w *worker) _currentState(_ context.Context) edgeapi.EdgeWorkerState { + free := w.freeConcurrency.Load() + if free != w.maxConcurrency { + if w.drain { + return edgeapi.EdgeWorkerStateTerminating + } + return edgeapi.EdgeWorkerStateRunning + } else if w.drain { + // We were asked to drain, and we've got nothing left running + return edgeapi.EdgeWorkerStateOffline + } + return edgeapi.EdgeWorkerStateIdle +} + +func (w *worker) heartbeat(ctx context.Context) error { + state := w._currentState(ctx) + w.logger.Debug("Heartbeating", "current_state", state) + free := w.freeConcurrency.Load() + + slot := w.sysInfo["free_concurrency"] + slot.FromWorkerStateBodySysinfo1(edgeapi.WorkerStateBodySysinfo1(free)) + w.sysInfo["free_concurrency"] = slot + + jobsActive := len(w.activeWorkloads) + + resp, err := w.client.Worker().SetState(ctx, w.hostname, &edgeapi.WorkerStateBody{ + State: state, + Sysinfo: w.sysInfo, + Queues: &w.queues, + JobsActive: &jobsActive, + }) + if err != nil { + return err + } + + if resp.Queues != nil { + w.queues = *resp.Queues + } + + switch resp.State { + case edgeapi.EdgeWorkerStateShutdownRequest: + w.logger.Info("Shutdown request from server!") + w.drain = true + } + + return nil +} + +func (w *worker) mainLoop(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + logServer, err := logserver.NewFromConfig(viper.GetViper()) + if err != nil { + return err + } + + go logServer.ListenAndServe(ctx, time.Duration(0)) + + if err := w.DiscoverBundles(ctx); err != nil { + return err + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGUSR1) + + // Report state to Idle + if err := w.heartbeat(ctx); err != nil { + return err + } + + heartbeat := time.NewTicker(HeartbeatInterval) + defer heartbeat.Stop() + + // Create a sub-context so we stop the fetching jobs loop when we want to shut down + fetchJobsCtx, stopFetchJobs := context.WithCancel(ctx) + defer stopFetchJobs() + + workloadsChan := w.fetchJobsForever(fetchJobsCtx) + + for { + select { + case sig := <-sigChan: + switch sig { + case os.Interrupt: + if !w.drain { + w.logger.Info( + "Draining worker of running tasks before stopping, hit ^C again to force terminate", + ) + w.drain = true + w.heartbeat(ctx) + stopFetchJobs() + } else { + return nil + } + case syscall.SIGUSR1: + // TODO: Print stats + } + case <-heartbeat.C: + err := w.heartbeat(ctx) + if err != nil { + // TODO: shut down the worker after too many failed heartbeats + w.logger.Warn("Unable to send Heartbeat request", logging.AttrErr(err)) + continue + } + + // TODO: Deal with response (drain/cordon/etc) + case workload := <-workloadsChan: + w.logger.Debug("Got allocation", "workload", workload) + go w.runWorkload(ctx, workload.ConcurrencySlots, workload.ExecuteTaskWorkload) + // ... + case <-ctx.Done(): + w.logger.Debug("runHeartbeater stopping") + // Something signaled us to stop. + return nil + } + + if w.drain && len(w.activeWorkloads) == 0 { + return nil + } + } +} + +type jobInfo struct { + bundlev1.ExecuteTaskWorkload + ConcurrencySlots int32 +} + +// fetchJobsForever will fetch jobs from the API server every second (if there is capacity), sending the +// resulting workloads out on the channel. +func (w *worker) fetchJobsForever(ctx context.Context) <-chan jobInfo { + ch := make(chan jobInfo) + + go func() { + forever := time.NewTicker(time.Second) + defer forever.Stop() + for { + select { + case <-forever.C: + workload, slots, err := w.fetchJob(ctx) + if err != nil { + w.logger.Warn("Problem getting task from API server", "err", err) + } else if workload != nil { + ch <- jobInfo{*workload, slots} + } + case <-ctx.Done(): + return + } + } + }() + + return ch +} + +func (w *worker) fetchJob(ctx context.Context) (*bundlev1.ExecuteTaskWorkload, int32, error) { + free := w.freeConcurrency.Load() + if free == 0 { + // No free slots, no point making a request + return nil, 0, nil + } + // This is super verbose, once a second, so even on debug we don't want this on + w.logger.LogAttrs( + ctx, + logging.LevelTrace, + "Asking for work", + slog.Int("freeConcurrency", int(free)), + slog.Int("max", int(w.maxConcurrency)), + ) + resp, err := w.client.Jobs().Fetch(ctx, w.hostname, &edgeapi.WorkerQueuesBody{ + FreeConcurrency: int(free), + Queues: &w.queues, + }) + if err != nil { + return nil, 0, fmt.Errorf("unable to get jobs %w", err) + } + + if resp.TaskId == "" { + // Empty response, got nothing! + return nil, 0, nil + } + + w.logger.Info("fetchJob", "resp", fmt.Sprintf("%#v\n", resp)) + + // Round trip via json. Inefficient, but easy to code + asJSON, err := json.Marshal(resp.Command) + if err != nil { + // TODO: Report this to API server + return nil, 0, fmt.Errorf("unable to marshal workload %w", err) + } + + var out bundlev1.ExecuteTaskWorkload + err = json.Unmarshal(asJSON, &out) + if err != nil { + // TODO: Report this to API server + return nil, 0, fmt.Errorf("unable to unmarshal into workload %w", err) + } + w.logger.Info("fetchJob", "out", fmt.Sprintf("%#v\n", out)) + + return &out, int32(resp.ConcurrencySlots), nil +} + +func (w *worker) runWorkload( + ctx context.Context, + slots int32, + workload bundlev1.ExecuteTaskWorkload, +) error { + var err error + w.freeConcurrency.Add(-slots) + w.activeWorkloads[workload.TI.Id] = workload + + mapIndex := -1 + if workload.TI.MapIndex != nil { + mapIndex = *workload.TI.MapIndex + } + w.client.Jobs().State( + ctx, + workload.TI.DagId, + workload.TI.TaskId, + workload.TI.RunId, + workload.TI.TryNumber, + mapIndex, + edgeapi.TaskInstanceStateRunning, + ) + + defer func() { + jobState := edgeapi.TaskInstanceStateSuccess + + if r := recover(); r != nil { + w.logger.Error( + "Recovered in runWorkload", + slog.Any("error", r), + "stack", + string(debug.Stack()), + ) + jobState = edgeapi.TaskInstanceStateFailed + } else if err != nil { + jobState = edgeapi.TaskInstanceStateFailed + } + w.client.Jobs().State( + ctx, + workload.TI.DagId, + workload.TI.TaskId, + workload.TI.RunId, + workload.TI.TryNumber, + mapIndex, + jobState, + ) + + w.freeConcurrency.Add(slots) + delete(w.activeWorkloads, workload.TI.Id) + }() + + client, err := w.ClientForBundle(workload.BundleInfo.Name, workload.BundleInfo.Version) + if err != nil { + // TODO: This Should write something to the log file + return err + } + // TODO: Don't kill the backend process here, but instead kill it after a bit of idleness. See if we can + // reuse the process for multiple tasks too + defer client.Kill() + rpcClient, err := client.Client() + if err != nil { + return err + } + + raw, err := rpcClient.Dispense("dag-bundle") + if err != nil { + return err + } + + bundleClient := raw.(bundlev1client.BundleClient) + err = bundleClient.ExecuteTaskWorkload(ctx, workload) + return err +} diff --git a/go-sdk/go.mod b/go-sdk/go.mod index bcf328d37f9bd..697297c4801bd 100644 --- a/go-sdk/go.mod +++ b/go-sdk/go.mod @@ -9,6 +9,7 @@ require ( github.com/hashicorp/go-plugin v1.7.0 github.com/ivanpirog/coloredcobra v1.0.1 github.com/oapi-codegen/runtime v1.1.1 + github.com/spf13/cast v1.10.0 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 github.com/spf13/viper v1.20.1 @@ -37,7 +38,6 @@ require ( github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect - github.com/spf13/cast v1.7.1 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.opentelemetry.io/otel v1.29.0 // indirect diff --git a/go-sdk/go.sum b/go-sdk/go.sum index c8fea1d41b8f9..55a27682e3465 100644 --- a/go-sdk/go.sum +++ b/go-sdk/go.sum @@ -102,8 +102,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= -github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= -github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= diff --git a/go-sdk/pkg/bundles/shared/config.go b/go-sdk/pkg/config/config.go similarity index 58% rename from go-sdk/pkg/bundles/shared/config.go rename to go-sdk/pkg/config/config.go index 26e002b8b0d4e..6fb5c66f16176 100644 --- a/go-sdk/pkg/bundles/shared/config.go +++ b/go-sdk/pkg/config/config.go @@ -15,17 +15,23 @@ // specific language governing permissions and limitations // under the License. -package shared +package config import ( "fmt" + "log/slog" "os" "path" "strings" + "github.com/MatusOllah/slogcolor" + "github.com/fatih/color" + cc "github.com/ivanpirog/coloredcobra" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" + + "github.com/apache/airflow/go-sdk/pkg/logging" ) type BundleConfig struct { @@ -34,7 +40,69 @@ type BundleConfig struct { var envKeyReplacer *strings.Replacer = strings.NewReplacer(".", "__", "-", "_") -func SetupViper(cfgFile string) error { +func InitColor(rootCmd *cobra.Command) { + cc.Init(&cc.Config{ + RootCmd: rootCmd, + Headings: cc.Bold, + Commands: cc.Yellow + cc.Bold, + Example: cc.Italic, + ExecName: cc.HiMagenta + cc.Bold, + Flags: cc.Green, + FlagsDataType: cc.Italic + cc.White, + }) +} + +func Configure(cmd *cobra.Command) error { + var cfgFile string + cfgFlag := cmd.Flags().Lookup("config") + if cfgFlag != nil { + cfgFile = cfgFlag.Value.String() + } + + v, err := SetupViper(cfgFile) + if err != nil { + return err + } + // Bind the current command's flags to viper + BindFlagsToViper(cmd, v) + + logger := makeLogger(v) + slog.SetDefault(logger) + + return nil +} + +func makeLogger(v *viper.Viper) *slog.Logger { + opts := *slogcolor.DefaultOptions + leveler := &slog.LevelVar{} + + // TODO: Should we have consistency with Airflow's config option? That would mean "logging.logging_level" here + levelConfig := v.GetString("log.level") + + switch strings.ToUpper(levelConfig) { + case "TRACE": + leveler.Set(logging.LevelTrace) + case "": + // Default level is info. Job done + default: + err := leveler.UnmarshalText([]byte(levelConfig)) + cobra.CheckErr(err) + } + + opts.Level = leveler + opts.LevelTags = map[slog.Level]string{ + logging.LevelTrace: color.New(color.FgHiGreen).Sprint("TRACE"), + slog.LevelDebug: color.New(color.BgCyan, color.FgHiWhite).Sprint("DEBUG"), + slog.LevelInfo: color.New(color.BgGreen, color.FgHiWhite).Sprint("INFO "), + slog.LevelWarn: color.New(color.BgYellow, color.FgHiWhite).Sprint("WARN "), + slog.LevelError: color.New(color.BgRed, color.FgHiWhite).Sprint("ERROR"), + } + + log := slog.New(slogcolor.NewHandler(os.Stderr, &opts)) + return log +} + +func SetupViper(cfgFile string) (*viper.Viper, error) { if cfgFile != "" { // Use config file from the flag. viper.SetConfigFile(cfgFile) @@ -43,7 +111,7 @@ func SetupViper(cfgFile string) error { if airflowHome == "" { home, err := os.UserHomeDir() if err != nil { - return err + return nil, err } airflowHome = path.Join(home, "airflow") } @@ -62,7 +130,7 @@ func SetupViper(cfgFile string) error { if err := viper.ReadInConfig(); err != nil { // It's okay if there isn't a config file if _, ok := err.(viper.ConfigFileNotFoundError); !ok { - return err + return nil, err } } @@ -73,12 +141,12 @@ func SetupViper(cfgFile string) error { viper.SetEnvKeyReplacer(envKeyReplacer) viper.AutomaticEnv() // read in environment variables that match - return nil + return viper.GetViper(), nil } // Bind each cobra flag to its associated viper configuration (config file and environment variable) // This approach cribbed from https://github.com/carolynvs/stingoftheviper/blob/19bd73117f0285436505ca17616cbc394d22e63d/main.go -func BindFlagsToViper(cmd *cobra.Command) { +func BindFlagsToViper(cmd *cobra.Command, viper *viper.Viper) { cmd.Flags().VisitAll(func(f *pflag.Flag) { // Determine the naming convention of the flags when represented in the config file @@ -87,6 +155,11 @@ func BindFlagsToViper(cmd *cobra.Command) { if ann, ok := f.Annotations["viper-mapping"]; ok { configName = ann[0] } else { + // Skip the default "help" flag + if f.Name == "help" { + return + } + // Since viper does case-insensitive comparisons, we don't need to bother fixing the case, and only need to remove the hyphens. configName = envKeyReplacer.Replace(f.Name) } @@ -103,6 +176,14 @@ func BindFlagsToViper(cmd *cobra.Command) { // If we have a viper config but no flag, set the flag value. This lets `MarkRequiredFlag` work val := viper.Get(configName) cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + } else if f.Value.String() != "" { + // No changed flag (i.e. default), and no explicit viper set, set the viper value to the flag default + val := f.Value + if slice, ok := val.(pflag.SliceValue); ok { + viper.Set(configName, strings.Join(slice.GetSlice(), " ")) + } else { + viper.Set(configName, f.Value.String()) + } } }) } diff --git a/go-sdk/pkg/edgeapi/client.gen.go b/go-sdk/pkg/edgeapi/client.gen.go new file mode 100644 index 0000000000000..ab1f489632742 --- /dev/null +++ b/go-sdk/pkg/edgeapi/client.gen.go @@ -0,0 +1,1207 @@ +// Package edgeapi provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/ashb/oapi-resty-codegen version v0.0.0-20250926211356-86b7321337dd DO NOT EDIT. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package edgeapi + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/oapi-codegen/runtime" + openapi_types "github.com/oapi-codegen/runtime/types" + "resty.dev/v3" +) + +// Defines values for EdgeWorkerState. +const ( + EdgeWorkerStateIdle EdgeWorkerState = "idle" + EdgeWorkerStateMaintenanceExit EdgeWorkerState = "maintenance exit" + EdgeWorkerStateMaintenanceMode EdgeWorkerState = "maintenance mode" + EdgeWorkerStateMaintenancePending EdgeWorkerState = "maintenance pending" + EdgeWorkerStateMaintenanceRequest EdgeWorkerState = "maintenance request" + EdgeWorkerStateOffline EdgeWorkerState = "offline" + EdgeWorkerStateOfflineMaintenance EdgeWorkerState = "offline maintenance" + EdgeWorkerStateRunning EdgeWorkerState = "running" + EdgeWorkerStateShutdownRequest EdgeWorkerState = "shutdown request" + EdgeWorkerStateStarting EdgeWorkerState = "starting" + EdgeWorkerStateTerminating EdgeWorkerState = "terminating" + EdgeWorkerStateUnknown EdgeWorkerState = "unknown" +) + +// Defines values for TaskInstanceState. +const ( + TaskInstanceStateDeferred TaskInstanceState = "deferred" + TaskInstanceStateFailed TaskInstanceState = "failed" + TaskInstanceStateQueued TaskInstanceState = "queued" + TaskInstanceStateRemoved TaskInstanceState = "removed" + TaskInstanceStateRestarting TaskInstanceState = "restarting" + TaskInstanceStateRunning TaskInstanceState = "running" + TaskInstanceStateScheduled TaskInstanceState = "scheduled" + TaskInstanceStateSkipped TaskInstanceState = "skipped" + TaskInstanceStateSuccess TaskInstanceState = "success" + TaskInstanceStateUpForReschedule TaskInstanceState = "up_for_reschedule" + TaskInstanceStateUpForRetry TaskInstanceState = "up_for_retry" + TaskInstanceStateUpstreamFailed TaskInstanceState = "upstream_failed" +) + +// BundleInfo Schema for telling task which bundle to run with. +type BundleInfo struct { + Name string `json:"name"` + Version *string `json:"version"` +} + +// EdgeJobFetched Job that is to be executed on the edge worker. +type EdgeJobFetched struct { + // Command Execute the given Task. + Command ExecuteTask `json:"command"` + + // ConcurrencySlots Number of concurrency slots the job requires. + ConcurrencySlots int `json:"concurrency_slots"` + + // DagId Identifier of the DAG to which the task belongs. + DagId string `json:"dag_id"` + + // MapIndex For dynamically mapped tasks the mapping number, -1 if the task is not mapped. + MapIndex int `json:"map_index"` + + // RunId Run ID of the DAG execution. + RunId string `json:"run_id"` + + // TaskId Task name in the DAG. + TaskId string `json:"task_id"` + + // TryNumber The number of attempt to execute this task. + TryNumber int `json:"try_number"` +} + +// EdgeWorkerState Status of a Edge Worker instance. +type EdgeWorkerState string + +// ExecuteTask Execute the given Task. +type ExecuteTask struct { + // BundleInfo Schema for telling task which bundle to run with. + BundleInfo BundleInfo `json:"bundle_info"` + DagRelPath string `json:"dag_rel_path"` + LogPath *string `json:"log_path"` + + // Ti Schema for TaskInstance with minimal required fields needed for Executors and Task SDK. + Ti TaskInstance `json:"ti"` + Token string `json:"token"` + Type *string `json:"type,omitempty"` +} + +// HTTPExceptionResponse HTTPException Model used for error response. +type HTTPExceptionResponse struct { + Detail HTTPExceptionResponse_Detail `json:"detail"` +} + +// HTTPExceptionResponseDetail0 defines model for . +type HTTPExceptionResponseDetail0 = string + +// HTTPExceptionResponseDetail1 defines model for . +type HTTPExceptionResponseDetail1 map[string]interface{} + +// HTTPExceptionResponse_Detail defines model for HTTPExceptionResponse.Detail. +type HTTPExceptionResponse_Detail struct { + union json.RawMessage +} + +// HTTPValidationError defines model for HTTPValidationError. +type HTTPValidationError struct { + Detail *[]ValidationError `json:"detail,omitempty"` +} + +// Job Details of the job sent to the scheduler. +type Job struct { + // DagId Identifier of the DAG to which the task belongs. + DagId string `json:"dag_id"` + + // EdgeWorker The worker processing the job during execution. + EdgeWorker *string `json:"edge_worker"` + + // LastUpdate Last heartbeat of the job. + LastUpdate *time.Time `json:"last_update"` + + // MapIndex For dynamically mapped tasks the mapping number, -1 if the task is not mapped. + MapIndex int `json:"map_index"` + + // Queue Queue for which the task is scheduled/running. + Queue string `json:"queue"` + + // QueuedDttm When the job was queued. + QueuedDttm *time.Time `json:"queued_dttm"` + + // RunId Run ID of the DAG execution. + RunId string `json:"run_id"` + + // State All possible states that a Task Instance can be in. + // + // Note that None is also allowed, so always use this in a type hint with Optional. + State TaskInstanceState `json:"state"` + + // TaskId Task name in the DAG. + TaskId string `json:"task_id"` + + // TryNumber The number of attempt to execute this task. + TryNumber int `json:"try_number"` +} + +// JobCollectionResponse Job Collection serializer. +type JobCollectionResponse struct { + Jobs []Job `json:"jobs"` + TotalEntries int `json:"total_entries"` +} + +// MaintenanceRequest Request body for maintenance operations. +type MaintenanceRequest struct { + // MaintenanceComment Comment describing the maintenance reason. + MaintenanceComment string `json:"maintenance_comment"` +} + +// PushLogsBody Incremental new log content from worker. +type PushLogsBody struct { + // LogChunkData Log chunk data as incremental log text. + LogChunkData string `json:"log_chunk_data"` + + // LogChunkTime Time of the log chunk at point of sending. + LogChunkTime time.Time `json:"log_chunk_time"` +} + +// TaskInstance Schema for TaskInstance with minimal required fields needed for Executors and Task SDK. +type TaskInstance struct { + ContextCarrier *map[string]interface{} `json:"context_carrier"` + DagId string `json:"dag_id"` + DagVersionId openapi_types.UUID `json:"dag_version_id"` + Id openapi_types.UUID `json:"id"` + MapIndex *int `json:"map_index,omitempty"` + ParentContextCarrier *map[string]interface{} `json:"parent_context_carrier"` + PoolSlots int `json:"pool_slots"` + PriorityWeight int `json:"priority_weight"` + Queue string `json:"queue"` + RunId string `json:"run_id"` + TaskId string `json:"task_id"` + TryNumber int `json:"try_number"` +} + +// TaskInstanceState All possible states that a Task Instance can be in. +// +// Note that None is also allowed, so always use this in a type hint with Optional. +type TaskInstanceState string + +// ValidationError defines model for ValidationError. +type ValidationError struct { + Loc []ValidationError_Loc_Item `json:"loc"` + Msg string `json:"msg"` + Type string `json:"type"` +} + +// ValidationErrorLoc0 defines model for . +type ValidationErrorLoc0 = string + +// ValidationErrorLoc1 defines model for . +type ValidationErrorLoc1 = int + +// ValidationError_Loc_Item defines model for ValidationError.loc.Item. +type ValidationError_Loc_Item struct { + union json.RawMessage +} + +// Worker Details of the worker state sent to the scheduler. +type Worker struct { + // FirstOnline When the worker was first online. + FirstOnline *time.Time `json:"first_online"` + + // JobsActive Number of active jobs the worker is running. + JobsActive *int `json:"jobs_active,omitempty"` + + // LastHeartbeat When the worker last sent a heartbeat. + LastHeartbeat *time.Time `json:"last_heartbeat"` + + // MaintenanceComments Comments about the maintenance state of the worker. + MaintenanceComments *string `json:"maintenance_comments"` + + // Queues List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues. + Queues *[]string `json:"queues"` + + // State Status of a Edge Worker instance. + State EdgeWorkerState `json:"state"` + + // Sysinfo System information of the worker. + Sysinfo map[string]Worker_Sysinfo_AdditionalProperties `json:"sysinfo"` + + // WorkerName Name of the worker. + WorkerName string `json:"worker_name"` +} + +// WorkerSysinfo0 defines model for . +type WorkerSysinfo0 = string + +// WorkerSysinfo1 defines model for . +type WorkerSysinfo1 = int + +// Worker_Sysinfo_AdditionalProperties defines model for Worker.sysinfo.AdditionalProperties. +type Worker_Sysinfo_AdditionalProperties struct { + union json.RawMessage +} + +// WorkerCollectionResponse Worker Collection serializer. +type WorkerCollectionResponse struct { + TotalEntries int `json:"total_entries"` + Workers []Worker `json:"workers"` +} + +// WorkerQueueUpdateBody Changed queues for the worker. +type WorkerQueueUpdateBody struct { + // NewQueues Additional queues to be added to worker. + NewQueues *[]string `json:"new_queues"` + + // RemoveQueues Queues to remove from worker. + RemoveQueues *[]string `json:"remove_queues"` +} + +// WorkerQueuesBody Queues that a worker supports to run jobs on. +type WorkerQueuesBody struct { + // FreeConcurrency Number of free concurrency slots on the worker. + FreeConcurrency int `json:"free_concurrency"` + + // Queues List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues. + Queues *[]string `json:"queues"` +} + +// WorkerRegistrationReturn The return class for the worker registration. +type WorkerRegistrationReturn struct { + // LastUpdate Time of the last update of the worker. + LastUpdate time.Time `json:"last_update"` +} + +// WorkerSetStateReturn The return class for the worker set state. +type WorkerSetStateReturn struct { + // MaintenanceComments Comments about the maintenance state of the worker. + MaintenanceComments *string `json:"maintenance_comments"` + + // Queues List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues. + Queues *[]string `json:"queues"` + + // State Status of a Edge Worker instance. + State EdgeWorkerState `json:"state"` +} + +// WorkerStateBody Details of the worker state sent to the scheduler. +type WorkerStateBody struct { + // JobsActive Number of active jobs the worker is running. + JobsActive *int `json:"jobs_active,omitempty"` + + // MaintenanceComments Comments about the maintenance state of the worker. + MaintenanceComments *string `json:"maintenance_comments"` + + // Queues List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues. + Queues *[]string `json:"queues"` + + // State Status of a Edge Worker instance. + State EdgeWorkerState `json:"state"` + + // Sysinfo System information of the worker. + Sysinfo map[string]WorkerStateBody_Sysinfo_AdditionalProperties `json:"sysinfo"` +} + +// WorkerStateBodySysinfo0 defines model for . +type WorkerStateBodySysinfo0 = string + +// WorkerStateBodySysinfo1 defines model for . +type WorkerStateBodySysinfo1 = int + +// WorkerStateBody_Sysinfo_AdditionalProperties defines model for WorkerStateBody.sysinfo.AdditionalProperties. +type WorkerStateBody_Sysinfo_AdditionalProperties struct { + union json.RawMessage +} + +// FetchJSONRequestBody defines body for Fetch for application/json ContentType. +type FetchJSONRequestBody = WorkerQueuesBody + +// PushLogsJSONRequestBody defines body for PushLogs for application/json ContentType. +type PushLogsJSONRequestBody = PushLogsBody + +// UpdateQueuesJSONRequestBody defines body for UpdateQueues for application/json ContentType. +type UpdateQueuesJSONRequestBody = WorkerQueueUpdateBody + +// SetStateJSONRequestBody defines body for SetState for application/json ContentType. +type SetStateJSONRequestBody = WorkerStateBody + +// RegisterJSONRequestBody defines body for Register for application/json ContentType. +type RegisterJSONRequestBody = WorkerStateBody + +// AsHTTPExceptionResponseDetail0 returns the union data inside the HTTPExceptionResponse_Detail as a HTTPExceptionResponseDetail0 +func (t HTTPExceptionResponse_Detail) AsHTTPExceptionResponseDetail0() (HTTPExceptionResponseDetail0, error) { + var body HTTPExceptionResponseDetail0 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromHTTPExceptionResponseDetail0 overwrites any union data inside the HTTPExceptionResponse_Detail as the provided HTTPExceptionResponseDetail0 +func (t *HTTPExceptionResponse_Detail) FromHTTPExceptionResponseDetail0(v HTTPExceptionResponseDetail0) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeHTTPExceptionResponseDetail0 performs a merge with any union data inside the HTTPExceptionResponse_Detail, using the provided HTTPExceptionResponseDetail0 +func (t *HTTPExceptionResponse_Detail) MergeHTTPExceptionResponseDetail0(v HTTPExceptionResponseDetail0) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsHTTPExceptionResponseDetail1 returns the union data inside the HTTPExceptionResponse_Detail as a HTTPExceptionResponseDetail1 +func (t HTTPExceptionResponse_Detail) AsHTTPExceptionResponseDetail1() (HTTPExceptionResponseDetail1, error) { + var body HTTPExceptionResponseDetail1 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromHTTPExceptionResponseDetail1 overwrites any union data inside the HTTPExceptionResponse_Detail as the provided HTTPExceptionResponseDetail1 +func (t *HTTPExceptionResponse_Detail) FromHTTPExceptionResponseDetail1(v HTTPExceptionResponseDetail1) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeHTTPExceptionResponseDetail1 performs a merge with any union data inside the HTTPExceptionResponse_Detail, using the provided HTTPExceptionResponseDetail1 +func (t *HTTPExceptionResponse_Detail) MergeHTTPExceptionResponseDetail1(v HTTPExceptionResponseDetail1) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t HTTPExceptionResponse_Detail) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *HTTPExceptionResponse_Detail) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// AsValidationErrorLoc0 returns the union data inside the ValidationError_Loc_Item as a ValidationErrorLoc0 +func (t ValidationError_Loc_Item) AsValidationErrorLoc0() (ValidationErrorLoc0, error) { + var body ValidationErrorLoc0 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromValidationErrorLoc0 overwrites any union data inside the ValidationError_Loc_Item as the provided ValidationErrorLoc0 +func (t *ValidationError_Loc_Item) FromValidationErrorLoc0(v ValidationErrorLoc0) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeValidationErrorLoc0 performs a merge with any union data inside the ValidationError_Loc_Item, using the provided ValidationErrorLoc0 +func (t *ValidationError_Loc_Item) MergeValidationErrorLoc0(v ValidationErrorLoc0) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsValidationErrorLoc1 returns the union data inside the ValidationError_Loc_Item as a ValidationErrorLoc1 +func (t ValidationError_Loc_Item) AsValidationErrorLoc1() (ValidationErrorLoc1, error) { + var body ValidationErrorLoc1 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromValidationErrorLoc1 overwrites any union data inside the ValidationError_Loc_Item as the provided ValidationErrorLoc1 +func (t *ValidationError_Loc_Item) FromValidationErrorLoc1(v ValidationErrorLoc1) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeValidationErrorLoc1 performs a merge with any union data inside the ValidationError_Loc_Item, using the provided ValidationErrorLoc1 +func (t *ValidationError_Loc_Item) MergeValidationErrorLoc1(v ValidationErrorLoc1) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t ValidationError_Loc_Item) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *ValidationError_Loc_Item) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// AsWorkerSysinfo0 returns the union data inside the Worker_Sysinfo_AdditionalProperties as a WorkerSysinfo0 +func (t Worker_Sysinfo_AdditionalProperties) AsWorkerSysinfo0() (WorkerSysinfo0, error) { + var body WorkerSysinfo0 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWorkerSysinfo0 overwrites any union data inside the Worker_Sysinfo_AdditionalProperties as the provided WorkerSysinfo0 +func (t *Worker_Sysinfo_AdditionalProperties) FromWorkerSysinfo0(v WorkerSysinfo0) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWorkerSysinfo0 performs a merge with any union data inside the Worker_Sysinfo_AdditionalProperties, using the provided WorkerSysinfo0 +func (t *Worker_Sysinfo_AdditionalProperties) MergeWorkerSysinfo0(v WorkerSysinfo0) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsWorkerSysinfo1 returns the union data inside the Worker_Sysinfo_AdditionalProperties as a WorkerSysinfo1 +func (t Worker_Sysinfo_AdditionalProperties) AsWorkerSysinfo1() (WorkerSysinfo1, error) { + var body WorkerSysinfo1 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWorkerSysinfo1 overwrites any union data inside the Worker_Sysinfo_AdditionalProperties as the provided WorkerSysinfo1 +func (t *Worker_Sysinfo_AdditionalProperties) FromWorkerSysinfo1(v WorkerSysinfo1) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWorkerSysinfo1 performs a merge with any union data inside the Worker_Sysinfo_AdditionalProperties, using the provided WorkerSysinfo1 +func (t *Worker_Sysinfo_AdditionalProperties) MergeWorkerSysinfo1(v WorkerSysinfo1) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t Worker_Sysinfo_AdditionalProperties) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *Worker_Sysinfo_AdditionalProperties) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// AsWorkerStateBodySysinfo0 returns the union data inside the WorkerStateBody_Sysinfo_AdditionalProperties as a WorkerStateBodySysinfo0 +func (t WorkerStateBody_Sysinfo_AdditionalProperties) AsWorkerStateBodySysinfo0() (WorkerStateBodySysinfo0, error) { + var body WorkerStateBodySysinfo0 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWorkerStateBodySysinfo0 overwrites any union data inside the WorkerStateBody_Sysinfo_AdditionalProperties as the provided WorkerStateBodySysinfo0 +func (t *WorkerStateBody_Sysinfo_AdditionalProperties) FromWorkerStateBodySysinfo0(v WorkerStateBodySysinfo0) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWorkerStateBodySysinfo0 performs a merge with any union data inside the WorkerStateBody_Sysinfo_AdditionalProperties, using the provided WorkerStateBodySysinfo0 +func (t *WorkerStateBody_Sysinfo_AdditionalProperties) MergeWorkerStateBodySysinfo0(v WorkerStateBodySysinfo0) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsWorkerStateBodySysinfo1 returns the union data inside the WorkerStateBody_Sysinfo_AdditionalProperties as a WorkerStateBodySysinfo1 +func (t WorkerStateBody_Sysinfo_AdditionalProperties) AsWorkerStateBodySysinfo1() (WorkerStateBodySysinfo1, error) { + var body WorkerStateBodySysinfo1 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWorkerStateBodySysinfo1 overwrites any union data inside the WorkerStateBody_Sysinfo_AdditionalProperties as the provided WorkerStateBodySysinfo1 +func (t *WorkerStateBody_Sysinfo_AdditionalProperties) FromWorkerStateBodySysinfo1(v WorkerStateBodySysinfo1) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWorkerStateBodySysinfo1 performs a merge with any union data inside the WorkerStateBody_Sysinfo_AdditionalProperties, using the provided WorkerStateBodySysinfo1 +func (t *WorkerStateBody_Sysinfo_AdditionalProperties) MergeWorkerStateBodySysinfo1(v WorkerStateBodySysinfo1) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t WorkerStateBody_Sysinfo_AdditionalProperties) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *WorkerStateBody_Sysinfo_AdditionalProperties) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// Client which conforms to the OpenAPI3 specification for this service. +type Client struct { + // The endpoint of the server conforming to this interface, with scheme, + // https://api.deepmap.com for example. This can contain a path relative + // to the server, such as https://api.deepmap.com/dev-test, and all the + // paths in the swagger spec will be appended to the server. + Server string + + *resty.Client + + RequestMiddleware []resty.RequestMiddleware +} + +// ClientOption allows setting custom parameters during construction +type ClientOption func(*Client) error + +func NewClient(server string, opts ...ClientOption) (ClientInterface, error) { + // create a client with sane default values + client := Client{ + Server: server, + Client: resty.New(), + } + client.Client.SetBaseURL(client.Server) + // mutate client and add all optional params + for _, o := range opts { + if err := o(&client); err != nil { + return nil, err + } + } + // ensure the server URL always has a trailing slash + if !strings.HasSuffix(client.Server, "/") { + client.Server += "/" + } + + return &client, nil +} + +// WithClient allows overriding the default [resty.Client], which is +// automatically created using http.Client. +// +// If this is used the `server` base URL argument passed in will not be respected anymore +func WithClient(r *resty.Client) ClientOption { + return func(c *Client) error { + c.Client = r + return nil + } +} + +// WithRoundTripper method sets custom http.Transport or any http.RoundTripper +// compatible interface implementation in the Resty client +func WithRoundTripper(transport http.RoundTripper) ClientOption { + return func(c *Client) error { + c.Client.SetTransport(transport) + return nil + } +} + +// WithRequestMiddleware allows setting up a callback function, which will be +// called right before sending the request. This can be used to mutate the request. +func WithRequestMiddleware(mw resty.RequestMiddleware) ClientOption { + return func(c *Client) error { + c.RequestMiddleware = append(c.RequestMiddleware, mw) + c.Client = c.Client.AddRequestMiddleware(mw) + return nil + } +} + +func (c *Client) Jobs() JobsClient { + return &jobsClient{c.Client} +} + +func (c *Client) Logs() LogsClient { + return &logsClient{c.Client} +} + +func (c *Client) Worker() WorkerClient { + return &workerClient{c.Client} +} + +type ClientInterface interface { + // Jobs deals with all the Jobs endpoints + Jobs() JobsClient + // Logs deals with all the Logs endpoints + Logs() LogsClient + // Worker deals with all the Worker endpoints + Worker() WorkerClient +} + +type GeneralHTTPError struct { + Response *resty.Response + JSON map[string]any + Text string +} + +var errorTypes = map[int]string{ + 1: "informational response", + 3: "redirect response", + 4: "client error", + 5: "server error", +} + +func (e GeneralHTTPError) Error() string { + var b strings.Builder + kind, ok := errorTypes[e.Response.StatusCode()/100] + if !ok { + kind = "unknown HTTP error" + } + fmt.Fprintf(&b, "%s '%s'", kind, e.Response.Status()) + if e.JSON != nil { + fmt.Fprintf(&b, " %v", e.JSON) + } else { + fmt.Fprintf(&b, " content=%q", e.Text) + } + return b.String() +} + +func HandleError(client *resty.Client, resp *resty.Response) error { + if !resp.IsError() { + return nil + } + + e := GeneralHTTPError{Response: resp} + + e.Text = resp.String() + if resp.Header().Get("content-type") == "application/json" { + if json.Unmarshal([]byte(e.Text), &e.JSON) == nil { + e.Text = "" + } + } + + // Set the parsed error back into the object so `resp.Error()` returns the populated one! + resp.Request.SetError(&e) + + return &e +} + +type jobsClient struct { + *resty.Client +} + +// FetchResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *jobsClient) FetchResponse(ctx context.Context, workerName string, body *WorkerQueuesBody) (resp *resty.Response, err error) { + var res struct { + // Command Execute the given Task. + Command ExecuteTask `json:"command"` + + // ConcurrencySlots Number of concurrency slots the job requires. + ConcurrencySlots int `json:"concurrency_slots"` + + // DagId Identifier of the DAG to which the task belongs. + DagId string `json:"dag_id"` + + // MapIndex For dynamically mapped tasks the mapping number, -1 if the task is not mapped. + MapIndex int `json:"map_index"` + + // RunId Run ID of the DAG execution. + RunId string `json:"run_id"` + + // TaskId Task name in the DAG. + TaskId string `json:"task_id"` + + // TryNumber The number of attempt to execute this task. + TryNumber int `json:"try_number"` + } + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "worker_name", runtime.ParamLocationPath, workerName) + if err != nil { + return nil, err + } + + if body == nil { + return nil, fmt.Errorf("Fetch requires a non-nil body argument") + } + resp, err = c.R(). + SetContext(ctx). + SetPathParam("worker_name", pathParam0). + SetBody(body). + SetResult(&res). + Post("edge_worker/v1/jobs/fetch/{worker_name}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *jobsClient) Fetch(ctx context.Context, workerName string, body *WorkerQueuesBody) (*struct { + // Command Execute the given Task. + Command ExecuteTask `json:"command"` + + // ConcurrencySlots Number of concurrency slots the job requires. + ConcurrencySlots int `json:"concurrency_slots"` + + // DagId Identifier of the DAG to which the task belongs. + DagId string `json:"dag_id"` + + // MapIndex For dynamically mapped tasks the mapping number, -1 if the task is not mapped. + MapIndex int `json:"map_index"` + + // RunId Run ID of the DAG execution. + RunId string `json:"run_id"` + + // TaskId Task name in the DAG. + TaskId string `json:"task_id"` + + // TryNumber The number of attempt to execute this task. + TryNumber int `json:"try_number"` +}, error, +) { + res, err := c.FetchResponse(ctx, workerName, body) + if err != nil { + return nil, err + } + + return res.Result().(*struct { + // Command Execute the given Task. + Command ExecuteTask `json:"command"` + + // ConcurrencySlots Number of concurrency slots the job requires. + ConcurrencySlots int `json:"concurrency_slots"` + + // DagId Identifier of the DAG to which the task belongs. + DagId string `json:"dag_id"` + + // MapIndex For dynamically mapped tasks the mapping number, -1 if the task is not mapped. + MapIndex int `json:"map_index"` + + // RunId Run ID of the DAG execution. + RunId string `json:"run_id"` + + // TaskId Task name in the DAG. + TaskId string `json:"task_id"` + + // TryNumber The number of attempt to execute this task. + TryNumber int `json:"try_number"` + }), nil +} + +// StateResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *jobsClient) StateResponse(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, state TaskInstanceState) (resp *resty.Response, err error) { + var res interface{} + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "dag_id", runtime.ParamLocationPath, dagId) + if err != nil { + return nil, err + } + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "task_id", runtime.ParamLocationPath, taskId) + if err != nil { + return nil, err + } + var pathParam2 string + + pathParam2, err = runtime.StyleParamWithLocation("simple", false, "run_id", runtime.ParamLocationPath, runId) + if err != nil { + return nil, err + } + var pathParam3 string + + pathParam3, err = runtime.StyleParamWithLocation("simple", false, "try_number", runtime.ParamLocationPath, tryNumber) + if err != nil { + return nil, err + } + var pathParam4 string + + pathParam4, err = runtime.StyleParamWithLocation("simple", false, "map_index", runtime.ParamLocationPath, mapIndex) + if err != nil { + return nil, err + } + var pathParam5 string + + pathParam5, err = runtime.StyleParamWithLocation("simple", false, "state", runtime.ParamLocationPath, state) + if err != nil { + return nil, err + } + + resp, err = c.R(). + SetContext(ctx). + SetPathParam("dag_id", pathParam0). + SetPathParam("task_id", pathParam1). + SetPathParam("run_id", pathParam2). + SetPathParam("try_number", pathParam3). + SetPathParam("map_index", pathParam4). + SetPathParam("state", pathParam5). + SetResult(&res). + Patch("edge_worker/v1/jobs/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *jobsClient) State(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, state TaskInstanceState) (*interface{}, error) { + res, err := c.StateResponse(ctx, dagId, taskId, runId, tryNumber, mapIndex, state) + if err != nil { + return nil, err + } + + return res.Result().(*interface{}), nil +} + +type logsClient struct { + *resty.Client +} + +// filePathResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *logsClient) filePathResponse(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int) (resp *resty.Response, err error) { + var res string + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "dag_id", runtime.ParamLocationPath, dagId) + if err != nil { + return nil, err + } + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "task_id", runtime.ParamLocationPath, taskId) + if err != nil { + return nil, err + } + var pathParam2 string + + pathParam2, err = runtime.StyleParamWithLocation("simple", false, "run_id", runtime.ParamLocationPath, runId) + if err != nil { + return nil, err + } + var pathParam3 string + + pathParam3, err = runtime.StyleParamWithLocation("simple", false, "try_number", runtime.ParamLocationPath, tryNumber) + if err != nil { + return nil, err + } + var pathParam4 string + + pathParam4, err = runtime.StyleParamWithLocation("simple", false, "map_index", runtime.ParamLocationPath, mapIndex) + if err != nil { + return nil, err + } + + resp, err = c.R(). + SetContext(ctx). + SetPathParam("dag_id", pathParam0). + SetPathParam("task_id", pathParam1). + SetPathParam("run_id", pathParam2). + SetPathParam("try_number", pathParam3). + SetPathParam("map_index", pathParam4). + SetResult(&res). + Get("edge_worker/v1/logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *logsClient) filePath(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int) (*string, error) { + res, err := c.filePathResponse(ctx, dagId, taskId, runId, tryNumber, mapIndex) + if err != nil { + return nil, err + } + + return res.Result().(*string), nil +} + +// PushResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *logsClient) PushResponse(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, body *PushLogsBody) (resp *resty.Response, err error) { + var res interface{} + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "dag_id", runtime.ParamLocationPath, dagId) + if err != nil { + return nil, err + } + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "task_id", runtime.ParamLocationPath, taskId) + if err != nil { + return nil, err + } + var pathParam2 string + + pathParam2, err = runtime.StyleParamWithLocation("simple", false, "run_id", runtime.ParamLocationPath, runId) + if err != nil { + return nil, err + } + var pathParam3 string + + pathParam3, err = runtime.StyleParamWithLocation("simple", false, "try_number", runtime.ParamLocationPath, tryNumber) + if err != nil { + return nil, err + } + var pathParam4 string + + pathParam4, err = runtime.StyleParamWithLocation("simple", false, "map_index", runtime.ParamLocationPath, mapIndex) + if err != nil { + return nil, err + } + + if body == nil { + return nil, fmt.Errorf("Push requires a non-nil body argument") + } + resp, err = c.R(). + SetContext(ctx). + SetPathParam("dag_id", pathParam0). + SetPathParam("task_id", pathParam1). + SetPathParam("run_id", pathParam2). + SetPathParam("try_number", pathParam3). + SetPathParam("map_index", pathParam4). + SetBody(body). + SetResult(&res). + Post("edge_worker/v1/logs/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *logsClient) Push(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, body *PushLogsBody) (*interface{}, error) { + res, err := c.PushResponse(ctx, dagId, taskId, runId, tryNumber, mapIndex, body) + if err != nil { + return nil, err + } + + return res.Result().(*interface{}), nil +} + +type workerClient struct { + *resty.Client +} + +// UpdateQueuesResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *workerClient) UpdateQueuesResponse(ctx context.Context, workerName string, body *WorkerQueueUpdateBody) (resp *resty.Response, err error) { + var res interface{} + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "worker_name", runtime.ParamLocationPath, workerName) + if err != nil { + return nil, err + } + + if body == nil { + return nil, fmt.Errorf("UpdateQueues requires a non-nil body argument") + } + resp, err = c.R(). + SetContext(ctx). + SetPathParam("worker_name", pathParam0). + SetBody(body). + SetResult(&res). + Patch("edge_worker/v1/worker/queues/{worker_name}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *workerClient) UpdateQueues(ctx context.Context, workerName string, body *WorkerQueueUpdateBody) (*interface{}, error) { + res, err := c.UpdateQueuesResponse(ctx, workerName, body) + if err != nil { + return nil, err + } + + return res.Result().(*interface{}), nil +} + +// SetStateResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *workerClient) SetStateResponse(ctx context.Context, workerName string, body *WorkerStateBody) (resp *resty.Response, err error) { + var res WorkerSetStateReturn + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "worker_name", runtime.ParamLocationPath, workerName) + if err != nil { + return nil, err + } + + if body == nil { + return nil, fmt.Errorf("SetState requires a non-nil body argument") + } + resp, err = c.R(). + SetContext(ctx). + SetPathParam("worker_name", pathParam0). + SetBody(body). + SetResult(&res). + Patch("edge_worker/v1/worker/{worker_name}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *workerClient) SetState(ctx context.Context, workerName string, body *WorkerStateBody) (*WorkerSetStateReturn, error) { + res, err := c.SetStateResponse(ctx, workerName, body) + if err != nil { + return nil, err + } + + return res.Result().(*WorkerSetStateReturn), nil +} + +// RegisterResponse performs the HTTP request and returns the lower level [resty.Response] +func (c *workerClient) RegisterResponse(ctx context.Context, workerName string, body *WorkerStateBody) (resp *resty.Response, err error) { + var res WorkerRegistrationReturn + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "worker_name", runtime.ParamLocationPath, workerName) + if err != nil { + return nil, err + } + + if body == nil { + return nil, fmt.Errorf("Register requires a non-nil body argument") + } + resp, err = c.R(). + SetContext(ctx). + SetPathParam("worker_name", pathParam0). + SetBody(body). + SetResult(&res). + Post("edge_worker/v1/worker/{worker_name}") + if err != nil { + return resp, err + } + return resp, HandleError(c.Client, resp) +} + +func (c *workerClient) Register(ctx context.Context, workerName string, body *WorkerStateBody) (*WorkerRegistrationReturn, error) { + res, err := c.RegisterResponse(ctx, workerName, body) + if err != nil { + return nil, err + } + + return res.Result().(*WorkerRegistrationReturn), nil +} + +type JobsClient interface { + // Fetch a job to execute on the edge worker. + Fetch(ctx context.Context, workerName string, body *WorkerQueuesBody) (*struct { + // Command Execute the given Task. + Command ExecuteTask `json:"command"` + + // ConcurrencySlots Number of concurrency slots the job requires. + ConcurrencySlots int `json:"concurrency_slots"` + + // DagId Identifier of the DAG to which the task belongs. + DagId string `json:"dag_id"` + + // MapIndex For dynamically mapped tasks the mapping number, -1 if the task is not mapped. + MapIndex int `json:"map_index"` + + // RunId Run ID of the DAG execution. + RunId string `json:"run_id"` + + // TaskId Task name in the DAG. + TaskId string `json:"task_id"` + + // TryNumber The number of attempt to execute this task. + TryNumber int `json:"try_number"` + }, error) + // FetchResponse is a lower level version of [Fetch] and provides access to the raw [resty.Response] + FetchResponse(ctx context.Context, workerName string, body *WorkerQueuesBody) (*resty.Response, error) + + // Update the state of a job running on the edge worker. + State(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, state TaskInstanceState) (*interface{}, error) + // StateResponse is a lower level version of [State] and provides access to the raw [resty.Response] + StateResponse(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, state TaskInstanceState) (*resty.Response, error) +} + +var _ JobsClient = (*jobsClient)(nil) + +type LogsClient interface { + // Elaborate the path and filename to expect from task execution. + filePath(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int) (*string, error) + // filePathResponse is a lower level version of [filePath] and provides access to the raw [resty.Response] + filePathResponse(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int) (*resty.Response, error) + + // Push an incremental log chunk from Edge Worker to central site. + Push(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, body *PushLogsBody) (*interface{}, error) + // PushResponse is a lower level version of [Push] and provides access to the raw [resty.Response] + PushResponse(ctx context.Context, dagId string, taskId string, runId string, tryNumber int, mapIndex int, body *PushLogsBody) (*resty.Response, error) +} + +var _ LogsClient = (*logsClient)(nil) + +type WorkerClient interface { + UpdateQueues(ctx context.Context, workerName string, body *WorkerQueueUpdateBody) (*interface{}, error) + // UpdateQueuesResponse is a lower level version of [UpdateQueues] and provides access to the raw [resty.Response] + UpdateQueuesResponse(ctx context.Context, workerName string, body *WorkerQueueUpdateBody) (*resty.Response, error) + + // Set state of worker and returns the current assigned queues. + SetState(ctx context.Context, workerName string, body *WorkerStateBody) (*WorkerSetStateReturn, error) + // SetStateResponse is a lower level version of [SetState] and provides access to the raw [resty.Response] + SetStateResponse(ctx context.Context, workerName string, body *WorkerStateBody) (*resty.Response, error) + + // Register a new worker to the backend. + Register(ctx context.Context, workerName string, body *WorkerStateBody) (*WorkerRegistrationReturn, error) + // RegisterResponse is a lower level version of [Register] and provides access to the raw [resty.Response] + RegisterResponse(ctx context.Context, workerName string, body *WorkerStateBody) (*resty.Response, error) +} + +var _ WorkerClient = (*workerClient)(nil) diff --git a/go-sdk/pkg/edgeapi/client.go b/go-sdk/pkg/edgeapi/client.go new file mode 100644 index 0000000000000..d4d888ffbbf5c --- /dev/null +++ b/go-sdk/pkg/edgeapi/client.go @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package edgeapi + +import ( + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "resty.dev/v3" +) + +//go:generate -command openapi-gen go run github.com/ashb/oapi-resty-codegen@latest --config oapi-codegen.yml + +//go:generate openapi-gen https://raw.githubusercontent.com/apache/airflow/refs/tags/providers-edge3/1.3.0/providers/edge3/src/airflow/providers/edge3/openapi/v2-edge-generated.yaml + +func WithEdgeAPIJWTKey(key []byte) ClientOption { + return func(c *Client) error { + c.SetAuthScheme("") + + mw := func(c *resty.Client, req *resty.Request) error { + endpointPath := strings.TrimPrefix(req.RawRequest.URL.String(), c.BaseURL()) + endpointPath = strings.TrimPrefix(endpointPath, "/edge_worker/v1/") + now := time.Now().UTC().Unix() + t := jwt.NewWithClaims(jwt.SigningMethodHS512, jwt.MapClaims{ + "method": endpointPath, + "aud": "api", + "iat": now, + "nbf": now, + "exp": now + 5, + }) + s, err := t.SignedString(key) + if err != nil { + return err + } + req.RawRequest.Header.Set( + req.HeaderAuthorizationKey, + strings.TrimSpace(req.AuthScheme+" "+s), + ) + return nil + } + + mws := append(c.RequestMiddleware, resty.PrepareRequestMiddleware, mw) + c.SetRequestMiddlewares(mws...) + return nil + } +} diff --git a/go-sdk/pkg/edgeapi/oapi-codegen.yml b/go-sdk/pkg/edgeapi/oapi-codegen.yml new file mode 100644 index 0000000000000..e2061ae374f7c --- /dev/null +++ b/go-sdk/pkg/edgeapi/oapi-codegen.yml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# yamllint disable rule:line-length +# yaml-language-server: $schema=https://raw.githubusercontent.com/oapi-codegen/oapi-codegen/HEAD/configuration-schema.json +# yamllint enable +--- +package: edgeapi +output: client.gen.go +generate: + models: true + client: true +output-options: + exclude-tags: [Health, UI, Monitor] + # Include the unreferenced schemas/models, such as `TaskInstance` + skip-prune: true + overlay: + path: overlay.yml +downgrade-options: + anyOf-to-oneOf: true diff --git a/go-sdk/pkg/edgeapi/overlay.yml b/go-sdk/pkg/edgeapi/overlay.yml new file mode 100644 index 0000000000000..2592f41db9ba5 --- /dev/null +++ b/go-sdk/pkg/edgeapi/overlay.yml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +# yamllint disable rule:line-length +overlay: 1.0.0 +info: + title: OpenAPI Overlay + version: 0.0.0 +actions: + - description: |- + Remove Authorization param + + We handle it separately, we don't want it in the params object + target: $.paths.*.*.parameters[?(@.name == "authorization" && @.in == "header")] + remove: true diff --git a/go-sdk/pkg/logging/level.go b/go-sdk/pkg/logging/level.go new file mode 100644 index 0000000000000..5b17f4a0b9c98 --- /dev/null +++ b/go-sdk/pkg/logging/level.go @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package logging + +import "log/slog" + +// LevelTrace is for log messages even lower than debug +const LevelTrace = slog.LevelDebug - (slog.LevelInfo - slog.LevelDebug) diff --git a/go-sdk/pkg/logging/shclog/shclog.go b/go-sdk/pkg/logging/shclog/shclog.go index c1984fbea2e3a..dda40d6bad6a4 100644 --- a/go-sdk/pkg/logging/shclog/shclog.go +++ b/go-sdk/pkg/logging/shclog/shclog.go @@ -31,10 +31,9 @@ import ( "time" "github.com/hashicorp/go-hclog" -) -// Slog does not have built in trace log level -const SlogLevelTrace = slog.LevelDebug - (slog.LevelInfo - slog.LevelDebug) + "github.com/apache/airflow/go-sdk/pkg/logging" +) const ( TimeFormatJSON = "2006-01-02T15:04:05.000000Z0700" @@ -114,7 +113,7 @@ func (l *Shclog) log(ctx context.Context, level slog.Level, msg string, args ... } func (l *Shclog) Trace(msg string, args ...any) { - l.log(context.Background(), SlogLevelTrace, msg, args...) + l.log(context.Background(), logging.LevelTrace, msg, args...) } func (l *Shclog) Debug(msg string, args ...any) { @@ -219,7 +218,7 @@ func getSlogLevel(l *slog.Logger) hclog.Level { h := l.Handler() ctx := context.Background() - if h.Enabled(ctx, SlogLevelTrace) { + if h.Enabled(ctx, logging.LevelTrace) { return hclog.Trace } if h.Enabled(ctx, slog.LevelDebug) { diff --git a/go-sdk/pkg/worker/runner.go b/go-sdk/pkg/worker/runner.go index 9d63d7f1a093c..29a4b659bb0fd 100644 --- a/go-sdk/pkg/worker/runner.go +++ b/go-sdk/pkg/worker/runner.go @@ -258,13 +258,6 @@ func (w *worker) ExecuteTaskWorkload(ctx context.Context, workload api.ExecuteTa State: api.Running, StartDate: time.Now().UTC(), }) - logger.LogAttrs( - ctx, - slog.LevelInfo, - "Start context", - slog.Any("resp", runtimeContext), - logging.AttrErr(err), - ) if err != nil { var httpError *api.GeneralHTTPError if errors.As(err, &httpError) {