Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions cursor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package pgkit

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"

sq "github.com/Masterminds/squirrel"
)

// ErrInvalidCursor signals a client-supplied cursor that failed to decode — map to 400, not 500.
var ErrInvalidCursor = errors.New("invalid cursor")

// EncodeCursor produces an opaque cursor: base64-JSON, not signed, never use it for authorization.
func EncodeCursor[C any](cursor C) (string, error) {
raw, err := json.Marshal(cursor)
if err != nil {
return "", fmt.Errorf("marshal cursor: %w", err)
}
return base64.RawURLEncoding.EncodeToString(raw), nil
}

// DecodeCursor returns (nil, nil) for empty input so callers can compose with a nil-check.
func DecodeCursor[C any](value string) (*C, error) {
if value == "" {
return nil, nil
}
raw, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return nil, ErrInvalidCursor
}
var cursor C
if err := json.Unmarshal(raw, &cursor); err != nil {
return nil, ErrInvalidCursor
}
return &cursor, nil
}

// Cursor is the interface a typed keyset cursor satisfies — mirrors pgkit.Record[T, I]'s self-pointer pattern.
type Cursor[Self any, Row any] interface {
*Self
Apply(sq.SelectBuilder) sq.SelectBuilder
From(Row) error
}

// CursorPaginator is the keyset sibling of Paginator[T] for ordering-stable pagination under concurrent writes.
// The caller owns ORDER BY; C.Apply must match it or pages will silently skip or duplicate rows.
type CursorPaginator[T any, C any, PC Cursor[C, T]] struct {
settings PaginatorSettings
}

// NewCursorPaginator honors only size options — WithSort / WithColumnFunc are no-ops because the caller owns ORDER BY.
func NewCursorPaginator[T any, C any, PC Cursor[C, T]](options ...PaginatorOption) CursorPaginator[T, C, PC] {
settings := &PaginatorSettings{
DefaultSize: DefaultPageSize,
MaxSize: MaxPageSize,
}
for _, option := range options {
option(settings)
}
if settings.MaxSize < settings.DefaultSize {
settings.MaxSize = settings.DefaultSize
}
return CursorPaginator[T, C, PC]{settings: *settings}
}

// PrepareQuery chains LIMIT n+1 so PrepareResult can detect a next page without a second round-trip.
func (p CursorPaginator[T, C, PC]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder, error) {
if page == nil {
page = &Page{}
}
page.SetDefaults(&p.settings)

if page.Cursor != "" {
cursor, err := DecodeCursor[C](page.Cursor)
if err != nil {
return nil, q, err
}
q = PC(cursor).Apply(q)
}

limit := page.Limit()
q = q.Limit(limit + 1)
return make([]T, 0, limit+1), q, nil
}

// PrepareResult must be called after GetAll to populate page.More and page.NextCursor.
func (p CursorPaginator[T, C, PC]) PrepareResult(result []T, page *Page) ([]T, error) {
limit := int(page.Limit())
page.Size = uint32(limit)
page.More = len(result) > limit
if !page.More {
return result, nil
}
result = result[:limit]

var cursor C
if err := PC(&cursor).From(result[len(result)-1]); err != nil {
return nil, fmt.Errorf("cursor from row: %w", err)
}
next, err := EncodeCursor(cursor)
if err != nil {
return nil, err
}
page.NextCursor = next
return result, nil
}
258 changes: 258 additions & 0 deletions cursor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
package pgkit_test

import (
"errors"
"strconv"
"strings"
"testing"

sq "github.com/Masterminds/squirrel"
"github.com/goware/pgkit/v2"
"github.com/stretchr/testify/require"
)

type row struct {
ID string
}

type rowCursor struct {
ID string `json:"id"`
}

func (c *rowCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder {
return q.Where(sq.Lt{"id": c.ID})
}

func (c *rowCursor) From(r row) error {
c.ID = r.ID
return nil
}

func TestEncodeDecodeCursorRoundTrip(t *testing.T) {
encoded, err := pgkit.EncodeCursor(rowCursor{ID: "row_1"})
require.NoError(t, err)
require.NotEmpty(t, encoded)

decoded, err := pgkit.DecodeCursor[rowCursor](encoded)
require.NoError(t, err)
require.NotNil(t, decoded)
require.Equal(t, "row_1", decoded.ID)
}

func TestDecodeCursorEmptyReturnsNil(t *testing.T) {
decoded, err := pgkit.DecodeCursor[rowCursor]("")
require.NoError(t, err)
require.Nil(t, decoded)
}

func TestDecodeCursorInvalidBase64(t *testing.T) {
_, err := pgkit.DecodeCursor[rowCursor]("!!!not-base64!!!")
require.Error(t, err)
require.True(t, errors.Is(err, pgkit.ErrInvalidCursor))
}

func TestDecodeCursorInvalidJSON(t *testing.T) {
encoded, err := pgkit.EncodeCursor("not a struct")
require.NoError(t, err)

_, err = pgkit.DecodeCursor[rowCursor](encoded)
require.Error(t, err)
require.True(t, errors.Is(err, pgkit.ErrInvalidCursor))
}

func TestCursorPaginatorFirstPage(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](
pgkit.WithDefaultSize(2),
pgkit.WithMaxSize(5),
)
page := &pgkit.Page{}

result, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)
require.Len(t, result, 0)
require.Equal(t, 3, cap(result))

