Skip to content

Commit

Permalink
feat: Add AndCond/OrCond extensions condition method support. #72 fea…
Browse files Browse the repository at this point in the history
…t: Add streaming paginate support. #71 (#73)
  • Loading branch information
PhoenixL0911 authored Sep 14, 2023
1 parent eb9dcc4 commit 0d39664
Show file tree
Hide file tree
Showing 2 changed files with 299 additions and 1 deletion.
92 changes: 91 additions & 1 deletion gplus/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package gplus

import (
"database/sql"
"fmt"
"github.com/acmestack/gorm-plus/constants"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
"reflect"
"strings"
"time"
)

var globalDb *gorm.DB
Expand Down Expand Up @@ -51,6 +53,29 @@ func NewPage[T any](current, size int) *Page[T] {
return &Page[T]{Current: current, Size: size}
}

type Comparable interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | time.Time
}

type StreamingPage[T any, V Comparable] struct {
ColumnName any `json:"columnName"` // 进行分页的列字段名称
StartValue V `json:"startValue"` // 分页起始值
Limit int `json:"limit"` // 页大小
Forward bool `json:"forward"` // 上下页翻页标识
Total int64 `json:"total"` // 总记录数
Records []*T `json:"records"` // 查询记录
RecordsMap []T `json:"recordsMap"` // 查询记录Map
}

func NewStreamingPage[T any, V Comparable](columnName any, startValue V, limit int) *StreamingPage[T, V] {
return &StreamingPage[T, V]{
ColumnName: columnName,
StartValue: startValue,
Limit: limit,
Forward: true,
}
}

// Insert 插入一条记录
func Insert[T any](entity *T, opts ...OptionFunc) *gorm.DB {
db := getDb(opts...)
Expand Down Expand Up @@ -193,6 +218,26 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag
return page, resultDb
}

// SelectStreamingPage 根据条件分页查询记录
func SelectStreamingPage[T any, V Comparable](page *StreamingPage[T, V], q *QueryCond[T], opts ...OptionFunc) (*StreamingPage[T, V], *gorm.DB) {
option := getOption(opts)

// 如果需要分页忽略总数,不查询总数
if !option.IgnoreTotal {
total, countDb := SelectCount[T](q, opts...)
if countDb.Error != nil {
return page, countDb
}
page.Total = total
}

resultDb := buildCondition(q, opts...)
var results []*T
resultDb.Scopes(streamingPaginate(page)).Find(&results)
page.Records = results
return page, resultDb
}

// SelectCount 根据条件查询记录数量
func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) {
var count int64
Expand Down Expand Up @@ -237,6 +282,34 @@ func SelectPageGeneric[T any, R any](page *Page[R], q *QueryCond[T], opts ...Opt
return page, resultDb
}

// SelectStreamingPageGeneric 根据传入的泛型封装分页记录
// 第一个泛型代表数据库表实体
// 第二个泛型代表返回记录实体
func SelectStreamingPageGeneric[T any, R any, V Comparable](page *StreamingPage[R, V], q *QueryCond[T], opts ...OptionFunc) (*StreamingPage[R, V], *gorm.DB) {
option := getOption(opts)
// 如果需要分页忽略总数,不查询总数
if !option.IgnoreTotal {
total, countDb := SelectCount[T](q, opts...)
if countDb.Error != nil {
return page, countDb
}
page.Total = total
}
resultDb := buildCondition(q, opts...)
var r R
switch any(r).(type) {
case map[string]any:
var results []R
resultDb.Scopes(streamingPaginate(page)).Scan(&results)
page.RecordsMap = results
default:
var results []*R
resultDb.Scopes(streamingPaginate(page)).Scan(&results)
page.Records = results
}
return page, resultDb
}

// SelectGeneric 根据传入的泛型封装记录
// 第一个泛型代表数据库表实体
// 第二个泛型代表返回记录实体
Expand All @@ -251,12 +324,13 @@ func Begin(opts ...*sql.TxOptions) *gorm.DB {
return db.Begin(opts...)
}

// 事务
// Tx 事务
func Tx(txFunc func(tx *gorm.DB) error, opts ...OptionFunc) error {
db := getDb(opts...)
return db.Transaction(txFunc)
}

// paginate offset分页
func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB {
page := p.Current
pageSize := p.Size
Expand All @@ -272,6 +346,22 @@ func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB {
}
}

// streamingPaginate 流式分页,根据自增ID、雪花ID、时间等数值类型或者时间类型分页
// Tips: 相比于 offset 分页性能更好,走的是 range,缺点是没办法跳页查询
func streamingPaginate[T any, V Comparable](p *StreamingPage[T, V]) func(db *gorm.DB) *gorm.DB {
column := getColumnName(p.ColumnName)
startValue := p.StartValue
limit := p.Limit
return func(db *gorm.DB) *gorm.DB {
// 下一页
if p.Forward {
return db.Where(fmt.Sprintf("%v > ?", column), startValue).Limit(limit)
}
// 上一页
return db.Where(fmt.Sprintf("%v < ?", column), startValue).Order(fmt.Sprintf("%v DESC", column)).Limit(limit)
}
}

func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB {
db := getDb(opts...)
resultDb := db.Model(new(T))
Expand Down
208 changes: 208 additions & 0 deletions gplus/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,214 @@ func (q *QueryCond[T]) InCond(cond bool, column any, val any) *QueryCond[T] {
return q
}

// AndEqCond 并且等于 =
func (q *QueryCond[T]) AndEqCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Eq(column, val)
}
return q
}

