Skip to content

Commit ba2ac63

Browse files
committed
Fix traced mTLS identity client transport
1 parent 231ef7d commit ba2ac63

3 files changed

Lines changed: 110 additions & 23 deletions

File tree

identityclient/client.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818

1919
identityv1 "github.com/evalops/proto/gen/go/identity/v1"
2020
"github.com/evalops/service-runtime/mtls"
21-
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
2221
"google.golang.org/protobuf/encoding/protojson"
2322
"google.golang.org/protobuf/proto"
2423
"google.golang.org/protobuf/types/known/timestamppb"
@@ -159,7 +158,7 @@ func New(config Config) *Client {
159158
httpClient = http.DefaultClient
160159
}
161160
usesMTLSCert := httpClientUsesMTLSCertificate(httpClient)
162-
httpClient = tracedHTTPClient(httpClient)
161+
httpClient = mtls.TraceHTTPClient(httpClient)
163162
maxSize := config.MaxCacheSize
164163
if maxSize <= 0 {
165164
maxSize = defaultMaxCacheSize
@@ -177,19 +176,6 @@ func New(config Config) *Client {
177176
}
178177
}
179178

180-
func tracedHTTPClient(client *http.Client) *http.Client {
181-
if client == nil {
182-
client = http.DefaultClient
183-
}
184-
cloned := *client
185-
baseTransport := cloned.Transport
186-
if baseTransport == nil {
187-
baseTransport = http.DefaultTransport
188-
}
189-
cloned.Transport = otelhttp.NewTransport(baseTransport)
190-
return &cloned
191-
}
192-
193179
// NewClient creates a Client that introspects tokens at the given URL.
194180
func NewClient(introspectURL string, requestTimeout time.Duration, httpClient *http.Client) *Client {
195181
return New(Config{
@@ -422,11 +408,22 @@ func httpClientUsesMTLSCertificate(client *http.Client) bool {
422408
if client == nil {
423409
return false
424410
}
425-
transport, ok := client.Transport.(*http.Transport)
426-
if !ok || transport == nil {
427-
return false
411+
transport := client.Transport
412+
if transport == nil {
413+
transport = http.DefaultTransport
414+
}
415+
for transport != nil {
416+
httpTransport, ok := transport.(*http.Transport)
417+
if ok {
418+
return tlsConfigHasClientCertificate(httpTransport.TLSClientConfig)
419+
}
420+
wrapped, ok := transport.(interface{ Unwrap() http.RoundTripper })
421+
if !ok {
422+
return false
423+
}
424+
transport = wrapped.Unwrap()
428425
}
429-
return tlsConfigHasClientCertificate(transport.TLSClientConfig)
426+
return false
430427
}
431428

432429
func tlsConfigHasClientCertificate(cfg *tls.Config) bool {

identityclient/client_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"go.opentelemetry.io/otel"
1919
"go.opentelemetry.io/otel/propagation"
2020
sdktrace "go.opentelemetry.io/otel/sdk/trace"
21+
"go.opentelemetry.io/otel/sdk/trace/tracetest"
2122
)
2223

2324
type roundTripFunc func(*http.Request) (*http.Response, error)
@@ -67,6 +68,65 @@ func TestConfigured(t *testing.T) {
6768
}
6869
}
6970

71+
func TestConfiguredDetectsMTLSCertificatesThroughTracedTransport(t *testing.T) {
72+
if !New(Config{
73+
ServiceTokensURL: "https://identity.internal/v1/service-tokens",
74+
HTTPClient: mtls.TraceHTTPClient(&http.Client{
75+
Transport: &http.Transport{
76+
TLSClientConfig: &tls.Config{Certificates: []tls.Certificate{{}}},
77+
},
78+
}),
79+
}).ServiceTokensConfigured() {
80+
t.Fatal("expected traced mtls-authenticated service tokens to be configured")
81+
}
82+
}
83+
84+
func TestNewDoesNotDoubleWrapTracedHTTPClient(t *testing.T) {
85+
originalProvider := otel.GetTracerProvider()
86+
originalPropagator := otel.GetTextMapPropagator()
87+
recorder := tracetest.NewSpanRecorder()
88+
tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
89+
otel.SetTracerProvider(tracerProvider)
90+
otel.SetTextMapPropagator(propagation.TraceContext{})
91+
t.Cleanup(func() {
92+
otel.SetTracerProvider(originalProvider)
93+
otel.SetTextMapPropagator(originalPropagator)
94+
_ = tracerProvider.Shutdown(context.Background())
95+
})
96+
97+
httpClient := mtls.TraceHTTPClient(&http.Client{
98+
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
99+
return &http.Response{
100+
StatusCode: http.StatusOK,
101+
Body: io.NopCloser(strings.NewReader(`{"active":true,"organization_id":"org-123"}`)),
102+
Header: make(http.Header),
103+
}, nil
104+
}),
105+
})
106+
client := New(Config{
107+
IntrospectURL: "https://identity.test/v1/tokens/introspect",
108+
RequestTimeout: time.Second,
109+
HTTPClient: httpClient,
110+
})
111+
112+
ctx, span := tracerProvider.Tracer("identityclient-test").Start(context.Background(), "root")
113+
defer span.End()
114+
115+
if _, err := client.Introspect(ctx, "write-token"); err != nil {
116+
t.Fatalf("introspect: %v", err)
117+
}
118+
119+
httpSpanCount := 0
120+
for _, ended := range recorder.Ended() {
121+
if ended.Name() == "HTTP POST" {
122+
httpSpanCount++
123+
}
124+
}
125+
if httpSpanCount != 1 {
126+
t.Fatalf("expected one HTTP client span, got %d", httpSpanCount)
127+
}
128+
}
129+
70130
func TestIntrospectSuccess(t *testing.T) {
71131
server := testutil.NewTestServer(t, http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
72132
if got := request.Header.Get("Authorization"); got != "Bearer write-token" {

mtls/mtls.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func BuildHTTPClient(cfg ClientConfig) (*http.Client, error) {
6969
return nil, err
7070
}
7171
if tlsConfig == nil {
72-
return traceHTTPClient(http.DefaultClient), nil
72+
return TraceHTTPClient(http.DefaultClient), nil
7373
}
7474

7575
transport, ok := http.DefaultTransport.(*http.Transport)
@@ -78,10 +78,11 @@ func BuildHTTPClient(cfg ClientConfig) (*http.Client, error) {
7878
}
7979
clone := transport.Clone()
8080
clone.TLSClientConfig = tlsConfig
81-
return traceHTTPClient(&http.Client{Transport: clone}), nil
81+
return TraceHTTPClient(&http.Client{Transport: clone}), nil
8282
}
8383

84-
func traceHTTPClient(client *http.Client) *http.Client {
84+
// TraceHTTPClient clones client and wraps its transport with OTel propagation.
85+
func TraceHTTPClient(client *http.Client) *http.Client {
8586
if client == nil {
8687
client = http.DefaultClient
8788
}
@@ -90,10 +91,39 @@ func traceHTTPClient(client *http.Client) *http.Client {
9091
if transport == nil {
9192
transport = http.DefaultTransport
9293
}
93-
cloned.Transport = otelhttp.NewTransport(transport)
94+
cloned.Transport = traceRoundTripper(transport)
9495
return &cloned
9596
}
9697

98+
func traceRoundTripper(transport http.RoundTripper) http.RoundTripper {
99+
if transport == nil {
100+
transport = http.DefaultTransport
101+
}
102+
if _, ok := transport.(*tracedTransport); ok {
103+
return transport
104+
}
105+
if _, ok := transport.(*otelhttp.Transport); ok {
106+
return transport
107+
}
108+
return &tracedTransport{
109+
base: transport,
110+
traced: otelhttp.NewTransport(transport),
111+
}
112+
}
113+
114+
type tracedTransport struct {
115+
base http.RoundTripper
116+
traced http.RoundTripper
117+
}
118+
119+
func (t *tracedTransport) RoundTrip(request *http.Request) (*http.Response, error) {
120+
return t.traced.RoundTrip(request)
121+
}
122+
123+
func (t *tracedTransport) Unwrap() http.RoundTripper {
124+
return t.base
125+
}
126+
97127
// BuildClientTLSConfig returns a *tls.Config for an outbound mTLS client, or nil when all fields are empty.
98128
func BuildClientTLSConfig(cfg ClientConfig) (*tls.Config, error) {
99129
if cfg.CAFile == "" && cfg.CertFile == "" && cfg.KeyFile == "" && cfg.ServerName == "" {

0 commit comments

Comments
 (0)