From bf85f0003787d5594c39a3e3877e751f12a4bf26 Mon Sep 17 00:00:00 2001 From: Kevin Rieger <74607524+Frecherenkel60@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:29:17 +0100 Subject: [PATCH] Add global and private feed endpoints (#16) --- .github/CODEOWNERS | 16 +- api/README.md | 16 +- go.mod | 1 + go.sum | 2 + internal/managers/jwt_manager.go | 138 +++---- internal/routing/handlers/post_handler.go | 171 ++++++++- internal/routing/handlers/user_handler.go | 430 +++++++++++++--------- internal/routing/router.go | 24 +- internal/routing/router_test.go | 1 + internal/schemas/dto_schema.go | 34 +- internal/utils/url_keys.go | 2 + 11 files changed, 557 insertions(+), 278 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index be00294..2add5bf 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,16 +1,16 @@ # Scripts owned by Luca (@930C) -/scripts/* @930C +/scripts/** @930C # GitHub related files owned by Kevin (@Frecherenkel60) -.github/* @Frecherenkel60 +.github/** @Frecherenkel60 # Deployments and Build owned by Luca (@930C) and Keivn (@Frecherenkel60) /deployments/* @930C @Frecherenkel60 -/build/* @930C @Frecherenkel60 +/build/** @930C @Frecherenkel60 # Everything else owned by project team (@wwi21seb-projekt/server-alpha) -/api/* @wwi21seb-projekt/server-alpha -/cmd/* @wwi21seb-projekt/server-alpha -/configs/* @wwi21seb-projekt/server-alpha -/docs/* @wwi21seb-projekt/server-alpha -/internal/* @wwi21seb-projekt/server-alpha +/api/** @wwi21seb-projekt/server-alpha +/cmd/** @wwi21seb-projekt/server-alpha +/configs/** @wwi21seb-projekt/server-alpha +/docs/** @wwi21seb-projekt/server-alpha +/internal/** @wwi21seb-projekt/server-alpha diff --git a/api/README.md b/api/README.md index 44f34e1..869738a 100644 --- a/api/README.md +++ b/api/README.md @@ -6,15 +6,15 @@ | [Trello Ticket 1] | POST | `/users/login` | Login | ✅ | | [Trello Ticket 2] | GET | `/imprint` | Impressum | ✅ | | [Trello Ticket 3.1], [Trello Ticket 3.2] | POST | `/posts` | Create Posts | ✅ | -| [Trello Ticket 5] | GET | `/users/:username` | Visit User Profile | ❌ | -| [Trello Ticket 4] | GET | `/users?username&offset&limit` | Search User | ❌ | +| [Trello Ticket 5] | GET | `/users/:username` | Visit User Profile | ✅ | +| [Trello Ticket 4] | GET | `/users?username&offset&limit` | Search User | ✅ | | [Trello Ticket 4] | GET | `/users/:username/feed?offset&limit` | User Feed | ❌ | -| [Trello Ticket 4] | POST | `/subscriptions` | Subscribe User | ❌ | -| [Trello Ticket 4] | DELETE | `/subscriptions/:subscriptionId` | Unsubscribe User | ❌ | -| [Trello Ticket 6.1] | GET | `/feed?postId&limit&feedType` | Get own or global feed | ❌ | -| [Trello Ticket 6.2] | GET | `/feed?postId&limit` | Get global feed (no auth) | ❌ | -| [Trello Ticket 7] | PUT | `/users` | Change trivial information | ❌ | -| [Trello Ticket 7] | PATCH | `/users` | Change password | ❌ | +| [Trello Ticket 4] | POST | `/subscriptions` | Subscribe User | ✅ | +| [Trello Ticket 4] | DELETE | `/subscriptions/:subscriptionId` | Unsubscribe User | ✅ | +| [Trello Ticket 6.1] | GET | `/feed?postId&limit&feedType` | Get own or global feed | ✅ | +| [Trello Ticket 6.2] | GET | `/feed?postId&limit` | Get global feed (no auth) | ✅ | +| [Trello Ticket 7] | PUT | `/users` | Change trivial information | ✅ | +| [Trello Ticket 7] | PATCH | `/users` | Change password | ✅ | [Trello Ticket 1]: https://trello.com/c/1w0QP6u5/209-id-0-als-nutzer-möchte-ich-mich-mit-email-und-passwort-registrieren-können-um-einen-gesicherten-zugang-zu-meinem-account-zu-habe diff --git a/go.mod b/go.mod index a88aa54..71d1eb9 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/gavv/httpexpect/v2 v2.16.0 github.com/go-chi/chi/v5 v5.0.11 + github.com/go-chi/cors v1.2.1 github.com/go-playground/validator/v10 v10.17.0 github.com/golang-jwt/jwt/v5 v5.2.0 github.com/google/uuid v1.5.0 diff --git a/go.sum b/go.sum index 4612e1a..22ac0f3 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/gavv/httpexpect/v2 v2.16.0/go.mod h1:uJLaO+hQ25ukBJtQi750PsztObHybNll github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-gomail/gomail v0.0.0-20160411212932-81ebce5c23df/go.mod h1:GJr+FCSXshIwgHBtLglIg9M2l2kQSi6QjVAngtzI08Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= diff --git a/internal/managers/jwt_manager.go b/internal/managers/jwt_manager.go index e792d0c..212dd31 100644 --- a/internal/managers/jwt_manager.go +++ b/internal/managers/jwt_manager.go @@ -17,10 +17,9 @@ import ( ) type JWTMgr interface { - GenerateJWT(claims jwt.Claims) (string, error) + GenerateJWT(userId, username string, isRefreshToken bool) (string, error) ValidateJWT(tokenString string) (jwt.Claims, error) JWTMiddleware(next http.Handler) http.Handler - GenerateClaims(userId, username string) jwt.Claims } // JWTManager handles JWT generation, signing, and validation. @@ -43,6 +42,67 @@ func NewJWTManager(privateKey ed25519.PrivateKey, publicKey ed25519.PublicKey) J return &JWTManager } +// GenerateJWT generates a new JWT with the given claims. +func (jm *JWTManager) GenerateJWT(userId, username string, isRefreshToken bool) (string, error) { + claims := generateClaims(userId, username, isRefreshToken) + + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + return token.SignedString(jm.privateKey) +} + +// ValidateJWT validates the given JWT and returns the claims if valid. +func (jm *JWTManager) ValidateJWT(tokenString string) (jwt.Claims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify the signing method + if token.Method.Alg() != jwt.SigningMethodEdDSA.Alg() { + return nil, fmt.Errorf("invalid signing method") + } + + return jm.publicKey, nil + }) + + if err != nil { + return nil, err + } + + if !token.Valid { + return nil, jwt.ErrSignatureInvalid + } + + return token.Claims, nil +} + +// JWTMiddleware is a middleware that validates the JWT token in the request header. +func (jm *JWTManager) JWTMiddleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + // Check if the request has a JWT token + if r.Header.Get("Authorization") == "" { + utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, fmt.Errorf("missing authorization header")) + return + } + + // Extract the JWT token from the request header + header := r.Header.Get("Authorization") + token := header[len("Bearer "):] + + // Validate the JWT token + claims, err := jm.ValidateJWT(token) + if err != nil || claims.(jwt.MapClaims)["refresh"] == "true" { + utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, err) + return + } + + // Add the claims to the request context + ctx := r.Context() + ctx = context.WithValue(ctx, utils.ClaimsKey, claims) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + func NewJWTManagerFromFile() (JWTMgr, error) { log.Info("Initializing JWT manager using key pair from file...") privateKeyPath := "private_key.pem" @@ -201,70 +261,24 @@ func generateAndStoreKeys(privateKeyPath, publicKeyPath string) error { } // GenerateClaims generates the standard JWT claims. -func (jm *JWTManager) GenerateClaims(userId, username string) jwt.Claims { +func generateClaims(userId, username string, isRefreshToken bool) jwt.Claims { + var exp int64 + var refresh string + + if isRefreshToken { + exp = time.Now().Add(time.Hour * 24 * 7).Unix() + refresh = "true" + } else { + exp = time.Now().Add(time.Hour * 24).Unix() + refresh = "false" + } + return jwt.MapClaims{ "iss": "server-alpha.tech", "iat": time.Now().Unix(), - "exp": time.Now().Add(time.Hour * 24).Unix(), + "exp": exp, "sub": userId, "username": username, + "refresh": refresh, } } - -// GenerateJWT generates a new JWT with the given claims. -func (jm *JWTManager) GenerateJWT(claims jwt.Claims) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) - return token.SignedString(jm.privateKey) -} - -// ValidateJWT validates the given JWT and returns the claims if valid. -func (jm *JWTManager) ValidateJWT(tokenString string) (jwt.Claims, error) { - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // Verify the signing method - if token.Method.Alg() != jwt.SigningMethodEdDSA.Alg() { - return nil, fmt.Errorf("invalid signing method") - } - - return jm.publicKey, nil - }) - - if err != nil { - return nil, err - } - - if !token.Valid { - return nil, jwt.ErrSignatureInvalid - } - - return token.Claims, nil -} - -func (jm *JWTManager) JWTMiddleware(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - // Check if the request has a JWT token - if r.Header.Get("Authorization") == "" { - utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, fmt.Errorf("missing authorization header")) - return - } - - // Extract the JWT token from the request header - header := r.Header.Get("Authorization") - token := header[len("Bearer "):] - - // Validate the JWT token - claims, err := jm.ValidateJWT(token) - if err != nil { - utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, err) - return - } - - // Add the claims to the request context - ctx := r.Context() - ctx = context.WithValue(ctx, utils.ClaimsKey, claims) - r = r.WithContext(ctx) - - next.ServeHTTP(w, r) - } - - return http.HandlerFunc(fn) -} diff --git a/internal/routing/handlers/post_handler.go b/internal/routing/handlers/post_handler.go index 0d22928..5fe01e6 100644 --- a/internal/routing/handlers/post_handler.go +++ b/internal/routing/handlers/post_handler.go @@ -1,6 +1,9 @@ package handlers import ( + "context" + "fmt" + "github.com/jackc/pgx/v5" "net/http" "regexp" "server-alpha/internal/managers" @@ -14,18 +17,21 @@ import ( type PostHdl interface { CreatePost(w http.ResponseWriter, r *http.Request) + HandleGetFeedRequest(w http.ResponseWriter, r *http.Request) } type PostHandler struct { DatabaseManager managers.DatabaseMgr + JWTManager managers.JWTMgr Validator *utils.Validator } var hashtagRegex = regexp.MustCompile(`#\w+`) -func NewPostHandler(databaseManager *managers.DatabaseMgr) PostHdl { +func NewPostHandler(databaseManager *managers.DatabaseMgr, jwtManager *managers.JWTMgr) PostHdl { return &PostHandler{ DatabaseManager: *databaseManager, + JWTManager: *jwtManager, Validator: utils.GetValidator(), } } @@ -34,6 +40,7 @@ func (handler *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) if tx == nil || transactionCtx == nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, nil) return } var err error @@ -109,7 +116,165 @@ func (handler *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) { Nickname: author.Nickname, ProfilePictureURL: "", }, - Content: createPostRequest.Content, - CreatedAt: createdAt.String(), + Content: createPostRequest.Content, + CreationDate: createdAt.Format(time.RFC3339), }, http.StatusCreated) } + +func (handler *PostHandler) HandleGetFeedRequest(w http.ResponseWriter, r *http.Request) { + // Begin a new transaction + tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) + if tx == nil || transactionCtx == nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, nil) + return + } + var err error + defer utils.RollbackTransaction(w, tx, transactionCtx, cancel, err) + + // Determine the feed type + publicFeedWanted, claims, err := determineFeedType(r, w, handler) + if err != nil { + return + } + + // Retrieve the feed based on the wanted feed type + posts, records, lastPostId, limit, err := retrieveFeed(transactionCtx, tx, w, r, publicFeedWanted, claims) + if err != nil { + return + } + + // Get the last post ID + if len(posts) > 0 { + lastPostId = posts[len(posts)-1].PostId + } + + // Create pagination DTO + pagination := &schemas.PostPagination{ + LastPostId: lastPostId, + Limit: limit, + Records: records, + } + + paginatedResponse := &schemas.PaginatedResponse{ + Records: posts, + Pagination: pagination, + } + + if err := utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + } + + utils.WriteAndLogResponse(w, paginatedResponse, http.StatusOK) +} + +func determineFeedType(r *http.Request, w http.ResponseWriter, handler *PostHandler) (bool, jwt.Claims, error) { + // Get UserId from JWT token + authHeader := r.Header.Get("Authorization") + var claims jwt.Claims + var err error + publicFeedWanted := false + + if authHeader == "" { + publicFeedWanted = true + } else { + jwtToken := authHeader[len("Bearer "):] + claims, err = handler.JWTManager.ValidateJWT(jwtToken) + if err != nil { + utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, err) + return false, nil, err + } + + feedType := r.URL.Query().Get(utils.FeedTypeParamKey) + if feedType == "global" { + publicFeedWanted = true + } + } + + return publicFeedWanted, claims, nil +} + +func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http.Request, publicFeedWanted bool, claims jwt.Claims) ([]*schemas.PostDTO, int, string, string, error) { + lastPostId := r.URL.Query().Get(utils.PostIdParamKey) + limit := r.URL.Query().Get(utils.LimitParamKey) + + if limit == "" { + limit = "10" + } + + currentDataQueryIndex := 1 + currentCountQueryIndex := 1 + var userId string + var countQuery string + var dataQuery string + var countQueryArgs []interface{} + var dataQueryArgs []interface{} + + if !publicFeedWanted { + userId = claims.(jwt.MapClaims)["sub"].(string) + countQuery = `SELECT COUNT(*) FROM alpha_schema.posts + INNER JOIN alpha_schema.users ON author_id = user_id + INNER JOIN alpha_schema.subscriptions ON user_id = subscriptions.subscriber_id + WHERE subscriptions.subscriber_id` + fmt.Sprintf(" = $%d", currentCountQueryIndex) + currentCountQueryIndex++ + dataQuery = `SELECT post_id, username, nickname, profile_picture_url, content, posts.created_at FROM alpha_schema.posts + INNER JOIN alpha_schema.users ON author_id = user_id + INNER JOIN alpha_schema.subscriptions ON user_id = subscriptions.subscriber_id + WHERE subscriptions.subscriber_id` + fmt.Sprintf(" = $%d", currentDataQueryIndex) + currentDataQueryIndex++ + + countQueryArgs = append(countQueryArgs, userId) + dataQueryArgs = append(dataQueryArgs, userId) + } else { + countQuery = "SELECT COUNT(*) FROM alpha_schema.posts" + dataQuery = `SELECT post_id, username, nickname, profile_picture_url, content, posts.created_at FROM alpha_schema.posts + INNER JOIN alpha_schema.users ON author_id = user_id` + } + + // Get the count of posts in the database that match the criteria + row := tx.QueryRow(ctx, countQuery, countQueryArgs...) + + var count int + if err := row.Scan(&count); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return nil, 0, "", "", err + } + + var rows pgx.Rows + var err error + + if lastPostId == "" { + // If we don't have a last post ID, we'll return the newest posts + dataQuery += " ORDER BY created_at DESC LIMIT" + fmt.Sprintf(" $%d", currentDataQueryIndex) + dataQueryArgs = append(dataQueryArgs, limit) + } else { + // If we have a last post ID, we'll return the posts that were created before the last post + dataQuery += " AND posts.created_at < (SELECT created_at FROM alpha_schema.posts WHERE post_id = " + + fmt.Sprintf("$%d) ", currentDataQueryIndex) + "ORDER BY created_at DESC LIMIT " + + fmt.Sprintf("$%d", currentDataQueryIndex+1) + dataQueryArgs = append(dataQueryArgs, lastPostId, limit) + } + + rows, err = tx.Query(ctx, dataQuery, dataQueryArgs...) + if err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return nil, 0, "", "", err + } + + // Iterate over the rows and create the post DTOs + posts := make([]*schemas.PostDTO, 0) + + for rows.Next() { + post := &schemas.PostDTO{} + var createdAt time.Time + + if err := rows.Scan(&post.PostId, &post.Author.Username, &post.Author.Nickname, &post.Author.ProfilePictureURL, &post.Content, &createdAt); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return nil, 0, "", "", err + } + + post.CreationDate = createdAt.Format(time.RFC3339) + posts = append(posts, post) + } + + return posts, count, lastPostId, limit, nil +} diff --git a/internal/routing/handlers/user_handler.go b/internal/routing/handlers/user_handler.go index c0bcf45..b840873 100644 --- a/internal/routing/handlers/user_handler.go +++ b/internal/routing/handlers/user_handler.go @@ -2,7 +2,6 @@ package handlers import ( "context" - "encoding/json" "errors" "math/rand" "net/http" @@ -34,6 +33,7 @@ type UserHdl interface { SearchUsers(w http.ResponseWriter, r *http.Request) ChangeTrivialInformation(w http.ResponseWriter, r *http.Request) ChangePassword(w http.ResponseWriter, r *http.Request) + RefreshToken(w http.ResponseWriter, r *http.Request) } type UserHandler struct { @@ -52,6 +52,9 @@ func NewUserHandler(databaseManager *managers.DatabaseMgr, jwtManager *managers. } } +var errInvalidToken = errors.New("invalid token") + +// RegisterUser registers a new user and sends an activation token to the user's email. func (handler *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) @@ -103,7 +106,7 @@ func (handler *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) } // Generate a token for the user - if err = generateAndSendToken(w, handler, tx, transactionCtx, registrationRequest.Email, registrationRequest.Username, userId.String()); err != nil { + if err = generateAndSendToken(transactionCtx, w, handler, tx, registrationRequest.Email, registrationRequest.Username, userId.String()); err != nil { return } @@ -112,20 +115,17 @@ func (handler *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) return } - // Send success response - w.WriteHeader(http.StatusCreated) userDto := &schemas.UserDTO{ Username: registrationRequest.Username, Nickname: registrationRequest.Nickname, Email: registrationRequest.Email, } - if err = json.NewEncoder(w).Encode(userDto); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } + // Send success response + utils.WriteAndLogResponse(w, userDto, http.StatusCreated) } +// ActivateUser activates the user specified in the path. func (handler *UserHandler) ActivateUser(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) @@ -157,7 +157,7 @@ func (handler *UserHandler) ActivateUser(w http.ResponseWriter, r *http.Request) // Check if the user is activated if _, activated, err := checkUserExistenceAndActivation(transactionCtx, w, tx, username); err != nil { - + return } else if activated { utils.WriteAndLogError(w, schemas.UserAlreadyActivated, http.StatusAlreadyReported, errors.New("already activated")) return @@ -182,12 +182,21 @@ func (handler *UserHandler) ActivateUser(w http.ResponseWriter, r *http.Request) return } + // Generate token pair + tokenDto, err := generateTokenPair(handler, userID.String(), username) + if err != nil { + utils.WriteAndLogError(w, schemas.InternalServerError, http.StatusInternalServerError, err) + return + } + if err := utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil { return } - w.WriteHeader(http.StatusNoContent) + + utils.WriteAndLogResponse(w, tokenDto, http.StatusOK) } +// ResendToken resends the activation token to the user specified in the path. func (handler *UserHandler) ResendToken(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) @@ -214,7 +223,7 @@ func (handler *UserHandler) ResendToken(w http.ResponseWriter, r *http.Request) } // Generate a new token and send it to the user - if err := generateAndSendToken(w, handler, tx, transactionCtx, email, username, userId.String()); err != nil { + if err := generateAndSendToken(transactionCtx, w, handler, tx, email, username, userId.String()); err != nil { return } @@ -222,9 +231,11 @@ func (handler *UserHandler) ResendToken(w http.ResponseWriter, r *http.Request) if err := utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil { return } + w.WriteHeader(http.StatusNoContent) } +// ChangeTrivialInformation changes the nickname and status of the user specified in the path. func (handler *UserHandler) ChangeTrivialInformation(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) @@ -270,13 +281,10 @@ func (handler *UserHandler) ChangeTrivialInformation(w http.ResponseWriter, r *h } // Send the updated user in the response - if err := json.NewEncoder(w).Encode(UserNicknameAndStatusDTO); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) + utils.WriteAndLogResponse(w, UserNicknameAndStatusDTO, http.StatusOK) } +// ChangePassword changes the password of the user specified in the path. func (handler *UserHandler) ChangePassword(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) @@ -334,6 +342,7 @@ func (handler *UserHandler) ChangePassword(w http.ResponseWriter, r *http.Reques w.WriteHeader(http.StatusNoContent) } +// LoginUser logs in the user specified in the request body and returns a token pair. func (handler *UserHandler) LoginUser(w http.ResponseWriter, r *http.Request) { // Begin a new transaction tx, transactionCtx, cancel := utils.BeginTransaction(w, r, handler.DatabaseManager.GetPool()) @@ -381,169 +390,24 @@ func (handler *UserHandler) LoginUser(w http.ResponseWriter, r *http.Request) { return } - // Generate a token for the user - claims := handler.JWTManager.GenerateClaims(userId.String(), loginRequest.Username) - token, err := handler.JWTManager.GenerateJWT(claims) - + // Generate token pair + tokenDto, err := generateTokenPair(handler, userId.String(), loginRequest.Username) if err != nil { utils.WriteAndLogError(w, schemas.InternalServerError, http.StatusInternalServerError, err) return - } - tokenDto := &schemas.TokenDTO{ - Token: token, } - if err := json.NewEncoder(w).Encode(tokenDto); err != nil { - utils.WriteAndLogError(w, schemas.InternalServerError, http.StatusInternalServerError, err) + // Commit the transaction + if err = utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil { return } - w.WriteHeader(200) -} - -func retrieveUserIdAndEmail(transactionCtx context.Context, w http.ResponseWriter, tx pgx.Tx, username string) (string, uuid.UUID, bool) { - // Get the user ID - queryString := "SELECT email, user_id FROM alpha_schema.users WHERE username = $1" - rows, err := tx.Query(transactionCtx, queryString, username) - if err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return "", uuid.UUID{}, true - } - defer rows.Close() - - var email string - var userID uuid.UUID - if rows.Next() { - if err := rows.Scan(&email, &userID); err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return "", uuid.UUID{}, true - } - } else { - utils.WriteAndLogError(w, schemas.UserNotFound, http.StatusNotFound, errors.New("user not found")) - return "", uuid.UUID{}, true - } - - return email, userID, false -} - -func checkUsernameEmailTaken(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, username, email string) error { - queryString := "SELECT username, email FROM alpha_schema.users WHERE username = $1 OR email = $2" - rows, err := tx.Query(ctx, queryString, username, email) - if err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return err - } - defer rows.Close() - - if rows.Next() { - var foundUsername string - var foundEmail string - - if err := rows.Scan(&foundUsername, &foundEmail); err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return err - } - - customErr := &schemas.CustomError{} - if foundUsername == username { - customErr = schemas.UsernameTaken - } else { - customErr = schemas.EmailTaken - } - - err = errors.New("username or email taken") - utils.WriteAndLogError(w, customErr, http.StatusConflict, err) - return err - } - - return nil -} - -func checkTokenValidity(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, token, username string) error { - queryString := "SELECT expires_at FROM alpha_schema.activation_tokens WHERE token = $1 AND user_id = (SELECT user_id FROM alpha_schema.users WHERE username = $2)" - rows, err := tx.Query(ctx, queryString, token, username) - if err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return err - } - defer rows.Close() - - if !rows.Next() { - utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, errors.New("invalid token")) - return errors.New("invalid token") - } - - var expiresAt pgtype.Timestamptz - if err := rows.Scan(&expiresAt); err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return err - } - if time.Now().After(expiresAt.Time) { - utils.WriteAndLogError(w, schemas.ActivationTokenExpired, http.StatusUnauthorized, errors.New("token expired")) - return errors.New("token expired") - } - - return nil -} - -func checkUserExistenceAndActivation(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, username string) (bool, bool, error) { - queryString := "SELECT activated_at FROM alpha_schema.users WHERE username = $1" - rows, err := tx.Query(ctx, queryString, username) - if err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return false, false, err - } - defer rows.Close() - - var activatedAt pgtype.Timestamptz - if rows.Next() { - if err := rows.Scan(&activatedAt); err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return false, false, err - } - } else { - return false, false, nil - } - - return true, !activatedAt.Time.IsZero() && activatedAt.Valid, nil -} - -func generateAndSendToken(w http.ResponseWriter, handler *UserHandler, tx pgx.Tx, ctx context.Context, email, username, userId string) error { - // Generate a new token and send it to the user - token := generateToken() - tokenID := uuid.New() - tokenExpiresAt := time.Now().Add(2 * time.Hour) - - // Delete the old token if it exists - queryString := "DELETE FROM alpha_schema.activation_tokens WHERE user_id = $1" - if _, err := tx.Exec(ctx, queryString, userId); err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return err - } - - queryString = "INSERT INTO alpha_schema.activation_tokens (token_id, user_id, token, expires_at) VALUES ($1, $2, $3, $4)" - if _, err := tx.Exec(ctx, queryString, tokenID, userId, token, tokenExpiresAt); err != nil { - utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) - return err - } - - // Send the token to the user - if err := handler.MailManager.SendActivationMail(email, username, token, "UI-Service"); err != nil { - utils.WriteAndLogError(w, schemas.EmailNotSent, http.StatusInternalServerError, err) - return err - } - - return nil -} - -func generateToken() string { - rand.NewSource(time.Now().UnixNano()) - - // Generate a random 6-digit number - return strconv.Itoa(rand.Intn(900000) + 100000) + // Send success response + utils.WriteAndLogResponse(w, tokenDto, http.StatusOK) } +// HandleGetUserRequest returns the user profile of the user specified in the path. func (handler *UserHandler) HandleGetUserRequest(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithDeadline(r.Context(), time.Now().Add(10*time.Second)) defer func() { @@ -604,12 +468,7 @@ func (handler *UserHandler) HandleGetUserRequest(w http.ResponseWriter, r *http. } // Send success response - w.WriteHeader(http.StatusOK) - - if err := json.NewEncoder(w).Encode(user); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } + utils.WriteAndLogResponse(w, user, http.StatusOK) } // Subscribe creates a new subscription between the current user and the username specified in the request body. @@ -684,23 +543,18 @@ func (handler *UserHandler) Subscribe(w http.ResponseWriter, r *http.Request) { // Send the subscription to the user subscriptionDto := &schemas.SubscriptionDTO{ SubscriptionId: subscriptionId, - SubscriptionDate: createdAt.String(), + SubscriptionDate: createdAt.Format(time.RFC3339), Following: subscriptionRequest.Following, Follower: jwtUsername, } - if err := json.NewEncoder(w).Encode(subscriptionDto); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - // Commit the transaction if err := utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil { return } // Send success response - w.WriteHeader(http.StatusCreated) + utils.WriteAndLogResponse(w, subscriptionDto, http.StatusCreated) } // Unsubscribe removes a subscription between the current user and the user specified by the subscription ID. @@ -845,14 +699,201 @@ func (handler *UserHandler) SearchUsers(w http.ResponseWriter, r *http.Request) Pagination: paginationDto, } - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(paginatedResponse); err != nil { - w.WriteHeader(http.StatusInternalServerError) + // Send success response + utils.WriteAndLogResponse(w, paginatedResponse, http.StatusOK) +} + +// RefreshToken refreshes the token pair of the user. +func (handler *UserHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { + // Get the refresh token from the request body + refreshTokenRequest := &schemas.RefreshTokenRequest{} + if err := utils.DecodeRequestBody(w, r, refreshTokenRequest); err != nil { + return + } + + // Validate the refresh token request struct using the validator + if err := utils.ValidateStruct(w, refreshTokenRequest); err != nil { + return + } + + // Get the user ID and username from the refresh token + refreshTokenClaims, err := handler.JWTManager.ValidateJWT(refreshTokenRequest.RefreshToken) + if err != nil { + utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, err) + return + } + + refreshClaims := refreshTokenClaims.(jwt.MapClaims) + userId := refreshClaims["sub"].(string) + username := refreshClaims["username"].(string) + isRefreshToken := refreshClaims["refresh"].(string) + + if isRefreshToken != "true" { + utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, errInvalidToken) + return + } + + // Generate token pair + tokenDto, err := generateTokenPair(handler, userId, username) + if err != nil { utils.WriteAndLogError(w, schemas.InternalServerError, http.StatusInternalServerError, err) return } + + // Send success response + utils.WriteAndLogResponse(w, tokenDto, http.StatusOK) +} + +// retrieveUserIdAndEmail retrieves the user ID and email of the user specified by the username. +func retrieveUserIdAndEmail(transactionCtx context.Context, w http.ResponseWriter, tx pgx.Tx, username string) (string, uuid.UUID, bool) { + // Get the user ID + queryString := "SELECT email, user_id FROM alpha_schema.users WHERE username = $1" + rows, err := tx.Query(transactionCtx, queryString, username) + if err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return "", uuid.UUID{}, true + } + defer rows.Close() + + var email string + var userID uuid.UUID + if rows.Next() { + if err := rows.Scan(&email, &userID); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return "", uuid.UUID{}, true + } + } else { + utils.WriteAndLogError(w, schemas.UserNotFound, http.StatusNotFound, errors.New("user not found")) + return "", uuid.UUID{}, true + } + + return email, userID, false +} + +// checkUsernameEmailTaken checks if the username or email is taken. +func checkUsernameEmailTaken(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, username, email string) error { + queryString := "SELECT username, email FROM alpha_schema.users WHERE username = $1 OR email = $2" + rows, err := tx.Query(ctx, queryString, username, email) + if err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return err + } + defer rows.Close() + + if rows.Next() { + var foundUsername string + var foundEmail string + + if err := rows.Scan(&foundUsername, &foundEmail); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return err + } + + customErr := &schemas.CustomError{} + if foundUsername == username { + customErr = schemas.UsernameTaken + } else { + customErr = schemas.EmailTaken + } + + err = errors.New("username or email taken") + utils.WriteAndLogError(w, customErr, http.StatusConflict, err) + return err + } + + return nil +} + +// checkTokenValidity checks if the token for the given user and token value combination is valid. +func checkTokenValidity(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, token, username string) error { + queryString := "SELECT expires_at FROM alpha_schema.activation_tokens WHERE token = $1 AND user_id = (SELECT user_id FROM alpha_schema.users WHERE username = $2)" + rows, err := tx.Query(ctx, queryString, token, username) + if err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return err + } + defer rows.Close() + + if !rows.Next() { + utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, errInvalidToken) + return errInvalidToken + } + + var expiresAt pgtype.Timestamptz + if err := rows.Scan(&expiresAt); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return err + } + + if time.Now().After(expiresAt.Time) { + utils.WriteAndLogError(w, schemas.ActivationTokenExpired, http.StatusUnauthorized, errors.New("token expired")) + return errors.New("token expired") + } + + return nil } +// checkUserExistenceAndActivation checks if the user exists and if yes, returns separate values for existence and activation. +func checkUserExistenceAndActivation(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, username string) (bool, bool, error) { + queryString := "SELECT activated_at FROM alpha_schema.users WHERE username = $1" + rows, err := tx.Query(ctx, queryString, username) + if err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return false, false, err + } + defer rows.Close() + + var activatedAt pgtype.Timestamptz + if rows.Next() { + if err := rows.Scan(&activatedAt); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return false, false, err + } + } else { + return false, false, nil + } + + return true, !activatedAt.Time.IsZero() && activatedAt.Valid, nil +} + +// generateAndSendToken generates a new token and sends it to the user's email. +func generateAndSendToken(ctx context.Context, w http.ResponseWriter, handler *UserHandler, tx pgx.Tx, email, username, userId string) error { + // Generate a new token and send it to the user + token := generateToken() + tokenID := uuid.New() + tokenExpiresAt := time.Now().Add(2 * time.Hour) + + // Delete the old token if it exists + queryString := "DELETE FROM alpha_schema.activation_tokens WHERE user_id = $1" + if _, err := tx.Exec(ctx, queryString, userId); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return err + } + + queryString = "INSERT INTO alpha_schema.activation_tokens (token_id, user_id, token, expires_at) VALUES ($1, $2, $3, $4)" + if _, err := tx.Exec(ctx, queryString, tokenID, userId, token, tokenExpiresAt); err != nil { + utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) + return err + } + + // Send the token to the user + if err := handler.MailManager.SendActivationMail(email, username, token, "UI-Service"); err != nil { + utils.WriteAndLogError(w, schemas.EmailNotSent, http.StatusInternalServerError, err) + return err + } + + return nil +} + +// generateToken generates a new token for the user activation. +func generateToken() string { + rand.NewSource(time.Now().UnixNano()) + + // Generate a random 6-digit number + return strconv.Itoa(rand.Intn(900000) + 100000) +} + +// checkPassword checks if the given password is correct for the given user. func checkPassword(transactionCtx context.Context, w http.ResponseWriter, tx pgx.Tx, username, givenPassword string) error { queryString := "SELECT password, user_id FROM alpha_schema.users WHERE username = $1" rows, err := tx.Query(transactionCtx, queryString, username) @@ -877,3 +918,24 @@ func checkPassword(transactionCtx context.Context, w http.ResponseWriter, tx pgx } return nil } + +// generateTokenPair generates a token pair for the given user. +func generateTokenPair(handler *UserHandler, userId, username string) (*schemas.TokenPairDTO, error) { + // Generate a token for the user + token, err := handler.JWTManager.GenerateJWT(userId, username, false) + if err != nil { + return nil, err + } + + refreshToken, err := handler.JWTManager.GenerateJWT(userId, username, true) + if err != nil { + return nil, err + } + + tokenPair := &schemas.TokenPairDTO{ + Token: token, + RefreshToken: refreshToken, + } + + return tokenPair, nil +} diff --git a/internal/routing/router.go b/internal/routing/router.go index f7a9b03..f4c58c0 100644 --- a/internal/routing/router.go +++ b/internal/routing/router.go @@ -1,6 +1,7 @@ package routing import ( + "github.com/go-chi/cors" "net/http" "server-alpha/internal/managers" "server-alpha/internal/routing/handlers" @@ -15,15 +16,26 @@ import ( func InitRouter(databaseMgr managers.DatabaseMgr, mailMgr managers.MailMgr, jwtMgr managers.JWTMgr) *chi.Mux { r := chi.NewRouter() + // Configure CORS + r.Use(cors.Handler(cors.Options{ + AllowedOrigins: []string{"http://localhost:5173", "http://localhost:19000"}, + AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, + MaxAge: 300, + })) + // Inject request ID into context r.Use(middleware.RequestID) - r.Use(middleware.RealIP) + // Add logger to middleware stack r.Use(middleware.Logger) + // Add recoverer to middleware stack r.Use(middleware.Recoverer) + // Set base timeout for all requests to 15 seconds r.Use(middleware.Timeout(15 * time.Second)) + // Set content type to JSON r.Use(middleware.SetHeader("Content-Type", "application/json")) // Initialize handlers - postHdl := handlers.NewPostHandler(&databaseMgr) + postHdl := handlers.NewPostHandler(&databaseMgr, &jwtMgr) // Initialize user handlers userHdl := handlers.NewUserHandler(&databaseMgr, &jwtMgr, &mailMgr) @@ -40,7 +52,7 @@ func InitRouter(databaseMgr managers.DatabaseMgr, mailMgr managers.MailMgr, jwtM w.WriteHeader(http.StatusOK) }) - r.Get("/imprint", func(w http.ResponseWriter, r *http.Request) { + r.Get("/api/v1/imprint", func(w http.ResponseWriter, r *http.Request) { imprint := "Impressum\n\nEinen Löwen interessiert es nicht, was Schafe über ihn denken.\n\nDiese Webseite " + "wird im Rahmen eines Universitätsprojektes angeboten von:\nKurs WWI21SEB\nDuale Hochschule " + "Baden-Württemberg Mannheim\nCoblitzallee 1 – 9, 68163 Mannheim\n\nKontakt:\nE-Mail: " + @@ -64,6 +76,11 @@ func InitRouter(databaseMgr managers.DatabaseMgr, mailMgr managers.MailMgr, jwtM // Initialize user routes r.Route("/api/v1/users", userRouter(&databaseMgr, &jwtMgr, &mailMgr)) + // Initialize feed routes + r.Route("/api/v1/feed", func(r chi.Router) { + r.Get("/", postHdl.HandleGetFeedRequest) + }) + // Initialize post routes r.Route("/api/v1/posts", func(r chi.Router) { r.Use(jwtMgr.JWTMiddleware) @@ -88,6 +105,7 @@ func userRouter(databaseMgr *managers.DatabaseMgr, jwtMgr *managers.JWTMgr, mail r.Put("/", userHdl.ChangeTrivialInformation) r.Patch("/", userHdl.ChangePassword) r.Post("/login", userHdl.LoginUser) + r.Post("/refresh", userHdl.RefreshToken) r.Post("/{username}/activate", userHdl.ActivateUser) r.Delete("/{username}/activate", userHdl.ResendToken) } diff --git a/internal/routing/router_test.go b/internal/routing/router_test.go index a5a86b1..5014b06 100644 --- a/internal/routing/router_test.go +++ b/internal/routing/router_test.go @@ -206,6 +206,7 @@ func TestUserLogin(t *testing.T) { poolMock.ExpectQuery("SELECT activated_at").WithArgs(tc.user.Username).WillReturnRows(pgxmock.NewRows([]string{"activated_at"}).AddRow("2006-01-02 15:04:05.999999999Z")) poolMock.ExpectQuery("SELECT password, user_id").WithArgs(tc.user.Username).WillReturnRows(pgxmock.NewRows([]string{"password", "user_id"}).AddRow(tc.user.HashedPassword, tc.user.UserId)) poolMock.ExpectQuery("SELECT email, user_id").WithArgs(tc.user.Username).WillReturnRows(pgxmock.NewRows([]string{"email", "user_id"}).AddRow(tc.user.Email, tc.user.UserId)) + poolMock.ExpectCommit() // Assert that the response status code is 200 and the response body contains the expected values expect := httpexpect.Default(t, server.URL) diff --git a/internal/schemas/dto_schema.go b/internal/schemas/dto_schema.go index 5cae510..af51e93 100644 --- a/internal/schemas/dto_schema.go +++ b/internal/schemas/dto_schema.go @@ -32,10 +32,12 @@ type UserNicknameAndStatusDTO struct { Status string `json:"status"` } -// TokenDTO is a struct that represents a token response -// Token is the JWT token -type TokenDTO struct { - Token string `json:"token"` +// TokenPairDTO is a struct that represents a token response +// Token is the main JWT token used for auth +// RefreshToken is the refresh token used to get a new token +type TokenPairDTO struct { + Token string `json:"token"` + RefreshToken string `json:"refreshToken"` } // AuthorDTO is a struct that represents an author response @@ -45,7 +47,7 @@ type TokenDTO struct { type AuthorDTO struct { Username string `json:"username"` Nickname string `json:"nickname"` - ProfilePictureURL string `json:"profile_picture_url"` + ProfilePictureURL string `json:"profilePictureURL"` } // PostDTO is a struct that represents a post response @@ -54,10 +56,10 @@ type AuthorDTO struct { // Content is the content of the post // CreatedAt is the timestamp of when the post was created type PostDTO struct { - PostId string `json:"post_id"` - Author AuthorDTO `json:"author"` - Content string `json:"content"` - CreatedAt string `json:"created_at"` + PostId string `json:"postId"` + Author AuthorDTO `json:"author"` + CreationDate string `json:"creationDate"` + Content string `json:"content"` } /** **/ @@ -120,7 +122,7 @@ type SubscriptionDTO struct { type PaginatedResponse struct { Records interface{} `json:"records"` - Pagination Pagination `json:"pagination"` + Pagination interface{} `json:"pagination"` } type Pagination struct { @@ -129,6 +131,12 @@ type Pagination struct { Records int `json:"records"` } +type PostPagination struct { + LastPostId string `json:"lastPostId"` + Limit string `json:"limit"` + Records int `json:"records"` +} + // ChangeTrivialInformationRequest is a struct that represents a NicknameChange request // NewNickname is required and must be less than 25 characters // Status is required and must be less than 256 characters @@ -144,3 +152,9 @@ type ChangePasswordRequest struct { OldPassword string `json:"oldPassword" validate:"required,min=8,password_validation"` NewPassword string `json:"newPassword" validate:"required,min=8,password_validation"` } + +// RefreshTokenRequest is a struct that represents a RefreshToken request +// RefreshToken is required and must be a valid refresh token +type RefreshTokenRequest struct { + RefreshToken string `json:"refreshToken" validate:"required"` +} diff --git a/internal/utils/url_keys.go b/internal/utils/url_keys.go index 7e7e121..b387751 100644 --- a/internal/utils/url_keys.go +++ b/internal/utils/url_keys.go @@ -6,4 +6,6 @@ const ( UsernameParamKey = "username" OffsetParamKey = "offset" LimitParamKey = "limit" + PostIdParamKey = "postId" + FeedTypeParamKey = "feedType" )