From 58074615bbb9a4cb491fb7d833237eb41a80156e Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:27 +0100 Subject: [PATCH 1/6] Add stream status trace messages for Airbyte protocol v2 Airbyte 2.x requires sources to emit STREAM_STATUS trace messages (STARTED, COMPLETE, INCOMPLETE) for each stream. Without these, every sync fails with: "streams did not receive a terminal stream status message" Changes: - Add TRACE message type and stream status constants to types.go - Add StreamDescriptor, AirbyteStreamStatus, AirbyteTraceMessage types - Replace legacy global State() with per-stream StreamState() that emits state.type=STREAM (required by Airbyte 2.x, which rejects the LEGACY format with IllegalArgumentException) - Add StreamStatus() method to emit STARTED/COMPLETE/INCOMPLETE traces - Update AirbyteLogger interface and test mock accordingly --- cmd/internal/logger.go | 38 +++++++++++++++++++++++++--- cmd/internal/mock_types.go | 9 ++++--- cmd/internal/types.go | 52 +++++++++++++++++++++++++++++++++----- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/cmd/internal/logger.go b/cmd/internal/logger.go index 0ada8c0..15aa50a 100644 --- a/cmd/internal/logger.go +++ b/cmd/internal/logger.go @@ -14,8 +14,9 @@ type AirbyteLogger interface { ConnectionStatus(status ConnectionStatus) Record(tableNamespace, tableName string, data map[string]interface{}) Flush() - State(syncState SyncState) + StreamState(namespace, streamName string, shardStates ShardStates) Error(error string) + StreamStatus(namespace, streamName, status string) } const MaxBatchSize = 10000 @@ -82,10 +83,19 @@ func (a *airbyteLogger) Flush() { a.records = a.records[:0] } -func (a *airbyteLogger) State(syncState SyncState) { +func (a *airbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { if err := a.recordEncoder.Encode(AirbyteMessage{ - Type: STATE, - State: &AirbyteState{syncState}, + Type: STATE, + State: &AirbyteState{ + Type: STATE_TYPE_STREAM, + Stream: &AirbyteStreamState{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + StreamState: &shardStates, + }, + }, }); err != nil { a.Error(fmt.Sprintf("state encoding error: %v", err)) } @@ -103,6 +113,26 @@ func (a *airbyteLogger) Error(error string) { } } +func (a *airbyteLogger) StreamStatus(namespace, streamName, status string) { + now := time.Now() + if err := a.recordEncoder.Encode(AirbyteMessage{ + Type: TRACE, + Trace: &AirbyteTraceMessage{ + Type: TRACE_TYPE_STREAM_STATUS, + EmittedAt: float64(now.UnixMilli()), + StreamStatus: &AirbyteStreamStatus{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + Status: status, + }, + }, + }); err != nil { + a.Error(fmt.Sprintf("stream status encoding error: %v", err)) + } +} + func (a *airbyteLogger) ConnectionStatus(status ConnectionStatus) { if err := a.recordEncoder.Encode(AirbyteMessage{ Type: CONNECTION_STATUS, diff --git a/cmd/internal/mock_types.go b/cmd/internal/mock_types.go index 742b822..2e7dc04 100644 --- a/cmd/internal/mock_types.go +++ b/cmd/internal/mock_types.go @@ -50,9 +50,8 @@ func (tal *testAirbyteLogger) Record(tableNamespace, tableName string, data map[ func (testAirbyteLogger) Flush() { } -func (testAirbyteLogger) State(syncState SyncState) { - // TODO implement me - panic("implement me") +func (testAirbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { + // no-op for tests } func (testAirbyteLogger) Error(error string) { @@ -60,6 +59,10 @@ func (testAirbyteLogger) Error(error string) { panic("implement me") } +func (testAirbyteLogger) StreamStatus(namespace, streamName, status string) { + // no-op for tests +} + type vstreamClientMock struct { vstreamFn func(ctx context.Context, in *vtgate.VStreamRequest, opts ...grpc.CallOption) (vtgateservice.Vitess_VStreamClient, error) vstreamFnInvoked bool diff --git a/cmd/internal/types.go b/cmd/internal/types.go index 19b5f64..a24a099 100644 --- a/cmd/internal/types.go +++ b/cmd/internal/types.go @@ -21,6 +21,17 @@ const ( LOG = "LOG" CONNECTION_STATUS = "CONNECTION_STATUS" CATALOG = "CATALOG" + TRACE = "TRACE" +) + +const ( + TRACE_TYPE_STREAM_STATUS = "STREAM_STATUS" +) + +const ( + STREAM_STATUS_STARTED = "STARTED" + STREAM_STATUS_COMPLETE = "COMPLETE" + STREAM_STATUS_INCOMPLETE = "INCOMPLETE" ) const ( @@ -385,17 +396,44 @@ func mapEnumValue(value sqltypes.Value, values []string) sqltypes.Value { return value } +const ( + STATE_TYPE_STREAM = "STREAM" +) + +type AirbyteStreamState struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + StreamState *ShardStates `json:"stream_state"` +} + type AirbyteState struct { - Data SyncState `json:"data"` + Type string `json:"type"` + Stream *AirbyteStreamState `json:"stream,omitempty"` +} + +type StreamDescriptor struct { + Name string `json:"name"` + Namespace string `json:"namespace"` +} + +type AirbyteStreamStatus struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + Status string `json:"status"` +} + +type AirbyteTraceMessage struct { + Type string `json:"type"` + EmittedAt float64 `json:"emitted_at"` + StreamStatus *AirbyteStreamStatus `json:"stream_status,omitempty"` } type AirbyteMessage struct { - Type string `json:"type"` - Log *AirbyteLogMessage `json:"log,omitempty"` - ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` - Catalog *Catalog `json:"catalog,omitempty"` - Record *AirbyteRecord `json:"record,omitempty"` - State *AirbyteState `json:"state,omitempty"` + Type string `json:"type"` + Log *AirbyteLogMessage `json:"log,omitempty"` + ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` + Catalog *Catalog `json:"catalog,omitempty"` + Record *AirbyteRecord `json:"record,omitempty"` + State *AirbyteState `json:"state,omitempty"` + Trace *AirbyteTraceMessage `json:"trace,omitempty"` } // A map of starting GTIDs for every keyspace and shard From 8dd6975c8e5dce739fa8aafe40c4fc5a79e9e66f Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:39 +0100 Subject: [PATCH 2/6] Emit per-stream status and state in read loop, handle v2 state input Update the read command to be fully compatible with Airbyte 2.x: Read loop changes: - Emit STARTED before reading each stream - Emit COMPLETE after successful read, INCOMPLETE on error - Replace os.Exit(1) with break on per-stream errors so remaining streams still get status messages - Emit per-stream STATE (type=STREAM) after each stream completes instead of one global state blob at the end State parsing changes: - Handle Airbyte v2 per-stream state format on incremental syncs. Airbyte 2.x passes state back as a JSON array of per-stream state objects, not the legacy global SyncState blob. Without this, the second sync always fails because json.Unmarshal fails on the array format, causing os.Exit(1) before any streams are processed. - Fall back to legacy format for backwards compatibility - Default empty namespace to source database name to prevent state key mismatches --- cmd/airbyte-source/read.go | 43 +++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index 09b8031..c56e04b 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -109,9 +109,13 @@ func ReadCommand(ch *Helper) *cobra.Command { streamState, ok := syncState.Streams[streamStateKey] if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) os.Exit(1) } + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_STARTED) + + streamFailed := false for shardName, shardState := range streamState.Shards { var tc *psdbconnectv1alpha1.TableCursor @@ -119,21 +123,27 @@ func ReadCommand(ch *Helper) *cobra.Command { ch.Logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Using serialized cursor for stream %s", streamStateKey)) if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) if err != nil { ch.Logger.Error(err.Error()) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } if sc != nil { - // if we get any new state, we assign it here. - // otherwise, the older state is round-tripped back to Airbyte. syncState.Streams[streamStateKey].Shards[shardName] = sc } - ch.Logger.State(syncState) + } + + if !streamFailed { + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } } }, @@ -153,9 +163,26 @@ func readState(state string, psc internal.PlanetScaleSource, streams []internal. Streams: map[string]internal.ShardStates{}, } if state != "" { - err := json.Unmarshal([]byte(state), &syncState) - if err != nil { - return syncState, err + // Try parsing as Airbyte v2 per-stream state array first + var perStreamStates []internal.AirbyteState + if err := json.Unmarshal([]byte(state), &perStreamStates); err == nil && len(perStreamStates) > 0 && perStreamStates[0].Type == internal.STATE_TYPE_STREAM { + logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Parsing Airbyte v2 per-stream state (%d streams)", len(perStreamStates))) + for _, s := range perStreamStates { + if s.Stream != nil && s.Stream.StreamState != nil { + ns := s.Stream.StreamDescriptor.Namespace + if ns == "" { + ns = psc.Database + } + key := ns + ":" + s.Stream.StreamDescriptor.Name + syncState.Streams[key] = *s.Stream.StreamState + } + } + } else { + // Fall back to legacy global state format + err := json.Unmarshal([]byte(state), &syncState) + if err != nil { + return syncState, err + } } } From 0f4df09b20d5fe072810dbcb5286726d40d2123d Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:47 +0100 Subject: [PATCH 3/6] Add tests for Airbyte protocol v2 compliance Logger tests: - StreamState emits correct per-stream format with type=STREAM - Multiple shards included in state output - No legacy "data" field present (would cause LEGACY rejection) - StreamStatus emits TRACE messages with correct status values - JSON round-trip matches exact Airbyte protocol v2 structure Read protocol tests: - Read emits per-stream STATE, not legacy global state - STARTED and COMPLETE emitted for each configured stream - Correct message ordering: STARTED -> STATE -> COMPLETE - Multi-shard state contains all shard cursors - Read errors emit INCOMPLETE and skip state emission --- cmd/airbyte-source/read_protocol_test.go | 373 +++++++++++++++++++++++ cmd/internal/logger_test.go | 194 ++++++++++++ 2 files changed, 567 insertions(+) create mode 100644 cmd/airbyte-source/read_protocol_test.go create mode 100644 cmd/internal/logger_test.go diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go new file mode 100644 index 0000000..3f9efce --- /dev/null +++ b/cmd/airbyte-source/read_protocol_test.go @@ -0,0 +1,373 @@ +package airbyte_source + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "testing" + + "github.com/planetscale/airbyte-source/cmd/internal" + psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDatabase implements internal.PlanetScaleDatabase for read protocol tests. +type mockDatabase struct { + shards []string + readFunc func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) + readCalls int +} + +func (m *mockDatabase) CanConnect(ctx context.Context, ps internal.PlanetScaleSource) error { + return nil +} + +func (m *mockDatabase) DiscoverSchema(ctx context.Context, ps internal.PlanetScaleSource) (internal.Catalog, error) { + return internal.Catalog{}, nil +} + +func (m *mockDatabase) ListShards(ctx context.Context, ps internal.PlanetScaleSource) ([]string, error) { + return m.shards, nil +} + +func (m *mockDatabase) Read(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + m.readCalls++ + if m.readFunc != nil { + return m.readFunc(ctx, w, ps, s, tc) + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/updated-position", + }) + return newCursor, nil +} + +func (m *mockDatabase) Close() error { + return nil +} + +func newTestConfig() []byte { + return []byte(`{"host":"test.psdb.cloud","database":"testdb","username":"user","password":"pass"}`) +} + +func newTestCatalog(t *testing.T, streams ...string) string { + t.Helper() + catalog := internal.ConfiguredCatalog{} + for _, name := range streams { + catalog.Streams = append(catalog.Streams, internal.ConfiguredStream{ + Stream: internal.Stream{ + Name: name, + Namespace: "testdb", + }, + SyncMode: "full_refresh", + }) + } + b, err := json.Marshal(catalog) + require.NoError(t, err) + return string(b) +} + +func writeTempFile(t *testing.T, content []byte) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "*.json") + require.NoError(t, err) + _, err = f.Write(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + return f.Name() +} + +func parseOutputMessages(t *testing.T, buf *bytes.Buffer) []internal.AirbyteMessage { + t.Helper() + var messages []internal.AirbyteMessage + decoder := json.NewDecoder(buf) + for decoder.More() { + var msg internal.AirbyteMessage + if err := decoder.Decode(&msg); err != nil { + break + } + messages = append(messages, msg) + } + return messages +} + +func setupReadCommand(t *testing.T, db *mockDatabase, catalogJSON string) (*bytes.Buffer, *Helper) { + t.Helper() + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + return b, h +} + +func TestRead_EmitsPerStreamStateNotLegacy(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "users") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMessages []internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMessages = append(stateMessages, msg) + } + } + + require.NotEmpty(t, stateMessages, "should emit at least one STATE message") + + for _, msg := range stateMessages { + assert.Equal(t, internal.STATE_TYPE_STREAM, msg.State.Type, + "state.type must be STREAM, not LEGACY") + require.NotNil(t, msg.State.Stream, + "state.stream must be present") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Name, + "stream_descriptor.name must be set") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Namespace, + "stream_descriptor.namespace must be set") + require.NotNil(t, msg.State.Stream.StreamState, + "stream_state must be present") + } +} + +func TestRead_EmitsStartedAndCompletePerStream(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "orders", "products") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + type streamStatusEntry struct { + Name string + Status string + } + var statuses []streamStatusEntry + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.Type == internal.TRACE_TYPE_STREAM_STATUS && + msg.Trace.StreamStatus != nil { + statuses = append(statuses, streamStatusEntry{ + Name: msg.Trace.StreamStatus.StreamDescriptor.Name, + Status: msg.Trace.StreamStatus.Status, + }) + } + } + + expectedStatuses := []streamStatusEntry{ + {"orders", internal.STREAM_STATUS_STARTED}, + {"orders", internal.STREAM_STATUS_COMPLETE}, + {"products", internal.STREAM_STATUS_STARTED}, + {"products", internal.STREAM_STATUS_COMPLETE}, + } + assert.Equal(t, expectedStatuses, statuses) +} + +func TestRead_StatePerStreamContainsCorrectDescriptor(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "accounts", "sessions") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + statesByStream := map[string]internal.AirbyteMessage{} + for _, msg := range messages { + if msg.Type == internal.STATE { + name := msg.State.Stream.StreamDescriptor.Name + statesByStream[name] = msg + } + } + + assert.Contains(t, statesByStream, "accounts") + assert.Contains(t, statesByStream, "sessions") + assert.Equal(t, "testdb", statesByStream["accounts"].State.Stream.StreamDescriptor.Namespace) + assert.Equal(t, "testdb", statesByStream["sessions"].State.Stream.StreamDescriptor.Namespace) +} + +func TestRead_StateEmittedAfterStartedBeforeComplete(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + startedIdx := -1 + stateIdx := -1 + completeIdx := -1 + + for i, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil && + msg.Trace.StreamStatus.StreamDescriptor.Name == "events" { + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_STARTED { + startedIdx = i + } + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_COMPLETE { + completeIdx = i + } + } + if msg.Type == internal.STATE && msg.State != nil && + msg.State.Stream != nil && + msg.State.Stream.StreamDescriptor.Name == "events" { + stateIdx = i + } + } + + require.Greater(t, startedIdx, -1, "STARTED should be emitted") + require.Greater(t, stateIdx, -1, "STATE should be emitted") + require.Greater(t, completeIdx, -1, "COMPLETE should be emitted") + + assert.Less(t, startedIdx, stateIdx, "STARTED should come before STATE") + assert.Less(t, stateIdx, completeIdx, "STATE should come before COMPLETE") +} + +func TestRead_MultiShardStateContainsAllShards(t *testing.T) { + db := &mockDatabase{shards: []string{"-80", "80-"}} + catalogJSON := newTestCatalog(t, "data") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMsg = &msg + } + } + + require.NotNil(t, stateMsg, "should have a STATE message") + require.NotNil(t, stateMsg.State.Stream.StreamState) + assert.Len(t, stateMsg.State.Stream.StreamState.Shards, 2, + "state should contain both shards") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "-80") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "80-") +} + +func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + if s.Stream.Name == "bad_table" { + return nil, assert.AnError + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/pos", + }) + return newCursor, nil + }, + } + + catalog := internal.ConfiguredCatalog{ + Streams: []internal.ConfiguredStream{ + { + Stream: internal.Stream{Name: "good_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + { + Stream: internal.Stream{Name: "bad_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + }, + } + catalogBytes, _ := json.Marshal(catalog) + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, catalogBytes) + + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + streamStatuses := map[string][]string{} + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil { + name := msg.Trace.StreamStatus.StreamDescriptor.Name + streamStatuses[name] = append(streamStatuses[name], msg.Trace.StreamStatus.Status) + } + } + + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_COMPLETE}, + streamStatuses["good_table"]) + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_INCOMPLETE}, + streamStatuses["bad_table"]) + + // good_table should have a STATE message, bad_table should NOT + hasGoodState := false + hasBadState := false + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil && msg.State.Stream != nil { + if msg.State.Stream.StreamDescriptor.Name == "good_table" { + hasGoodState = true + } + if msg.State.Stream.StreamDescriptor.Name == "bad_table" { + hasBadState = true + } + } + } + assert.True(t, hasGoodState, "good_table should have a STATE message") + assert.False(t, hasBadState, "bad_table should NOT have a STATE message (it failed)") +} diff --git a/cmd/internal/logger_test.go b/cmd/internal/logger_test.go new file mode 100644 index 0000000..751655e --- /dev/null +++ b/cmd/internal/logger_test.go @@ -0,0 +1,194 @@ +package internal + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamState_EmitsPerStreamFormat(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc123"}, + }, + } + + logger.StreamState("my-database", "users", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE, msg.Type) + require.NotNil(t, msg.State) + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + require.NotNil(t, msg.State.Stream) + assert.Equal(t, "users", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "my-database", msg.State.Stream.StreamDescriptor.Namespace) + require.NotNil(t, msg.State.Stream.StreamState) + assert.Equal(t, "abc123", msg.State.Stream.StreamState.Shards["-"].Cursor) +} + +func TestStreamState_MultipleShards(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-80": {Cursor: "cursor1"}, + "80-": {Cursor: "cursor2"}, + }, + } + + logger.StreamState("sharded-db", "orders", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + assert.Equal(t, "orders", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "sharded-db", msg.State.Stream.StreamDescriptor.Namespace) + assert.Len(t, msg.State.Stream.StreamState.Shards, 2) + assert.Equal(t, "cursor1", msg.State.Stream.StreamState.Shards["-80"].Cursor) + assert.Equal(t, "cursor2", msg.State.Stream.StreamState.Shards["80-"].Cursor) +} + +func TestStreamState_NoLegacyDataField(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc"}, + }, + } + + logger.StreamState("db", "table1", shardStates) + + // Parse as raw JSON to verify no "data" key exists (which would indicate LEGACY format) + var raw map[string]json.RawMessage + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + var stateRaw map[string]json.RawMessage + err = json.Unmarshal(raw["state"], &stateRaw) + require.NoError(t, err) + + _, hasData := stateRaw["data"] + assert.False(t, hasData, "state should not contain 'data' field (LEGACY format)") + + _, hasType := stateRaw["type"] + assert.True(t, hasType, "state must contain 'type' field") + + _, hasStream := stateRaw["stream"] + assert.True(t, hasStream, "state must contain 'stream' field") +} + +func TestStreamStatus_EmitsTraceMessage(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("my-db", "accounts", STREAM_STATUS_STARTED) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, TRACE, msg.Type) + require.NotNil(t, msg.Trace) + assert.Equal(t, TRACE_TYPE_STREAM_STATUS, msg.Trace.Type) + assert.True(t, msg.Trace.EmittedAt > 0) + require.NotNil(t, msg.Trace.StreamStatus) + assert.Equal(t, "accounts", msg.Trace.StreamStatus.StreamDescriptor.Name) + assert.Equal(t, "my-db", msg.Trace.StreamStatus.StreamDescriptor.Namespace) + assert.Equal(t, STREAM_STATUS_STARTED, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Complete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_COMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_COMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Incomplete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_INCOMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_INCOMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamState_JSONRoundTrip(t *testing.T) { + // Verify the JSON output can be parsed back into the exact expected Airbyte protocol structure + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("anam-lab", "persona", ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "encoded-cursor-data"}, + }, + }) + + // Parse into a generic structure to verify exact JSON shape + var raw map[string]interface{} + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + assert.Equal(t, "STATE", raw["type"]) + + state := raw["state"].(map[string]interface{}) + assert.Equal(t, "STREAM", state["type"]) + + stream := state["stream"].(map[string]interface{}) + descriptor := stream["stream_descriptor"].(map[string]interface{}) + assert.Equal(t, "persona", descriptor["name"]) + assert.Equal(t, "anam-lab", descriptor["namespace"]) + + streamState := stream["stream_state"].(map[string]interface{}) + shards := streamState["shards"].(map[string]interface{}) + shard := shards["-"].(map[string]interface{}) + assert.Equal(t, "encoded-cursor-data", shard["cursor"]) +} + +func TestMultipleStreamStates_EachIndependent(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("db", "table1", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c1"}}, + }) + logger.StreamState("db", "table2", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c2"}}, + }) + + decoder := json.NewDecoder(b) + + var msg1 AirbyteMessage + require.NoError(t, decoder.Decode(&msg1)) + assert.Equal(t, "table1", msg1.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c1", msg1.State.Stream.StreamState.Shards["-"].Cursor) + + var msg2 AirbyteMessage + require.NoError(t, decoder.Decode(&msg2)) + assert.Equal(t, "table2", msg2.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c2", msg2.State.Stream.StreamState.Shards["-"].Cursor) +} From 41314533218a5bfcdc99c75df22a7efd309f36b9 Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Wed, 15 Apr 2026 12:06:53 +0100 Subject: [PATCH 4/6] Fix lost shard progress and silent success on stream errors Address review feedback: 1. Always emit StreamState after the shard loop, even on failure. Previously, state was only emitted when all shards succeeded. If shard A advanced and shard B failed, shard A's cursor was lost and the next retry would re-read already-synced data. 2. Return an error from the read command when any stream fails. The os.Exit(1) calls were replaced with break to allow other streams to emit proper status messages, but the command was silently exiting successfully. Now uses RunE so cobra surfaces the error and exits non-zero. Also converts remaining os.Exit(1) calls to return errors for consistency and testability, and adds a test for multi-shard partial failure checkpointing. --- cmd/airbyte-source/read.go | 44 +++++++++------ cmd/airbyte-source/read_protocol_test.go | 70 +++++++++++++++++++++++- 2 files changed, 93 insertions(+), 21 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index c56e04b..d9ec84e 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -22,32 +22,33 @@ func init() { func ReadCommand(ch *Helper) *cobra.Command { readCmd := &cobra.Command{ - Use: "read", - Short: "Converts rows from a PlanetScale database into AirbyteRecordMessages", - Run: func(cmd *cobra.Command, args []string) { + Use: "read", + Short: "Converts rows from a PlanetScale database into AirbyteRecordMessages", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() ch.Logger = internal.NewLogger(cmd.OutOrStdout()) if readSourceConfigFilePath == "" { fmt.Fprintf(cmd.ErrOrStderr(), "Please pass path to a valid source config file via the [%v] argument", "config") - os.Exit(1) + return fmt.Errorf("missing config file path") } if readSourceCatalogPath == "" { fmt.Fprintf(cmd.OutOrStdout(), "Please pass path to a valid source catalog file via the [%v] argument", "config") - os.Exit(1) + return fmt.Errorf("missing catalog file path") } psc, err := parseSource(ch.FileReader, readSourceConfigFilePath) if err != nil { fmt.Fprintln(cmd.OutOrStdout(), "Please provide path to a valid configuration file") - return + return err } ch.Logger.Log(internal.LOGLEVEL_INFO, "Ensure database") if err := ch.EnsureDB(psc); err != nil { fmt.Fprintln(cmd.OutOrStdout(), "Unable to connect to PlanetScale Database") - return + return err } defer func() { @@ -60,19 +61,19 @@ func ReadCommand(ch *Helper) *cobra.Command { cs, err := checkConnectionStatus(ctx, ch.Database, psc) if err != nil { ch.Logger.ConnectionStatus(cs) - return + return err } ch.Logger.Log(internal.LOGLEVEL_INFO, "Reading catalog") catalog, err := readCatalog(readSourceCatalogPath) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to read catalog: %+v", err)) - os.Exit(1) + return fmt.Errorf("unable to read catalog: %w", err) } if len(catalog.Streams) == 0 { ch.Logger.Log(internal.LOGLEVEL_ERROR, "Catalog has no streams") - return + return nil } state := "" @@ -81,7 +82,7 @@ func ReadCommand(ch *Helper) *cobra.Command { b, err := os.ReadFile(stateFilePath) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to read state : %v", err)) - os.Exit(1) + return fmt.Errorf("unable to read state: %w", err) } state = string(b) } @@ -90,16 +91,17 @@ func ReadCommand(ch *Helper) *cobra.Command { shards, err := ch.Database.ListShards(ctx, psc) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to list shards : %v", err)) - os.Exit(1) + return fmt.Errorf("unable to list shards: %w", err) } ch.Logger.Log(internal.LOGLEVEL_INFO, "Reading state") syncState, err := readState(state, psc, catalog.Streams, shards, ch.Logger) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to read state : %v", err)) - os.Exit(1) + return fmt.Errorf("unable to read state: %w", err) } + var readErr error for _, configuredStream := range catalog.Streams { keyspaceOrDatabase := configuredStream.Stream.Namespace if keyspaceOrDatabase == "" { @@ -110,7 +112,7 @@ func ReadCommand(ch *Helper) *cobra.Command { if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) - os.Exit(1) + return fmt.Errorf("unable to read state for stream %v", streamStateKey) } ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_STARTED) @@ -123,7 +125,6 @@ func ReadCommand(ch *Helper) *cobra.Command { ch.Logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Using serialized cursor for stream %s", streamStateKey)) if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) - ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) streamFailed = true break } @@ -131,7 +132,6 @@ func ReadCommand(ch *Helper) *cobra.Command { sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) if err != nil { ch.Logger.Error(err.Error()) - ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) streamFailed = true break } @@ -141,11 +141,19 @@ func ReadCommand(ch *Helper) *cobra.Command { } } - if !streamFailed { - ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + // Always emit state to checkpoint whatever progress was made, + // including partial progress when only some shards succeeded. + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + + if streamFailed { + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + readErr = fmt.Errorf("read failed for stream %v", streamStateKey) + } else { ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } } + + return readErr }, } readCmd.Flags().StringVar(&readSourceCatalogPath, "catalog", "", "Path to the PlanetScale catalog configuration") diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go index 3f9efce..a0b501d 100644 --- a/cmd/airbyte-source/read_protocol_test.go +++ b/cmd/airbyte-source/read_protocol_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "os" "testing" @@ -337,7 +338,9 @@ func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) - require.NoError(t, cmd.Execute()) + + err := cmd.Execute() + require.Error(t, err, "command should return an error when a stream fails") messages := parseOutputMessages(t, b) @@ -355,7 +358,8 @@ func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_INCOMPLETE}, streamStatuses["bad_table"]) - // good_table should have a STATE message, bad_table should NOT + // Both streams should have STATE messages: state is always emitted to + // checkpoint whatever progress was made, even on failure. hasGoodState := false hasBadState := false for _, msg := range messages { @@ -369,5 +373,65 @@ func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { } } assert.True(t, hasGoodState, "good_table should have a STATE message") - assert.False(t, hasBadState, "bad_table should NOT have a STATE message (it failed)") + assert.True(t, hasBadState, "bad_table should have a STATE message (checkpointing progress)") +} + +func TestRead_MultiShardPartialFailureCheckpointsProgress(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-80", "80-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + // Fail the "80-" shard to simulate a partial failure. + if tc.Shard == "80-" { + return nil, fmt.Errorf("shard read error") + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/advanced-position", + }) + return newCursor, nil + }, + } + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + + err := cmd.Execute() + require.Error(t, err, "command should fail when a shard errors") + + messages := parseOutputMessages(t, b) + + // A state message must be emitted even on partial failure so that + // progress from successful shards is checkpointed. + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil { + stateMsg = &msg + } + } + require.NotNil(t, stateMsg, "state should be emitted even on partial failure") + require.NotNil(t, stateMsg.State.Stream) + require.NotNil(t, stateMsg.State.Stream.StreamState) + assert.Len(t, stateMsg.State.Stream.StreamState.Shards, 2, + "state should contain both shards") + + // Stream should be marked INCOMPLETE, not COMPLETE. + var statuses []string + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil && + msg.Trace.StreamStatus.StreamDescriptor.Name == "events" { + statuses = append(statuses, msg.Trace.StreamStatus.Status) + } + } + assert.Contains(t, statuses, internal.STREAM_STATUS_STARTED) + assert.Contains(t, statuses, internal.STREAM_STATUS_INCOMPLETE) + assert.NotContains(t, statuses, internal.STREAM_STATUS_COMPLETE) } From cc30081a9c007627702245284c7408cc8bf3a1fe Mon Sep 17 00:00:00 2001 From: Matthias Crauwels Date: Thu, 18 Jun 2026 10:26:33 +0200 Subject: [PATCH 5/6] Process remaining shards on error and checkpoint partial progress Address the second round of review feedback on the read loop: - On a shard error, continue to the remaining shards instead of break. A single failing shard no longer prevents the other shards in the stream from syncing (the previous break left it up to map-iteration order whether they ran at all). - Persist the cursor returned by Database.Read before handling the error. Read returns a progress-so-far cursor alongside the error on a server timeout; checking the error first discarded it, causing the next attempt to re-read already-synced data. - Extract the namespace/state-key logic into streamStateKeyFor so the read loop and both branches of readState share one implementation and cannot drift. Adds tests proving each fix: TestRead_ShardErrorStillProcessesOtherShards and TestRead_ProgressCursorPersistedOnError. Co-Authored-By: Claude Opus 4.8 (1M context) --- cmd/airbyte-source/read.go | 47 ++++++++------ cmd/airbyte-source/read_protocol_test.go | 83 ++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 21 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index d9ec84e..ce21933 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -103,11 +103,7 @@ func ReadCommand(ch *Helper) *cobra.Command { var readErr error for _, configuredStream := range catalog.Streams { - keyspaceOrDatabase := configuredStream.Stream.Namespace - if keyspaceOrDatabase == "" { - keyspaceOrDatabase = psc.Database - } - streamStateKey := keyspaceOrDatabase + ":" + configuredStream.Stream.Name + keyspaceOrDatabase, streamStateKey := streamStateKeyFor(configuredStream.Stream.Namespace, configuredStream.Stream.Name, psc.Database) streamState, ok := syncState.Streams[streamStateKey] if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) @@ -126,18 +122,24 @@ func ReadCommand(ch *Helper) *cobra.Command { if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) streamFailed = true - break + // A bad cursor only affects this shard; keep going so the + // other shards in this stream can still sync. + continue } sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) + // Read can return a cursor reflecting the progress made so far + // alongside an error (e.g. on a server timeout), so persist it + // before handling the error to avoid re-reading already-synced + // data on the next attempt. + if sc != nil { + syncState.Streams[streamStateKey].Shards[shardName] = sc + } if err != nil { ch.Logger.Error(err.Error()) streamFailed = true - break - } - - if sc != nil { - syncState.Streams[streamStateKey].Shards[shardName] = sc + // One shard failing shouldn't stop the others from syncing. + continue } } @@ -166,6 +168,17 @@ type State struct { Shards map[string]map[string]interface{} `json:"shards"` } +// streamStateKeyFor resolves the effective namespace for a stream (defaulting +// to the source database when the catalog leaves it empty) and the composite +// key used to look that stream up in the sync state. Keeping this in one place +// avoids the namespace/key logic drifting between the read loop and readState. +func streamStateKeyFor(namespace, streamName, database string) (string, string) { + if namespace == "" { + namespace = database + } + return namespace, namespace + ":" + streamName +} + func readState(state string, psc internal.PlanetScaleSource, streams []internal.ConfiguredStream, shards []string, logger internal.AirbyteLogger) (internal.SyncState, error) { syncState := internal.SyncState{ Streams: map[string]internal.ShardStates{}, @@ -177,11 +190,7 @@ func readState(state string, psc internal.PlanetScaleSource, streams []internal. logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Parsing Airbyte v2 per-stream state (%d streams)", len(perStreamStates))) for _, s := range perStreamStates { if s.Stream != nil && s.Stream.StreamState != nil { - ns := s.Stream.StreamDescriptor.Namespace - if ns == "" { - ns = psc.Database - } - key := ns + ":" + s.Stream.StreamDescriptor.Name + _, key := streamStateKeyFor(s.Stream.StreamDescriptor.Namespace, s.Stream.StreamDescriptor.Name, psc.Database) syncState.Streams[key] = *s.Stream.StreamState } } @@ -195,11 +204,7 @@ func readState(state string, psc internal.PlanetScaleSource, streams []internal. } for _, s := range streams { - keyspaceOrDatabase := s.Stream.Namespace - if keyspaceOrDatabase == "" { - keyspaceOrDatabase = psc.Database - } - stateKey := keyspaceOrDatabase + ":" + s.Stream.Name + keyspaceOrDatabase, stateKey := streamStateKeyFor(s.Stream.Namespace, s.Stream.Name, psc.Database) logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Syncing stream %s with sync mode %s", s.Stream.Name, s.SyncMode)) ignoreCurrentCursor := !s.IncrementalSyncRequested() diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go index a0b501d..3ad9029 100644 --- a/cmd/airbyte-source/read_protocol_test.go +++ b/cmd/airbyte-source/read_protocol_test.go @@ -435,3 +435,86 @@ func TestRead_MultiShardPartialFailureCheckpointsProgress(t *testing.T) { assert.Contains(t, statuses, internal.STREAM_STATUS_INCOMPLETE) assert.NotContains(t, statuses, internal.STREAM_STATUS_COMPLETE) } + +// One failing shard must not stop the remaining shards in the same stream from +// being read (continue, not break). +func TestRead_ShardErrorStillProcessesOtherShards(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-40", "40-80", "80-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + if tc.Shard == "40-80" { + return nil, fmt.Errorf("shard read error") + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/advanced-position", + }) + return newCursor, nil + }, + } + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + + err := cmd.Execute() + require.Error(t, err, "command should fail when a shard errors") + + // Every shard should have been attempted regardless of which one failed + // and regardless of map iteration order. + assert.Equal(t, 3, db.readCalls, "all shards should be read even when one fails") +} + +// Read can hand back a cursor reflecting progress-so-far together with an error +// (e.g. a server timeout). That cursor must be checkpointed, not discarded. +func TestRead_ProgressCursorPersistedOnError(t *testing.T) { + advancedCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: "-", + Keyspace: "testdb", + Position: "MySQL56/progress-before-timeout", + }) + db := &mockDatabase{ + shards: []string{"-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + // Return progress so far alongside the error. + return advancedCursor, fmt.Errorf("timed out mid-sync") + }, + } + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + + err := cmd.Execute() + require.Error(t, err, "command should fail when the shard read errors") + + messages := parseOutputMessages(t, b) + + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil { + stateMsg = &msg + } + } + require.NotNil(t, stateMsg, "state should be emitted even when the read errors") + require.NotNil(t, stateMsg.State.Stream) + require.NotNil(t, stateMsg.State.Stream.StreamState) + persisted, ok := stateMsg.State.Stream.StreamState.Shards["-"] + require.True(t, ok, "state should contain the shard that errored") + require.NotNil(t, persisted) + assert.Equal(t, advancedCursor.Cursor, persisted.Cursor, + "the progress-so-far cursor returned alongside the error should be checkpointed") +} From 6e0a2f1619cd7003acdb103db8fa8e05f6d08453 Mon Sep 17 00:00:00 2001 From: Matthias Crauwels Date: Fri, 19 Jun 2026 09:00:09 +0200 Subject: [PATCH 6/6] Address review feedback on read loop and state parsing - Checkpoint per-shard instead of once after the shard loop, so a crash mid-stream doesn't discard the progress of shards that already completed (mhamza15). - Accumulate per-stream read errors with errors.Join instead of letting each failing stream overwrite the last, so no failure is lost (mhamza15). - Treat an empty Airbyte v2 state array ("[]") as valid v2 state instead of falling through to the legacy object parser, which errors on an array; the stream loop then initializes fresh cursors (maxenglander, mhamza15). Adds TestReadState_EmptyV2ArrayInitializesFreshCursors, verified to fail without the fix. - Drop the unused catalogJSON parameter from setupReadCommand (mhamza15). Co-Authored-By: Claude Opus 4.8 (1M context) --- cmd/airbyte-source/read.go | 19 ++++++++++------- cmd/airbyte-source/read_protocol_test.go | 18 ++++++++-------- cmd/airbyte-source/read_test.go | 26 ++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index ce21933..9135234 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -2,6 +2,7 @@ package airbyte_source import ( "encoding/json" + "errors" "fmt" "os" @@ -135,6 +136,10 @@ func ReadCommand(ch *Helper) *cobra.Command { if sc != nil { syncState.Streams[streamStateKey].Shards[shardName] = sc } + // Checkpoint after every shard so that if we crash mid-stream + // the progress of shards that already completed isn't lost and + // re-read on the next attempt. + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) if err != nil { ch.Logger.Error(err.Error()) streamFailed = true @@ -143,13 +148,9 @@ func ReadCommand(ch *Helper) *cobra.Command { } } - // Always emit state to checkpoint whatever progress was made, - // including partial progress when only some shards succeeded. - ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) - if streamFailed { ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) - readErr = fmt.Errorf("read failed for stream %v", streamStateKey) + readErr = errors.Join(readErr, fmt.Errorf("read failed for stream %v", streamStateKey)) } else { ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } @@ -184,9 +185,13 @@ func readState(state string, psc internal.PlanetScaleSource, streams []internal. Streams: map[string]internal.ShardStates{}, } if state != "" { - // Try parsing as Airbyte v2 per-stream state array first + // Try parsing as Airbyte v2 per-stream state array first. An empty + // array is valid v2 state (no checkpoints yet) and must be treated as + // v2 rather than falling through to the legacy object parser, which + // would fail to unmarshal it; the stream loop below then initializes + // fresh cursors. var perStreamStates []internal.AirbyteState - if err := json.Unmarshal([]byte(state), &perStreamStates); err == nil && len(perStreamStates) > 0 && perStreamStates[0].Type == internal.STATE_TYPE_STREAM { + if err := json.Unmarshal([]byte(state), &perStreamStates); err == nil && (len(perStreamStates) == 0 || perStreamStates[0].Type == internal.STATE_TYPE_STREAM) { logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Parsing Airbyte v2 per-stream state (%d streams)", len(perStreamStates))) for _, s := range perStreamStates { if s.Stream != nil && s.Stream.StreamState != nil { diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go index 3ad9029..4ab2498 100644 --- a/cmd/airbyte-source/read_protocol_test.go +++ b/cmd/airbyte-source/read_protocol_test.go @@ -96,7 +96,7 @@ func parseOutputMessages(t *testing.T, buf *bytes.Buffer) []internal.AirbyteMess return messages } -func setupReadCommand(t *testing.T, db *mockDatabase, catalogJSON string) (*bytes.Buffer, *Helper) { +func setupReadCommand(t *testing.T, db *mockDatabase) (*bytes.Buffer, *Helper) { t.Helper() b := bytes.NewBufferString("") h := &Helper{ @@ -114,7 +114,7 @@ func TestRead_EmitsPerStreamStateNotLegacy(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -153,7 +153,7 @@ func TestRead_EmitsStartedAndCompletePerStream(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -194,7 +194,7 @@ func TestRead_StatePerStreamContainsCorrectDescriptor(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -224,7 +224,7 @@ func TestRead_StateEmittedAfterStartedBeforeComplete(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -270,7 +270,7 @@ func TestRead_MultiShardStateContainsAllShards(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -397,7 +397,7 @@ func TestRead_MultiShardPartialFailureCheckpointsProgress(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -458,7 +458,7 @@ func TestRead_ShardErrorStillProcessesOtherShards(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) @@ -492,7 +492,7 @@ func TestRead_ProgressCursorPersistedOnError(t *testing.T) { configFile := writeTempFile(t, newTestConfig()) catalogFile := writeTempFile(t, []byte(catalogJSON)) - b, h := setupReadCommand(t, db, catalogJSON) + b, h := setupReadCommand(t, db) cmd := ReadCommand(h) cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) diff --git a/cmd/airbyte-source/read_test.go b/cmd/airbyte-source/read_test.go index 5c30779..228b1d9 100644 --- a/cmd/airbyte-source/read_test.go +++ b/cmd/airbyte-source/read_test.go @@ -8,8 +8,34 @@ import ( "github.com/planetscale/airbyte-source/cmd/internal" psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// An empty Airbyte v2 state array ("[]") is valid — it means there are no +// checkpoints yet. It must be parsed as v2 and the streams initialized from +// scratch, not fall through to the legacy object parser, which errors when +// asked to unmarshal an array. +func TestReadState_EmptyV2ArrayInitializesFreshCursors(t *testing.T) { + psc := internal.PlanetScaleSource{ + Host: "aws.connect.psdb.cloud", + Database: "mydb", + Username: "user", + Password: "pscale_password", + } + streams := []internal.ConfiguredStream{ + { + Stream: internal.Stream{Name: "table1", Namespace: "mydb"}, + SyncMode: "incremental", + }, + } + shards := []string{"-"} + + syncState, err := readState("[]", psc, streams, shards, internal.NewLogger(os.Stdout)) + require.NoError(t, err, "empty v2 state array should not error") + require.Contains(t, syncState.Streams, "mydb:table1") + require.Contains(t, syncState.Streams["mydb:table1"].Shards, "-") +} + // This tests that when starting_gtids are passed AND a state file is passed, // the state file takes precedence. func TestRead_StartingGtidsAndState(t *testing.T) {