Skip to content
This repository has been archived by the owner on Jun 2, 2024. It is now read-only.

Commit

Permalink
fix remaining bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Frecherenkel60 committed Jan 16, 2024
1 parent b165e9c commit 35c1497
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
43 changes: 25 additions & 18 deletions internal/routing/handlers/post_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package handlers

import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"net/http"
"regexp"
"server-alpha/internal/managers"
"server-alpha/internal/schemas"
"server-alpha/internal/utils"
"strconv"
"time"

"github.com/golang-jwt/jwt/v5"
Expand Down Expand Up @@ -160,7 +160,7 @@ func (handler *PostHandler) HandleGetFeedRequest(w http.ResponseWriter, r *http.
Pagination: pagination,
}

if err := utils.CommitTransaction(w, nil, transactionCtx, nil); err != nil {
if err := utils.CommitTransaction(w, tx, transactionCtx, cancel); err != nil {
utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err)
}

Expand Down Expand Up @@ -193,26 +193,34 @@ func determineFeedType(r *http.Request, w http.ResponseWriter, handler *PostHand
return publicFeedWanted, claims, nil
}

func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http.Request, private bool, claims jwt.Claims) ([]*schemas.PostDTO, int, string, string, error) {
func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http.Request, publicFeedWanted bool, claims jwt.Claims) ([]*schemas.PostDTO, int, string, string, error) {
lastPostId := r.URL.Query().Get(utils.PostIdParamKey)
limit := r.URL.Query().Get(utils.LimitParamKey)

var userId interface{}
if limit == "" {
limit = "10"
}

currentDataQueryIndex := 1
currentCountQueryIndex := 1
var userId string
var countQuery string
var dataQuery string
var countQueryArgs []interface{}
var dataQueryArgs []interface{}

if private {
userId = claims.(jwt.MapClaims)["sub"]
if !publicFeedWanted {
userId = claims.(jwt.MapClaims)["sub"].(string)
countQuery = `SELECT COUNT(*) FROM alpha_schema.posts
INNER JOIN alpha_schema.users ON author_id = user_id
INNER JOIN alpha_schema.subscriptions ON user_id = subscriptions.subscriber_id
WHERE subscriptions.subscriber_id = $1`
WHERE subscriptions.subscriber_id` + fmt.Sprintf(" = $%d", currentCountQueryIndex)
currentCountQueryIndex++
dataQuery = `SELECT post_id, username, nickname, profile_picture_url, content, posts.created_at FROM alpha_schema.posts
INNER JOIN alpha_schema.users ON author_id = user_id
INNER JOIN alpha_schema.subscriptions ON user_id = subscriptions.subscriber_id
WHERE subscriptions.subscriber_id = $1`
WHERE subscriptions.subscriber_id` + fmt.Sprintf(" = $%d", currentDataQueryIndex)
currentDataQueryIndex++

countQueryArgs = append(countQueryArgs, userId)
dataQueryArgs = append(dataQueryArgs, userId)
Expand All @@ -236,11 +244,13 @@ func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http

if lastPostId == "" {
// If we don't have a last post ID, we'll return the newest posts
dataQuery += " ORDER BY created_at DESC LIMIT $2"
dataQuery += " ORDER BY created_at DESC LIMIT" + fmt.Sprintf(" $%d", currentDataQueryIndex)
dataQueryArgs = append(dataQueryArgs, limit)
} else {
// If we have a last post ID, we'll return the posts that were created before the last post
dataQuery += " AND posts.created_at < (SELECT created_at FROM alpha_schema.posts WHERE post_id = $2) ORDER BY created_at DESC LIMIT $3"
dataQuery += " AND posts.created_at < (SELECT created_at FROM alpha_schema.posts WHERE post_id = " +
fmt.Sprintf("$%d) ", currentDataQueryIndex) + "ORDER BY created_at DESC LIMIT " +
fmt.Sprintf("$%d", currentDataQueryIndex+1)
dataQueryArgs = append(dataQueryArgs, lastPostId, limit)
}

Expand All @@ -251,21 +261,18 @@ func retrieveFeed(ctx context.Context, tx pgx.Tx, w http.ResponseWriter, r *http
}

// Iterate over the rows and create the post DTOs
intLimit, err := strconv.Atoi(limit)
if err != nil {
utils.WriteAndLogError(w, schemas.BadRequest, http.StatusBadRequest, err)
return nil, 0, "", "", err
}

posts := make([]*schemas.PostDTO, intLimit)
posts := make([]*schemas.PostDTO, 0)

for rows.Next() {
post := &schemas.PostDTO{}
if err := rows.Scan(&post.PostId, &post.Author.Username, &post.Author.Nickname, &post.Author.ProfilePictureURL, &post.Content, &post.CreationDate); err != nil {
var createdAt time.Time

if err := rows.Scan(&post.PostId, &post.Author.Username, &post.Author.Nickname, &post.Author.ProfilePictureURL, &post.Content, &createdAt); err != nil {
utils.WriteAndLogError(w, schemas.DatabaseError, http.StatusInternalServerError, err)
return nil, 0, "", "", err
}

post.CreationDate = createdAt.Format(time.RFC3339)
posts = append(posts, post)
}

Expand Down
8 changes: 5 additions & 3 deletions internal/routing/handlers/user_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func NewUserHandler(databaseManager *managers.DatabaseMgr, jwtManager *managers.
}
}

var errInvalidToken = errors.New("invalid token")

// RegisterUser registers a new user and sends an activation token to the user's email.
func (handler *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request) {
// Begin a new transaction
Expand Down Expand Up @@ -710,7 +712,7 @@ func (handler *UserHandler) RefreshToken(w http.ResponseWriter, r *http.Request)
isRefreshToken := refreshClaims["refresh"].(string)

if isRefreshToken != "true" {
utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, errors.New("invalid token"))
utils.WriteAndLogError(w, schemas.Unauthorized, http.StatusUnauthorized, errInvalidToken)
return
}

Expand Down Expand Up @@ -796,8 +798,8 @@ func checkTokenValidity(ctx context.Context, w http.ResponseWriter, tx pgx.Tx, t
defer rows.Close()

if !rows.Next() {
utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, errors.New("invalid token"))
return errors.New("invalid token")
utils.WriteAndLogError(w, schemas.InvalidToken, http.StatusUnauthorized, errInvalidToken)
return errInvalidToken
}

var expiresAt pgtype.Timestamptz
Expand Down

0 comments on commit 35c1497

Please sign in to comment.