diff --git a/internal/auth0/connection.go b/internal/auth0/connection.go index c0b553fbb..cecc24418 100644 --- a/internal/auth0/connection.go +++ b/internal/auth0/connection.go @@ -27,4 +27,7 @@ type ConnectionAPI interface { // List all connections. List(ctx context.Context, opts ...management.RequestOption) (ul *management.ConnectionList, err error) + + // ReadEnabledClients retrieves the enabled clients for a connection. + ReadEnabledClients(ctx context.Context, id string, opts ...management.RequestOption) (c *management.ConnectionEnabledClientList, err error) } diff --git a/internal/auth0/mock/connection_mock.go b/internal/auth0/mock/connection_mock.go index 0060253aa..c63d4fb83 100644 --- a/internal/auth0/mock/connection_mock.go +++ b/internal/auth0/mock/connection_mock.go @@ -133,6 +133,26 @@ func (mr *MockConnectionAPIMockRecorder) ReadByName(ctx, id interface{}, opts .. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadByName", reflect.TypeOf((*MockConnectionAPI)(nil).ReadByName), varargs...) } +// ReadEnabledClients mocks base method. +func (m *MockConnectionAPI) ReadEnabledClients(ctx context.Context, id string, opts ...management.RequestOption) (*management.ConnectionEnabledClientList, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, id} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ReadEnabledClients", varargs...) + ret0, _ := ret[0].(*management.ConnectionEnabledClientList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadEnabledClients indicates an expected call of ReadEnabledClients. +func (mr *MockConnectionAPIMockRecorder) ReadEnabledClients(ctx, id interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, id}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadEnabledClients", reflect.TypeOf((*MockConnectionAPI)(nil).ReadEnabledClients), varargs...) +} + // Update mocks base method. func (m *MockConnectionAPI) Update(ctx context.Context, id string, c *management.Connection, opts ...management.RequestOption) error { m.ctrl.T.Helper() diff --git a/internal/cli/users.go b/internal/cli/users.go index 6284369d6..fb0955f95 100644 --- a/internal/cli/users.go +++ b/internal/cli/users.go @@ -404,7 +404,11 @@ func createUserCmd(cli *cli) *cobra.Command { return fmt.Errorf("failed to find connection with name %q: %w", inputs.connectionName, err) } - if len(connection.GetEnabledClients()) == 0 { + hasClients, err := connectionHasEnabledClients(cmd.Context(), cli.api.Connection, connection.GetID()) + if err != nil { + return fmt.Errorf("failed to check enabled clients for connection %q: %w", inputs.connectionName, err) + } + if !hasClients { return fmt.Errorf( "failed to continue due to the connection with name %q being disabled, enable an application on this connection and try again", inputs.connectionName, @@ -894,15 +898,19 @@ The file size limit for a bulk import is 500KB. You will need to start multiple return fmt.Errorf("failed to read connection with name %q: %w", inputs.ConnectionName, err) } - if len(connection.GetEnabledClients()) == 0 { + inputs.ConnectionID = connection.GetID() + + hasClients, err := connectionHasEnabledClients(cmd.Context(), cli.api.Connection, inputs.ConnectionID) + if err != nil { + return fmt.Errorf("failed to check enabled clients for connection %q: %w", inputs.ConnectionName, err) + } + if !hasClients { return fmt.Errorf( "failed to continue due to the connection with name %q being disabled, enable an application on this connection and try again", inputs.ConnectionName, ) } - inputs.ConnectionID = connection.GetID() - pipedUsersBody := iostream.PipedInput() if len(pipedUsersBody) > 0 && inputs.UsersBody == "" { inputs.UsersBody = string(pipedUsersBody) @@ -996,7 +1004,11 @@ func (c *cli) databaseAndPasswordlessConnectionOptions(ctx context.Context) ([]s var connectionNames []string for _, connection := range connectionList.Connections { - if len(connection.GetEnabledClients()) == 0 { + hasClients, err := connectionHasEnabledClients(ctx, c.api.Connection, connection.GetID()) + if err != nil { + continue + } + if !hasClients { continue } @@ -1010,6 +1022,17 @@ func (c *cli) databaseAndPasswordlessConnectionOptions(ctx context.Context) ([]s return connectionNames, nil } +// connectionHasEnabledClients checks if a connection has any enabled clients +// using the dedicated endpoint (replaces deprecated enabled_clients field). +func connectionHasEnabledClients(ctx context.Context, api auth0.ConnectionAPI, connectionID string) (bool, error) { + clients, err := api.ReadEnabledClients(ctx, connectionID) + if err != nil { + return false, err + } + + return clients.Clients != nil && len(*clients.Clients) > 0, nil +} + func (c *cli) getUserConnection(users *management.User) []string { var res []string for _, i := range users.Identities { diff --git a/internal/cli/users_test.go b/internal/cli/users_test.go index c66f21709..c1fe74dea 100644 --- a/internal/cli/users_test.go +++ b/internal/cli/users_test.go @@ -16,35 +16,26 @@ import ( func TestConnectionsPickerOptions(t *testing.T) { tests := []struct { - name string - connections []*management.Connection - apiError error - assertOutput func(t testing.TB, options []string) - assertError func(t testing.TB, err error) + name string + connections []*management.Connection + enabledClients map[string]*management.ConnectionEnabledClientList // Keyed by connection ID. + apiError error + assertOutput func(t testing.TB, options []string) + assertError func(t testing.TB, err error) }{ { name: "happy path", connections: []*management.Connection{ - { - Name: auth0.String("some-name-1"), - Strategy: auth0.String("auth0"), - EnabledClients: &[]string{"1"}, - }, - { - Name: auth0.String("some-name-2"), - Strategy: auth0.String("auth0"), - EnabledClients: &[]string{"1"}, - }, - { - Name: auth0.String("some-name-3"), - Strategy: auth0.String("sms"), - EnabledClients: &[]string{"1"}, - }, - { - Name: auth0.String("some-name-4"), - Strategy: auth0.String("email"), - EnabledClients: &[]string{"1"}, - }, + {ID: auth0.String("conn-1"), Name: auth0.String("some-name-1"), Strategy: auth0.String("auth0")}, + {ID: auth0.String("conn-2"), Name: auth0.String("some-name-2"), Strategy: auth0.String("auth0")}, + {ID: auth0.String("conn-3"), Name: auth0.String("some-name-3"), Strategy: auth0.String("sms")}, + {ID: auth0.String("conn-4"), Name: auth0.String("some-name-4"), Strategy: auth0.String("email")}, + }, + enabledClients: map[string]*management.ConnectionEnabledClientList{ + "conn-1": {Clients: &[]management.ConnectionEnabledClient{{ClientID: auth0.String("app-1")}}}, + "conn-2": {Clients: &[]management.ConnectionEnabledClient{{ClientID: auth0.String("app-1")}}}, + "conn-3": {Clients: &[]management.ConnectionEnabledClient{{ClientID: auth0.String("app-1")}}}, + "conn-4": {Clients: &[]management.ConnectionEnabledClient{{ClientID: auth0.String("app-1")}}}, }, assertOutput: func(t testing.TB, options []string) { assert.Len(t, options, 4) @@ -60,24 +51,16 @@ func TestConnectionsPickerOptions(t *testing.T) { { name: "happy path: returning only active connections", connections: []*management.Connection{ - { - Name: auth0.String("some-name-1"), - Strategy: auth0.String("auth0"), - EnabledClients: &[]string{"1"}, - }, - { - Name: auth0.String("some-name-2"), - Strategy: auth0.String("auth0"), - EnabledClients: &[]string{"1"}, - }, - { - Name: auth0.String("some-name-3"), - Strategy: auth0.String("sms"), - }, - { - Name: auth0.String("some-name-4"), - Strategy: auth0.String("email"), - }, + {ID: auth0.String("conn-1"), Name: auth0.String("some-name-1"), Strategy: auth0.String("auth0")}, + {ID: auth0.String("conn-2"), Name: auth0.String("some-name-2"), Strategy: auth0.String("auth0")}, + {ID: auth0.String("conn-3"), Name: auth0.String("some-name-3"), Strategy: auth0.String("sms")}, + {ID: auth0.String("conn-4"), Name: auth0.String("some-name-4"), Strategy: auth0.String("email")}, + }, + enabledClients: map[string]*management.ConnectionEnabledClientList{ + "conn-1": {Clients: &[]management.ConnectionEnabledClient{{ClientID: auth0.String("app-1")}}}, + "conn-2": {Clients: &[]management.ConnectionEnabledClient{{ClientID: auth0.String("app-1")}}}, + "conn-3": {Clients: &[]management.ConnectionEnabledClient{}}, + "conn-4": {Clients: &[]management.ConnectionEnabledClient{}}, }, assertOutput: func(t testing.TB, options []string) { assert.Len(t, options, 2) @@ -127,6 +110,18 @@ func TestConnectionsPickerOptions(t *testing.T) { test.apiError, ) + // Set up ReadEnabledClients expectations for each connection. + if test.enabledClients != nil { + for _, conn := range test.connections { + id := conn.GetID() + if clients, ok := test.enabledClients[id]; ok { + connectionAPI.EXPECT(). + ReadEnabledClients(ctx, id). + Return(clients, nil) + } + } + } + cli := &cli{ api: &auth0.API{ Connection: connectionAPI,