// AndNeCond 并且不等于 !=
func (q *QueryCond[T]) AndNeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Ne(column, val)
}
return q
}

// AndGtCond 并且大于 >
func (q *QueryCond[T]) AndGtCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Gt(column, val)
}
return q
}

// AndGeCond 并且大于等于 >=
func (q *QueryCond[T]) AndGeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Ge(column, val)
}
return q
}

// AndLtCond 并且小于 <
func (q *QueryCond[T]) AndLtCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Lt(column, val)
}
return q
}

// AndLeCond 并且小于等于 <=
func (q *QueryCond[T]) AndLeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Le(column, val)
}
return q
}

// AndLikeCond 并且模糊 LIKE '%值%'
func (q *QueryCond[T]) AndLikeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().Like(column, val)
}
return q
}

// AndNotLikeCond 并且非模糊 NOT LIKE '%值%'
func (q *QueryCond[T]) AndNotLikeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().NotLike(column, val)
}
return q
}

// AndLikeLeftCond 并且左模糊 LIKE '%值'
func (q *QueryCond[T]) AndLikeLeftCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().LikeLeft(column, val)
}
return q
}

// AndNotLikeLeftCond 并且非左模糊 NOT LIKE '%值'
func (q *QueryCond[T]) AndNotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().NotLikeLeft(column, val)
}
return q
}

// AndLikeRightCond 并且右模糊 LIKE '值%'
func (q *QueryCond[T]) AndLikeRightCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().LikeRight(column, val)
}
return q
}

// AndNotLikeRightCond 并且非右模糊 NOT LIKE '值%'
func (q *QueryCond[T]) AndNotLikeRightCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().NotLikeRight(column, val)
}
return q
}

// AndInCond 并且字段 IN (值1, 值2, ...)
func (q *QueryCond[T]) AndInCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.And().In(column, val)
}
return q
}

// OrEqCond 或者等于 =
func (q *QueryCond[T]) OrEqCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Eq(column, val)
}
return q
}

// OrNeCond 或者不等于 !=
func (q *QueryCond[T]) OrNeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Ne(column, val)
}
return q
}

// OrGtCond 或者大于 >
func (q *QueryCond[T]) OrGtCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Gt(column, val)
}
return q
}

// OrGeCond 或者大于等于 >=
func (q *QueryCond[T]) OrGeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Ge(column, val)
}
return q
}

// OrLtCond 或者小于 <
func (q *QueryCond[T]) OrLtCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Lt(column, val)
}
return q
}

// OrLeCond 或者小于等于 <=
func (q *QueryCond[T]) OrLeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Le(column, val)
}
return q
}

// OrLikeCond 或者模糊 LIKE '%值%'
func (q *QueryCond[T]) OrLikeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().Like(column, val)
}
return q
}

// OrNotLikeCond 或者非模糊 NOT LIKE '%值%'
func (q *QueryCond[T]) OrNotLikeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().NotLike(column, val)
}
return q
}

// OrLikeLeftCond 或者左模糊 LIKE '%值'
func (q *QueryCond[T]) OrLikeLeftCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().LikeLeft(column, val)
}
return q
}

// OrNotLikeLeftCond 或者非左模糊 NOT LIKE '%值'
func (q *QueryCond[T]) OrNotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().NotLikeLeft(column, val)
}
return q
}

// OrLikeRightCond 或者右模糊 LIKE '值%'
func (q *QueryCond[T]) OrLikeRightCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().LikeRight(column, val)
}
return q
}

// OrNotLikeRightCond 或者非右模糊 NOT LIKE '值%'
func (q *QueryCond[T]) OrNotLikeRightCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().NotLikeRight(column, val)
}
return q
}

// OrInCond 或者字段 IN (值1, 值2, ...)
func (q *QueryCond[T]) OrInCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Or().In(column, val)
}
return q
}

func (q *QueryCond[T]) addExpression(sqlSegments ...SqlSegment) {
if len(sqlSegments) == 1 {
q.handleSingle(sqlSegments[0])
Expand Down

0 comments on commit 0d39664

Please sign in to comment.