Skip to content

Commit

Permalink
Fix hanging DB connection issue
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidArthurCole committed Jul 23, 2024
1 parent 1eefedb commit ac0776f
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 114 deletions.
6 changes: 3 additions & 3 deletions data.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func fetchFirstContactWithContext(ctx context.Context, playerId string) (*ei.Egg
}
timestamp := fc.GetBackup().GetSettings().GetLastBackupTime()
if timestamp != 0 {
if err := db.InsertBackup(playerId, timestamp, payload, 12*time.Hour); err != nil {
if err := db.InsertBackup(ctx, playerId, timestamp, payload, 12*time.Hour); err != nil {
// Treat as non-fatal error for now.
log.Error(err)
}
Expand All @@ -56,7 +56,7 @@ func fetchCompleteMissionWithContext(ctx context.Context, playerId string, missi
wrap := func(err error) error {
return errors.Wrap(err, "error "+action)
}
resp, err := db.RetrieveCompleteMission(playerId, missionId)
resp, err := db.RetrieveCompleteMission(ctx, playerId, missionId)
if err != nil {
return nil, wrap(err)
}
Expand All @@ -77,6 +77,6 @@ func fetchCompleteMissionWithContext(ctx context.Context, playerId string, missi
if len(resp.GetArtifacts()) == 0 {
return nil, wrap(errors.New("no artifact found in server response"))
}
err = db.InsertCompleteMission(playerId, missionId, startTimestamp, payload)
err = db.InsertCompleteMission(ctx, playerId, missionId, startTimestamp, payload)
return resp, err
}
200 changes: 100 additions & 100 deletions db/crud.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"database/sql"
"fmt"
"time"
Expand All @@ -12,78 +13,79 @@ import (
"github.com/DavidArthurCole/EggLedger/ei"
)

func InsertBackup(playerId string, timestamp float64, payload []byte, minimumTimeSinceLastEntry time.Duration) error {
func InsertBackup(ctx context.Context, playerId string, timestamp float64, payload []byte, minimumTimeSinceLastEntry time.Duration) error {
action := fmt.Sprintf("insert backup for player %s into database", playerId)
compressedPayload, err := compress(payload)
if err != nil {
return errors.Wrap(err, action)
}
return transact(action, func(tx *sql.Tx) error {
var previousTimestamp float64
if minimumTimeSinceLastEntry.Seconds() > 0 {
row := tx.QueryRow(`SELECT backed_up_at FROM backup
WHERE player_id = ?
ORDER BY backed_up_at DESC LIMIT 1;`,
playerId)
err := row.Scan(&previousTimestamp)
switch {
case err == sql.ErrNoRows:
// No stored backup
case err != nil:
return DoDBOperation(ctx, func(ctx context.Context, db *sql.DB) error {
return transact(ctx, action, func(tx *sql.Tx) error {
var previousTimestamp float64
if minimumTimeSinceLastEntry.Seconds() > 0 {
row := tx.QueryRowContext(ctx, `SELECT backed_up_at FROM backup
WHERE player_id = ?
ORDER BY backed_up_at DESC LIMIT 1;`, playerId)
err := row.Scan(&previousTimestamp)
switch {
case err == sql.ErrNoRows:
// No stored backup
case err != nil:
return err
}
}
timeSinceLastEntry := time.Duration(timestamp-previousTimestamp) * time.Second
if timeSinceLastEntry < minimumTimeSinceLastEntry {
log.Infof("%s: %s since last recorded backup, ignoring", playerId, timeSinceLastEntry)
return nil
}
_, err = tx.ExecContext(ctx, `INSERT INTO
backup(player_id, backed_up_at, payload, payload_authenticated)
VALUES (?, ?, ?, FALSE);`, playerId, timestamp, compressedPayload)
if err != nil {
return err
}
}
timeSinceLastEntry := time.Duration(timestamp-previousTimestamp) * time.Second
if timeSinceLastEntry < minimumTimeSinceLastEntry {
log.Infof("%s: %s since last recorded backup, ignoring", playerId, timeSinceLastEntry)
return nil
}
_, err = tx.Exec(`INSERT INTO
backup(player_id, backed_up_at, payload, payload_authenticated)
VALUES (?, ?, ?, FALSE);`,
playerId, timestamp, compressedPayload)
if err != nil {
return err
}
return nil
})
})
}

func InsertCompleteMission(playerId string, missionId string, startTimestamp float64, completePayload []byte) error {
func InsertCompleteMission(ctx context.Context, playerId string, missionId string, startTimestamp float64, completePayload []byte) error {
action := fmt.Sprintf("insert mission %s for player %s into database", missionId, playerId)
compressedPayload, err := compress(completePayload)
if err != nil {
return errors.Wrap(err, action)
}
return transact(action, func(tx *sql.Tx) error {
_, err := tx.Exec(`INSERT INTO
mission(player_id, mission_id, start_timestamp, complete_payload)
VALUES (?, ?, ?, ?);`,
playerId, missionId, startTimestamp, compressedPayload)
if err != nil {
return err
}
return nil
return DoDBOperation(ctx, func(ctx context.Context, db *sql.DB) error {
return transact(ctx, action, func(tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `INSERT INTO
mission(player_id, mission_id, start_timestamp, complete_payload)
VALUES (?, ?, ?, ?);`, playerId, missionId, startTimestamp, compressedPayload)
if err != nil {
return err
}
return nil
})
})
}

// RetrieveCompleteMission returns the stored CompleteMissionResponse, or nil if not found.
func RetrieveCompleteMission(playerId string, missionId string) (*ei.CompleteMissionResponse, error) {
func RetrieveCompleteMission(ctx context.Context, playerId string, missionId string) (*ei.CompleteMissionResponse, error) {
action := fmt.Sprintf("retrieve mission %s for player %s from database", missionId, playerId)
var startTimestamp float64
var compressedPayload []byte
err := transact(action, func(tx *sql.Tx) error {
row := tx.QueryRow(`SELECT start_timestamp, complete_payload FROM mission
WHERE player_id = ? AND mission_id = ?;`,
playerId, missionId)
err := row.Scan(&startTimestamp, &compressedPayload)
switch {
case err == sql.ErrNoRows:
// No such mission
case err != nil:
return err
}
return nil
err := DoDBOperation(ctx, func(ctx context.Context, db *sql.DB) error {
return transact(ctx, action, func(tx *sql.Tx) error {
row := tx.QueryRowContext(ctx, `SELECT start_timestamp, complete_payload FROM mission
WHERE player_id = ? AND mission_id = ?;`, playerId, missionId)
err := row.Scan(&startTimestamp, &compressedPayload)
switch {
case err == sql.ErrNoRows:
// No such mission
case err != nil:
return err
}
return nil
})
})
if err != nil {
return nil, err
Expand All @@ -99,44 +101,42 @@ func RetrieveCompleteMission(playerId string, missionId string) (*ei.CompleteMis
if err != nil {
return nil, errors.Wrap(err, action)
}
// /ei_afx/complete_mission response leaves out start_time_derived, so we
// have to manually attach it.
m.Info.StartTimeDerived = &startTimestamp
return m, nil
}

// RetrievePlayerCompleteMissions retrieves stored completed missions for a
// player, in chronological order.
func RetrievePlayerCompleteMissions(playerId string) ([]*ei.CompleteMissionResponse, error) {
func RetrievePlayerCompleteMissions(ctx context.Context, playerId string) ([]*ei.CompleteMissionResponse, error) {
action := fmt.Sprintf("retrieve complete missions for player %s from database", playerId)
var count int
var startTimestamps []float64
var compressedPayloads [][]byte
trerr := transact(action, func(tx *sql.Tx) error {
rows, querr := tx.Query(`SELECT start_timestamp, complete_payload FROM mission
WHERE player_id = ?
ORDER BY start_timestamp;`, playerId)
if querr != nil {
return querr
}
defer rows.Close()
for rows.Next() {
var startTimestamp float64
var compressedPayload []byte
if scerr := rows.Scan(&startTimestamp, &compressedPayload); scerr != nil {
return scerr
err := DoDBOperation(ctx, func(ctx context.Context, db *sql.DB) error {
return transact(ctx, action, func(tx *sql.Tx) error {
rows, querr := tx.QueryContext(ctx, `SELECT start_timestamp, complete_payload FROM mission
WHERE player_id = ?
ORDER BY start_timestamp;`, playerId)
if querr != nil {
return querr
}
count++
startTimestamps = append(startTimestamps, startTimestamp)
compressedPayloads = append(compressedPayloads, compressedPayload)
}
if rerr := rows.Err(); rerr != nil {
return rerr
}
return nil
defer rows.Close()
for rows.Next() {
var startTimestamp float64
var compressedPayload []byte
if scerr := rows.Scan(&startTimestamp, &compressedPayload); scerr != nil {
return scerr
}
count++
startTimestamps = append(startTimestamps, startTimestamp)
compressedPayloads = append(compressedPayloads, compressedPayload)
}
if rerr := rows.Err(); rerr != nil {
return rerr
}
return nil
})
})
if trerr != nil {
return nil, trerr
if err != nil {
return nil, err
}
var missions []*ei.CompleteMissionResponse
for i := 0; i < count; i++ {
Expand All @@ -154,39 +154,39 @@ func RetrievePlayerCompleteMissions(playerId string) ([]*ei.CompleteMissionRespo
return missions, nil
}

// RetrievePlayerCompleteMissionIds retrieves IDs of stored completed missions
// for a player, in chronological order.
func RetrievePlayerCompleteMissionIds(playerId string) ([]string, error) {
func RetrievePlayerCompleteMissionIds(ctx context.Context, playerId string) ([]string, error) {
action := fmt.Sprintf("retrieve complete mission ids for player %s from database", playerId)
var missionIds []string
err := transact(action, func(tx *sql.Tx) error {
rows, err := tx.Query(`SELECT mission_id FROM mission
WHERE player_id = ?
ORDER BY start_timestamp;`, playerId)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var missionId string
if err := rows.Scan(&missionId); err != nil {
err := DoDBOperation(ctx, func(ctx context.Context, db *sql.DB) error {
return transact(ctx, action, func(tx *sql.Tx) error {
rows, err := tx.QueryContext(ctx, `SELECT mission_id FROM mission
WHERE player_id = ?
ORDER BY start_timestamp;`, playerId)
if err != nil {
return err
}
missionIds = append(missionIds, missionId)
}
if err := rows.Err(); err != nil {
return err
}
return nil
defer rows.Close()
for rows.Next() {
var missionId string
if err := rows.Scan(&missionId); err != nil {
return err
}
missionIds = append(missionIds, missionId)
}
if err := rows.Err(); err != nil {
return err
}
return nil
})
})
if err != nil {
return nil, err
}
return missionIds, nil
}

func transact(description string, txFunc func(*sql.Tx) error) (err error) {
tx, err := _db.Begin()
func transact(ctx context.Context, description string, txFunc func(*sql.Tx) error) (err error) {
tx, err := _db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, description)
}
Expand Down
28 changes: 28 additions & 0 deletions db/init.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"database/sql"
"os"
"path/filepath"
Expand All @@ -14,6 +15,9 @@ import (
var (
_db *sql.DB
_initDBOnce sync.Once
_dbWg sync.WaitGroup
_dbCtx context.Context
_dbCancel context.CancelFunc
)

func InitDB(path string) error {
Expand All @@ -39,7 +43,31 @@ func InitDB(path string) error {
err = errors.Wrapf(err, "failed to open SQLite3 database %#v", path)
return
}

_dbCtx, _dbCancel = context.WithCancel(context.Background())
err = nil
})
return err
}

func CloseDB() error {
if _db != nil {
_dbCancel() // cancel all ongoing operations
_dbWg.Wait() // wait for all operations to complete
return _db.Close()
}
return nil
}

// A function to perform a database operation with context and wait group
func DoDBOperation(ctx context.Context, operation func(ctx context.Context, db *sql.DB) error) error {
_dbWg.Add(1)
defer _dbWg.Done()

// Create a child context that will be canceled if the parent context is canceled
childCtx, cancel := context.WithCancel(_dbCtx)
defer cancel()

// Run the operation
return operation(childCtx, _db)
}
Loading

0 comments on commit ac0776f

Please sign in to comment.