diff --git a/internal/extension/extension.go b/internal/extension/extension.go index 3f69f41..8e54afc 100644 --- a/internal/extension/extension.go +++ b/internal/extension/extension.go @@ -14,6 +14,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" "net/http" "os" "reflect" @@ -97,7 +98,14 @@ func (em *ExtensionManager) checkAgentRunning() { // Tell the extension not to create an execution span if universal instrumentation is disabled if !em.isUniversalInstrumentation { req, _ := http.NewRequest(http.MethodGet, em.helloRoute, nil) - if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 { + response, err := em.httpClient.Do(req) + if response != nil && response.Body != nil { + defer func() { + _, _ = io.Copy(io.Discard, response.Body) + response.Body.Close() + }() + } + if err == nil && response.StatusCode == 200 { logger.Debug("Hit the extension /hello route") } else { logger.Debug("Will use the API since the Serverless Agent was detected but the hello route was unreachable") @@ -110,7 +118,14 @@ func (em *ExtensionManager) checkAgentRunning() { func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, eventPayload json.RawMessage) context.Context { body := bytes.NewBuffer(eventPayload) req, _ := http.NewRequest(http.MethodPost, em.startInvocationUrl, body) - if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 { + response, err := em.httpClient.Do(req) + if response != nil && response.Body != nil { + defer func() { + _, _ = io.Copy(io.Discard, response.Body) + response.Body.Close() + }() + } + if err == nil && response.StatusCode == 200 { // Propagate dd-trace context from the extension response if found in the response headers traceId := response.Header.Get(string(DdTraceId)) if traceId != "" { @@ -179,7 +194,7 @@ func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functi logger.Error(fmt.Errorf("could not get sampling priority from spanContext.SamplingPriority()")) } } else { - if priority, ok := getSamplingPriority(functionExecutionSpan) ; ok { + if priority, ok := getSamplingPriority(functionExecutionSpan); ok { req.Header.Set(string(DdSamplingPriority), fmt.Sprint(priority)) } else { logger.Error(fmt.Errorf("could not get sampling priority from getSamplingPriority()")) @@ -188,6 +203,12 @@ func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functi } resp, err := em.httpClient.Do(req) + if resp != nil && resp.Body != nil { + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + } if err != nil || resp.StatusCode != 200 { logger.Error(fmt.Errorf("could not send end invocation payload to the extension: %v", err)) } @@ -236,7 +257,14 @@ func (em *ExtensionManager) IsExtensionRunning() bool { func (em *ExtensionManager) Flush() error { req, _ := http.NewRequest(http.MethodGet, em.flushRoute, nil) - if response, err := em.httpClient.Do(req); err != nil { + response, err := em.httpClient.Do(req) + if response != nil && response.Body != nil { + defer func() { + _, _ = io.Copy(io.Discard, response.Body) + response.Body.Close() + }() + } + if err != nil { err := fmt.Errorf("was not able to reach the Agent to flush: %s", err) logger.Error(err) return err @@ -252,33 +280,33 @@ func (em *ExtensionManager) Flush() error { // But for dd-trace-go v1.74.x, reflection is needed to access the SamplingPriority method because // the method hidden in the v2 SpanContextV2Adapter struct. func getSamplingPriority(span ddtrace.Span) (int, bool) { - // Get the span context - ctx := span.Context() - - // Use reflection to access the underlying v2 SpanContext - ctxValue := reflect.ValueOf(ctx) - if ctxValue.Type().String() != "internal.SpanContextV2Adapter" { - return 0, false - } - - // Get the Ctx field (the underlying v2.SpanContext) - ctxField := ctxValue.FieldByName("Ctx") - if !ctxField.IsValid() { - return 0, false - } - - // Call SamplingPriority() on the underlying v2 SpanContext - method := ctxField.MethodByName("SamplingPriority") - if !method.IsValid() { - return 0, false - } - - results := method.Call([]reflect.Value{}) - if len(results) != 2 { - return 0, false - } - - priority := int(results[0].Int()) - ok := results[1].Bool() - return priority, ok - } + // Get the span context + ctx := span.Context() + + // Use reflection to access the underlying v2 SpanContext + ctxValue := reflect.ValueOf(ctx) + if ctxValue.Type().String() != "internal.SpanContextV2Adapter" { + return 0, false + } + + // Get the Ctx field (the underlying v2.SpanContext) + ctxField := ctxValue.FieldByName("Ctx") + if !ctxField.IsValid() { + return 0, false + } + + // Call SamplingPriority() on the underlying v2 SpanContext + method := ctxField.MethodByName("SamplingPriority") + if !method.IsValid() { + return 0, false + } + + results := method.Call([]reflect.Value{}) + if len(results) != 2 { + return 0, false + } + + priority := int(results[0].Int()) + ok := results[1].Bool() + return priority, ok +}