From 35c14975d035fc3cd572e265af6d7fe7e9c612d2 Mon Sep 17 00:00:00 2001 From: Kevin Rieger Date: Tue, 16 Jan 2024 21:02:18 +0100 Subject: [PATCH] fix remaining bugs --- internal/routing/handlers/post_handler.go | 43 +++++++++++++---------- internal/routing/handlers/user_handler.go | 8 +++-- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/internal/routing/handlers/post_handler.go b/internal/routing/handlers/post_handler.go index 5bb3653..5fe01e6 100644 --- a/internal/routing/handlers/post_handler.go +++ b/internal/routing/handlers/post_handler.go @@ -2,13 +2,13 @@ package handlers import ( "context" + "fmt" "github.com/jackc/pgx/v5" "net/http" "regexp" "server-alpha/internal/managers" "server-alpha/internal/schemas" "server-alpha/internal/utils" - "strconv" "time" "github.com/golang-jwt/jwt/v5" @@ -160,7 +160,7 @@ func (handler *PostHandler) HandleGetFeedRequest(w http.ResponseWriter, r *http. Pagination: pagination, } - if err := utils.CommitTransaction(w, nil, transactionCtx, nil); err != nil { + if err := utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil { utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) } @@ -193,26 +193,34 @@ func determineFeedType(r *http.Request, w http.ResponseWriter, handler *PostHand return publicFeedWanted, claims, nil } -func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http.Request, private bool, claims jwt.Claims) ([]*schemas.PostDTO, int, string, string, error) { +func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http.Request, publicFeedWanted bool, claims jwt.Claims) ([]*schemas.PostDTO, int, string, string, error) { lastPostId := r.URL.Query().Get(utils.PostIdParamKey) limit := r.URL.Query().Get(utils.LimitParamKey) - var userId interface{} + if limit == "" { + limit = "10" + } + + currentDataQueryIndex := 1 + currentCountQueryIndex := 1 + var userId string var countQuery string var dataQuery string var countQueryArgs []interface{} var dataQueryArgs []interface{} - if private { - userId = claims.(jwt.MapClaims)["sub"] + if !publicFeedWanted { + userId = claims.(jwt.MapClaims)["sub"].(string) countQuery = `SELECT COUNT(*) FROM alpha_schema.posts INNER JOIN alpha_schema.users ON author_id = user_id INNER JOIN alpha_schema.subscriptions ON user_id = subscriptions.subscriber_id - WHERE subscriptions.subscriber_id = $1` + WHERE subscriptions.subscriber_id` + fmt.Sprintf(" = $%d", currentCountQueryIndex) + currentCountQueryIndex++ dataQuery = `SELECT post_id, username, nickname, profile_picture_url, content, posts.created_at FROM alpha_schema.posts INNER JOIN alpha_schema.users ON author_id = user_id INNER JOIN alpha_schema.subscriptions ON user_id = subscriptions.subscriber_id - WHERE subscriptions.subscriber_id = $1` + WHERE subscriptions.subscriber_id` + fmt.Sprintf(" = $%d", currentDataQueryIndex) + currentDataQueryIndex++ countQueryArgs = append(countQueryArgs, userId) dataQueryArgs = append(dataQueryArgs, userId) @@ -236,11 +244,13 @@ func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http if lastPostId == "" { // If we don't have a last post ID, we'll return the newest posts - dataQuery += " ORDER BY created_at DESC LIMIT $2" + dataQuery += " ORDER BY created_at DESC LIMIT" + fmt.Sprintf(" $%d", currentDataQueryIndex) dataQueryArgs = append(dataQueryArgs, limit) } else { // If we have a last post ID, we'll return the posts that were created before the last post - dataQuery += " AND posts.created_at < (SELECT created_at FROM alpha_schema.posts WHERE post_id = $2) ORDER BY created_at DESC LIMIT $3" + dataQuery += " AND posts.created_at < (SELECT created_at FROM alpha_schema.posts WHERE post_id = " + + fmt.Sprintf("$%d) ", currentDataQueryIndex) + "ORDER BY created_at DESC LIMIT " + + fmt.Sprintf("$%d", currentDataQueryIndex+1) dataQueryArgs = append(dataQueryArgs, lastPostId, limit) } @@ -251,21 +261,18 @@ func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http } // Iterate over the rows and create the post DTOs - intLimit, err := strconv.Atoi(limit) - if err != nil { - utils.WriteAndLogError(w, schemas.BadRequest, http.StatusBadRequest, err) - return nil, 0, "", "", err - } - - posts := make([]*schemas.PostDTO, intLimit) + posts := make([]*schemas.PostDTO, 0) for rows.Next() { post := &schemas.PostDTO{} - if err := rows.Scan(&post.PostId, &post.Author.Username, &post.Author.Nickname, &post.Author.ProfilePictureURL, &post.Content, &post.CreationDate); err != nil { + var createdAt time.Time + + if err := rows.Scan(&post.PostId, &post.Author.Username, &post.Author.Nickname, &post.Author.ProfilePictureURL, &post.Content, &createdAt); err != nil { utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err) return nil, 0, "", "", err } + post.CreationDate = createdAt.Format(time.RFC3339) posts = append(posts, post) } diff --git a/internal/routing/handlers/user_handler.go b/internal/routing/handlers/user_handler.go index f18c5f1..1e4e8fb 100644 --- a/internal/routing/handlers/user_handler.go +++ b/internal/routing/handlers/user_handler.go @@ -52,6 +52,8 @@ func NewUserHandler(databaseManager *managers.DatabaseMgr, jwtManager *managers. } } +var errInvalidToken = errors.New("invalid token") + // RegisterUser registers a new user and sends an activation token to the user's email. func (handler *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) { // Begin a new transaction @@ -710,7 +712,7 @@ func (handler *UserHandler) RefreshToken(w http.ResponseWriter, r *http.Request) isRefreshToken := refreshClaims["refresh"].(string) if isRefreshToken != "true" { - utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, errors.New("invalid token")) + utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, errInvalidToken) return } @@ -796,8 +798,8 @@ func checkTokenValidity(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, t defer rows.Close() if !rows.Next() { - utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, errors.New("invalid token")) - return errors.New("invalid token") + utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, errInvalidToken) + return errInvalidToken } var expiresAt pgtype.Timestamptz