From 5372c37e6c1b909f6f86a200666b54dcd80a8836 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 21 Feb 2020 17:22:12 -0300 Subject: [PATCH] fix: improve code a bit Signed-off-by: Carlos Alexandro Becker --- main.go | 54 ++++++++++++++++++------------------------------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/main.go b/main.go index 6ff067b..1c6f7b1 100644 --- a/main.go +++ b/main.go @@ -18,15 +18,15 @@ import ( ) var ( - authorizations authorizationSlice - bucket string - address string + auths stringSlice + bucket string + listen string ) func main() { - flag.Var(&authorizations, "authorize", "usernames:passwords to be used to authenticate (e.g.: carlos:asd123)") + flag.Var(&auths, "authorize", "user/passwords that can authenticate in the user:pwd format (e.g.: carlos:asd123)") flag.StringVar(&bucket, "bucket", "", "bucket name (e.g.: s3://foo)") - flag.StringVar(&address, "addr", "127.0.0.1:8080", "address to listen to (e.g. 127.0.0.1:9090)") + flag.StringVar(&listen, "listen", "127.0.0.1:8080", "address to listen to (e.g. 127.0.0.1:9090)") flag.Parse() ctx := context.Background() @@ -42,8 +42,9 @@ func main() { var path = strings.Replace(r.URL.EscapedPath(), "/", "", 1) log.Println(path) - if !authorize(r) { - log.Println("unauthorized") + user, pwd, ok := r.BasicAuth() + if !(ok && isAuthorized(user+":"+pwd)) { + log.Println("unauthorized:", user) return httperr.Wrap(fmt.Errorf("missing/invalid authorization"), http.StatusUnauthorized) } @@ -58,43 +59,24 @@ func main() { return nil })) - log.Println("listening on", address) - http.ListenAndServe(address, handler) + log.Println("listening on", listen) + http.ListenAndServe(listen, handler) } -type authorization struct { - Username, Password string -} - -type authorizationSlice []authorization +type stringSlice []string -func (i *authorizationSlice) String() string { - var strs []string - for _, a := range *i { - strs = append(strs, a.Username) - } - return "[" + strings.Join(strs, ", ") + "]" +func (i *stringSlice) String() string { + return "[" + strings.Join(*i, ", ") + "]" } -func (i *authorizationSlice) Set(value string) error { - var parts = strings.Split(value, ":") - if len(parts) != 2 { - return fmt.Errorf("must be in the username:password format") - } - *i = append(*i, authorization{ - parts[0], - parts[1], - }) +func (i *stringSlice) Set(value string) error { + *i = append(*i, value) return nil } -func authorize(r *http.Request) bool { - user, pwd, ok := r.BasicAuth() - if !ok { - return false - } - for _, auth := range authorizations { - if auth.Username == user && auth.Password == pwd { +func isAuthorized(input string) bool { + for _, auth := range auths { + if input == auth { return true } }