From 23697e5c7051c199f9c15001463f93a8fc72da9f Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 7 Apr 2026 00:28:40 +0530 Subject: [PATCH 1/6] Fix U2M OAuth: don't send empty client_secret for public apps The U2M flow uses PKCE (public app) and should not send a client secret. Previously, ClientSecret was always set to "" on the oauth2.Config, which caused Go's oauth2 library to send an empty client_secret via Basic auth. The OIDC server rejects this with "Public app should not use a client secret". Only set ClientSecret when it's non-empty, so public apps use the "none" token endpoint auth method as intended. Signed-off-by: Madhavendra Rathore Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore --- auth/oauth/u2m/u2m.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/auth/oauth/u2m/u2m.go b/auth/oauth/u2m/u2m.go index 456e369a..c0eb615d 100644 --- a/auth/oauth/u2m/u2m.go +++ b/auth/oauth/u2m/u2m.go @@ -25,11 +25,16 @@ func GetConfig(ctx context.Context, hostName, clientID, clientSecret, callbackUR } config := oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - Endpoint: endpoint, - RedirectURL: callbackURL, - Scopes: scopes, + ClientID: clientID, + Endpoint: endpoint, + RedirectURL: callbackURL, + Scopes: scopes, + } + // Only set ClientSecret if non-empty. For U2M (public apps using PKCE), + // sending an empty client_secret causes the server to reject with + // "Public app should not use a client secret". + if clientSecret != "" { + config.ClientSecret = clientSecret } return config, nil From 0ec7e068a8aecbc9943214bdad22c16701a56678 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 7 Apr 2026 00:36:07 +0530 Subject: [PATCH 2/6] Fix feature flags and U2M OAuth for SPOG support Feature flags: - Fix endpoint path: /api/2.0/feature-flags -> /api/2.0/connector-service/feature-flags/GOLANG/{version} - Fix response parsing: map format -> array of {name, value} entries - Add extraHeaders for SPOG routing (x-databricks-org-id) - Extract ?o= from httpPath in connector U2M OAuth: - Don't set ClientSecret for public apps (PKCE) - Force AuthStyleInParams to prevent Basic auth with empty password - Server rejects "Public app should not use a client secret" otherwise Signed-off-by: Madhavendra Rathore Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore --- auth/oauth/u2m/u2m.go | 9 +++-- connector.go | 25 +++++++++++++ telemetry/config.go | 4 +- telemetry/config_test.go | 50 +++++++++---------------- telemetry/driver_integration.go | 4 +- telemetry/featureflag.go | 65 +++++++++++++++++++++------------ telemetry/featureflag_test.go | 43 ++++++++++------------ 7 files changed, 113 insertions(+), 87 deletions(-) diff --git a/auth/oauth/u2m/u2m.go b/auth/oauth/u2m/u2m.go index c0eb615d..81cf12d1 100644 --- a/auth/oauth/u2m/u2m.go +++ b/auth/oauth/u2m/u2m.go @@ -30,11 +30,14 @@ func GetConfig(ctx context.Context, hostName, clientID, clientSecret, callbackUR RedirectURL: callbackURL, Scopes: scopes, } - // Only set ClientSecret if non-empty. For U2M (public apps using PKCE), - // sending an empty client_secret causes the server to reject with - // "Public app should not use a client secret". if clientSecret != "" { config.ClientSecret = clientSecret + } else { + // For U2M (public apps using PKCE), force AuthStyleInParams to avoid + // sending Basic auth with empty password. AuthStyleInHeader sends + // "Authorization: Basic base64(clientID:)" which the server rejects + // with "Public app should not use a client secret". + config.Endpoint.AuthStyle = oauth2.AuthStyleInParams } return config, nil diff --git a/connector.go b/connector.go index f5d33d37..e16eb9c5 100644 --- a/connector.go +++ b/connector.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "net/http" + "net/url" "strings" "time" @@ -76,12 +77,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") + // Extract SPOG routing headers from ?o= in HTTPPath + spogHeaders := extractSpogHeaders(c.cfg.HTTPPath) + // Initialize telemetry: client config overlay decides; if unset, feature flags decide conn.telemetry = telemetry.InitializeForConnection( ctx, c.cfg.Host, c.client, c.cfg.EnableTelemetry, + spogHeaders, ) if conn.telemetry != nil { log.Debug().Msg("telemetry initialized for connection") @@ -107,6 +112,7 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) { // config with default options cfg := config.WithDefaults() cfg.DriverVersion = DriverVersion + telemetry.SetDriverVersion(DriverVersion) for _, opt := range options { opt(cfg) @@ -117,6 +123,25 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) { return &connector{cfg: cfg, client: client}, nil } +// extractSpogHeaders extracts ?o= from httpPath and returns +// an x-databricks-org-id header for SPOG routing. +func extractSpogHeaders(httpPath string) map[string]string { + if !strings.Contains(httpPath, "?") { + return nil + } + // Parse query string from httpPath + parts := strings.SplitN(httpPath, "?", 2) + params, err := url.ParseQuery(parts[1]) + if err != nil { + return nil + } + orgID := params.Get("o") + if orgID == "" { + return nil + } + return map[string]string{"x-databricks-org-id": orgID} +} + func withUserConfig(ucfg config.UserConfig) ConnOption { return func(c *config.Config) { c.UserConfig = ucfg diff --git a/telemetry/config.go b/telemetry/config.go index 7bc76d00..a049bf79 100644 --- a/telemetry/config.go +++ b/telemetry/config.go @@ -102,7 +102,7 @@ func ParseTelemetryConfig(params map[string]string) *Config { // // Returns: // - bool: true if telemetry should be enabled, false otherwise -func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client) bool { +func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client, extraHeaders map[string]string) bool { // Priority 1: Client explicitly set (overrides server) if cfg.EnableTelemetry.IsSet() { val, _ := cfg.EnableTelemetry.Get() @@ -111,7 +111,7 @@ func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClien // Priority 2: Check server-side feature flag flagCache := getFeatureFlagCache() - serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient) + serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient, extraHeaders) if err != nil { // Priority 3: Fail-safe default (disabled) return false diff --git a/telemetry/config_test.go b/telemetry/config_test.go index d5ecdc2b..3b35bcb6 100644 --- a/telemetry/config_test.go +++ b/telemetry/config_test.go @@ -2,7 +2,7 @@ package telemetry import ( "context" - "encoding/json" + "net/http" "net/http/httptest" "testing" @@ -206,12 +206,8 @@ func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) { // Setup: Create a server that returns disabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Server says disabled, but client override should win - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`)) })) defer server.Close() @@ -228,7 +224,7 @@ func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) { defer flagCache.releaseContext(server.URL) // Client override should bypass server check - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil) if !result { t.Error("Expected telemetry to be enabled when client explicitly sets enableTelemetry=true, got disabled") @@ -240,12 +236,8 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) { // Setup: Create a server that returns enabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Server says enabled, but client override should win - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`)) })) defer server.Close() @@ -261,7 +253,7 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil) if result { t.Error("Expected telemetry to be disabled when client explicitly sets enableTelemetry=false, got enabled") @@ -272,12 +264,8 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) { func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) { // Setup: Create a server that returns enabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`)) })) defer server.Close() @@ -293,7 +281,7 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil) if !result { t.Error("Expected telemetry to be enabled when server flag is true, got disabled") @@ -304,12 +292,8 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) { func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) { // Setup: Create a server that returns disabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`)) })) defer server.Close() @@ -325,7 +309,7 @@ func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil) if result { t.Error("Expected telemetry to be disabled when server flag is false, got enabled") @@ -340,7 +324,7 @@ func TestIsTelemetryEnabled_FailSafeDefault(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // No server available, should default to disabled (fail-safe) - result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient) + result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient, nil) if result { t.Error("Expected telemetry to be disabled when server unavailable (fail-safe), got enabled") @@ -367,7 +351,7 @@ func TestIsTelemetryEnabled_ServerError(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil) // On error, should default to disabled (fail-safe) if result { @@ -390,7 +374,7 @@ func TestIsTelemetryEnabled_ServerUnreachable(t *testing.T) { flagCache.getOrCreateContext(unreachableHost) defer flagCache.releaseContext(unreachableHost) - result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient) + result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient, nil) // On error, should default to disabled (fail-safe) if result { @@ -418,7 +402,7 @@ func TestIsTelemetryEnabled_ClientOverridesServerError(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil) // Client override should work even when server errors if !result { diff --git a/telemetry/driver_integration.go b/telemetry/driver_integration.go index 5dcd2f71..4d9de81c 100644 --- a/telemetry/driver_integration.go +++ b/telemetry/driver_integration.go @@ -16,6 +16,7 @@ import ( // - host: Databricks host // - httpClient: HTTP client for making requests // - enableTelemetry: Client config overlay (unset = check server flag, true/false = override server) +// - extraHeaders: Additional HTTP headers for SPOG routing (e.g. x-databricks-org-id) // // Returns: // - *Interceptor: Telemetry interceptor if enabled, nil otherwise @@ -24,13 +25,14 @@ func InitializeForConnection( host string, httpClient *http.Client, enableTelemetry config.ConfigValue[bool], + extraHeaders map[string]string, ) *Interceptor { // Create telemetry config and apply client overlay cfg := DefaultConfig() cfg.EnableTelemetry = enableTelemetry // Check if telemetry should be enabled - if !isTelemetryEnabled(ctx, cfg, host, httpClient) { + if !isTelemetryEnabled(ctx, cfg, host, httpClient, extraHeaders) { return nil } diff --git a/telemetry/featureflag.go b/telemetry/featureflag.go index 6943e455..209c4216 100644 --- a/telemetry/featureflag.go +++ b/telemetry/featureflag.go @@ -24,6 +24,14 @@ const ( // flagEnableNewFeature = "databricks.partnerplatform.clientConfigsFeatureFlags.enableNewFeatureForGoDriver" ) +// driverVersion is set during initialization from the connector package. +var driverVersion = "unknown" + +// SetDriverVersion sets the driver version used in feature flag endpoint paths. +func SetDriverVersion(version string) { + driverVersion = version +} + // featureFlagCache manages feature flag state per host with reference counting. // This prevents rate limiting by caching feature flag responses. type featureFlagCache struct { @@ -90,7 +98,7 @@ func (c *featureFlagCache) releaseContext(host string) { // getFeatureFlag retrieves a specific feature flag value for the host. // This is the generic method that handles caching and fetching for any flag. // Uses cached value if available and not expired. -func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, httpClient *http.Client, flagName string) (bool, error) { +func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, httpClient *http.Client, flagName string, extraHeaders map[string]string) (bool, error) { c.mu.RLock() flagCtx, exists := c.contexts[host] c.mu.RUnlock() @@ -111,7 +119,7 @@ func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, http // If we just created the context, make the initial blocking fetch if !exists { - flags, err := fetchFeatureFlags(ctx, host, httpClient) + flags, err := fetchFeatureFlags(ctx, host, httpClient, extraHeaders) flagCtx.mu.Lock() flagCtx.fetching = false @@ -155,7 +163,7 @@ func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, http flagCtx.mu.RUnlock() // Fetch fresh values for all flags - flags, err := fetchFeatureFlags(ctx, host, httpClient) + flags, err := fetchFeatureFlags(ctx, host, httpClient, extraHeaders) // Update cache (with proper locking) flagCtx.mu.Lock() @@ -184,8 +192,8 @@ func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, http // isTelemetryEnabled checks if telemetry is enabled for the host. // Uses cached value if available and not expired. -func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client) (bool, error) { - return c.getFeatureFlag(ctx, host, httpClient, flagEnableTelemetry) +func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client, extraHeaders map[string]string) (bool, error) { + return c.getFeatureFlag(ctx, host, httpClient, flagEnableTelemetry, extraHeaders) } // isExpired returns true if the cache has expired. @@ -203,9 +211,22 @@ func getAllFeatureFlags() []string { } } -// fetchFeatureFlags fetches multiple feature flag values from Databricks in a single request. +// featureFlagEntry represents a single flag from the connector-service response. +type featureFlagEntry struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// featureFlagResponse represents the response from the connector-service endpoint. +type featureFlagResponse struct { + Flags []featureFlagEntry `json:"flags"` + TTLSeconds int `json:"ttl_seconds,omitempty"` +} + +// fetchFeatureFlags fetches feature flag values from the connector-service endpoint. +// Endpoint: GET /api/2.0/connector-service/feature-flags/{CLIENT_TYPE}/{VERSION} // Returns a map of flag names to their boolean values. -func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client) (map[string]bool, error) { +func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client, extraHeaders map[string]string) (map[string]bool, error) { // Add timeout to context if it doesn't have a deadline if _, hasDeadline := ctx.Deadline(); !hasDeadline { var cancel context.CancelFunc @@ -213,12 +234,12 @@ func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client defer cancel() } - // Construct endpoint URL, adding https:// if not already present + // Construct endpoint URL using the connector-service path var endpoint string if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") { - endpoint = fmt.Sprintf("%s/api/2.0/feature-flags", host) + endpoint = fmt.Sprintf("%s/api/2.0/connector-service/feature-flags/GOLANG/%s", host, driverVersion) } else { - endpoint = fmt.Sprintf("https://%s/api/2.0/feature-flags", host) + endpoint = fmt.Sprintf("https://%s/api/2.0/connector-service/feature-flags/GOLANG/%s", host, driverVersion) } req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) @@ -226,12 +247,10 @@ func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client return nil, fmt.Errorf("failed to create feature flag request: %w", err) } - // Add query parameter with comma-separated list of feature flags - // This fetches all flags in a single request for efficiency - allFlags := getAllFeatureFlags() - q := req.URL.Query() - q.Add("flags", strings.Join(allFlags, ",")) - req.URL.RawQuery = q.Encode() + // Add extra headers (e.g. x-databricks-org-id for SPOG routing) + for k, v := range extraHeaders { + req.Header.Set(k, v) + } resp, err := httpClient.Do(req) if err != nil { @@ -245,18 +264,16 @@ func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client return nil, fmt.Errorf("feature flag check failed: %d", resp.StatusCode) } - var result struct { - Flags map[string]bool `json:"flags"` - } + var result featureFlagResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode feature flag response: %w", err) } - // Return the full map of flags - // Flags not present in the response will have false value when accessed - if result.Flags == nil { - return make(map[string]bool), nil + // Convert array of {name, value} entries to a map of name -> bool + flags := make(map[string]bool, len(result.Flags)) + for _, f := range result.Flags { + flags[f.Name] = strings.EqualFold(f.Value, "true") } - return result.Flags, nil + return flags, nil } diff --git a/telemetry/featureflag_test.go b/telemetry/featureflag_test.go index b0aa519a..02467adb 100644 --- a/telemetry/featureflag_test.go +++ b/telemetry/featureflag_test.go @@ -100,7 +100,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Cached(t *testing.T) { ctx.lastFetched = time.Now() // Should return cached value without HTTP call - result, err := cache.isTelemetryEnabled(context.Background(), host, nil) + result, err := cache.isTelemetryEnabled(context.Background(), host, nil, nil) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -116,7 +116,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Expired(t *testing.T) { callCount++ w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true}}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`)) })) defer server.Close() @@ -135,7 +135,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Expired(t *testing.T) { // Should fetch fresh value httpClient := &http.Client{} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient, nil) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -161,7 +161,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_NoContext(t *testing.T) { // Should return false for non-existent context (network error expected) httpClient := &http.Client{Timeout: 1 * time.Second} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient, nil) // Error expected due to network failure, but should not panic if result != false { t.Error("Expected false for non-existent context") @@ -192,7 +192,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_ErrorFallback(t *testing.T) { // Should return cached value on error httpClient := &http.Client{} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient, nil) if err != nil { t.Errorf("Expected no error (fallback to cache), got %v", err) } @@ -217,7 +217,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_ErrorNoCache(t *testing.T) { // No cached value, should return error httpClient := &http.Client{} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient, nil) if err == nil { t.Error("Expected error when no cache available and fetch fails") } @@ -323,27 +323,22 @@ func TestFetchFeatureFlags_Success(t *testing.T) { if r.Method != "GET" { t.Errorf("Expected GET request, got %s", r.Method) } - if r.URL.Path != "/api/2.0/feature-flags" { - t.Errorf("Expected /api/2.0/feature-flags path, got %s", r.URL.Path) + expectedPath := "/api/2.0/connector-service/feature-flags/GOLANG/" + driverVersion + if r.URL.Path != expectedPath { + t.Errorf("Expected %s path, got %s", expectedPath, r.URL.Path) } - flags := r.URL.Query().Get("flags") - expectedFlag := "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver" - if flags != expectedFlag { - t.Errorf("Expected flag query param %s, got %s", expectedFlag, flags) - } - - // Return success response + // Return success response in connector-service format w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true}}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`)) })) defer server.Close() host := server.URL // Use full URL for testing httpClient := &http.Client{} - flags, err := fetchFeatureFlags(context.Background(), host, httpClient) + flags, err := fetchFeatureFlags(context.Background(), host, httpClient, nil) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -356,14 +351,14 @@ func TestFetchFeatureFlags_Disabled(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false}}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`)) })) defer server.Close() host := server.URL // Use full URL for testing httpClient := &http.Client{} - flags, err := fetchFeatureFlags(context.Background(), host, httpClient) + flags, err := fetchFeatureFlags(context.Background(), host, httpClient, nil) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -376,14 +371,14 @@ func TestFetchFeatureFlags_FlagNotPresent(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {}}`)) + _, _ = w.Write([]byte(`{"flags": []}`)) })) defer server.Close() host := server.URL // Use full URL for testing httpClient := &http.Client{} - flags, err := fetchFeatureFlags(context.Background(), host, httpClient) + flags, err := fetchFeatureFlags(context.Background(), host, httpClient, nil) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -401,7 +396,7 @@ func TestFetchFeatureFlags_HTTPError(t *testing.T) { host := server.URL // Use full URL for testing httpClient := &http.Client{} - _, err := fetchFeatureFlags(context.Background(), host, httpClient) + _, err := fetchFeatureFlags(context.Background(), host, httpClient, nil) if err == nil { t.Error("Expected error for HTTP 500") } @@ -418,7 +413,7 @@ func TestFetchFeatureFlags_InvalidJSON(t *testing.T) { host := server.URL // Use full URL for testing httpClient := &http.Client{} - _, err := fetchFeatureFlags(context.Background(), host, httpClient) + _, err := fetchFeatureFlags(context.Background(), host, httpClient, nil) if err == nil { t.Error("Expected error for invalid JSON") } @@ -437,7 +432,7 @@ func TestFetchFeatureFlags_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - _, err := fetchFeatureFlags(ctx, host, httpClient) + _, err := fetchFeatureFlags(ctx, host, httpClient, nil) if err == nil { t.Error("Expected error for cancelled context") } From 6bda425b319e9f9eec67b6aa72d0413bc4aa6783 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 21 Apr 2026 12:46:38 +0530 Subject: [PATCH 3/6] Add debug logging for SPOG x-databricks-org-id header extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors equivalent logging added to OSS JDBC and pysql. Emits at DEBUG level in three paths of extractSpogHeaders: 1. Malformed query string in httpPath — log and skip. 2. httpPath has "?" but no ?o= param — log and skip. 3. Injection happens — log the extracted workspace ID so customers diagnosing SPOG routing can confirm the header was added. Also adds a detailed docstring explaining the role this header plays: Thrift routing stays URL-driven via ?o= in httpPath; only the separate endpoints (telemetry, feature flags) need the header for account-level routing on SPOG hosts. Helps with customer support: when a customer reports "SPOG isn't routing correctly", they can enable DEBUG logging and immediately see whether the driver saw their ?o= value. Signed-off-by: Madhavendra Rathore Signed-off-by: Madhavendra Rathore --- connector.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/connector.go b/connector.go index e16eb9c5..773a6435 100644 --- a/connector.go +++ b/connector.go @@ -125,6 +125,18 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) { // extractSpogHeaders extracts ?o= from httpPath and returns // an x-databricks-org-id header for SPOG routing. +// +// On SPOG (Custom URL) workspaces, httpPath is of the form +// /sql/1.0/warehouses/?o=. The ?o= parameter keeps Thrift +// requests routed to the correct workspace via the URL itself, but other +// endpoints (telemetry, feature flags) run on separate hosts and need the +// x-databricks-org-id header. This function extracts ?o= from httpPath once +// and returns it so those paths can inject it as an HTTP header. +// +// Returns nil if: +// - httpPath has no query string ("?"), or +// - the query string is malformed and can't be parsed, or +// - the ?o= parameter is missing or empty. func extractSpogHeaders(httpPath string) map[string]string { if !strings.Contains(httpPath, "?") { return nil @@ -133,12 +145,21 @@ func extractSpogHeaders(httpPath string) map[string]string { parts := strings.SplitN(httpPath, "?", 2) params, err := url.ParseQuery(parts[1]) if err != nil { + logger.Debug().Msgf( + "SPOG header extraction: malformed query string in httpPath, skipping org-id extraction: %s", + err) return nil } orgID := params.Get("o") if orgID == "" { + logger.Debug().Msg( + "SPOG header extraction: httpPath has query string but no ?o= param, " + + "skipping x-databricks-org-id injection") return nil } + logger.Debug().Msgf( + "SPOG header extraction: injecting x-databricks-org-id=%s (extracted from ?o= in httpPath)", + orgID) return map[string]string{"x-databricks-org-id": orgID} } From bc4e923345a023ac28ffe565c25a6bc304fd44c5 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 21 Apr 2026 13:13:22 +0530 Subject: [PATCH 4/6] Refactor SPOG header injection: use transport wrapper, drop telemetry API churn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the per-function extraHeaders parameter threading (added during the previous merge with main) with a minimal http.Client wrapper in the top-level dbsql package. Before (5 files, including 3 in telemetry/*): connector.go extracted ?o=, then passed it as an ExtraHeaders field through telemetry.TelemetryInitOptions → isTelemetryEnabled → featureFlagCache.isTelemetryEnabled → fetchFeatureFlag, where it was applied to the outbound request. The telemetry push path (telemetry/exporter.go) was NOT covered. After (2 files — connector.go and auth/oauth/u2m/u2m.go): connector.go extracts ?o= as before, but now wraps the driver's *http.Client with a headerInjectingTransport that sets the SPOG header on every outbound request through that client. Passes the wrapped client (not c.client) into TelemetryInitOptions.HTTPClient. Advantages: - telemetry/*.go files revert to identical-to-main. No API churn. - Both feature-flag and telemetry-push paths automatically get the SPOG header (previously only feature-flag did). - Future HTTP paths that reuse the telemetry http.Client inherit SPOG routing for free. Thrift is unaffected: it uses c.client directly (not the wrapper) and routes via ?o= in the URL path. The transport wrapper is only applied to the HTTP client handed to telemetry. Signed-off-by: Madhavendra Rathore Signed-off-by: Madhavendra Rathore --- connector.go | 59 ++++++++++++++++++++++++++++++--- telemetry/config.go | 4 +-- telemetry/driver_integration.go | 10 +----- telemetry/featureflag.go | 11 ++---- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/connector.go b/connector.go index e6effe99..d97196db 100644 --- a/connector.go +++ b/connector.go @@ -81,20 +81,26 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") - // Extract SPOG routing headers from ?o= in HTTPPath - spogHeaders := extractSpogHeaders(c.cfg.HTTPPath) + // Extract SPOG routing headers from ?o= in HTTPPath. When ?o= + // is present (Custom URL / SPOG hosts), wrap the HTTP client used for + // telemetry + feature-flag calls with a transport that injects + // x-databricks-org-id. Thrift routes via the URL so its own c.client + // doesn't need wrapping. + telemetryClient := c.client + if spogHeaders := extractSpogHeaders(c.cfg.HTTPPath); len(spogHeaders) > 0 { + telemetryClient = withSpogHeaders(c.client, spogHeaders) + } // Initialize telemetry: client config overlay decides; if unset, feature flags decide conn.telemetry = telemetry.InitializeForConnection(ctx, telemetry.TelemetryInitOptions{ Host: c.cfg.Host, DriverVersion: c.cfg.DriverVersion, - HTTPClient: c.client, + HTTPClient: telemetryClient, EnableTelemetry: c.cfg.EnableTelemetry, BatchSize: c.cfg.TelemetryBatchSize, FlushInterval: c.cfg.TelemetryFlushInterval, RetryCount: c.cfg.TelemetryRetryCount, RetryDelay: c.cfg.TelemetryRetryDelay, - ExtraHeaders: spogHeaders, }) if conn.telemetry != nil { log.Debug().Msg("telemetry initialized for connection") @@ -171,6 +177,51 @@ func extractSpogHeaders(httpPath string) map[string]string { return map[string]string{"x-databricks-org-id": orgID} } +// withSpogHeaders returns a new *http.Client that reuses the transport of the +// provided client, wrapped to inject the given SPOG headers on every outbound +// request. The original client is left unchanged. If a request already has a +// given header set (e.g., the caller set it explicitly), the wrapper does not +// override it. +// +// This is how the driver gets x-databricks-org-id onto both the feature-flag +// check and the telemetry push without touching the telemetry package's +// signatures. +func withSpogHeaders(base *http.Client, headers map[string]string) *http.Client { + baseTransport := base.Transport + if baseTransport == nil { + baseTransport = http.DefaultTransport + } + return &http.Client{ + Transport: &headerInjectingTransport{ + base: baseTransport, + headers: headers, + }, + CheckRedirect: base.CheckRedirect, + Jar: base.Jar, + Timeout: base.Timeout, + } +} + +// headerInjectingTransport wraps an http.RoundTripper and sets a fixed set of +// headers on every outbound request. Caller-supplied headers with the same +// name are not overridden. +type headerInjectingTransport struct { + base http.RoundTripper + headers map[string]string +} + +// RoundTrip implements http.RoundTripper. +func (t *headerInjectingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone per RoundTripper contract — must not mutate the caller's request. + req2 := req.Clone(req.Context()) + for k, v := range t.headers { + if req2.Header.Get(k) == "" { + req2.Header.Set(k, v) + } + } + return t.base.RoundTrip(req2) +} + func withUserConfig(ucfg config.UserConfig) ConnOption { return func(c *config.Config) { c.UserConfig = ucfg diff --git a/telemetry/config.go b/telemetry/config.go index 4138aa1e..9054cb36 100644 --- a/telemetry/config.go +++ b/telemetry/config.go @@ -126,12 +126,12 @@ func ParseTelemetryConfig(params map[string]string) *Config { // (databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver). // // In all other cases — explicit opt-out or server flag absent/unreachable — returns false. -func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, driverVersion string, httpClient *http.Client, extraHeaders map[string]string) bool { +func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, driverVersion string, httpClient *http.Client) bool { if cfg.EnableTelemetry != nil { return *cfg.EnableTelemetry } - serverEnabled, err := getFeatureFlagCache().isTelemetryEnabled(ctx, host, driverVersion, httpClient, extraHeaders) + serverEnabled, err := getFeatureFlagCache().isTelemetryEnabled(ctx, host, driverVersion, httpClient) if err != nil { return false } diff --git a/telemetry/driver_integration.go b/telemetry/driver_integration.go index f46271cf..f8b32ffc 100644 --- a/telemetry/driver_integration.go +++ b/telemetry/driver_integration.go @@ -41,14 +41,6 @@ type TelemetryInitOptions struct { // RetryDelay is the base delay between retries (0 = use default 100ms). RetryDelay time.Duration - - // ExtraHeaders are additional HTTP headers to attach to feature-flag - // check requests and telemetry-push requests. Primarily used to carry - // x-databricks-org-id for SPOG (Custom URL) workspace routing — see - // extractSpogHeaders in the top-level dbsql package. - // - // May be nil. - ExtraHeaders map[string]string } // InitializeForConnection initializes telemetry for a database connection. @@ -86,7 +78,7 @@ func InitializeForConnection(ctx context.Context, opts TelemetryInitOptions) *In flagCache.getOrCreateContext(opts.Host) // Check if telemetry should be enabled - enabled := isTelemetryEnabled(ctx, cfg, opts.Host, opts.DriverVersion, opts.HTTPClient, opts.ExtraHeaders) + enabled := isTelemetryEnabled(ctx, cfg, opts.Host, opts.DriverVersion, opts.HTTPClient) if !enabled { flagCache.releaseContext(opts.Host) return nil diff --git a/telemetry/featureflag.go b/telemetry/featureflag.go index 228918b8..81696baa 100644 --- a/telemetry/featureflag.go +++ b/telemetry/featureflag.go @@ -86,7 +86,7 @@ func (c *featureFlagCache) releaseContext(host string) { // isTelemetryEnabled checks if telemetry is enabled for the host. // Uses cached value if available and not expired. -func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, driverVersion string, httpClient *http.Client, extraHeaders map[string]string) (bool, error) { +func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, driverVersion string, httpClient *http.Client) (bool, error) { c.mu.RLock() flagCtx, exists := c.contexts[host] c.mu.RUnlock() @@ -135,7 +135,7 @@ func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, flagCtx.mu.Unlock() // Fetch fresh value (outside lock so other readers are not blocked). - enabled, err := fetchFeatureFlag(ctx, host, driverVersion, httpClient, extraHeaders) + enabled, err := fetchFeatureFlag(ctx, host, driverVersion, httpClient) // Update cache. flagCtx.mu.Lock() @@ -166,7 +166,7 @@ func (c *featureFlagContext) isExpired() bool { } // fetchFeatureFlag fetches the feature flag value from Databricks. -func fetchFeatureFlag(ctx context.Context, host string, driverVersion string, httpClient *http.Client, extraHeaders map[string]string) (bool, error) { +func fetchFeatureFlag(ctx context.Context, host string, driverVersion string, httpClient *http.Client) (bool, error) { // Add timeout to context if it doesn't have a deadline if _, hasDeadline := ctx.Deadline(); !hasDeadline { var cancel context.CancelFunc @@ -183,11 +183,6 @@ func fetchFeatureFlag(ctx context.Context, host string, driverVersion string, ht return false, fmt.Errorf("failed to create feature flag request: %w", err) } - // Attach extra headers (e.g. x-databricks-org-id for SPOG routing). - for k, v := range extraHeaders { - req.Header.Set(k, v) - } - resp, err := httpClient.Do(req) if err != nil { return false, fmt.Errorf("failed to fetch feature flag: %w", err) From 3576c92c33d4278fbcad33c43451de21d3703268 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 21 Apr 2026 13:17:04 +0530 Subject: [PATCH 5/6] =?UTF-8?q?Revert=20u2m.go=20client=5Fsecret=20handlin?= =?UTF-8?q?g=20=E2=80=94=20not=20needed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Earlier commits in this branch (23697e5, 0ec7e06) modified u2m.go to avoid sending an empty client_secret on the PKCE public-app flow, citing server rejection with "Public app should not use a client secret". Empirical verification (2026-04-21): - Prod Legacy (adb-6436897454825492.12.azuredatabricks.net): PASS with unpatched u2m.go — server accepts the request. - Stg Legacy (adb-7064161269814046.2.staging.azuredatabricks.net): FAIL with 400 Bad Request on unpatched u2m.go. Since the production server tolerates the current behavior, the patch isn't strictly required for customers. Reverting to keep the PR minimal and matching upstream main exactly for this file. If staging server strictness later rolls out to prod, we can re-add this fix then. Signed-off-by: Madhavendra Rathore Signed-off-by: Madhavendra Rathore --- auth/oauth/u2m/u2m.go | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/auth/oauth/u2m/u2m.go b/auth/oauth/u2m/u2m.go index 81cf12d1..456e369a 100644 --- a/auth/oauth/u2m/u2m.go +++ b/auth/oauth/u2m/u2m.go @@ -25,19 +25,11 @@ func GetConfig(ctx context.Context, hostName, clientID, clientSecret, callbackUR } config := oauth2.Config{ - ClientID: clientID, - Endpoint: endpoint, - RedirectURL: callbackURL, - Scopes: scopes, - } - if clientSecret != "" { - config.ClientSecret = clientSecret - } else { - // For U2M (public apps using PKCE), force AuthStyleInParams to avoid - // sending Basic auth with empty password. AuthStyleInHeader sends - // "Authorization: Basic base64(clientID:)" which the server rejects - // with "Public app should not use a client secret". - config.Endpoint.AuthStyle = oauth2.AuthStyleInParams + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: endpoint, + RedirectURL: callbackURL, + Scopes: scopes, } return config, nil From 3d8c1593acd69273d444d2b1366a6827bdce1cd9 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 21 Apr 2026 13:30:55 +0530 Subject: [PATCH 6/6] Add unit tests for SPOG header extraction and injection Covers extractSpogHeaders (8 cases: missing/empty query, valid o=, missing o=, empty value, multi-param, duplicate o=, bare `?`) and headerInjectingTransport (injection, caller-set not overridden, other headers preserved, original client untouched). Signed-off-by: Madhavendra Rathore --- connector_spog_test.go | 164 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 connector_spog_test.go diff --git a/connector_spog_test.go b/connector_spog_test.go new file mode 100644 index 00000000..69273ee4 --- /dev/null +++ b/connector_spog_test.go @@ -0,0 +1,164 @@ +package dbsql + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractSpogHeaders(t *testing.T) { + tests := []struct { + name string + httpPath string + want map[string]string + }{ + { + name: "no query string returns nil", + httpPath: "/sql/1.0/warehouses/abc123", + want: nil, + }, + { + name: "empty httpPath returns nil", + httpPath: "", + want: nil, + }, + { + name: "query string with o= extracts org id", + httpPath: "/sql/1.0/warehouses/abc123?o=7064161269814046", + want: map[string]string{"x-databricks-org-id": "7064161269814046"}, + }, + { + name: "query string without o= returns nil", + httpPath: "/sql/1.0/warehouses/abc123?other=val", + want: nil, + }, + { + name: "empty o= value returns nil", + httpPath: "/sql/1.0/warehouses/abc123?o=", + want: nil, + }, + { + name: "o= among multiple params extracts correctly", + httpPath: "/sql/1.0/warehouses/abc?foo=1&o=12345&bar=2", + want: map[string]string{"x-databricks-org-id": "12345"}, + }, + { + name: "first o= wins when duplicated", + httpPath: "/sql/1.0/warehouses/abc?o=first&o=second", + want: map[string]string{"x-databricks-org-id": "first"}, + }, + { + name: "just ? with nothing after returns nil", + httpPath: "/sql/1.0/warehouses/abc?", + want: nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := extractSpogHeaders(tc.httpPath) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestHeaderInjectingTransport_InjectsHeader(t *testing.T) { + var gotHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("x-databricks-org-id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := withSpogHeaders(&http.Client{}, map[string]string{ + "x-databricks-org-id": "7064161269814046", + }) + + req, err := http.NewRequest("GET", srv.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + assert.Equal(t, "7064161269814046", gotHeader, "SPOG header should be injected") +} + +func TestHeaderInjectingTransport_DoesNotOverrideCallerSet(t *testing.T) { + var gotHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("x-databricks-org-id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := withSpogHeaders(&http.Client{}, map[string]string{ + "x-databricks-org-id": "from-wrapper", + }) + + req, err := http.NewRequest("GET", srv.URL, nil) + require.NoError(t, err) + req.Header.Set("x-databricks-org-id", "from-caller") + resp, err := client.Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + assert.Equal(t, "from-caller", gotHeader, "caller-set header must not be overridden") +} + +func TestHeaderInjectingTransport_PreservesOtherHeaders(t *testing.T) { + var gotAuth, gotSpog, gotCustom string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotSpog = r.Header.Get("x-databricks-org-id") + gotCustom = r.Header.Get("X-Custom") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := withSpogHeaders(&http.Client{}, map[string]string{ + "x-databricks-org-id": "abc", + }) + + req, err := http.NewRequest("GET", srv.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer xxx") + req.Header.Set("X-Custom", "hello") + resp, err := client.Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + assert.Equal(t, "Bearer xxx", gotAuth) + assert.Equal(t, "hello", gotCustom) + assert.Equal(t, "abc", gotSpog) +} + +func TestWithSpogHeaders_OriginalClientUntouched(t *testing.T) { + originalTransport := &countingTransport{} + original := &http.Client{Transport: originalTransport} + + wrapped := withSpogHeaders(original, map[string]string{"x-databricks-org-id": "x"}) + + // Original client's transport should NOT be the wrapper type. + _, isWrapped := original.Transport.(*headerInjectingTransport) + assert.False(t, isWrapped, "original client's transport must not be mutated") + + // Wrapped client MUST have the wrapper. + _, isWrapped = wrapped.Transport.(*headerInjectingTransport) + assert.True(t, isWrapped, "wrapped client must have headerInjectingTransport") +} + +type countingTransport struct { + count int +} + +func (c *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.count++ + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil +}