diff --git a/core/ipc_handler.go b/core/ipc_handler.go index f687b24b..5f016594 100644 --- a/core/ipc_handler.go +++ b/core/ipc_handler.go @@ -212,14 +212,27 @@ func buildEndpoints(neigh *state.Neighbour) []*protocol.EndpointInfo { if ap, err := nep.DynEP.Get(); err == nil { resolved = stringPtr(ap.String()) } + var bindID, bindInterface, bindSource *string + if nep.LocalBind != "" { + bindID = stringPtr(string(nep.LocalBind)) + } + if nep.Bind.Interface != "" { + bindInterface = stringPtr(nep.Bind.Interface) + } + if nep.Bind.Source.IsValid() { + bindSource = stringPtr(nep.Bind.Source.String()) + } eps = append(eps, &protocol.EndpointInfo{ - Address: nep.DynEP.Value, - Resolved: resolved, - Active: ep.IsActive(), - RemoteInit: ep.IsRemote(), - Metric: ep.Metric(), - FilteredRttNs: int64(nep.FilteredPing()), - StabilizedRttNs: int64(nep.StabilizedPing()), + Address: nep.DynEP.Value, + Resolved: resolved, + Active: ep.IsActive(), + RemoteInit: ep.IsRemote(), + Metric: ep.Metric(), + FilteredRttNs: int64(nep.FilteredPing()), + StabilizedRttNs: int64(nep.StabilizedPing()), + LocalBindId: bindID, + LocalBindInterface: bindInterface, + LocalBindSource: bindSource, }) } slices.SortFunc(eps, func(a, b *protocol.EndpointInfo) int { diff --git a/core/nylon.go b/core/nylon.go index aa6f31d7..6659d014 100644 --- a/core/nylon.go +++ b/core/nylon.go @@ -43,6 +43,7 @@ type Nylon struct { router struct { LastStarvationRequest time.Time IO map[state.NodeId]*IOPending + RouteComputePending atomic.Bool // ForwardTable contains the full routing table ForwardTable atomic.Pointer[bart.Table[RouteTableEntry]] diff --git a/core/nylon_apply.go b/core/nylon_apply.go index b5ad501b..0c86d350 100644 --- a/core/nylon_apply.go +++ b/core/nylon_apply.go @@ -2,6 +2,7 @@ package core import ( "errors" + "fmt" "net/netip" "reflect" "slices" @@ -70,7 +71,7 @@ func (n *Nylon) reconcileRouterState(next *state.CentralCfg) error { continue } // configure existing neighbours - reconcileConfiguredEndpoints(neigh, cfg.Endpoints, &n.RouterTunables) + reconcileConfiguredEndpoints(neigh, configuredEndpoints(n.LocalCfg.EndpointBinds, neigh.Id, cfg.Endpoints, &n.RouterTunables)) neighs = append(neighs, neigh) delete(desired, neigh.Id) } @@ -88,9 +89,7 @@ func (n *Nylon) reconcileRouterState(next *state.CentralCfg) error { Routes: make(map[netip.Prefix]state.NeighRoute), Eps: make([]state.Endpoint, 0, len(cfg.Endpoints)), } - for _, ep := range cfg.Endpoints { - stNeigh.Eps = append(stNeigh.Eps, state.NewEndpoint(ep, false, nil, &n.RouterTunables)) - } + stNeigh.Eps = append(stNeigh.Eps, configuredEndpoints(n.LocalCfg.EndpointBinds, id, cfg.Endpoints, &n.RouterTunables)...) neighs = append(neighs, stNeigh) } n.RouterState.Neighbours = neighs @@ -107,10 +106,35 @@ func (n *Nylon) reconcileRouterState(next *state.CentralCfg) error { return nil } -func reconcileConfiguredEndpoints(neigh *state.Neighbour, desired []*state.DynamicEndpoint, t *state.RouterTunables) { - desiredByValue := make(map[string]*state.DynamicEndpoint, len(desired)) +func configuredEndpoints(binds []state.LocalEndpointBind, peer state.NodeId, endpoints []*state.DynamicEndpoint, t *state.RouterTunables) []state.Endpoint { + eps := make([]state.Endpoint, 0, len(endpoints)) + for _, ep := range endpoints { + matched := false + for idx, bind := range binds { + if bind.Peer != peer || bind.Endpoint != ep.Value { + continue + } + nep := state.NewEndpoint(ep, false, nil, t) + nep.LocalBind = bind.LocalBind(idx).ID + nep.Bind = bind.LocalBind(idx) + eps = append(eps, nep) + matched = true + } + if !matched { + eps = append(eps, state.NewEndpoint(ep, false, nil, t)) + } + } + return eps +} + +func endpointKey(ep *state.NylonEndpoint) string { + return fmt.Sprintf("%s\x00%s\x00%s", ep.DynEP.Value, ep.Bind.Interface, ep.Bind.Source) +} + +func reconcileConfiguredEndpoints(neigh *state.Neighbour, desired []state.Endpoint) { + desiredByKey := make(map[string]state.Endpoint, len(desired)) for _, ep := range desired { - desiredByValue[ep.Value] = ep + desiredByKey[endpointKey(ep.AsNylonEndpoint())] = ep } eps := make([]state.Endpoint, 0, len(neigh.Eps)+len(desired)) @@ -122,16 +146,18 @@ func reconcileConfiguredEndpoints(neigh *state.Neighbour, desired []*state.Dynam continue } // only keep if desired - if desiredEp, ok := desiredByValue[nep.DynEP.Value]; ok { + key := endpointKey(nep) + if _, ok := desiredByKey[key]; ok { eps = append(eps, ep) - seen[desiredEp.Value] = struct{}{} + seen[key] = struct{}{} } } for _, ep := range desired { - if _, ok := seen[ep.Value]; ok { + key := endpointKey(ep.AsNylonEndpoint()) + if _, ok := seen[key]; ok { continue } - eps = append(eps, state.NewEndpoint(ep, false, nil, t)) + eps = append(eps, ep) } neigh.Eps = eps } diff --git a/core/nylon_endpoints.go b/core/nylon_endpoints.go index 67812ab1..9327f5e9 100644 --- a/core/nylon_endpoints.go +++ b/core/nylon_endpoints.go @@ -2,6 +2,7 @@ package core import ( "math/rand/v2" + "net" "slices" "sync" "time" @@ -15,6 +16,8 @@ import ( type EpPing struct { TimeSent time.Time + Node state.NodeId + Endpoint *state.NylonEndpoint } func (n *Nylon) Probe(node state.NodeId, ep *state.NylonEndpoint, waitErr bool) error { @@ -46,6 +49,8 @@ func (n *Nylon) Probe(node state.NodeId, ep *state.NylonEndpoint, waitErr bool) n.PingBuf.Set(token, EpPing{ TimeSent: time.Now(), + Node: node, + Endpoint: ep, }, ttlcache.DefaultTTL) }() @@ -83,6 +88,29 @@ func handleProbe(n *Nylon, pkt *protocol.Ny_Probe, endpoint conn.Endpoint, peer } } +func endpointMatchesPacket(ep *state.NylonEndpoint, packetEp conn.Endpoint) bool { + ap, err := ep.DynEP.Get() + if err != nil || ap != packetEp.DstIPPort() { + return false + } + if ep.Bind.Source.IsValid() && ep.Bind.Source != packetEp.SrcIP() { + return false + } + if !ep.Bind.Source.IsValid() && ep.Bind.Interface != "" { + iface, err := net.InterfaceByName(ep.Bind.Interface) + if err != nil { + return false + } + srcIf, ok := packetEp.(interface { + SrcIfidx() int32 + }) + if !ok || int32(iface.Index) != srcIf.SrcIfidx() { + return false + } + } + return true +} + func handleProbePing(n *Nylon, node state.NodeId, wgEndpoint conn.Endpoint) { if node == n.LocalCfg.Id { return @@ -91,8 +119,7 @@ func handleProbePing(n *Nylon, node state.NodeId, wgEndpoint conn.Endpoint) { for _, neigh := range n.RouterState.Neighbours { for _, dep := range neigh.Eps { dep := dep.AsNylonEndpoint() - ap, err := dep.DynEP.Get() - if err == nil && ap == wgEndpoint.DstIPPort() && neigh.Id == node { + if neigh.Id == node && endpointMatchesPacket(dep, wgEndpoint) { // we have a link // refresh wireguard ep @@ -104,7 +131,7 @@ func handleProbePing(n *Nylon, node state.NodeId, wgEndpoint conn.Endpoint) { dep.Renew() if n.DBG_log_probe { - n.Log.Debug("probe from", "addr", ap.String()) + n.Log.Debug("probe from", "addr", wgEndpoint.DstToString(), "src", wgEndpoint.SrcIP()) } return } @@ -124,33 +151,35 @@ func handleProbePing(n *Nylon, node state.NodeId, wgEndpoint conn.Endpoint) { } func handleProbePong(n *Nylon, node state.NodeId, token uint64, ep conn.Endpoint) { - // check if link exists - for _, neigh := range n.RouterState.Neighbours { - for _, dep := range neigh.Eps { - dpLink := dep.AsNylonEndpoint() - ap, err := dpLink.DynEP.Get() - if err == nil && ap == ep.DstIPPort() && neigh.Id == node { - linkHealth, ok := n.PingBuf.GetAndDelete(token) - if ok { - health := linkHealth.Value() - latency := time.Since(health.TimeSent) - // we have a link - if n.DBG_log_probe { - n.Log.Debug("probe back", "peer", node, "ping", latency) - } - dpLink.Renew() - dpLink.UpdatePing(latency) - - // update wireguard endpoint - dpLink.WgEndpoint = ep - - ComputeRoutes(n.RouterState, n) - } - return - } - } + linkHealth, ok := n.PingBuf.GetAndDelete(token) + if !ok { + n.Log.Warn("probe came back and couldn't find token", "from", ep.DstToString(), "node", node) + return + } + health := linkHealth.Value() + dpLink := health.Endpoint + if health.Node != node || dpLink == nil { + n.Log.Warn("probe came back for unexpected node", "from", ep.DstToString(), "node", node, "expected", health.Node) + return + } + latency := time.Since(health.TimeSent) + // we have a link + if n.DBG_log_probe { + n.Log.Debug("probe back", "peer", node, "ping", latency) + } + dpLink.Renew() + dpLink.UpdatePing(latency) + + if dpLink.Bind.Source.IsValid() && dpLink.Bind.Source != ep.SrcIP() { + n.Log.Warn("bound probe returned on unexpected source", "from", ep.DstToString(), "node", node, "expected", dpLink.Bind.Source, "actual", ep.SrcIP()) + } else { + // update wireguard endpoint + dpLink.WgEndpoint = ep } - n.Log.Warn("probe came back and couldn't find link", "from", ep.DstToString(), "node", node) + + // Probe pongs arrive frequently on larger meshes. Coalesce route + // recomputation so RTT samples can update without saturating dispatch. + n.ScheduleRouteCompute(n.StarvationDelay) } func (n *Nylon) probeLinks(active bool) error { diff --git a/core/nylon_scheduler.go b/core/nylon_scheduler.go index c5d53ce9..11049df1 100644 --- a/core/nylon_scheduler.go +++ b/core/nylon_scheduler.go @@ -7,8 +7,9 @@ import ( "time" ) -// Dispatch Dispatches the function to run on the main thread without waiting for it to complete -func (n *Nylon) Dispatch(fun func() error) { +// Dispatch dispatches the function to run on the main thread without waiting for it to complete. +// It returns false when the dispatch queue is full and the function was dropped. +func (n *Nylon) Dispatch(fun func() error) bool { defer func() { if r := recover(); r != nil { n.Cancel(fmt.Errorf("dispatch panic: %v", r)) @@ -17,10 +18,10 @@ func (n *Nylon) Dispatch(fun func() error) { for { select { case n.DispatchChannel <- fun: - return + return true default: n.Log.Error("dispatch channel is full, discarded function", "fun", runtime.FuncForPC(reflect.ValueOf(fun).Pointer()).Name(), "len", len(n.DispatchChannel)) - return + return false } } } diff --git a/core/nylon_tc.go b/core/nylon_tc.go index cce08160..7e449240 100644 --- a/core/nylon_tc.go +++ b/core/nylon_tc.go @@ -186,23 +186,49 @@ func (n *Nylon) handleNylonPacket(packet []byte, endpoint conn.Endpoint, peer *d } }() + controlPackets := make([]*protocol.Ny, 0, len(bundle.Packets)) + for _, pkt := range bundle.Packets { switch pkt.Type.(type) { case *protocol.Ny_SeqnoRequestOp: - n.Dispatch(func() error { - return n.routerHandleSeqnoRequest(neigh, pkt.GetSeqnoRequestOp()) - }) + controlPackets = append(controlPackets, pkt) case *protocol.Ny_RouteOp: - n.Dispatch(func() error { - return n.routerHandleRouteUpdate(neigh, pkt.GetRouteOp()) - }) + controlPackets = append(controlPackets, pkt) case *protocol.Ny_AckRetractOp: - n.Dispatch(func() error { - return n.routerHandleAckRetract(neigh, pkt.GetAckRetractOp()) - }) + controlPackets = append(controlPackets, pkt) case *protocol.Ny_ProbeOp: // we don't want to wait for dispatch before responding to this packet handleProbe(n, pkt.GetProbeOp(), endpoint, peer, neigh) } } + + if len(controlPackets) == 0 { + return + } + + n.Dispatch(func() error { + routeUpdated := false + for _, pkt := range controlPackets { + switch pkt.Type.(type) { + case *protocol.Ny_SeqnoRequestOp: + if err := n.routerHandleSeqnoRequest(neigh, pkt.GetSeqnoRequestOp()); err != nil { + return err + } + case *protocol.Ny_RouteOp: + applied, err := n.routerApplyRouteUpdate(neigh, pkt.GetRouteOp()) + if err != nil { + return err + } + routeUpdated = routeUpdated || applied + case *protocol.Ny_AckRetractOp: + if err := n.routerHandleAckRetract(neigh, pkt.GetAckRetractOp()); err != nil { + return err + } + } + } + if routeUpdated { + ComputeRoutes(n.RouterState, n) + } + return nil + }) } diff --git a/core/router.go b/core/router.go index e9ed6139..8562053c 100644 --- a/core/router.go +++ b/core/router.go @@ -93,6 +93,22 @@ func (n *Nylon) RouterEvent(event string, desc string, args ...any) { n.router.log.Debug(desc, append([]any{"event", event}, args...)...) } +func (n *Nylon) ScheduleRouteCompute(delay time.Duration) { + if n.router.RouteComputePending.Swap(true) { + return + } + time.AfterFunc(delay, func() { + ok := n.Dispatch(func() error { + defer n.router.RouteComputePending.Store(false) + ComputeRoutes(n.RouterState, n) + return nil + }) + if !ok { + n.router.RouteComputePending.Store(false) + } + }) +} + func (n *Nylon) UpdateNeighbour(neigh state.NodeId) { PushFullTable(n.RouterState, n, neigh) } @@ -297,17 +313,17 @@ func (n *Nylon) checkNode(id state.NodeId) bool { } // packet handlers -func (n *Nylon) routerHandleRouteUpdate(node state.NodeId, update *protocol.Ny_Update) error { +func (n *Nylon) routerApplyRouteUpdate(node state.NodeId, update *protocol.Ny_Update) (bool, error) { prefix := netip.Prefix{} err := prefix.UnmarshalBinary(update.Prefix) if err != nil { n.router.log.Warn("received update with invalid prefix", "prefix", update.Prefix, "err", err) - return nil + return false, nil } if !n.checkNeigh(node) || !n.checkPrefix(prefix) || !n.checkNode(state.NodeId(update.RouterId)) { - return nil + return false, nil } HandleNeighbourUpdate(n.RouterState, n, node, state.PubRoute{ Source: state.Source{ @@ -319,8 +335,7 @@ func (n *Nylon) routerHandleRouteUpdate(node state.NodeId, update *protocol.Ny_U Metric: update.Metric, }, }) - ComputeRoutes(n.RouterState, n) - return nil + return true, nil } func (n *Nylon) routerHandleAckRetract(neigh state.NodeId, update *protocol.Ny_AckRetract) error { diff --git a/docs/reference/endpoint-local-binds.mdx b/docs/reference/endpoint-local-binds.mdx new file mode 100644 index 00000000..8d8d80bc --- /dev/null +++ b/docs/reference/endpoint-local-binds.mdx @@ -0,0 +1,62 @@ +# Endpoint Local Binds + +Nylon keeps routing decisions separate from transport endpoint selection. +Routing chooses which node to visit next. Endpoint selection chooses which +underlay address, source address, and local interface are used for the UDP +packet sent to that next node. + +## Model + +At the routing level, a neighbour is represented once. Multiple underlay +endpoints for the same neighbour are link-level candidates for the same routing +edge, and Nylon surfaces the best active endpoint metric to the router. + +At the link level, a Nylon endpoint may carry an optional local bind selector: + +- source address: the local address to use for the outgoing UDP packet +- interface index: the local interface to constrain the outgoing UDP packet to + +The source address and endpoint address must be in the same IP family when a +source address is configured. If no local bind selector is configured, endpoint +traffic uses the kernel's normal source address and route selection. + +This design does not create one socket per local bind. A Nylon device owns one +IPv4 UDP socket and one IPv6 UDP socket when both address families are +available. Local bind selectors are attached to endpoint sends. + +## Linux Transport + +On Linux, endpoint-local binds are encoded as per-message ancillary data: + +- IPv4 uses `IP_PKTINFO` with `in_pktinfo.ipi_spec_dst` and + `in_pktinfo.ipi_ifindex`. +- IPv6 uses `IPV6_PKTINFO` with `in6_pktinfo.ipi6_addr` and + `in6_pktinfo.ipi6_ifindex`. + +When both a source address and interface index are provided, Nylon sends both in +the packet info control message. The source address remains the requested source +address; the interface index constrains output interface selection. When only an +interface is provided, the packet info source address is left unspecified and +the kernel may select an address for that interface. + +This keeps probe packets and other control traffic precise without changing the +socket-level bind. It also allows different endpoints for the same peer to use +different local bind selectors while still sharing the device sockets. + +## Non-Linux Platforms + +The endpoint-local bind mechanism is currently implemented only for Linux. +Other platforms keep the default socket behaviour. A cross-platform user-facing +configuration should only expose selectors that the current platform can honor, +or should fail validation before the device starts. + +## Routing Behaviour + +Endpoint-local binds affect only the link-level endpoint used to send packets. +They do not create additional routing-level neighbours or independent routed +edges. Nylon still computes routes over nodes and uses the best active endpoint +metric for each neighbour. + +Bandwidth aggregation, ECMP, and max-flow style forwarding are outside this +model. Supporting those would require different routing semantics rather than +additional endpoint-local bind selectors. diff --git a/polyamide/conn/bind_std.go b/polyamide/conn/bind_std.go index 27f5fe86..6938d788 100644 --- a/polyamide/conn/bind_std.go +++ b/polyamide/conn/bind_std.go @@ -374,12 +374,12 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { defer s.udpAddrPool.Put(ua) if is6 { as16 := endpoint.DstIP().As16() - copy(ua.IP, as16[:]) ua.IP = ua.IP[:16] + copy(ua.IP, as16[:]) } else { as4 := endpoint.DstIP().As4() - copy(ua.IP, as4[:]) ua.IP = ua.IP[:4] + copy(ua.IP, as4[:]) } ua.Port = int(endpoint.(*StdNetEndpoint).Port()) var ( @@ -526,6 +526,7 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu if err != nil { return n, err } + control := append([]byte(nil), msg.OOB[:msg.NN]...) if gsoSize > 0 { numToSplit = (msg.N + gsoSize - 1) / gsoSize end = gsoSize @@ -536,6 +537,8 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu } copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) msgs[n].N = copied + msgs[n].OOB = append(msgs[n].OOB[:0], control...) + msgs[n].NN = len(control) msgs[n].Addr = msg.Addr start = end end += gsoSize diff --git a/polyamide/conn/bind_std_test.go b/polyamide/conn/bind_std_test.go index 34a3c9ac..cecf4afd 100644 --- a/polyamide/conn/bind_std_test.go +++ b/polyamide/conn/bind_std_test.go @@ -248,3 +248,27 @@ func Test_splitCoalescedMessages(t *testing.T) { }) } } + +func Test_splitCoalescedMessagesPreservesControl(t *testing.T) { + msgs := []ipv6.Message{ + {Buffers: [][]byte{make([]byte, 1<<16-1)}, OOB: make([]byte, 0, 2)}, + {Buffers: [][]byte{make([]byte, 1<<16-1)}, OOB: make([]byte, 0, 2)}, + {Buffers: [][]byte{make([]byte, 1<<16-1)}, N: 2, NN: 2, OOB: []byte{1, 0}}, + {Buffers: [][]byte{make([]byte, 1<<16-1)}, OOB: make([]byte, 0, 2)}, + } + got, err := splitCoalescedMessages(msgs, 2, mockGetGSOSize) + if err != nil { + t.Fatalf("err: %v", err) + } + if got != 2 { + t.Fatalf("got to eval: %d want: %d", got, 2) + } + for i := 0; i < got; i++ { + if msgs[i].NN != 2 { + t.Fatalf("msg[%d].NN: %d want: 2", i, msgs[i].NN) + } + if gotGSO, err := mockGetGSOSize(msgs[i].OOB[:msgs[i].NN]); err != nil || gotGSO != 1 { + t.Fatalf("msg[%d] gso: %d err: %v", i, gotGSO, err) + } + } +} diff --git a/polyamide/conn/sticky_default.go b/polyamide/conn/sticky_default.go index 15b65af8..c05ba49f 100644 --- a/polyamide/conn/sticky_default.go +++ b/polyamide/conn/sticky_default.go @@ -21,6 +21,8 @@ func (e *StdNetEndpoint) SrcToString() string { return "" } +func (e *StdNetEndpoint) SetSrc(addr netip.Addr, ifidx int32) {} + // TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets // {get,set}srcControl feature set, but use alternatively named flags and need // ports and require testing. diff --git a/polyamide/conn/sticky_linux.go b/polyamide/conn/sticky_linux.go index adfedc17..8d4d0b5b 100644 --- a/polyamide/conn/sticky_linux.go +++ b/polyamide/conn/sticky_linux.go @@ -45,6 +45,54 @@ func (e *StdNetEndpoint) SrcToString() string { return e.SrcIP().String() } +func (e *StdNetEndpoint) SetSrc(addr netip.Addr, ifidx int32) { + if !addr.IsValid() && ifidx == 0 { + e.ClearSrc() + return + } + if !addr.IsValid() { + addr = netip.AddrFrom16([16]byte{}) + if e.DstIP().Is4() { + addr = netip.AddrFrom4([4]byte{}) + } + } + if addr.Is4() { + if e.src == nil || cap(e.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + e.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + e.src = e.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(e.src, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(e.src[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + return + } + if addr.Is6() { + if e.src == nil || cap(e.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + e.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + e.src = e.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(e.src, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(e.src[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } +} + // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { diff --git a/polyamide/conn/sticky_linux_test.go b/polyamide/conn/sticky_linux_test.go index 1b1ee683..185a4199 100644 --- a/polyamide/conn/sticky_linux_test.go +++ b/polyamide/conn/sticky_linux_test.go @@ -54,6 +54,49 @@ func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { } func Test_setSrcControl(t *testing.T) { + t.Run("SetSrcStoresPacketInfo", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + ep.SetSrc(netip.MustParseAddr("127.0.0.1"), 5) + + control := make([]byte, stickyControlSize) + setSrcControl(&control, ep) + + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst != [4]byte{127, 0, 0, 1} { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("SetSrcStoresInterfaceOnlyPacketInfo", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + ep.SetSrc(netip.Addr{}, 5) + + control := make([]byte, stickyControlSize) + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst != [4]byte{} { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + t.Run("IPv4", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), diff --git a/polyamide/device/peer.go b/polyamide/device/peer.go index 9d1afc7b..6497742f 100644 --- a/polyamide/device/peer.go +++ b/polyamide/device/peer.go @@ -136,11 +136,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte, eps []conn.Endpoint) error { for _, ep := range endpoints { ep.ClearSrc() } - for _, ep := range eps { - if ep != nil { - ep.ClearSrc() - } - } peer.endpoints.clearSrcOnTx = false } peer.endpoints.Unlock() diff --git a/protocol/nylon_ipc.pb.go b/protocol/nylon_ipc.pb.go index 5be9758f..dc756d88 100644 --- a/protocol/nylon_ipc.pb.go +++ b/protocol/nylon_ipc.pb.go @@ -578,16 +578,19 @@ func (x *Advertisement) GetPassiveHold() bool { } type EndpointInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - Address string `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"` - Resolved *string `protobuf:"bytes,2,opt,name=resolved,proto3,oneof" json:"resolved,omitempty"` - Active bool `protobuf:"varint,3,opt,name=active,proto3" json:"active,omitempty"` - RemoteInit bool `protobuf:"varint,4,opt,name=remote_init,json=remoteInit,proto3" json:"remote_init,omitempty"` - Metric uint32 `protobuf:"varint,5,opt,name=metric,proto3" json:"metric,omitempty"` - FilteredRttNs int64 `protobuf:"varint,7,opt,name=filtered_rtt_ns,json=filteredRttNs,proto3" json:"filtered_rtt_ns,omitempty"` - StabilizedRttNs int64 `protobuf:"varint,8,opt,name=stabilized_rtt_ns,json=stabilizedRttNs,proto3" json:"stabilized_rtt_ns,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Address string `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"` + Resolved *string `protobuf:"bytes,2,opt,name=resolved,proto3,oneof" json:"resolved,omitempty"` + Active bool `protobuf:"varint,3,opt,name=active,proto3" json:"active,omitempty"` + RemoteInit bool `protobuf:"varint,4,opt,name=remote_init,json=remoteInit,proto3" json:"remote_init,omitempty"` + Metric uint32 `protobuf:"varint,5,opt,name=metric,proto3" json:"metric,omitempty"` + FilteredRttNs int64 `protobuf:"varint,7,opt,name=filtered_rtt_ns,json=filteredRttNs,proto3" json:"filtered_rtt_ns,omitempty"` + StabilizedRttNs int64 `protobuf:"varint,8,opt,name=stabilized_rtt_ns,json=stabilizedRttNs,proto3" json:"stabilized_rtt_ns,omitempty"` + LocalBindId *string `protobuf:"bytes,9,opt,name=local_bind_id,json=localBindId,proto3,oneof" json:"local_bind_id,omitempty"` + LocalBindInterface *string `protobuf:"bytes,10,opt,name=local_bind_interface,json=localBindInterface,proto3,oneof" json:"local_bind_interface,omitempty"` + LocalBindSource *string `protobuf:"bytes,11,opt,name=local_bind_source,json=localBindSource,proto3,oneof" json:"local_bind_source,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *EndpointInfo) Reset() { @@ -669,6 +672,27 @@ func (x *EndpointInfo) GetStabilizedRttNs() int64 { return 0 } +func (x *EndpointInfo) GetLocalBindId() string { + if x != nil && x.LocalBindId != nil { + return *x.LocalBindId + } + return "" +} + +func (x *EndpointInfo) GetLocalBindInterface() string { + if x != nil && x.LocalBindInterface != nil { + return *x.LocalBindInterface + } + return "" +} + +func (x *EndpointInfo) GetLocalBindSource() string { + if x != nil && x.LocalBindSource != nil { + return *x.LocalBindSource + } + return "" +} + type WireGuardPeerStats struct { state protoimpl.MessageState `protogen:"open.v1"` LatestHandshakeUnix int64 `protobuf:"varint,1,opt,name=latest_handshake_unix,json=latestHandshakeUnix,proto3" json:"latest_handshake_unix,omitempty"` @@ -1799,7 +1823,7 @@ const file_protocol_nylon_ipc_proto_rawDesc = "" + "\x06metric\x18\x03 \x01(\rR\x06metric\x12\x1f\n" + "\vexpiry_unix\x18\x04 \x01(\x03R\n" + "expiryUnix\x12!\n" + - "\fpassive_hold\x18\x05 \x01(\bR\vpassiveHold\"\xfb\x01\n" + + "\fpassive_hold\x18\x05 \x01(\bR\vpassiveHold\"\xcd\x03\n" + "\fEndpointInfo\x12\x18\n" + "\aaddress\x18\x01 \x01(\tR\aaddress\x12\x1f\n" + "\bresolved\x18\x02 \x01(\tH\x00R\bresolved\x88\x01\x01\x12\x16\n" + @@ -1808,8 +1832,15 @@ const file_protocol_nylon_ipc_proto_rawDesc = "" + "remoteInit\x12\x16\n" + "\x06metric\x18\x05 \x01(\rR\x06metric\x12&\n" + "\x0ffiltered_rtt_ns\x18\a \x01(\x03R\rfilteredRttNs\x12*\n" + - "\x11stabilized_rtt_ns\x18\b \x01(\x03R\x0fstabilizedRttNsB\v\n" + - "\t_resolved\"\xf0\x01\n" + + "\x11stabilized_rtt_ns\x18\b \x01(\x03R\x0fstabilizedRttNs\x12'\n" + + "\rlocal_bind_id\x18\t \x01(\tH\x01R\vlocalBindId\x88\x01\x01\x125\n" + + "\x14local_bind_interface\x18\n" + + " \x01(\tH\x02R\x12localBindInterface\x88\x01\x01\x12/\n" + + "\x11local_bind_source\x18\v \x01(\tH\x03R\x0flocalBindSource\x88\x01\x01B\v\n" + + "\t_resolvedB\x10\n" + + "\x0e_local_bind_idB\x17\n" + + "\x15_local_bind_interfaceB\x14\n" + + "\x12_local_bind_source\"\xf0\x01\n" + "\x12WireGuardPeerStats\x122\n" + "\x15latest_handshake_unix\x18\x01 \x01(\x03R\x13latestHandshakeUnix\x12\x19\n" + "\btx_bytes\x18\x02 \x01(\x04R\atxBytes\x12\x19\n" + diff --git a/protocol/nylon_ipc.proto b/protocol/nylon_ipc.proto index 080276ca..e152295d 100644 --- a/protocol/nylon_ipc.proto +++ b/protocol/nylon_ipc.proto @@ -63,6 +63,9 @@ message EndpointInfo { uint32 metric = 5; int64 filtered_rtt_ns = 7; int64 stabilized_rtt_ns = 8; + optional string local_bind_id = 9; + optional string local_bind_interface = 10; + optional string local_bind_source = 11; } message WireGuardPeerStats { diff --git a/state/config.go b/state/config.go index 083493e0..ea2d6f12 100644 --- a/state/config.go +++ b/state/config.go @@ -64,6 +64,27 @@ type LocalCfg struct { PreDown []string `yaml:"pre_down,omitempty"` // a list of commands executed in order before the nylon interface is brought down PostUp []string `yaml:"post_up,omitempty"` // a list of commands executed in order after the nylon interface is brought up PostDown []string `yaml:"post_down,omitempty"` // a list of commands executed in order after the nylon interface is brought down + EndpointBinds []LocalEndpointBind `yaml:"endpoint_binds,omitempty"` // manual endpoint-level local source/interface selectors +} + +type LocalEndpointBind struct { + Peer NodeId `yaml:"peer"` + Endpoint string `yaml:"endpoint"` + ID LocalBindID `yaml:"id,omitempty"` + Interface string `yaml:"interface,omitempty"` + Source netip.Addr `yaml:"source,omitempty"` +} + +func (b LocalEndpointBind) LocalBind(idx int) LocalBind { + id := b.ID + if id == "" { + id = LocalBindID(fmt.Sprintf("%s/%s/%d", b.Peer, b.Endpoint, idx)) + } + return LocalBind{ + ID: id, + Interface: b.Interface, + Source: b.Source, + } } func (c *CentralCfg) Clone() (error, *CentralCfg) { diff --git a/state/config_test.go b/state/config_test.go index 92a63912..aa9d1397 100644 --- a/state/config_test.go +++ b/state/config_test.go @@ -1,9 +1,12 @@ package state import ( + "net/netip" + "runtime" "strings" "testing" + "github.com/goccy/go-yaml" "github.com/stretchr/testify/assert" ) @@ -149,3 +152,56 @@ func TestParseGraph_InvalidGraph(t *testing.T) { failGraph(t, `,,,,,,,,,,,,,,,,`) failGraph(t, `a=a`) } + +func TestLocalEndpointBindsParse(t *testing.T) { + data := []byte(` +key: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +id: alice +port: 57175 +endpoint_binds: + - peer: bob + endpoint: 198.51.100.10:57175 + interface: eth0 + source: 203.0.113.10 +`) + var local LocalCfg + err := yaml.Unmarshal(data, &local) + assert.NoError(t, err) + assert.Len(t, local.EndpointBinds, 1) + assert.Equal(t, NodeId("bob"), local.EndpointBinds[0].Peer) + assert.Equal(t, netip.MustParseAddr("203.0.113.10"), local.EndpointBinds[0].Source) +} + +func TestNodeConfigValidatorRejectsSelectorlessEndpointBind(t *testing.T) { + local := LocalCfg{ + Id: "alice", + Key: [32]byte{1}, + Port: 57175, + EndpointBinds: []LocalEndpointBind{{ + Peer: "bob", + Endpoint: "198.51.100.10:57175", + }}, + } + + err := NodeConfigValidator(nil, &local) + assert.ErrorContains(t, err, "must specify source or interface") +} + +func TestNodeConfigValidatorAllowsEndpointBind(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("endpoint local binds are linux-only") + } + local := LocalCfg{ + Id: "alice", + Key: [32]byte{1}, + Port: 57175, + EndpointBinds: []LocalEndpointBind{{ + Peer: "bob", + Endpoint: "198.51.100.10:57175", + Source: netip.MustParseAddr("203.0.113.10"), + }}, + } + + err := NodeConfigValidator(nil, &local) + assert.NoError(t, err) +} diff --git a/state/endpoint.go b/state/endpoint.go index 2bdd3a1d..8965e222 100644 --- a/state/endpoint.go +++ b/state/endpoint.go @@ -23,6 +23,13 @@ type Endpoint interface { AsNylonEndpoint() *NylonEndpoint } +func SameIPFamily(a, b netip.Addr) bool { + if !a.IsValid() || !b.IsValid() { + return true + } + return a.BitLen() == b.BitLen() +} + /* DynamicEndpoint represents either an ip:port or a dns name. This may be resolved to a different address at any time @@ -164,6 +171,8 @@ type NylonEndpoint struct { remoteInit bool WgEndpoint conn.Endpoint DynEP *DynamicEndpoint + LocalBind LocalBindID + Bind LocalBind } func (ep *NylonEndpoint) AsNylonEndpoint() *NylonEndpoint { @@ -175,6 +184,9 @@ func (ep *NylonEndpoint) GetWgEndpoint(device *device.Device) (conn.Endpoint, er if err != nil { return nil, err } + if !SameIPFamily(ep.Bind.Source, ap.Addr()) { + return nil, fmt.Errorf("bind source %s does not match endpoint %s", ep.Bind.Source, ap) + } if ep.WgEndpoint == nil || ep.WgEndpoint.DstIPPort() != ap { wgEp, err := device.Bind().ParseEndpoint(ap.String()) @@ -183,6 +195,19 @@ func (ep *NylonEndpoint) GetWgEndpoint(device *device.Device) (conn.Endpoint, er } ep.WgEndpoint = wgEp } + if setter, ok := ep.WgEndpoint.(interface { + SetSrc(netip.Addr, int32) + }); ok && (ep.Bind.Source.IsValid() || ep.Bind.Interface != "") { + ifidx := int32(0) + if ep.Bind.Interface != "" { + iface, err := net.InterfaceByName(ep.Bind.Interface) + if err != nil { + return nil, fmt.Errorf("failed to resolve bind interface %s: %w", ep.Bind.Interface, err) + } + ifidx = int32(iface.Index) + } + setter.SetSrc(ep.Bind.Source, ifidx) + } return ep.WgEndpoint, nil } @@ -231,6 +256,8 @@ func NewEndpoint(endpoint *DynamicEndpoint, remoteInit bool, wgEndpoint conn.End remoteInit: remoteInit, WgEndpoint: wgEndpoint, DynEP: endpoint, + LocalBind: DefaultLocalBindID, + Bind: LocalBind{ID: DefaultLocalBindID}, history: make([]time.Duration, 0), expRTT: math.Inf(1), } diff --git a/state/endpoint_test.go b/state/endpoint_test.go index b7ad2a7d..a6cfc0a4 100644 --- a/state/endpoint_test.go +++ b/state/endpoint_test.go @@ -3,12 +3,20 @@ package state import ( "math" "math/rand/v2" + "net/netip" "testing" "time" "github.com/stretchr/testify/assert" ) +func TestSameIPFamily(t *testing.T) { + assert.True(t, SameIPFamily(netip.MustParseAddr("192.0.2.1"), netip.MustParseAddr("198.51.100.1"))) + assert.True(t, SameIPFamily(netip.MustParseAddr("2001:db8::1"), netip.MustParseAddr("2001:db8::2"))) + assert.False(t, SameIPFamily(netip.MustParseAddr("192.0.2.1"), netip.MustParseAddr("2001:db8::1"))) + assert.True(t, SameIPFamily(netip.Addr{}, netip.MustParseAddr("2001:db8::1"))) +} + type DataSource struct { Name string Data []time.Duration diff --git a/state/routing.go b/state/routing.go index 2de499f6..018c9f3f 100644 --- a/state/routing.go +++ b/state/routing.go @@ -10,6 +10,15 @@ import ( ) type NodeId string +type LocalBindID string + +const DefaultLocalBindID LocalBindID = "default" + +type LocalBind struct { + ID LocalBindID + Interface string + Source netip.Addr +} // Source is a pair of a router-id and a prefix (Babel Section 2.7). type Source struct { diff --git a/state/validation.go b/state/validation.go index 5c0e7009..69b69c30 100644 --- a/state/validation.go +++ b/state/validation.go @@ -5,6 +5,7 @@ import ( "net/netip" "net/url" "regexp" + "runtime" "slices" ) @@ -43,6 +44,49 @@ func NodeConfigValidator(central *CentralCfg, node *LocalCfg) error { return err } } + seenEndpointBinds := make(map[string]struct{}) + for idx, bind := range node.EndpointBinds { + if bind.Peer == "" { + return fmt.Errorf("endpoint bind peer must not be empty") + } + if err := NameValidator(string(bind.Peer)); err != nil { + return fmt.Errorf("endpoint bind peer is invalid: %w", err) + } + if bind.Endpoint == "" { + return fmt.Errorf("endpoint bind for peer %s has empty endpoint", bind.Peer) + } + if bind.Interface == "" && !bind.Source.IsValid() { + return fmt.Errorf("endpoint bind for peer %s endpoint %s must specify source or interface", bind.Peer, bind.Endpoint) + } + if bind.ID != "" { + if err := NameValidator(string(bind.ID)); err != nil { + return fmt.Errorf("endpoint bind id is invalid: %w", err) + } + } + key := fmt.Sprintf("%s\x00%s\x00%s\x00%s", bind.Peer, bind.Endpoint, bind.Interface, bind.Source) + if _, ok := seenEndpointBinds[key]; ok { + return fmt.Errorf("duplicate endpoint bind for peer %s endpoint %s", bind.Peer, bind.Endpoint) + } + seenEndpointBinds[key] = struct{}{} + if runtime.GOOS != "linux" { + return fmt.Errorf("endpoint local binds are only supported on linux") + } + if central == nil { + continue + } + if !central.IsRouter(bind.Peer) { + return fmt.Errorf("endpoint bind peer %s is not a router", bind.Peer) + } + if !slices.Contains(central.GetPeers(node.Id), bind.Peer) { + return fmt.Errorf("endpoint bind peer %s is not a peer of %s", bind.Peer, node.Id) + } + peerCfg := central.GetRouter(bind.Peer) + if !slices.ContainsFunc(peerCfg.Endpoints, func(ep *DynamicEndpoint) bool { + return ep != nil && ep.Value == bind.Endpoint + }) { + return fmt.Errorf("endpoint bind %d references unknown endpoint %s for peer %s", idx, bind.Endpoint, bind.Peer) + } + } if len(node.DnsResolvers) != 0 { for _, resolver := range node.DnsResolvers { if _, err := netip.ParseAddrPort(resolver); err != nil {