Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgkit
import (
"fmt"
"reflect"
"slices"

sq "github.com/Masterminds/squirrel"
)
Expand All @@ -19,10 +20,19 @@ func (s *StatementBuilder) InsertRecord(record interface{}, optTableName ...stri
if err != nil {
return InsertBuilder{InsertBuilder: insert, err: wrapErr(err)}
}
if len(cols) == 0 {
return InsertBuilder{InsertBuilder: insert, err: wrapErr(fmt.Errorf("Map returned no columns for %T; for an all-default INSERT use sq.Expr(\"INSERT INTO %s DEFAULT VALUES\")", record, tableName))}
}

return InsertBuilder{InsertBuilder: insert.Into(tableName).Columns(cols...).Values(vals...)}
}

// InsertRecords builds a multi-row INSERT from a slice of records.
//
// Every record must produce the same non-empty Map column set; a drifted
// shape (e.g. mixed nil and non-nil empty slices under ,omitzero) or an
// all-default record returns a build-time error rather than emitting
// malformed multi-row SQL.
func (s StatementBuilder) InsertRecords(recordsSlice interface{}, optTableName ...string) InsertBuilder {
insert := sq.InsertBuilder(s.StatementBuilderType)

Expand All @@ -39,6 +49,7 @@ func (s StatementBuilder) InsertRecords(recordsSlice interface{}, optTableName .
tableName = optTableName[0]
}

var baseCols []string
for i := 0; i < v.Len(); i++ {
record := v.Index(i).Interface()

Expand All @@ -52,10 +63,20 @@ func (s StatementBuilder) InsertRecords(recordsSlice interface{}, optTableName .
if err != nil {
return InsertBuilder{InsertBuilder: insert, err: wrapErr(err)}
}
if len(cols) == 0 {
return InsertBuilder{InsertBuilder: insert, err: wrapErr(fmt.Errorf("Map returned no columns for record %d (%T); for an all-default INSERT use sq.Expr", i, record))}
}

if i == 0 {
baseCols = cols
insert = insert.Columns(cols...).Values(vals...)
} else {
if !slices.Equal(cols, baseCols) {
return InsertBuilder{
InsertBuilder: insert,
err: wrapErr(fmt.Errorf("record %d columns %v differ from record 0 columns %v", i, cols, baseCols)),
}
}
insert = insert.Values(vals...)
}
}
Expand Down
90 changes: 90 additions & 0 deletions builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package pgkit_test

import (
"testing"

sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/goware/pgkit/v2"
)

func TestInsertRecords_ColumnDriftRejected(t *testing.T) {
// ,omitzero produces different column shapes for nil vs non-nil empty
// slices; squirrel would otherwise stitch the mismatched widths into
// malformed multi-row SQL and surface only at exec time.
type Item struct {
ID int `db:"id"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{
{ID: 1, Tags: nil},
{ID: 2, Tags: []string{}},
}
b := sb.InsertRecords(records, "items")
require.Error(t, b.Err())
assert.Contains(t, b.Err().Error(), "differ from record 0")
}

func TestInsertRecords_UniformShape(t *testing.T) {
// Sanity: batches with consistent column shape across rows still build.
type Item struct {
ID int `db:"id"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{
{ID: 1, Tags: []string{"a"}},
{ID: 2, Tags: []string{"b"}},
}
b := sb.InsertRecords(records, "items")
require.NoError(t, b.Err())
}

func TestInsertRecord_EmptyColumnsRejected(t *testing.T) {
// All fields tagged ,omitzero (or ,omitempty) and all zero leaves
// Map with no columns. Squirrel would emit invalid INSERT INTO t
// VALUES (); fail fast at build time and point at sq.Expr as the
// escape for the all-default INSERT case. Tracked in goware/pgkit#51.
type Item struct {
Tags []string `db:"tags,omitzero"`
}
sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
b := sb.InsertRecord(&Item{}, "items")
require.Error(t, b.Err())
assert.Contains(t, b.Err().Error(), "no columns")
assert.Contains(t, b.Err().Error(), "sq.Expr")
}

func TestInsertRecords_EmptyColumnsRejected(t *testing.T) {
type Item struct {
Tags []string `db:"tags,omitzero"`
}
sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{{}, {}}
b := sb.InsertRecords(records, "items")
require.Error(t, b.Err())
assert.Contains(t, b.Err().Error(), "no columns")
}

func TestInsertRecords_OmitEmptyMapDriftRejected(t *testing.T) {
// Latent footgun ,omitzero exposes: legacy ,omitempty on a map already
// produced shape drift (nil map skipped, non-nil empty map kept via the
// DeepEqual fallback). The validation catches this case for free.
type Item struct {
ID int `db:"id"`
Tags map[string]string `db:"tags,omitempty"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{
{ID: 1, Tags: nil},
{ID: 2, Tags: map[string]string{}},
}
b := sb.InsertRecords(records, "items")
require.Error(t, b.Err())
}
82 changes: 56 additions & 26 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ type MapOptions struct {
IncludeNil bool
}

// Map converts a struct object (aka record) to a mapping of column names and values
// which can be directly passed to a query executor. This allows you to use structs/objects
// to build easy insert/update queries without having to specify the column names manually.
// The mapper works by reading the column names from a struct fields `db:""` struct tag.
// If you specify `,omitempty` as a tag option, then it will omit the column from the list,
// which allows the database to take over and use its default value.
// Map converts a struct to (column, value) slices using `db:""` struct tags.
//
// ,omitempty and ,omitzero (mutually exclusive) both skip zero values, but
// ,omitzero keeps non-nil empty slices/maps so a clear-to-empty UPDATE
// actually clears the column. Matches encoding/json's omitzero (Go 1.24+).
// IncludeNil surfaces nil pointers as DEFAULT under ,omitempty and as
// NULL under ,omitzero.
func Map(record interface{}) ([]string, []interface{}, error) {
return MapWithOptions(record, nil)
}
Expand Down Expand Up @@ -86,38 +87,34 @@ func MapWithOptions(record interface{}, options *MapOptions) ([]string, []interf

// Field options
_, tagOmitEmpty := fi.Options["omitempty"]
_, tagOmitZero := fi.Options["omitzero"]
if tagOmitEmpty && tagOmitZero {
return nil, nil, fmt.Errorf("field %q has both ,omitempty and ,omitzero tags (mutually exclusive)", fi.Name)
}

fld := reflectx.FieldByIndexesReadOnly(recordV, fi.Index)

if fld.Kind() == reflect.Ptr && fld.IsNil() {
if tagOmitEmpty && !options.IncludeNil {
if (tagOmitEmpty || tagOmitZero) && !options.IncludeNil {
continue
}
fv.fields = append(fv.fields, fi.Name)
// ,omitempty preserves legacy: forced-include emits DEFAULT
// so callers can fall back to the column's DB default. ,omitzero
// is the strict tag: forced-include emits literal NULL so a
// PATCH can clear a nullable column with a non-null default.
var v any
if tagOmitEmpty {
fv.values = append(fv.values, sqlDefault)
} else {
fv.values = append(fv.values, nil)
v = sqlDefault
}
fv.values = append(fv.values, v)
continue
}

value := fld.Interface()

isZero := false
if t, ok := fld.Interface().(hasIsZero); ok {
if t.IsZero() {
isZero = true
}
} else if fld.Kind() == reflect.Array || fld.Kind() == reflect.Slice {
if fld.Len() == 0 {
isZero = true
}
} else if reflect.DeepEqual(fi.Zero.Interface(), value) {
isZero = true
}

if isZero && tagOmitEmpty && !options.IncludeZeroed {
isEmpty, isStrictZero := zeroFlags(fld, fi.Zero.Interface())
skip := (isEmpty && tagOmitEmpty) || (isStrictZero && tagOmitZero)
if skip && !options.IncludeZeroed {
continue
}

Expand All @@ -127,7 +124,7 @@ func MapWithOptions(record interface{}, options *MapOptions) ([]string, []interf
// return nil, nil, err
// }
v := value
if isZero && tagOmitEmpty {
if skip {
v = sqlDefault
}
fv.values = append(fv.values, v)
Expand Down Expand Up @@ -185,6 +182,39 @@ func (fv *fieldValue) Less(i, j int) bool {
return fv.fields[i] < fv.fields[j]
}

// Two return values because omitempty and omitzero disagree only on
// non-nil empty slices; every other path returns both flags the same.
func zeroFlags(fld reflect.Value, fieldZero any) (isEmpty, isStrictZero bool) {
if t, ok := fld.Interface().(hasIsZero); ok {
if t.IsZero() {
return true, true
}
return false, false
}
switch fld.Kind() {
case reflect.Slice:
if fld.IsNil() {
return true, true
}
if fld.Len() == 0 {
return true, false
}
case reflect.Map:
if fld.IsNil() {
return true, true
}
case reflect.Array:
// omitempty must keep all-zero arrays of normal length; switching
// to IsZero here would silently drop [16]byte UUIDs, [32]byte hashes.
return fld.Len() == 0, fld.IsZero()
default:
if reflect.DeepEqual(fieldZero, fld.Interface()) {
return true, true
}
}
return false, false
}

type hasIsZero interface {
IsZero() bool
}
Expand Down
Loading
Loading