sql, args, err := q.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t LIMIT 3", sql)
require.Empty(t, args)
}

func TestCursorPaginatorWithCursor(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2))
encoded, err := pgkit.EncodeCursor(rowCursor{ID: "row_5"})
require.NoError(t, err)
page := &pgkit.Page{Cursor: encoded}

_, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

sql, args, err := q.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t WHERE id < ? LIMIT 3", sql)
require.Equal(t, []any{"row_5"}, args)
}

func TestCursorPaginatorInvalidCursor(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]()
page := &pgkit.Page{Cursor: "!!!not-base64!!!"}

_, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.Error(t, err)
require.True(t, errors.Is(err, pgkit.ErrInvalidCursor))
}

func TestCursorPaginatorPrepareResultNoMore(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(3))
page := &pgkit.Page{}
_, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

result, err := paginator.PrepareResult([]row{{ID: "1"}, {ID: "2"}}, page)
require.NoError(t, err)
require.Len(t, result, 2)
require.False(t, page.More)
require.Empty(t, page.NextCursor)
require.Equal(t, uint32(3), page.Size)
}

func TestCursorPaginatorPrepareResultHasMore(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2))
page := &pgkit.Page{}
_, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

result, err := paginator.PrepareResult(
[]row{{ID: "3"}, {ID: "2"}, {ID: "1"}},
page,
)
require.NoError(t, err)
require.Equal(t, []row{{ID: "3"}, {ID: "2"}}, result)
require.True(t, page.More)
require.NotEmpty(t, page.NextCursor)

decoded, err := pgkit.DecodeCursor[rowCursor](page.NextCursor)
require.NoError(t, err)
require.NotNil(t, decoded)
require.Equal(t, "2", decoded.ID)
}

func TestCursorPaginatorDefaultsFromNilPage(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]()
_, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), nil)
require.NoError(t, err)

sql, _, err := q.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t LIMIT 11", sql)
}

func TestCursorPaginatorCapsAtMaxSize(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](
pgkit.WithDefaultSize(5),
pgkit.WithMaxSize(10),
)
page := &pgkit.Page{Size: 999}

_, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

sql, _, err := q.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t LIMIT 11", sql)
require.Equal(t, uint32(10), page.Size)
}

func TestCursorPaginatorMaxSizeBelowDefaultIsLifted(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](
pgkit.WithDefaultSize(20),
pgkit.WithMaxSize(5),
)
page := &pgkit.Page{}

_, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

sql, _, err := q.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t LIMIT 21", sql)
}

func TestCursorPaginatorWalksPages(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2))
all := []row{{ID: "5"}, {ID: "4"}, {ID: "3"}, {ID: "2"}, {ID: "1"}}

var (
page = &pgkit.Page{}
seen []row
)
for step := 0; step < 5; step++ {
_, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

fetched := fetch(t, all, q)
got, err := paginator.PrepareResult(fetched, page)
require.NoError(t, err)

seen = append(seen, got...)
if !page.More {
break
}
page.Cursor = page.NextCursor
page.NextCursor = ""
}
require.Equal(t, all, seen)
require.False(t, page.More)
}

type failingRowCursor struct {
ID string `json:"id"`
}

func (c *failingRowCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder {
return q.Where(sq.Lt{"id": c.ID})
}

var errBoom = errors.New("boom")

func (c *failingRowCursor) From(row) error {
return errBoom
}

func TestCursorPaginatorPrepareResultPropagatesCursorError(t *testing.T) {
paginator := pgkit.NewCursorPaginator[row, failingRowCursor, *failingRowCursor](pgkit.WithDefaultSize(1))
page := &pgkit.Page{}
_, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.NoError(t, err)

_, err = paginator.PrepareResult([]row{{ID: "2"}, {ID: "1"}}, page)
require.Error(t, err)
require.True(t, errors.Is(err, errBoom))
}

// In-memory stand-in so the pagination walk exercises encode/decode without a real database.
func fetch(t *testing.T, all []row, q sq.SelectBuilder) []row {
t.Helper()
sql, args, err := q.ToSql()
require.NoError(t, err)

limit, err := strconv.Atoi(sql[strings.LastIndex(sql, " ")+1:])
require.NoError(t, err)

cutoff := ""
if len(args) == 1 {
cutoff = args[0].(string)
}

out := make([]row, 0, limit)
for _, r := range all {
if cutoff != "" && r.ID >= cutoff {
continue
}
out = append(out, r)
if len(out) == limit {
break
}
}
return out
}
4 changes: 4 additions & 0 deletions page.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ type Page struct {
More bool
Column string
Sort []Sort

// Unused by the offset Paginator — shared here so callers can swap paginators without changing the Page type.
Cursor string
NextCursor string
}

func NewPage(size, page uint32, sort ...Sort) *Page {
Expand Down
Loading