diff --git a/connector.go b/connector.go index 9b1e087..4d01965 100644 --- a/connector.go +++ b/connector.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "net/http" + "net/url" "strings" "time" @@ -80,11 +81,21 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") + // Extract SPOG routing headers from ?o= in HTTPPath. When ?o= + // is present (Custom URL / SPOG hosts), wrap the HTTP client used for + // telemetry + feature-flag calls with a transport that injects + // x-databricks-org-id. Thrift routes via the URL so its own c.client + // doesn't need wrapping. + telemetryClient := c.client + if spogHeaders := extractSpogHeaders(c.cfg.HTTPPath); len(spogHeaders) > 0 { + telemetryClient = withSpogHeaders(c.client, spogHeaders) + } + // Initialize telemetry: client config overlay decides; if unset, feature flags decide conn.telemetry = telemetry.InitializeForConnection(ctx, telemetry.TelemetryInitOptions{ Host: c.cfg.Host, DriverVersion: c.cfg.DriverVersion, - HTTPClient: c.client, + HTTPClient: telemetryClient, EnableTelemetry: c.cfg.EnableTelemetry, BatchSize: c.cfg.TelemetryBatchSize, FlushInterval: c.cfg.TelemetryFlushInterval, @@ -126,6 +137,91 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) { return &connector{cfg: cfg, client: client}, nil } +// extractSpogHeaders extracts ?o= from httpPath and returns +// an x-databricks-org-id header for SPOG routing. +// +// On SPOG (Custom URL) workspaces, httpPath is of the form +// /sql/1.0/warehouses/?o=. The ?o= parameter keeps Thrift +// requests routed to the correct workspace via the URL itself, but other +// endpoints (telemetry, feature flags) run on separate hosts and need the +// x-databricks-org-id header. This function extracts ?o= from httpPath once +// and returns it so those paths can inject it as an HTTP header. +// +// Returns nil if: +// - httpPath has no query string ("?"), or +// - the query string is malformed and can't be parsed, or +// - the ?o= parameter is missing or empty. +func extractSpogHeaders(httpPath string) map[string]string { + if !strings.Contains(httpPath, "?") { + return nil + } + // Parse query string from httpPath + parts := strings.SplitN(httpPath, "?", 2) + params, err := url.ParseQuery(parts[1]) + if err != nil { + logger.Debug().Msgf( + "SPOG header extraction: malformed query string in httpPath, skipping org-id extraction: %s", + err) + return nil + } + orgID := params.Get("o") + if orgID == "" { + logger.Debug().Msg( + "SPOG header extraction: httpPath has query string but no ?o= param, " + + "skipping x-databricks-org-id injection") + return nil + } + logger.Debug().Msgf( + "SPOG header extraction: injecting x-databricks-org-id=%s (extracted from ?o= in httpPath)", + orgID) + return map[string]string{"x-databricks-org-id": orgID} +} + +// withSpogHeaders returns a new *http.Client that reuses the transport of the +// provided client, wrapped to inject the given SPOG headers on every outbound +// request. The original client is left unchanged. If a request already has a +// given header set (e.g., the caller set it explicitly), the wrapper does not +// override it. +// +// This is how the driver gets x-databricks-org-id onto both the feature-flag +// check and the telemetry push without touching the telemetry package's +// signatures. +func withSpogHeaders(base *http.Client, headers map[string]string) *http.Client { + baseTransport := base.Transport + if baseTransport == nil { + baseTransport = http.DefaultTransport + } + return &http.Client{ + Transport: &headerInjectingTransport{ + base: baseTransport, + headers: headers, + }, + CheckRedirect: base.CheckRedirect, + Jar: base.Jar, + Timeout: base.Timeout, + } +} + +// headerInjectingTransport wraps an http.RoundTripper and sets a fixed set of +// headers on every outbound request. Caller-supplied headers with the same +// name are not overridden. +type headerInjectingTransport struct { + base http.RoundTripper + headers map[string]string +} + +// RoundTrip implements http.RoundTripper. +func (t *headerInjectingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone per RoundTripper contract — must not mutate the caller's request. + req2 := req.Clone(req.Context()) + for k, v := range t.headers { + if req2.Header.Get(k) == "" { + req2.Header.Set(k, v) + } + } + return t.base.RoundTrip(req2) +} + func withUserConfig(ucfg config.UserConfig) ConnOption { return func(c *config.Config) { c.UserConfig = ucfg diff --git a/connector_spog_test.go b/connector_spog_test.go new file mode 100644 index 0000000..69273ee --- /dev/null +++ b/connector_spog_test.go @@ -0,0 +1,164 @@ +package dbsql + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractSpogHeaders(t *testing.T) { + tests := []struct { + name string + httpPath string + want map[string]string + }{ + { + name: "no query string returns nil", + httpPath: "/sql/1.0/warehouses/abc123", + want: nil, + }, + { + name: "empty httpPath returns nil", + httpPath: "", + want: nil, + }, + { + name: "query string with o= extracts org id", + httpPath: "/sql/1.0/warehouses/abc123?o=7064161269814046", + want: map[string]string{"x-databricks-org-id": "7064161269814046"}, + }, + { + name: "query string without o= returns nil", + httpPath: "/sql/1.0/warehouses/abc123?other=val", + want: nil, + }, + { + name: "empty o= value returns nil", + httpPath: "/sql/1.0/warehouses/abc123?o=", + want: nil, + }, + { + name: "o= among multiple params extracts correctly", + httpPath: "/sql/1.0/warehouses/abc?foo=1&o=12345&bar=2", + want: map[string]string{"x-databricks-org-id": "12345"}, + }, + { + name: "first o= wins when duplicated", + httpPath: "/sql/1.0/warehouses/abc?o=first&o=second", + want: map[string]string{"x-databricks-org-id": "first"}, + }, + { + name: "just ? with nothing after returns nil", + httpPath: "/sql/1.0/warehouses/abc?", + want: nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := extractSpogHeaders(tc.httpPath) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestHeaderInjectingTransport_InjectsHeader(t *testing.T) { + var gotHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("x-databricks-org-id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := withSpogHeaders(&http.Client{}, map[string]string{ + "x-databricks-org-id": "7064161269814046", + }) + + req, err := http.NewRequest("GET", srv.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + assert.Equal(t, "7064161269814046", gotHeader, "SPOG header should be injected") +} + +func TestHeaderInjectingTransport_DoesNotOverrideCallerSet(t *testing.T) { + var gotHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("x-databricks-org-id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := withSpogHeaders(&http.Client{}, map[string]string{ + "x-databricks-org-id": "from-wrapper", + }) + + req, err := http.NewRequest("GET", srv.URL, nil) + require.NoError(t, err) + req.Header.Set("x-databricks-org-id", "from-caller") + resp, err := client.Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + assert.Equal(t, "from-caller", gotHeader, "caller-set header must not be overridden") +} + +func TestHeaderInjectingTransport_PreservesOtherHeaders(t *testing.T) { + var gotAuth, gotSpog, gotCustom string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotSpog = r.Header.Get("x-databricks-org-id") + gotCustom = r.Header.Get("X-Custom") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := withSpogHeaders(&http.Client{}, map[string]string{ + "x-databricks-org-id": "abc", + }) + + req, err := http.NewRequest("GET", srv.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer xxx") + req.Header.Set("X-Custom", "hello") + resp, err := client.Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + assert.Equal(t, "Bearer xxx", gotAuth) + assert.Equal(t, "hello", gotCustom) + assert.Equal(t, "abc", gotSpog) +} + +func TestWithSpogHeaders_OriginalClientUntouched(t *testing.T) { + originalTransport := &countingTransport{} + original := &http.Client{Transport: originalTransport} + + wrapped := withSpogHeaders(original, map[string]string{"x-databricks-org-id": "x"}) + + // Original client's transport should NOT be the wrapper type. + _, isWrapped := original.Transport.(*headerInjectingTransport) + assert.False(t, isWrapped, "original client's transport must not be mutated") + + // Wrapped client MUST have the wrapper. + _, isWrapped = wrapped.Transport.(*headerInjectingTransport) + assert.True(t, isWrapped, "wrapped client must have headerInjectingTransport") +} + +type countingTransport struct { + count int +} + +func (c *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.count++ + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil +}