diff --git a/cli/app.go b/cli/app.go index ba15a22ca00..c211cf0c6cf 100644 --- a/cli/app.go +++ b/cli/app.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/urfave/cli/v2" + apppb "go.viam.com/api/app/v1" ) // CLI flags. @@ -124,12 +125,24 @@ const ( packageMetadataFlagFramework = "model_framework" + // TODO: APP-6993 remove these flags. authApplicationFlagName = "application-name" authApplicationFlagApplicationID = "application-id" authApplicationFlagOriginURIs = "origin-uris" authApplicationFlagRedirectURIs = "redirect-uris" authApplicationFlagLogoutURI = "logout-uri" + oauthAppFlagClientID = "client-id" + oauthAppFlagClientName = "client-name" + oauthAppFlagClientAuthentication = "client-authentication" + oauthAppFlagPKCE = "pkce" + oauthAppFlagEnabledGrants = "enabled-grants" + oauthAppFlagURLValidation = "url-validation" + oauthAppFlagOriginURIs = "origin-uris" + oauthAppFlagRedirectURIs = "redirect-uris" + oauthAppFlagLogoutURI = "logout-uri" + unspecified = "unspecified" + cpFlagRecursive = "recursive" cpFlagPreserve = "preserve" @@ -424,6 +437,82 @@ var app = &cli.App{ Usage: "work with organizations", HideHelpCommand: true, Subcommands: []*cli.Command{ + { + Name: "auth-service", + Usage: "manage auth-service", + Subcommands: []*cli.Command{ + { + Name: "oauth-app", + Usage: "manage the OAuth applications for an organization", + Subcommands: []*cli.Command{ + { + Name: "update", + Usage: "update an OAuth application", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: generalFlagOrgID, + Required: true, + Usage: "organization ID that is tied to the OAuth application", + }, + &cli.StringFlag{ + Name: oauthAppFlagClientID, + Usage: "id for the OAuth application to be updated", + Required: true, + }, + &cli.StringFlag{ + Name: oauthAppFlagClientName, + Usage: "updated name for the OAuth application", + Required: false, + }, + &cli.StringFlag{ + Name: oauthAppFlagClientAuthentication, + Usage: "updated client authentication policy for the OAuth application. can be one of " + + allEnumValues(clientAuthenticationPrefix, apppb.ClientAuthentication_value), + Required: false, + Value: unspecified, + }, + &cli.StringFlag{ + Name: oauthAppFlagURLValidation, + Usage: "updated url validation for the OAuth application. can be one of " + + allEnumValues(urlValidationPrefix, apppb.URLValidation_value), + Required: false, + Value: unspecified, + }, + &cli.StringFlag{ + Name: oauthAppFlagPKCE, + Usage: "updated pkce for the OAuth application. can be one of " + + allEnumValues(pkcePrefix, apppb.PKCE_value), + Required: false, + Value: unspecified, + }, + &cli.StringSliceFlag{ + Name: oauthAppFlagOriginURIs, + Usage: "updated comma separated origin uris for the OAuth application", + Required: false, + }, + &cli.StringSliceFlag{ + Name: oauthAppFlagRedirectURIs, + Usage: "updated comma separated redirect uris for the OAuth application", + Required: false, + }, + &cli.StringFlag{ + Name: oauthAppFlagLogoutURI, + Usage: "updated logout uri for the OAuth application", + Required: false, + }, + &cli.StringSliceFlag{ + Name: oauthAppFlagEnabledGrants, + Usage: "updated comma separated enabled grants for the OAuth application. values can be of " + + allEnumValues(enabledGrantPrefix, apppb.EnabledGrant_value), + Required: false, + }, + }, + Action: createCommandWithT[updateOAuthAppArgs](UpdateOAuthAppAction), + }, + }, + }, + }, + }, { Name: "list", Usage: "list organizations for the current user", diff --git a/cli/client.go b/cli/client.go index 6ffc908bddc..30e8d5d20c4 100644 --- a/cli/client.go +++ b/cli/client.go @@ -14,6 +14,7 @@ import ( "os/signal" "path/filepath" "runtime/debug" + "slices" "strings" "time" @@ -2050,3 +2051,135 @@ func logEntryFieldsToString(fields []*structpb.Struct) (string, error) { } return message + "}", nil } + +type updateOAuthAppArgs struct { + OrgID string + ClientID string + ClientName string + ClientAuthentication string + Pkce string + LogoutURI string + UrlValidation string //nolint:revive,stylecheck + OriginURIs []string + RedirectURIs []string + EnabledGrants []string +} + +const ( + clientAuthenticationPrefix = "CLIENT_AUTHENTICATION_" + pkcePrefix = "PKCE_" + urlValidationPrefix = "URL_VALIDATION_" + enabledGrantPrefix = "ENABLED_GRANT_" +) + +// allEnumValues returns the possible values we accept for a given proto enum. +func allEnumValues(prefixToTrim string, enumValueMap map[string]int32) string { + var formattedValues []string + for values := range enumValueMap { + formattedValue := strings.ToLower(strings.TrimPrefix(values, prefixToTrim)) + formattedValues = append(formattedValues, formattedValue) + } + slices.Sort(formattedValues) + return "[" + strings.Join(formattedValues, ", ") + "]" +} + +// UpdateOAuthAppAction is the corresponding action for 'oauth-app update'. +func UpdateOAuthAppAction(c *cli.Context, args updateOAuthAppArgs) error { + client, err := newViamClient(c) + if err != nil { + return err + } + + return client.updateOAuthAppAction(c, args) +} + +func (c *viamClient) updateOAuthAppAction(cCtx *cli.Context, args updateOAuthAppArgs) error { + if err := c.ensureLoggedIn(); err != nil { + return err + } + + req, err := createUpdateOAuthAppRequest(args) + if err != nil { + return err + } + + _, err = c.client.UpdateOAuthApp(c.c.Context, req) + if err != nil { + return err + } + + printf(cCtx.App.Writer, "Successfully updated OAuth app %s", args.ClientID) + return nil +} + +func createUpdateOAuthAppRequest(args updateOAuthAppArgs) (*apppb.UpdateOAuthAppRequest, error) { + orgID := args.OrgID + clientID := args.ClientID + clientName := args.ClientName + clientAuthentication := args.ClientAuthentication + urlValidation := args.UrlValidation + pkce := args.Pkce + originURIs := args.OriginURIs + redirectURIs := args.RedirectURIs + logoutURI := args.LogoutURI + enabledGrants := args.EnabledGrants + + clientAuthenticationEnum, ok := apppb.ClientAuthentication_value[clientAuthenticationPrefix+strings.ToUpper(clientAuthentication)] + if !ok { + return nil, errors.Errorf("--%s must be a valid ClientAuthentication, got %s. "+ + "See `viam organizations auth-service oauth-app update --help` for supported options", + oauthAppFlagClientAuthentication, clientAuthentication) + } + + pkceEnum, ok := apppb.PKCE_value[pkcePrefix+strings.ToUpper(pkce)] + if !ok { + return nil, errors.Errorf("--%s must be a valid PKCE, got %s. "+ + "See `viam organizations auth-service oauth-app update --help` for supported options", + oauthAppFlagPKCE, pkce) + } + + urlValidationEnum, ok := apppb.URLValidation_value[urlValidationPrefix+strings.ToUpper(urlValidation)] + if !ok { + return nil, errors.Errorf("--%s must be a valid UrlValidation, got %s. "+ + "See `viam organizations auth-service oauth-app update --help` for supported options", + oauthAppFlagURLValidation, urlValidation) + } + + egProto, err := enabledGrantsToProto(enabledGrants) + if err != nil { + return nil, err + } + + req := &apppb.UpdateOAuthAppRequest{ + OrgId: orgID, + ClientId: clientID, + ClientName: clientName, + OauthConfig: &apppb.OAuthConfig{ + ClientAuthentication: apppb.ClientAuthentication(clientAuthenticationEnum), + Pkce: apppb.PKCE(pkceEnum), + UrlValidation: apppb.URLValidation(urlValidationEnum), + OriginUris: originURIs, + RedirectUris: redirectURIs, + LogoutUri: logoutURI, + EnabledGrants: egProto, + }, + } + return req, nil +} + +func enabledGrantsToProto(enabledGrants []string) ([]apppb.EnabledGrant, error) { + if enabledGrants == nil { + return nil, nil + } + enabledGrantsProto := make([]apppb.EnabledGrant, len(enabledGrants)) + for i, eg := range enabledGrants { + enum, ok := apppb.EnabledGrant_value[enabledGrantPrefix+strings.ToUpper(eg)] + if !ok { + return nil, errors.Errorf("%s must consist of valid EnabledGrants, got %s. "+ + "See `viam organizations auth-service oauth-app update --help` for supported options", + oauthAppFlagEnabledGrants, eg) + } + enabledGrantsProto[i] = apppb.EnabledGrant(enum) + } + return enabledGrantsProto, nil +} diff --git a/cli/client_test.go b/cli/client_test.go index bda3d9e26ca..f215f5b5688 100644 --- a/cli/client_test.go +++ b/cli/client_test.go @@ -1001,3 +1001,65 @@ func TestShellFileCopy(t *testing.T) { }) }) } + +func TestUpdateOAuthAppAction(t *testing.T) { + updateOAuthAppFunc := func(ctx context.Context, in *apppb.UpdateOAuthAppRequest, + opts ...grpc.CallOption, + ) (*apppb.UpdateOAuthAppResponse, error) { + return &apppb.UpdateOAuthAppResponse{}, nil + } + asc := &inject.AppServiceClient{ + UpdateOAuthAppFunc: updateOAuthAppFunc, + } + + t.Run("valid inputs", func(t *testing.T) { + flags := make(map[string]any) + flags[generalFlagOrgID] = "org-id" + flags[oauthAppFlagClientID] = "client-id" + flags[oauthAppFlagClientName] = "client-name" + flags[oauthAppFlagClientAuthentication] = "required" + flags[oauthAppFlagURLValidation] = "allow_wildcards" + flags[oauthAppFlagPKCE] = "not_required" + flags[oauthAppFlagOriginURIs] = []string{"https://woof.com/login", "https://arf.com/"} + flags[oauthAppFlagRedirectURIs] = []string{"https://woof.com/home", "https://arf.com/home"} + flags[oauthAppFlagLogoutURI] = "https://woof.com/logout" + flags[oauthAppFlagEnabledGrants] = []string{"implicit", "password"} + cCtx, ac, out, errOut := setup(asc, nil, nil, nil, flags, "token") + test.That(t, ac.updateOAuthAppAction(cCtx, parseStructFromCtx[updateOAuthAppArgs](cCtx)), test.ShouldBeNil) + test.That(t, len(errOut.messages), test.ShouldEqual, 0) + test.That(t, out.messages[0], test.ShouldContainSubstring, "Successfully updated OAuth app") + }) + + t.Run("should error if client-authentication is not a valid enum value", func(t *testing.T) { + flags := make(map[string]any) + flags[generalFlagOrgID] = "org-id" + flags[oauthAppFlagClientID] = "client-id" + flags[oauthAppFlagClientAuthentication] = "not_one_of_the_allowed_values" + cCtx, ac, out, _ := setup(asc, nil, nil, nil, flags, "token") + err := ac.updateOAuthAppAction(cCtx, parseStructFromCtx[updateOAuthAppArgs](cCtx)) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "client-authentication must be a valid ClientAuthentication") + test.That(t, len(out.messages), test.ShouldEqual, 0) + }) + + t.Run("should error if url-validation is not a valid enum value", func(t *testing.T) { + flags := map[string]any{oauthAppFlagClientAuthentication: unspecified, oauthAppFlagPKCE: "not_one_of_the_allowed_values"} + cCtx, ac, out, _ := setup(asc, nil, nil, nil, flags, "token") + err := ac.updateOAuthAppAction(cCtx, parseStructFromCtx[updateOAuthAppArgs](cCtx)) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "pkce must be a valid PKCE") + test.That(t, len(out.messages), test.ShouldEqual, 0) + }) + + t.Run("should error if url-validation is not a valid enum value", func(t *testing.T) { + flags := map[string]any{ + oauthAppFlagClientAuthentication: unspecified, oauthAppFlagPKCE: unspecified, + oauthAppFlagURLValidation: "not_one_of_the_allowed_values", + } + cCtx, ac, out, _ := setup(asc, nil, nil, nil, flags, "token") + err := ac.updateOAuthAppAction(cCtx, parseStructFromCtx[updateOAuthAppArgs](cCtx)) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "url-validation must be a valid UrlValidation") + test.That(t, len(out.messages), test.ShouldEqual, 0) + }) +} diff --git a/testutils/inject/app_service_client.go b/testutils/inject/app_service_client.go index 45f35604723..74a129d134f 100644 --- a/testutils/inject/app_service_client.go +++ b/testutils/inject/app_service_client.go @@ -56,6 +56,8 @@ type AppServiceClient struct { opts ...grpc.CallOption) (*apppb.OrganizationSetLogoResponse, error) OrganizationGetLogoFunc func(ctx context.Context, in *apppb.OrganizationGetLogoRequest, opts ...grpc.CallOption) (*apppb.OrganizationGetLogoResponse, error) + UpdateOAuthAppFunc func(ctx context.Context, in *apppb.UpdateOAuthAppRequest, + opts ...grpc.CallOption) (*apppb.UpdateOAuthAppResponse, error) CreateLocationFunc func(ctx context.Context, in *apppb.CreateLocationRequest, opts ...grpc.CallOption) (*apppb.CreateLocationResponse, error) GetLocationFunc func(ctx context.Context, in *apppb.GetLocationRequest, @@ -403,6 +405,16 @@ func (asc *AppServiceClient) OrganizationGetLogo( return asc.OrganizationGetLogoFunc(ctx, in, opts...) } +// UpdateOAuthApp calls the injected UpdateOAuthAppFunc or the real version. +func (asc *AppServiceClient) UpdateOAuthApp( + ctx context.Context, in *apppb.UpdateOAuthAppRequest, opts ...grpc.CallOption, +) (*apppb.UpdateOAuthAppResponse, error) { + if asc.UpdateOAuthAppFunc == nil { + return asc.AppServiceClient.UpdateOAuthApp(ctx, in, opts...) + } + return asc.UpdateOAuthAppFunc(ctx, in, opts...) +} + // CreateLocation calls the injected CreateLocationFunc or the real version. func (asc *AppServiceClient) CreateLocation( ctx context.Context, in *apppb.CreateLocationRequest, opts ...grpc.CallOption,