diff --git a/driver/conn.go b/driver/conn.go index d95c3fa..b0c2bc5 100644 --- a/driver/conn.go +++ b/driver/conn.go @@ -20,21 +20,37 @@ import ( const maxSendChunk = ipcutil.MaxMessageSize - 64 // Conn implements net.Conn over a Pilot Protocol stream. +// +// Concurrency: like *net.TCPConn, a Conn may be used by at most one reader +// and one writer goroutine at a time. Read serialises concurrent callers +// with readMu (so recvBuf is never corrupted), but interleaving two +// readers still yields each a non-deterministic slice of the stream; do +// not do that. Write is safe for one writer; concurrent writers may +// interleave chunks on the wire. SetDeadline/SetReadDeadline/ +// SetWriteDeadline are safe to call from any goroutine. type Conn struct { id uint32 localAddr protocol.SocketAddr remoteAddr protocol.SocketAddr ipc *ipcClient recvCh chan []byte - recvBuf []byte // leftover from previous read - closed bool - mu sync.Mutex - readDeadline time.Time - deadlineCh chan struct{} // closed when deadline is set/changed + // readMu serialises Read so recvBuf (leftover from a previous read) + // cannot be observed or mutated by two readers at once. + readMu sync.Mutex + recvBuf []byte // leftover from previous read; guarded by readMu + + mu sync.Mutex + closed bool + readDeadline time.Time + writeDeadline time.Time + deadlineCh chan struct{} // closed when deadline is set/changed } func (c *Conn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + // Drain leftover first if len(c.recvBuf) > 0 { n := copy(b, c.recvBuf) @@ -78,17 +94,36 @@ func (c *Conn) Read(b []byte) (int, error) { } } +// Write enqueues b to the local daemon over IPC, splitting it into +// maxSendChunk-sized cmdSend frames. +// +// Send semantics: a nil error and n == len(b) mean every chunk was handed +// to the local daemon over IPC — NOT that the bytes were transmitted on the +// wire or acknowledged by the peer. The Pilot stream layer in the daemon +// handles retransmission/ordering after this point; Write does not block on +// it. Errors reported here are local IPC write failures or a passed +// write deadline. func (c *Conn) Write(b []byte) (int, error) { c.mu.Lock() if c.closed { c.mu.Unlock() return 0, protocol.ErrConnClosed } + wdl := c.writeDeadline c.mu.Unlock() + if !wdl.IsZero() && !time.Now().Before(wdl) { + return 0, os.ErrDeadlineExceeded + } + total := len(b) written := 0 for written < total { + // Honour the write deadline between chunks so a large, slow write + // to a backed-up IPC socket cannot block past the deadline. + if !wdl.IsZero() && !time.Now().Before(wdl) { + return written, os.ErrDeadlineExceeded + } chunk := total - written if chunk > maxSendChunk { chunk = maxSendChunk @@ -125,7 +160,15 @@ func (c *Conn) LocalAddr() net.Addr { return pilotAddr(c.localAddr) } func (c *Conn) RemoteAddr() net.Addr { return pilotAddr(c.remoteAddr) } func (c *Conn) SetDeadline(t time.Time) error { - c.SetReadDeadline(t) + c.mu.Lock() + c.readDeadline = t + c.writeDeadline = t + // Signal any blocked Read to re-check. + if c.deadlineCh != nil { + close(c.deadlineCh) + } + c.deadlineCh = make(chan struct{}) + c.mu.Unlock() return nil } @@ -141,7 +184,18 @@ func (c *Conn) SetReadDeadline(t time.Time) error { return nil } -func (c *Conn) SetWriteDeadline(t time.Time) error { return nil } +// SetWriteDeadline sets a deadline for Write. A passed deadline causes Write +// to return os.ErrDeadlineExceeded. Because Write never blocks waiting on a +// remote peer (it only enqueues chunks to the local daemon over IPC), the +// deadline is enforced before each chunk rather than via an interrupt — a +// zero time clears it. This satisfies the net.Conn contract instead of the +// previous silent no-op. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + c.writeDeadline = t + c.mu.Unlock() + return nil +} // pilotAddr wraps SocketAddr to satisfy net.Addr. type pilotAddr protocol.SocketAddr diff --git a/driver/driver.go b/driver/driver.go index c9fba5d..74b9aed 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -27,6 +27,14 @@ func DefaultSocketPath() string { return "/tmp/pilot.sock" } +// defaultDialTimeout bounds DialAddr / Listen / Broadcast so a wedged or +// non-responsive daemon can't block the caller forever. The daemon resolves +// + dials within this window in the normal case (direct punch or relay +// fallback both complete well under it); callers needing a tighter bound use +// DialAddrTimeout. Operations that legitimately block in the daemon +// (WaitForTrust) deliberately keep the unbounded sendAndWait path. +const defaultDialTimeout = 30 * time.Second + // Handshake sub-commands (must match daemon SubHandshake* constants) const ( subHandshakeSend byte = 0x01 @@ -83,32 +91,11 @@ func (d *Driver) Dial(addr string) (*Conn, error) { return d.DialAddr(sa.Addr, sa.Port) } -// DialAddr opens a stream connection to a remote Addr + port. +// DialAddr opens a stream connection to a remote Addr + port. It applies +// defaultDialTimeout so a non-responsive daemon cannot block the caller +// indefinitely; use DialAddrTimeout to supply an explicit bound. func (d *Driver) DialAddr(dst protocol.Addr, port uint16) (*Conn, error) { - msg := make([]byte, 1+protocol.AddrSize+2) - msg[0] = cmdDial - dst.MarshalTo(msg, 1) - binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) - - resp, err := d.ipc.sendAndWait(msg, cmdDialOK) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - if len(resp) < 4 { - return nil, fmt.Errorf("invalid dial response") - } - - connID := binary.BigEndian.Uint32(resp[0:4]) - recvCh := d.ipc.registerRecvCh(connID) - - return &Conn{ - id: connID, - remoteAddr: protocol.SocketAddr{Addr: dst, Port: port}, - ipc: d.ipc, - recvCh: recvCh, - deadlineCh: make(chan struct{}), - }, nil + return d.DialAddrTimeout(dst, port, defaultDialTimeout) } // DialAddrTimeout opens a stream connection with a client-side timeout. @@ -146,7 +133,7 @@ func (d *Driver) Listen(port uint16) (*Listener, error) { msg[0] = cmdBind binary.BigEndian.PutUint16(msg[1:3], port) - resp, err := d.ipc.sendAndWait(msg, cmdBindOK) + resp, err := d.ipc.sendAndWaitTimeout(msg, cmdBindOK, defaultDialTimeout) if err != nil { return nil, fmt.Errorf("bind: %w", err) } @@ -167,6 +154,13 @@ func (d *Driver) Listen(port uint16) (*Listener, error) { // SendTo sends an unreliable unicast datagram to the given address:port. // Broadcast addresses (Node=0xFFFFFFFF) are not accepted on this path; use // Broadcast, which requires the daemon's admin token. +// +// Send semantics: this is fire-and-forget. A nil return means only that the +// frame was successfully enqueued to the local daemon over IPC — it does NOT +// indicate the datagram was transmitted on the wire, routed, or delivered to +// the peer. Datagrams are unreliable; there is no acknowledgement. The only +// errors reported are local IPC failures (empty/oversized frame, socket +// write error). func (d *Driver) SendTo(dst protocol.Addr, port uint16, data []byte) error { if dst.IsBroadcast() { return fmt.Errorf("broadcast address requires admin token: use Driver.Broadcast") @@ -192,7 +186,7 @@ func (d *Driver) Broadcast(netID uint16, port uint16, data []byte, adminToken st binary.BigEndian.PutUint16(msg[5:7], uint16(len(tokenBytes))) copy(msg[7:7+len(tokenBytes)], tokenBytes) copy(msg[7+len(tokenBytes):], data) - if _, err := d.ipc.sendAndWait(msg, cmdBroadcastOK); err != nil { + if _, err := d.ipc.sendAndWaitTimeout(msg, cmdBroadcastOK, defaultDialTimeout); err != nil { return err } return nil diff --git a/driver/ipc.go b/driver/ipc.go index c2a88fc..eb55c59 100644 --- a/driver/ipc.go +++ b/driver/ipc.go @@ -109,6 +109,19 @@ type pendingResponse struct { payload []byte } +// ipcWaiter is the per-request reply slot. A waiter is identified by its +// pointer identity (abandonWaiter clears the active slot only if it is still +// this exact waiter), so a late reply for an abandoned request is dropped +// rather than delivered to the next caller. expect is the cmd this request +// awaits; deliverReply only accepts a frame whose cmd is expect or cmdError, +// dropping anything else (e.g. a stale reply for a prior request that arrives +// WHILE this one is in flight) and leaving the waiter armed. ch has capacity +// 1 so readLoop never blocks delivering a reply. +type ipcWaiter struct { + expect byte + ch chan *pendingResponse +} + type ipcClient struct { conn net.Conn @@ -121,8 +134,27 @@ type ipcClient struct { // instead of sync.Mutex lets goroutines waiting for the semaphore be // woken on doneCh close, preventing a deadlock when the daemon closes // while many goroutines are queued behind a slow sendAndWait. - waitSem chan struct{} // capacity 1 - pending chan *pendingResponse // capacity 16; buffers reply frames from readLoop + waitSem chan struct{} // capacity 1 + + // waiterMu guards the active-waiter slot. The IPC wire protocol has no + // request IDs and the daemon dispatches requests concurrently, so a + // reply that arrives AFTER its request timed out (or was abandoned) + // must NOT be handed to the next caller — doing so mis-correlates a + // stale reply with an unrelated request (wrong conn_id / result). + // + // We mitigate this client-side: every sendAndWait registers a private + // reply channel as the active waiter. readLoop delivers a reply only to + // the CURRENTLY active waiter (and only when the cmd matches). When a + // request times out or is abandoned it clears the active slot, so a + // late reply finds no matching waiter and is dropped. + // + // TODO(PILOT): cross-process ordering correctness ultimately requires + // daemon-side request IDs echoed in every reply envelope. That is a + // coordinated wire-protocol change (daemon + version bump) and is + // deliberately NOT done here; this slot scheme is the client-side + // mitigation that keeps a late reply from being mis-delivered. + waiterMu sync.Mutex + activeWaiter *ipcWaiter // current in-flight waiter, nil when idle recvMu sync.Mutex recvChs map[uint32]chan []byte // conn_id → data channel @@ -147,7 +179,6 @@ func newIPCClient(socketPath string) (*ipcClient, error) { c := &ipcClient{ conn: conn, waitSem: make(chan struct{}, 1), - pending: make(chan *pendingResponse, 16), recvChs: make(map[uint32]chan []byte), pendRecv: make(map[uint32][][]byte), pendAccept: make(map[uint16][][]byte), @@ -175,10 +206,10 @@ func (c *ipcClient) close() error { // Server-pushed frames (cmdRecv, cmdCloseOK, cmdRecvFrom, cmdAccept) are // routed by cmd to their per-connection channels. cmdCloseOK is always // a server-push (remote FIN); Driver.Disconnect uses send() not -// sendAndWait() so it never waits for cmdCloseOK in pending. -// Known response cmds are forwarded to c.pending for sendAndWait. -// Unknown cmds are silently dropped — they never reach pending, so -// sendAndWaitTimeout can use a single read without a discard loop. +// sendAndWait() so it never waits for cmdCloseOK. +// Known response cmds are delivered to the active sendAndWait waiter (if +// any). A reply with no active waiter — e.g. one that arrives after its +// request timed out — is dropped. Unknown cmds are silently dropped. func (c *ipcClient) readLoop() { defer c.cleanup() for { @@ -202,11 +233,12 @@ func (c *ipcClient) readLoop() { cmdDeregisterOK, cmdSetTagsOK, cmdSetWebhookOK, cmdNetworkOK, cmdHealthOK, cmdManagedOK, cmdRotateKeyOK, cmdBroadcastOK, cmdPreferDirectOK, cmdSubmitBadgeOK, cmdEnrollRecoveryOK: - // Known response cmds: route to pending for the in-flight sendAndWait. - select { - case c.pending <- &pendingResponse{cmd: cmd, payload: append([]byte(nil), payload...)}: - default: - } + // Known response cmds: deliver to the active sendAndWait waiter. + // If there is no active waiter (the request timed out / was + // abandoned, or this is a duplicate), the reply is dropped — + // this is the client-side mitigation that prevents a stale + // reply from being mis-delivered to a later, unrelated request. + c.deliverReply(&pendingResponse{cmd: cmd, payload: append([]byte(nil), payload...)}) // default: unknown cmd — silently drop (version mismatch, test injection, etc.) } } @@ -295,15 +327,12 @@ func (c *ipcClient) dispatchPush(cmd byte, payload []byte) { func (c *ipcClient) cleanup() { close(c.doneCh) - // Drain all buffered responses. - for { - select { - case <-c.pending: - default: - goto drained - } - } -drained: + // Clear any active waiter. The in-flight sendAndWaitTimeout selects on + // doneCh and will return "daemon disconnected"; dropping the slot here + // keeps a racing late reply from being delivered after shutdown. + c.waiterMu.Lock() + c.activeWaiter = nil + c.waiterMu.Unlock() // Close all receive channels c.recvMu.Lock() @@ -348,11 +377,73 @@ func (c *ipcClient) sendAndWait(data []byte, expectCmd byte) ([]byte, error) { return c.sendAndWaitTimeout(data, expectCmd, 0) } +// deliverReply hands a reply frame to the active waiter, if any. +// +// A frame is delivered only when there is an active waiter AND the frame's +// cmd is what that waiter expects (or cmdError, which is valid for any +// request). Anything else is dropped: +// - no active waiter: the request already timed out / was abandoned, or +// this is a duplicate; +// - cmd mismatch: a stale reply for a PRIOR (abandoned) request that +// happens to arrive while a different request is in flight — delivering +// it would mis-correlate, so it is dropped and the current waiter stays +// armed for its own reply. +// +// The active waiter's channel has capacity 1 and is single-use, so the send +// never blocks. When a frame is delivered the slot is cleared so a second +// (duplicate) frame for the same request is dropped rather than re-delivered. +func (c *ipcClient) deliverReply(resp *pendingResponse) { + c.waiterMu.Lock() + w := c.activeWaiter + if w == nil || (resp.cmd != w.expect && resp.cmd != cmdError) { + c.waiterMu.Unlock() + return + } + c.activeWaiter = nil + c.waiterMu.Unlock() + w.ch <- resp +} + +// registerWaiter installs a fresh active waiter for a request expecting +// expect (or cmdError) and returns it. Any previous waiter is replaced; +// since waitSem serialises callers there is normally at most one, but +// replacing defensively guarantees a stale slot never lingers. +func (c *ipcClient) registerWaiter(expect byte) *ipcWaiter { + c.waiterMu.Lock() + w := &ipcWaiter{expect: expect, ch: make(chan *pendingResponse, 1)} + c.activeWaiter = w + c.waiterMu.Unlock() + return w +} + +// abandonWaiter clears the active slot iff it is still w (pointer identity). +// Called on timeout/disconnect so a reply that arrives afterwards finds no +// active waiter and is dropped by deliverReply. If deliverReply already +// consumed the slot (reply won the race), this is a no-op. +func (c *ipcClient) abandonWaiter(w *ipcWaiter) { + c.waiterMu.Lock() + if c.activeWaiter == w { + c.activeWaiter = nil + } + c.waiterMu.Unlock() +} + // sendAndWaitTimeout serialises at most one request/reply pair at a time // via waitSem. timeout=0 means wait forever. The timer is started BEFORE // acquiring the semaphore so the timeout applies to queue wait + reply // wait combined — without this, goroutines queued behind the semaphore // can't time out and pile up indefinitely under high concurrency. +// +// Stale-reply safety: the reply is delivered through a private, single-use +// waiter channel registered immediately before the request is written. On +// timeout/disconnect the waiter is abandoned, so a late reply for THIS +// request is dropped by deliverReply instead of being handed to the next +// caller (which would mis-correlate a stale conn_id / result). +// +// TODO(PILOT): full cross-process ordering correctness needs daemon-side +// request IDs echoed in every reply (coordinated daemon change + version +// bump). This client-side scheme only guarantees a late reply is not +// MIS-DELIVERED; it cannot recover the abandoned request's actual result. func (c *ipcClient) sendAndWaitTimeout(data []byte, expectCmd byte, timeout time.Duration) ([]byte, error) { if len(data) < 1 { return nil, fmt.Errorf("ipc: empty request") @@ -378,24 +469,18 @@ func (c *ipcClient) sendAndWaitTimeout(data []byte, expectCmd byte, timeout time } defer func() { <-c.waitSem }() - // Drain all stale replies buffered before this request was sent. - for { - select { - case <-c.pending: - default: - goto drained - } - } -drained: + // Register the private reply slot BEFORE writing so readLoop can only + // ever route this request's reply here; mismatched (stale) frames are + // dropped by deliverReply and never reach w.ch. + w := c.registerWaiter(expectCmd) if err := c.writeFrame(data[0], data[1:]); err != nil { + c.abandonWaiter(w) return nil, err } - // Unknown cmds are dropped in readLoop, so the first frame in pending - // is always either the expected response or cmdError. select { - case resp := <-c.pending: + case resp := <-w.ch: if resp.cmd == cmdError { if len(resp.payload) >= 2 { return nil, fmt.Errorf("daemon: %s", string(resp.payload[2:])) @@ -407,8 +492,12 @@ drained: } return resp.payload, nil case <-c.doneCh: + c.abandonWaiter(w) return nil, fmt.Errorf("daemon disconnected") case <-timer: + // Abandon the slot so the late reply is dropped, not handed to the + // next caller. + c.abandonWaiter(w) return nil, fmt.Errorf("dial timeout") } } diff --git a/driver/zz_conn_test.go b/driver/zz_conn_test.go index 6ef215c..32c4650 100644 --- a/driver/zz_conn_test.go +++ b/driver/zz_conn_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/pilot-protocol/common/ipcutil" "github.com/pilot-protocol/common/protocol" ) @@ -171,7 +172,7 @@ func TestSetReadDeadlineUnblocksReader(t *testing.T) { } } -func TestSetDeadlineDelegatesToRead(t *testing.T) { +func TestSetDeadlineSetsReadAndWrite(t *testing.T) { t.Parallel() c := &Conn{ recvCh: make(chan []byte), @@ -184,14 +185,56 @@ func TestSetDeadlineDelegatesToRead(t *testing.T) { if !c.readDeadline.Equal(dl) { t.Errorf("readDeadline = %v, want %v", c.readDeadline, dl) } + if !c.writeDeadline.Equal(dl) { + t.Errorf("writeDeadline = %v, want %v", c.writeDeadline, dl) + } } -func TestSetWriteDeadlineNoop(t *testing.T) { +// TestSetWriteDeadlinePastBlocksWrite verifies SetWriteDeadline is no longer +// a no-op: a deadline already in the past makes Write fail with +// os.ErrDeadlineExceeded instead of silently succeeding. +func TestSetWriteDeadlinePastBlocksWrite(t *testing.T) { t.Parallel() - c := &Conn{} - if err := c.SetWriteDeadline(time.Now()); err != nil { - t.Errorf("expected nil, got %v", err) + clientSide, serverSide := net.Pipe() + defer clientSide.Close() + defer serverSide.Close() + + ipc := &ipcClient{ + conn: clientSide, + waitSem: make(chan struct{}, 1), + recvChs: make(map[uint32]chan []byte), + pendRecv: make(map[uint32][][]byte), + acceptChs: make(map[uint16]chan []byte), + dgCh: make(chan *Datagram, 1), + doneCh: make(chan struct{}), + } + c := &Conn{id: 1, ipc: ipc, deadlineCh: make(chan struct{})} + + if err := c.SetWriteDeadline(time.Now().Add(-time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + n, err := c.Write([]byte("data")) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("Write err = %v, want os.ErrDeadlineExceeded", err) + } + if n != 0 { + t.Errorf("Write n = %d, want 0", n) + } + + // Clearing the deadline (zero time) restores normal writes. + if err := c.SetWriteDeadline(time.Time{}); err != nil { + t.Fatalf("clear SetWriteDeadline: %v", err) + } + done := make(chan struct{}) + go func() { + defer close(done) + _ = serverSide.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _ = ipcutil.Read(serverSide) + }() + if _, err := c.Write([]byte("ok")); err != nil { + t.Fatalf("Write after clearing deadline: %v", err) } + <-done } func TestConnAddrs(t *testing.T) { diff --git a/driver/zz_conn_write_test.go b/driver/zz_conn_write_test.go index f208518..c525b7b 100644 --- a/driver/zz_conn_write_test.go +++ b/driver/zz_conn_write_test.go @@ -25,7 +25,6 @@ func TestConnWriteChunksLargePayload(t *testing.T) { ipc := &ipcClient{ conn: clientSide, waitSem: make(chan struct{}, 1), - pending: make(chan *pendingResponse, 16), recvChs: make(map[uint32]chan []byte), pendRecv: make(map[uint32][][]byte), acceptChs: make(map[uint16]chan []byte), @@ -123,7 +122,6 @@ func TestConnWriteSinglePayloadNotSplit(t *testing.T) { ipc := &ipcClient{ conn: clientSide, waitSem: make(chan struct{}, 1), - pending: make(chan *pendingResponse, 16), recvChs: make(map[uint32]chan []byte), pendRecv: make(map[uint32][][]byte), acceptChs: make(map[uint16]chan []byte), diff --git a/driver/zz_driver_simple_ops_test.go b/driver/zz_driver_simple_ops_test.go index b310ba2..ddc5d4b 100644 --- a/driver/zz_driver_simple_ops_test.go +++ b/driver/zz_driver_simple_ops_test.go @@ -111,7 +111,7 @@ func TestDriverWaitForTrust(t *testing.T) { // - the request frame is exactly [cmdPreferDirect(0x2D)][big-endian uint32 nodeID] (5 bytes), // - the cmdPreferDirectOK (0x2E) reply is routed/accepted by readLoop (not dropped) — // proven by the happy path returning a non-nil result and nil error, which only -// happens if the OK frame reaches the in-flight sendAndWait via c.pending, +// happens if the OK frame reaches the in-flight sendAndWait's active waiter, // - the daemon's returned routing state is unmarshalled and surfaced. func TestDriverPreferDirect(t *testing.T) { t.Parallel() diff --git a/driver/zz_ipc_listener_test.go b/driver/zz_ipc_listener_test.go index 6cf7db2..122a5c6 100644 --- a/driver/zz_ipc_listener_test.go +++ b/driver/zz_ipc_listener_test.go @@ -480,6 +480,105 @@ func TestSendAndWaitTimeoutFires(t *testing.T) { } } +// TestLateReplyAfterTimeoutNotMisdelivered is the regression test for the +// stale-IPC-reply mis-correlation bug. The daemon delays the reply to the +// first request past its timeout; that reply then arrives while a SECOND, +// unrelated request is in flight. The fix (per-request waiter slots that are +// abandoned on timeout) must DROP the late reply so the second caller gets +// its OWN reply, not the stale one — otherwise it would consume a reply with +// the wrong cmd / payload (in production: wrong conn_id / result). +// +// NOTE: cross-process ordering correctness ultimately needs daemon-side +// request IDs; this only guarantees the late reply is not mis-delivered. +func TestLateReplyAfterTimeoutNotMisdelivered(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // First request (cmdInfo): the daemon never replies promptly. Instead it + // schedules a STALE cmdInfoOK to be written to the conn slightly AFTER + // the first call has timed out AND a second, unrelated request is in + // flight. This reproduces the production race: a timed-out dial's late + // reply arriving during a later request. + staleArm := make(chan struct{}) // close to release the stale reply + staleSent := make(chan struct{}) // closed once the stale reply is written + d.onCmd(cmdInfo, func(_ []byte) [][]byte { + // acceptLoop already holds d.mu while invoking this handler, so read + // d.conn directly — re-locking would deadlock. + conn := d.conn + go func() { + <-staleArm + _ = ipcutil.Write(conn, jsonOK(cmdInfoOK, `{"stale":"info-reply"}`)) + close(staleSent) + }() + return nil + }) + + // 1) First call times out (no reply within the window). The buggy code + // left this request's reply destined for a shared buffer; the fix + // abandons the waiter slot here. + if _, err := drv.ipc.sendAndWaitTimeout([]byte{cmdInfo}, cmdInfoOK, 40*time.Millisecond); err == nil { + t.Fatal("expected first request to time out") + } + + // 2) Start a second, unrelated request (cmdHealth) that the daemon does + // NOT answer, so it is parked waiting for a reply. While it waits we + // release the stale cmdInfoOK. Under the bug, that stale frame would + // be handed to this Health waiter (mis-correlation: wrong cmd/payload + // — in production a wrong conn_id). The fix drops it because its cmd + // (cmdInfoOK) doesn't match what Health expects, so Health correctly + // times out instead of consuming the stale reply. + type result struct { + payload []byte + err error + } + resCh := make(chan result, 1) + go func() { + p, err := drv.ipc.sendAndWaitTimeout([]byte{cmdHealth}, cmdHealthOK, 400*time.Millisecond) + resCh <- result{p, err} + }() + + // Give Health time to park in its wait, then release the stale reply so + // it lands while Health is in flight. + time.Sleep(60 * time.Millisecond) + close(staleArm) + select { + case <-staleSent: + case <-time.After(2 * time.Second): + t.Fatal("stale reply was never written — test did not exercise the race") + } + + res := <-resCh + if res.err == nil { + t.Fatalf("Health must NOT succeed by consuming the stale info reply; got payload %q", res.payload) + } + // The stale reply is cmdInfoOK; if it were mis-delivered, Health would + // have returned either that payload or an "unexpected reply" cmd error + // rather than its own timeout. + if want := "dial timeout"; res.err.Error() != want { + t.Fatalf("Health err = %q, want %q (stale reply was mis-delivered)", res.err.Error(), want) + } + + // 3) Prove the connection is still healthy: a fresh Health that the + // daemon answers must succeed, confirming the dropped stale frame did + // not poison the waiter slot. + d.onCmd(cmdHealth, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdHealthOK, `{"fresh":"health-reply"}`)} + }) + got, err := drv.Health() + if err != nil { + t.Fatalf("Health after dropped stale frame failed (slot poisoned?): %v", err) + } + if got["fresh"] != "health-reply" { + t.Fatalf("Health got %v, want fresh health-reply", got) + } +} + func TestSendAndWaitReturnsWhenDaemonDisconnects(t *testing.T) { t.Parallel() d := newFakeDaemon(t) diff --git a/registry/client/client.go b/registry/client/client.go index a6f6e46..d5a2f31 100644 --- a/registry/client/client.go +++ b/registry/client/client.go @@ -27,6 +27,12 @@ import ( // recoverable condition. var ErrNoRegistry = errors.New("registry client not configured") +// dialTimeout bounds every TCP/TLS connection attempt to the registry so an +// unreachable or black-holed registry host cannot hang startup or any +// registry operation indefinitely. It matches the per-attempt timeout the +// reconnect paths already use. +const dialTimeout = 5 * time.Second + // Client talks to a registry server over TCP (optionally TLS). // It automatically reconnects if the connection drops. // @@ -126,7 +132,7 @@ func (c *Client) sign(challenge string) (string, error) { } func Dial(addr string) (*Client, error) { - conn, err := net.Dial("tcp", addr) + conn, err := net.DialTimeout("tcp", addr, dialTimeout) if err != nil { return nil, fmt.Errorf("dial registry: %w", err) } @@ -151,7 +157,7 @@ func DialPool(addr string, size int) (*Client, error) { if size <= 0 { size = 1 } - primary, err := net.Dial("tcp", addr) + primary, err := net.DialTimeout("tcp", addr, dialTimeout) if err != nil { return nil, fmt.Errorf("dial registry: %w", err) } @@ -169,7 +175,7 @@ func DialTLS(addr string, tlsConfig *tls.Config) (*Client, error) { if tlsConfig == nil { return nil, fmt.Errorf("TLS config required; use DialTLSPinned for certificate pinning") } - conn, err := tls.Dial("tcp", addr, tlsConfig) + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", addr, tlsConfig) if err != nil { return nil, fmt.Errorf("dial registry TLS: %w", err) } @@ -184,7 +190,7 @@ func DialTLSPool(addr string, tlsConfig *tls.Config, size int) (*Client, error) if size <= 0 { size = 1 } - primary, err := tls.Dial("tcp", addr, tlsConfig) + primary, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", addr, tlsConfig) if err != nil { return nil, fmt.Errorf("dial registry TLS: %w", err) } @@ -210,9 +216,9 @@ func (c *Client) initPool(size int, tlsCfg *tls.Config) error { var conn net.Conn var err error if tlsCfg != nil { - conn, err = tls.Dial("tcp", c.addr, tlsCfg) + conn, err = tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", c.addr, tlsCfg) } else { - conn, err = net.Dial("tcp", c.addr) + conn, err = net.DialTimeout("tcp", c.addr, dialTimeout) } if err != nil { // Close any conns we already opened (excluding primary — @@ -255,7 +261,7 @@ func DialTLSPinned(addr, fingerprint string) (*Client, error) { return nil }, } - conn, err := tls.Dial("tcp", addr, tlsConfig) + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", addr, tlsConfig) if err != nil { return nil, fmt.Errorf("dial registry TLS pinned: %w", err) }