diff --git a/cursor.go b/cursor.go new file mode 100644 index 0000000..49f4814 --- /dev/null +++ b/cursor.go @@ -0,0 +1,108 @@ +package pgkit + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + + sq "github.com/Masterminds/squirrel" +) + +// ErrInvalidCursor signals a client-supplied cursor that failed to decode — map to 400, not 500. +var ErrInvalidCursor = errors.New("invalid cursor") + +// EncodeCursor produces an opaque cursor: base64-JSON, not signed, never use it for authorization. +func EncodeCursor[C any](cursor C) (string, error) { + raw, err := json.Marshal(cursor) + if err != nil { + return "", fmt.Errorf("marshal cursor: %w", err) + } + return base64.RawURLEncoding.EncodeToString(raw), nil +} + +// DecodeCursor returns (nil, nil) for empty input so callers can compose with a nil-check. +func DecodeCursor[C any](value string) (*C, error) { + if value == "" { + return nil, nil + } + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return nil, ErrInvalidCursor + } + var cursor C + if err := json.Unmarshal(raw, &cursor); err != nil { + return nil, ErrInvalidCursor + } + return &cursor, nil +} + +// Cursor is the interface a typed keyset cursor satisfies — mirrors pgkit.Record[T, I]'s self-pointer pattern. +type Cursor[Self any, Row any] interface { + *Self + Apply(sq.SelectBuilder) sq.SelectBuilder + From(Row) error +} + +// CursorPaginator is the keyset sibling of Paginator[T] for ordering-stable pagination under concurrent writes. +// The caller owns ORDER BY; C.Apply must match it or pages will silently skip or duplicate rows. +type CursorPaginator[T any, C any, PC Cursor[C, T]] struct { + settings PaginatorSettings +} + +// NewCursorPaginator honors only size options — WithSort / WithColumnFunc are no-ops because the caller owns ORDER BY. +func NewCursorPaginator[T any, C any, PC Cursor[C, T]](options ...PaginatorOption) CursorPaginator[T, C, PC] { + settings := &PaginatorSettings{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + } + for _, option := range options { + option(settings) + } + if settings.MaxSize < settings.DefaultSize { + settings.MaxSize = settings.DefaultSize + } + return CursorPaginator[T, C, PC]{settings: *settings} +} + +// PrepareQuery chains LIMIT n+1 so PrepareResult can detect a next page without a second round-trip. +func (p CursorPaginator[T, C, PC]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder, error) { + if page == nil { + page = &Page{} + } + page.SetDefaults(&p.settings) + + if page.Cursor != "" { + cursor, err := DecodeCursor[C](page.Cursor) + if err != nil { + return nil, q, err + } + q = PC(cursor).Apply(q) + } + + limit := page.Limit() + q = q.Limit(limit + 1) + return make([]T, 0, limit+1), q, nil +} + +// PrepareResult must be called after GetAll to populate page.More and page.NextCursor. +func (p CursorPaginator[T, C, PC]) PrepareResult(result []T, page *Page) ([]T, error) { + limit := int(page.Limit()) + page.Size = uint32(limit) + page.More = len(result) > limit + if !page.More { + return result, nil + } + result = result[:limit] + + var cursor C + if err := PC(&cursor).From(result[len(result)-1]); err != nil { + return nil, fmt.Errorf("cursor from row: %w", err) + } + next, err := EncodeCursor(cursor) + if err != nil { + return nil, err + } + page.NextCursor = next + return result, nil +} diff --git a/cursor_test.go b/cursor_test.go new file mode 100644 index 0000000..0ff530a --- /dev/null +++ b/cursor_test.go @@ -0,0 +1,258 @@ +package pgkit_test + +import ( + "errors" + "strconv" + "strings" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" + "github.com/stretchr/testify/require" +) + +type row struct { + ID string +} + +type rowCursor struct { + ID string `json:"id"` +} + +func (c *rowCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder { + return q.Where(sq.Lt{"id": c.ID}) +} + +func (c *rowCursor) From(r row) error { + c.ID = r.ID + return nil +} + +func TestEncodeDecodeCursorRoundTrip(t *testing.T) { + encoded, err := pgkit.EncodeCursor(rowCursor{ID: "row_1"}) + require.NoError(t, err) + require.NotEmpty(t, encoded) + + decoded, err := pgkit.DecodeCursor[rowCursor](encoded) + require.NoError(t, err) + require.NotNil(t, decoded) + require.Equal(t, "row_1", decoded.ID) +} + +func TestDecodeCursorEmptyReturnsNil(t *testing.T) { + decoded, err := pgkit.DecodeCursor[rowCursor]("") + require.NoError(t, err) + require.Nil(t, decoded) +} + +func TestDecodeCursorInvalidBase64(t *testing.T) { + _, err := pgkit.DecodeCursor[rowCursor]("!!!not-base64!!!") + require.Error(t, err) + require.True(t, errors.Is(err, pgkit.ErrInvalidCursor)) +} + +func TestDecodeCursorInvalidJSON(t *testing.T) { + encoded, err := pgkit.EncodeCursor("not a struct") + require.NoError(t, err) + + _, err = pgkit.DecodeCursor[rowCursor](encoded) + require.Error(t, err) + require.True(t, errors.Is(err, pgkit.ErrInvalidCursor)) +} + +func TestCursorPaginatorFirstPage(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]( + pgkit.WithDefaultSize(2), + pgkit.WithMaxSize(5), + ) + page := &pgkit.Page{} + + result, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + require.Len(t, result, 0) + require.Equal(t, 3, cap(result)) + + sql, args, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t LIMIT 3", sql) + require.Empty(t, args) +} + +func TestCursorPaginatorWithCursor(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2)) + encoded, err := pgkit.EncodeCursor(rowCursor{ID: "row_5"}) + require.NoError(t, err) + page := &pgkit.Page{Cursor: encoded} + + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + sql, args, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t WHERE id < ? LIMIT 3", sql) + require.Equal(t, []any{"row_5"}, args) +} + +func TestCursorPaginatorInvalidCursor(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + page := &pgkit.Page{Cursor: "!!!not-base64!!!"} + + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.Error(t, err) + require.True(t, errors.Is(err, pgkit.ErrInvalidCursor)) +} + +func TestCursorPaginatorPrepareResultNoMore(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(3)) + page := &pgkit.Page{} + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + result, err := paginator.PrepareResult([]row{{ID: "1"}, {ID: "2"}}, page) + require.NoError(t, err) + require.Len(t, result, 2) + require.False(t, page.More) + require.Empty(t, page.NextCursor) + require.Equal(t, uint32(3), page.Size) +} + +func TestCursorPaginatorPrepareResultHasMore(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2)) + page := &pgkit.Page{} + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + result, err := paginator.PrepareResult( + []row{{ID: "3"}, {ID: "2"}, {ID: "1"}}, + page, + ) + require.NoError(t, err) + require.Equal(t, []row{{ID: "3"}, {ID: "2"}}, result) + require.True(t, page.More) + require.NotEmpty(t, page.NextCursor) + + decoded, err := pgkit.DecodeCursor[rowCursor](page.NextCursor) + require.NoError(t, err) + require.NotNil(t, decoded) + require.Equal(t, "2", decoded.ID) +} + +func TestCursorPaginatorDefaultsFromNilPage(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), nil) + require.NoError(t, err) + + sql, _, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t LIMIT 11", sql) +} + +func TestCursorPaginatorCapsAtMaxSize(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]( + pgkit.WithDefaultSize(5), + pgkit.WithMaxSize(10), + ) + page := &pgkit.Page{Size: 999} + + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + sql, _, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t LIMIT 11", sql) + require.Equal(t, uint32(10), page.Size) +} + +func TestCursorPaginatorMaxSizeBelowDefaultIsLifted(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]( + pgkit.WithDefaultSize(20), + pgkit.WithMaxSize(5), + ) + page := &pgkit.Page{} + + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + sql, _, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t LIMIT 21", sql) +} + +func TestCursorPaginatorWalksPages(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2)) + all := []row{{ID: "5"}, {ID: "4"}, {ID: "3"}, {ID: "2"}, {ID: "1"}} + + var ( + page = &pgkit.Page{} + seen []row + ) + for step := 0; step < 5; step++ { + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + fetched := fetch(t, all, q) + got, err := paginator.PrepareResult(fetched, page) + require.NoError(t, err) + + seen = append(seen, got...) + if !page.More { + break + } + page.Cursor = page.NextCursor + page.NextCursor = "" + } + require.Equal(t, all, seen) + require.False(t, page.More) +} + +type failingRowCursor struct { + ID string `json:"id"` +} + +func (c *failingRowCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder { + return q.Where(sq.Lt{"id": c.ID}) +} + +var errBoom = errors.New("boom") + +func (c *failingRowCursor) From(row) error { + return errBoom +} + +func TestCursorPaginatorPrepareResultPropagatesCursorError(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, failingRowCursor, *failingRowCursor](pgkit.WithDefaultSize(1)) + page := &pgkit.Page{} + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + _, err = paginator.PrepareResult([]row{{ID: "2"}, {ID: "1"}}, page) + require.Error(t, err) + require.True(t, errors.Is(err, errBoom)) +} + +// In-memory stand-in so the pagination walk exercises encode/decode without a real database. +func fetch(t *testing.T, all []row, q sq.SelectBuilder) []row { + t.Helper() + sql, args, err := q.ToSql() + require.NoError(t, err) + + limit, err := strconv.Atoi(sql[strings.LastIndex(sql, " ")+1:]) + require.NoError(t, err) + + cutoff := "" + if len(args) == 1 { + cutoff = args[0].(string) + } + + out := make([]row, 0, limit) + for _, r := range all { + if cutoff != "" && r.ID >= cutoff { + continue + } + out = append(out, r) + if len(out) == limit { + break + } + } + return out +} diff --git a/page.go b/page.go index 35621e7..6f2de60 100644 --- a/page.go +++ b/page.go @@ -81,6 +81,10 @@ type Page struct { More bool Column string Sort []Sort + + // Unused by the offset Paginator — shared here so callers can swap paginators without changing the Page type. + Cursor string + NextCursor string } func NewPage(size, page uint32, sort ...Sort) *Page {