From 542adffb8c0d6871d0924e58c5fef32180acc6ee Mon Sep 17 00:00:00 2001 From: AK Date: Wed, 22 Apr 2026 22:24:43 -0500 Subject: [PATCH 1/2] Added support for Hijack UDP --- UDP_HIJACK.md | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 UDP_HIJACK.md diff --git a/UDP_HIJACK.md b/UDP_HIJACK.md new file mode 100644 index 00000000..6ec23453 --- /dev/null +++ b/UDP_HIJACK.md @@ -0,0 +1,109 @@ +# UDP Hijacking for Homa + +## Overview + +UDP hijacking is an optional mechanism that encapsulates Homa packets as UDP +datagrams, using `IPPROTO_UDP` instead of `IPPROTO_HOMA` as the IP protocol. +It works alongside the existing TCP hijacking feature — only one can be active +at a time on a given socket. + +### Why UDP hijacking? + +TCP hijacking uses `SYN+RST` flag combinations that never occur in real TCP +traffic. However, some firewalls (particularly on virtualized environments) +inspect TCP flags and drop packets with these "impossible" flag combinations. +UDP hijacking avoids this issue entirely since UDP has no flags for firewalls +to inspect. + +### Trade-offs vs TCP hijacking + +| Feature | TCP hijacking | UDP hijacking | +|---------------------|------------------------|------------------------| +| NIC TSO support | Yes (multi-segment) | No (single-segment) | +| Firewall friendly | No (SYN+RST blocked) | Yes | +| GSO segments/packet | Multiple | 1 (`segs_per_gso = 1`) | +| IP protocol | `IPPROTO_TCP` | `IPPROTO_UDP` | +| sysctl | `hijack_tcp` | `hijack_udp` | + +Because NICs do not perform TSO on UDP packets the same way they do for TCP, +UDP hijacking forces `segs_per_gso = 1` (one segment per GSO packet). This +means each Homa data packet is sent individually rather than being batched +into large TSO super-packets. + +## Configuration + +Enable UDP hijacking at runtime via sysctl: + +```bash +# Enable UDP hijacking (disable TCP hijacking first if it was on) +sudo sysctl net.homa.hijack_tcp=0 +sudo sysctl net.homa.hijack_udp=1 +``` + +To switch back to TCP hijacking: + +```bash +sudo sysctl net.homa.hijack_udp=0 +sudo sysctl net.homa.hijack_tcp=1 +``` + +**Note:** If both `hijack_tcp` and `hijack_udp` are set, TCP hijacking takes +priority (sockets opened while both are set will use TCP). + +## How It Works + +### Sending (outgoing packets) + +1. **Socket initialization** (`homa_hijack_sock_init`): When a new Homa socket + is created, if `hijack_udp` is set the socket's `sk_protocol` is set to + `IPPROTO_UDP`. The kernel then transmits packets with a UDP IP protocol. + +2. **Header setup** (`homa_udp_hijack_set_hdr`): Before transmission, Homa + writes UDP-compatible header fields: + - `flags` is set to `HOMA_HIJACK_FLAGS` (6) — a marker value. + - `urgent` is set to `HOMA_HIJACK_URGENT` (0xb97d) — a second marker. + - Bytes 4-5 of the transport header are overwritten with the UDP length. + - Bytes 6-7 are set up for proper UDP checksum offload. + - Because the sequence field (bytes 4-7) is overwritten, the packet offset + is stored in `seg.offset` instead. + +3. **GSO geometry**: With UDP hijacking, `segs_per_gso` is forced to 1 (no + multi-segment GSO batching). + +### Receiving (incoming packets) + +1. **GRO interception** (`homa_udp_hijack_gro_receive`): Homa hooks into the + UDP GRO pipeline. When a UDP packet arrives, Homa checks: + - At least 20 bytes of transport header are available. + - `flags == HOMA_HIJACK_FLAGS` and `urgent == HOMA_HIJACK_URGENT`. + +2. If the packet is identified as a Homa-over-UDP packet, the IP protocol + is rewritten to `IPPROTO_HOMA` and the packet is handed to Homa's normal + GRO handler. Real UDP packets are passed through to the normal UDP stack. + +### Qdisc support + +The `is_homa_pkt()` function in `homa_qdisc.c` recognizes both TCP-hijacked +and UDP-hijacked packets, ensuring they receive proper Homa qdisc treatment. + +## Files Modified + +| File | Changes | +|-------------------|------------------------------------------------------------| +| `homa_wire.h` | No new defines needed (reuses `HOMA_HIJACK_FLAGS` and `HOMA_HIJACK_URGENT`) | +| `homa_impl.h` | Added `hijack_udp` field to `struct homa` | +| `homa_hijack.h` | Added `homa_udp_hijack_set_hdr()`, `homa_sock_udp_hijacked()`, `homa_skb_udp_hijacked()`; updated `homa_hijack_sock_init()` | +| `homa_hijack.c` | Added `homa_udp_hijack_init()`, `homa_udp_hijack_end()`, `homa_udp_hijack_gro_receive()` | +| `homa_outgoing.c` | Added `segs_per_gso=1` for UDP; added UDP header calls in xmit paths | +| `homa_plumbing.c` | Added `hijack_udp` sysctl; added UDP init/end calls | +| `homa_qdisc.c` | Added `IPPROTO_UDP` check in `is_homa_pkt()` | +| `util/homa_test.cc` | Added `udp_ping()`, `test_udp()`, "udp" test command | +| `util/server.cc` | Added `udp_server()` function | +| `util/cp_node.cc` | Added `udp_server` and `udp_client` classes, "udp" protocol option | + +## Key Constants + +| Constant | Value | Purpose | +|----------------------|----------|------------------------------------------------------| +| `HOMA_HIJACK_FLAGS` | 6 | Marker in the `flags` field (shared with TCP hijack) | +| `HOMA_HIJACK_URGENT` | 0xb97d | Marker in the `urgent` field (shared with TCP hijack)| From 40f82102d6902d7aaf4652ba8beaee634f773664 Mon Sep 17 00:00:00 2001 From: AK Date: Wed, 22 Apr 2026 22:25:36 -0500 Subject: [PATCH 2/2] modified: homa_hijack.c modified: homa_hijack.h modified: homa_impl.h modified: homa_outgoing.c modified: homa_plumbing.c modified: homa_qdisc.c modified: homa_wire.h modified: util/cp_node.cc modified: util/homa_test.cc modified: util/server.cc --- homa_hijack.c | 93 +++++++++++- homa_hijack.h | 84 ++++++++++- homa_impl.h | 8 ++ homa_outgoing.c | 30 +++- homa_plumbing.c | 9 ++ homa_qdisc.c | 3 +- homa_wire.h | 5 +- util/cp_node.cc | 358 +++++++++++++++++++++++++++++++++++++++++++++- util/homa_test.cc | 93 ++++++++++++ util/server.cc | 69 +++++++++ 10 files changed, 737 insertions(+), 15 deletions(-) diff --git a/homa_hijack.c b/homa_hijack.c index d1d11a7b..30c6f6fa 100644 --- a/homa_hijack.c +++ b/homa_hijack.c @@ -1,7 +1,9 @@ // SPDX-License-Identifier: BSD-2-Clause or GPL-2.0+ -/* This file implements TCP hijacking for Homa. See comments at the top of - * homa_hijack.h for an overview of TCP hijacking. +/* This file implements TCP and UDP hijacking for Homa. See comments at the + * top of homa_hijack.h for an overview of TCP hijacking. UDP hijacking works + * similarly but uses UDP as the IP protocol, which avoids issues with + * firewalls that inspect TCP flags. */ #include "homa_hijack.h" @@ -20,6 +22,19 @@ static const struct net_offload *tcp6_net_offload; static struct net_offload hook_tcp_net_offload; static struct net_offload hook_tcp6_net_offload; +/* Pointers to UDP's net_offload structures. NULL means homa_udp_hijack_init + * hasn't been called yet. + */ +static const struct net_offload *udp_net_offload; +static const struct net_offload *udp6_net_offload; + +/* + * Identical to *udp_net_offload except that the gro_receive function + * has been replaced with homa_udp_hijack_gro_receive. + */ +static struct net_offload hook_udp_net_offload; +static struct net_offload hook_udp6_net_offload; + /** * homa_hijack_init() - Initializes the mechanism for TCP hijacking (allows * incoming Homa packets encapsulated as TCP frames to be "stolen" back from @@ -93,4 +108,78 @@ struct sk_buff *homa_hijack_gro_receive(struct list_head *held_list, ip_hdr(skb)->protocol = IPPROTO_HOMA; } return homa_gro_receive(held_list, skb); +} + +/** + * homa_udp_hijack_init() - Initializes the mechanism for UDP hijacking + * (allows incoming Homa packets encapsulated as UDP datagrams to be + * "stolen" back from the UDP pipeline and funneled through Homa). + */ +void homa_udp_hijack_init(void) +{ + if (udp_net_offload) + return; + + pr_notice("Homa setting up UDP hijacking\n"); + rcu_read_lock(); + udp_net_offload = rcu_dereference(inet_offloads[IPPROTO_UDP]); + hook_udp_net_offload = *udp_net_offload; + hook_udp_net_offload.callbacks.gro_receive = homa_udp_hijack_gro_receive; + inet_offloads[IPPROTO_UDP] = (struct net_offload __rcu *) + &hook_udp_net_offload; + + udp6_net_offload = rcu_dereference(inet6_offloads[IPPROTO_UDP]); + hook_udp6_net_offload = *udp6_net_offload; + hook_udp6_net_offload.callbacks.gro_receive = homa_udp_hijack_gro_receive; + inet6_offloads[IPPROTO_UDP] = (struct net_offload __rcu *) + &hook_udp6_net_offload; + rcu_read_unlock(); +} + +/** + * homa_udp_hijack_end() - Reverses the effects of a previous call to + * homa_udp_hijack_init, so that incoming UDP packets are no longer checked + * to see if they are actually Homa frames. + */ +void homa_udp_hijack_end(void) +{ + if (!udp_net_offload) + return; + pr_notice("Homa cancelling UDP hijacking\n"); + inet_offloads[IPPROTO_UDP] = (struct net_offload __rcu *) + udp_net_offload; + udp_net_offload = NULL; + inet6_offloads[IPPROTO_UDP] = (struct net_offload __rcu *) + udp6_net_offload; + udp6_net_offload = NULL; +} + +/** + * homa_udp_hijack_gro_receive() - Invoked instead of UDP's gro_receive + * function when UDP hijacking is enabled. Identifies Homa-over-UDP packets + * and passes them to Homa; sends real UDP packets to UDP's gro_receive. + * @held_list: Pointer to header for list of packets that are being + * held for possible GRO merging. + * @skb: The newly arrived packet. + */ +struct sk_buff *homa_udp_hijack_gro_receive(struct list_head *held_list, + struct sk_buff *skb) +{ + /* Need at least 20 bytes of transport data to safely check the + * flags (offset 13) and urgent (offset 18-19) fields. + */ + if (skb_headlen(skb) >= skb_transport_offset(skb) + 20 && + homa_skb_udp_hijacked(skb)) { + if (skb_is_ipv6(skb)) { + ipv6_hdr(skb)->nexthdr = IPPROTO_HOMA; + } else { + ip_hdr(skb)->check = ~csum16_add( + csum16_sub(~ip_hdr(skb)->check, + htons(ip_hdr(skb)->protocol)), + htons(IPPROTO_HOMA)); + ip_hdr(skb)->protocol = IPPROTO_HOMA; + } + return homa_gro_receive(held_list, skb); + } + return udp_net_offload->callbacks.gro_receive(held_list, skb); } \ No newline at end of file diff --git a/homa_hijack.h b/homa_hijack.h index b65b4036..3930c465 100644 --- a/homa_hijack.h +++ b/homa_hijack.h @@ -25,6 +25,7 @@ #include "homa_peer.h" #include "homa_sock.h" #include "homa_wire.h" +#include /* Special value stored in the flags field of TCP headers to indicate that * the packet is actually a Homa packet. It includes the SYN and RST flags @@ -33,6 +34,11 @@ */ #define HOMA_HIJACK_FLAGS 6 +/* Special value stored in the flags field for UDP hijacking, distinct from + * the TCP hijack value. + */ +#define HOMA_UDP_FLAGS 5 + /* Special value stored in the urgent pointer of a TCP header to indicate * that the packet is actually a Homa packet (note that urgent pointer is * set even though the URG flag is not set). @@ -74,14 +80,16 @@ static inline void homa_hijack_set_hdr(struct sk_buff *skb, /** * homa_hijack_sock_init() - Perform socket initialization related to - * TCP hijacking (arrange for outgoing packets on the socket to use TCP, - * if the hijack_tcp option is set.) + * TCP/UDP hijacking (arrange for outgoing packets on the socket to use + * TCP or UDP, if the corresponding hijack option is set.) * @hsk: New socket to initialize. */ static inline void homa_hijack_sock_init(struct homa_sock *hsk) { if (hsk->homa->hijack_tcp) hsk->sock.sk_protocol = IPPROTO_TCP; + else if (hsk->homa->hijack_udp) + hsk->sock.sk_protocol = IPPROTO_UDP; } /* homa_sock_hijacked() - Returns true if outgoing packets on a socket @@ -93,6 +101,14 @@ static inline bool homa_sock_hijacked(struct homa_sock *hsk) return hsk->sock.sk_protocol == IPPROTO_TCP; } +/* homa_sock_udp_hijacked() - Returns true if outgoing packets on a socket + * should use UDP hijacking. + */ +static inline bool homa_sock_udp_hijacked(struct homa_sock *hsk) +{ + return hsk->sock.sk_protocol == IPPROTO_UDP; +} + /** * homa_skb_hijacked() - Return true if the TCP header fields in a packet * indicate that the packet is actually a Homa packet, false otherwise. @@ -108,10 +124,74 @@ static inline bool homa_skb_hijacked(struct sk_buff *skb) h->urgent == ntohs(HOMA_HIJACK_URGENT); } +/** + * homa_udp_hijack_set_hdr() - Set all header fields needed for UDP hijacking + * in an outgoing Homa packet. Overwrites the sequence field (bytes 4-7) with + * UDP length and checksum, so the packet offset must be stored in seg.offset. + * @skb: Packet buffer in which to set fields. + * @peer: Peer that contains source and destination addresses for the packet. + * @ipv6: True means the packet is going to be sent via IPv6. + */ +static inline void homa_udp_hijack_set_hdr(struct sk_buff *skb, + struct homa_peer *peer, + bool ipv6) +{ + struct homa_common_hdr *h; + int transport_len; + + h = (struct homa_common_hdr *)skb_transport_header(skb); + h->flags = HOMA_UDP_FLAGS; + h->urgent = htons(HOMA_UDP_URGENT); + + transport_len = skb->len - skb_transport_offset(skb); + + /* Set UDP length at bytes 4-5 (overlaps high 16 bits of sequence). */ + *((__be16 *)((u8 *)h + 4)) = htons(transport_len); + + /* Arrange for proper UDP checksumming at bytes 6-7. */ + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = 6; + if (ipv6) + *((__be16 *)((u8 *)h + 6)) = ~csum_ipv6_magic( + &peer->flow.u.ip6.saddr, + &peer->flow.u.ip6.daddr, + transport_len, IPPROTO_UDP, 0); + else + *((__be16 *)((u8 *)h + 6)) = ~csum_tcpudp_magic( + peer->flow.u.ip4.saddr, + peer->flow.u.ip4.daddr, + transport_len, IPPROTO_UDP, 0); +} + +/** + * homa_skb_udp_hijacked() - Return true if the header fields in a UDP + * packet indicate that the packet is actually a Homa packet, false otherwise. + * @skb: Packet to check: must have an IP protocol of IPPROTO_UDP. + */ +static inline bool homa_skb_udp_hijacked(struct sk_buff *skb) +{ + struct homa_common_hdr *h; + + /* Need at least 20 bytes of transport data to safely check the + * flags (offset 13) and urgent (offset 18-19) fields. + */ + if (skb_headlen(skb) < skb_transport_offset(skb) + 20) + return false; + h = (struct homa_common_hdr *)skb_transport_header(skb); + return h->flags == HOMA_UDP_FLAGS && + h->urgent == ntohs(HOMA_UDP_URGENT); +} + void homa_hijack_end(void); struct sk_buff * homa_hijack_gro_receive(struct list_head *held_list, struct sk_buff *skb); void homa_hijack_init(void); +void homa_udp_hijack_end(void); +struct sk_buff * + homa_udp_hijack_gro_receive(struct list_head *held_list, + struct sk_buff *skb); +void homa_udp_hijack_init(void); #endif /* _HOMA_HIJACK_H */ diff --git a/homa_impl.h b/homa_impl.h index 0261abc5..9d8f9bdd 100644 --- a/homa_impl.h +++ b/homa_impl.h @@ -363,6 +363,14 @@ struct homa { */ int hijack_tcp; + /** + * @hijack_udp: Non-zero means encapsulate outgoing Homa packets + * as UDP packets (i.e. use UDP as the IP protocol). This provides + * network traversability similar to TCP hijacking but avoids issues + * with firewalls inspecting TCP flags. Set externally via sysctl. + */ + int hijack_udp; + /** * @max_gro_skbs: Maximum number of socket buffers that can be * aggregated by the GRO mechanism. Set externally via sysctl. diff --git a/homa_outgoing.c b/homa_outgoing.c index 3e4808a2..4abd2499 100644 --- a/homa_outgoing.c +++ b/homa_outgoing.c @@ -186,7 +186,8 @@ struct sk_buff *homa_tx_data_pkt_alloc(struct homa_rpc *rpc, homa_info->rpc = rpc; #ifndef __STRIP__ /* See strip.py */ - if (segs > 1 && !homa_sock_hijacked(hsk)) { + if (segs > 1 && !homa_sock_hijacked(hsk) + && !homa_sock_udp_hijacked(hsk)) { #else /* See strip.py */ if (segs > 1) { #endif /* See strip.py */ @@ -286,9 +287,12 @@ int homa_message_out_fill(struct homa_rpc *rpc, struct iov_iter *iter, int xmit) #ifndef __STRIP__ /* See strip.py */ /* Round gso_size down to an even # of mtus; calculation depends * on whether we're doing TCP hijacking (need more space in TSO packet - * if no hijacking). + * if no hijacking). UDP hijacking uses single-segment packets + * (no multi-segment GSO). */ - if (homa_sock_hijacked(rpc->hsk)) { + if (homa_sock_udp_hijacked(rpc->hsk)) { + segs_per_gso = 1; + } else if (homa_sock_hijacked(rpc->hsk)) { segs_per_gso = gso_size - rpc->hsk->ip_header_length - sizeof(struct homa_data_hdr); do_div(segs_per_gso, max_seg_data); @@ -474,12 +478,18 @@ int __homa_xmit_control(void *contents, size_t length, struct homa_peer *peer, homa_set_doff(skb, 20); #ifndef __STRIP__ /* See strip.py */ if (hsk->inet.sk.sk_family == AF_INET6) { - homa_hijack_set_hdr(skb, peer, true); + if (homa_sock_udp_hijacked(hsk)) + homa_udp_hijack_set_hdr(skb, peer, true); + else + homa_hijack_set_hdr(skb, peer, true); result = ip6_xmit(&hsk->inet.sk, skb, &peer->flow.u.ip6, 0, NULL, hsk->homa->priority_map[priority] << 5, 0); } else { - homa_hijack_set_hdr(skb, peer, false); + if (homa_sock_udp_hijacked(hsk)) + homa_udp_hijack_set_hdr(skb, peer, false); + else + homa_hijack_set_hdr(skb, peer, false); /* This will find its way to the DSCP field in the IPv4 hdr. */ hsk->inet.tos = hsk->homa->priority_map[priority] << 5; @@ -686,7 +696,10 @@ void __homa_xmit_data(struct sk_buff *skb, struct homa_rpc *rpc) tt_addr(rpc->peer->addr), rpc->id, homa_get_skb_info(skb)->offset); #ifndef __STRIP__ /* See strip.py */ - homa_hijack_set_hdr(skb, rpc->peer, true); + if (homa_sock_udp_hijacked(rpc->hsk)) + homa_udp_hijack_set_hdr(skb, rpc->peer, true); + else + homa_hijack_set_hdr(skb, rpc->peer, true); err = ip6_xmit(&rpc->hsk->inet.sk, skb, &rpc->peer->flow.u.ip6, 0, NULL, rpc->hsk->homa->priority_map[priority] << 5, 0); @@ -701,7 +714,10 @@ void __homa_xmit_data(struct sk_buff *skb, struct homa_rpc *rpc) homa_get_skb_info(skb)->offset); #ifndef __STRIP__ /* See strip.py */ - homa_hijack_set_hdr(skb, rpc->peer, false); + if (homa_sock_udp_hijacked(rpc->hsk)) + homa_udp_hijack_set_hdr(skb, rpc->peer, false); + else + homa_hijack_set_hdr(skb, rpc->peer, false); rpc->hsk->inet.tos = rpc->hsk->homa->priority_map[priority] << 5; err = ip_queue_xmit(&rpc->hsk->inet.sk, skb, &rpc->peer->flow); diff --git a/homa_plumbing.c b/homa_plumbing.c index cdfcd865..ab3859b2 100644 --- a/homa_plumbing.c +++ b/homa_plumbing.c @@ -259,6 +259,13 @@ static struct ctl_table homa_ctl_table[] = { .mode = 0644, .proc_handler = homa_dointvec }, + { + .procname = "hijack_udp", + .data = OFFSET(hijack_udp), + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = homa_dointvec + }, { .procname = "link_mbps", .data = OFFSET(link_mbps), @@ -643,6 +650,7 @@ int __init homa_load(void) #ifndef __STRIP__ /* See strip.py */ homa_hijack_init(); + homa_udp_hijack_init(); #endif /* See strip.py */ #ifndef __UPSTREAM__ /* See strip.py */ tt_set_temp(homa->temp); @@ -695,6 +703,7 @@ void __exit homa_unload(void) pr_notice("Homa module unloading\n"); #ifndef __STRIP__ /* See strip.py */ + homa_udp_hijack_end(); homa_hijack_end(); #endif /* See strip.py */ if (timer_kthread) { diff --git a/homa_qdisc.c b/homa_qdisc.c index d9486053..4759b21b 100755 --- a/homa_qdisc.c +++ b/homa_qdisc.c @@ -155,7 +155,8 @@ static inline bool is_homa_pkt(struct sk_buff *skb) protocol = (skb_is_ipv6(skb)) ? ipv6_hdr(skb)->nexthdr : ip_hdr(skb)->protocol; return protocol == IPPROTO_HOMA || - (protocol == IPPROTO_TCP && homa_skb_hijacked(skb)); + (protocol == IPPROTO_TCP && homa_skb_hijacked(skb)) || + (protocol == IPPROTO_UDP && homa_skb_udp_hijacked(skb)); } /** diff --git a/homa_wire.h b/homa_wire.h index 8afe3f4b..6226a291 100644 --- a/homa_wire.h +++ b/homa_wire.h @@ -139,7 +139,7 @@ struct homa_common_hdr { #ifndef __STRIP__ /* See strip.py */ /** * @flags: Holds TCP flags such as URG, ACK, etc. Not used by Homa - * except for TCP hijacking. + * except for TCP/UDP hijacking. */ u8 flags; #else /* See strip.py */ @@ -162,10 +162,11 @@ struct homa_common_hdr { #ifndef __STRIP__ /* See strip.py */ /** * @urgent: occupies the same bytes as the urgent pointer in a TCP - * header. Not used by Homa except during TCP hijacking. + * header. Not used by Homa except during TCP/UDP hijacking. */ __be16 urgent; #define HOMA_HIJACK_URGENT 0xb97d +#define HOMA_UDP_URGENT 0xb97e #else /* See strip.py */ /** @reserved2: Not used (corresponds to TCP urgent field). */ __be16 reserved2; diff --git a/util/cp_node.cc b/util/cp_node.cc index 4665c3f2..9095aa0d 100644 --- a/util/cp_node.cc +++ b/util/cp_node.cc @@ -338,7 +338,7 @@ void print_help(const char *name) printf(" --ipv6 Use IPv6 instead of IPv4\n"); printf(" --pin All server threads will be restricted to run only\n" " on the givevn core\n"); - printf(" --protocol Transport protocol to use: homa or tcp (default: %s)\n", + printf(" --protocol Transport protocol to use: homa, tcp, or udp (default: %s)\n", protocol); printf(" --port-threads Number of server threads to service each port\n" " (default: %d)\n", @@ -2708,6 +2708,347 @@ void tcp_client::read(tcp_connection *connection, int pid) } } +/* ===================== UDP client and server ===================== */ + +/** + * class udp_server - Holds information about a single UDP server, + * which consists of a thread that handles requests on a given port. + */ +class udp_server { +public: + udp_server(int port, int id, int num_threads, + std::string& experiment); + ~udp_server(); + void server(int thread_id); + + /** @port: Port on which we listen. */ + int port; + + /** @id: Unique identifier for this server. */ + int id; + + /** @experiment: name of the experiment this server is running. */ + string experiment; + + /** @fd: File descriptor for the UDP socket. */ + int fd; + + /** @metrics: Performance statistics. Not owned by this class. */ + server_metrics *metrics; + + /** @threads: Background threads servicing this socket. */ + std::vector threads; + + /** @stop: True means background threads should exit. */ + bool stop; +}; + +/** @udp_servers: keeps track of all existing UDP servers. */ +std::vector udp_servers; + +/** + * udp_server::udp_server() - Constructor for udp_server objects. + * @port: Port number on which this server should listen. + * @id: Unique identifier for this server. + * @num_threads: Number of threads to service this socket. + * @experiment: Name of the experiment. + */ +udp_server::udp_server(int port, int id, int num_threads, + std::string& experiment) + : port(port) + , id(id) + , fd(-1) + , metrics() + , threads() + , stop(false) +{ + if (std::find(experiments.begin(), experiments.end(), experiment) + == experiments.end()) + experiments.emplace_back(experiment); + + fd = socket(inet_family, SOCK_DGRAM, 0); + if (fd == -1) { + log(NORMAL, "FATAL: couldn't open UDP server socket: %s\n", + strerror(errno)); + fatal(); + } + sockaddr_in_union addr; + if (inet_family == AF_INET) { + addr.in4.sin_family = AF_INET; + addr.in4.sin_port = htons(port); + addr.in4.sin_addr.s_addr = INADDR_ANY; + } else { + addr.in6.sin6_family = AF_INET6; + addr.in6.sin6_port = htons(port); + addr.in6.sin6_addr = in6addr_any; + } + if (bind(fd, &addr.sa, sizeof(addr)) == -1) { + log(NORMAL, "FATAL: couldn't bind UDP socket to port %d: %s\n", + port, strerror(errno)); + fatal(); + } + + metrics = new server_metrics(experiment); + ::metrics.push_back(metrics); + + for (int i = 0; i < num_threads; i++) + threads.emplace_back(&udp_server::server, this, i); + kfreeze_count = 0; +} + +/** + * udp_server::~udp_server() - Destructor for UDP servers. + */ +udp_server::~udp_server() +{ + stop = true; + shutdown(fd, SHUT_RDWR); + for (size_t i = 0; i < threads.size(); i++) + threads[i].join(); + close(fd); +} + +/** + * udp_server::server() - Handles incoming UDP requests. Invoked as top-level + * method in a thread. + * @thread_id: Unique id for this thread. + */ +void udp_server::server(int thread_id) +{ + char thread_name[50]; + char buffer[1000000]; + + snprintf(thread_name, sizeof(thread_name), "US%d.%d", id, thread_id); + time_trace::thread_buffer thread_buffer(thread_name); + int pid = syscall(__NR_gettid); + if (server_core >= 0) + pin_thread(server_core); + + while (!stop) { + sockaddr_in_union source; + socklen_t source_len = sizeof(source); + ssize_t length = recvfrom(fd, buffer, sizeof(buffer), 0, + &source.sa, &source_len); + if (length < 0) { + if (stop) + return; + if ((errno == EAGAIN) || (errno == EINTR)) + continue; + log(NORMAL, "FATAL: UDP recvfrom failed: %s\n", + strerror(errno)); + fatal(); + } + if (length < (ssize_t)sizeof(message_header)) + continue; + + message_header *header = + reinterpret_cast(buffer); + metrics->requests++; + metrics->bytes_in += header->length; + tt("Received UDP request, cid 0x%08x, id %u, length %d, " + "pid %d", header->cid, header->msg_id, + header->length, pid); + + if ((header->freeze) && !time_trace::frozen) { + tt("Freezing timetrace"); + time_trace::freeze(); + kfreeze(); + } + + /* Prepare and send response. */ + int resp_length = header->short_response ? 100 : header->length; + if (resp_length < (int)sizeof(message_header)) + resp_length = sizeof(message_header); + header->response = 1; + header->length = resp_length; + metrics->bytes_out += resp_length; + + ssize_t sent = sendto(fd, buffer, resp_length, 0, + &source.sa, source_len); + if (sent < 0) + log(NORMAL, "ERROR: UDP sendto failed: %s\n", + strerror(errno)); + tt("Sent UDP response, cid 0x%08x, id %u, length %d", + header->cid, header->msg_id, resp_length); + } +} + +/** + * class udp_client - Holds information about a single UDP client, + * which consists of one thread issuing requests and one thread receiving + * responses. + */ +class udp_client : public client { +public: + udp_client(int id, std::string& experiment); + virtual ~udp_client(); + void receiver(int id); + void sender(void); + + /** @fd: UDP socket file descriptor. */ + int fd; + + /** @stop: True means background threads should exit. */ + bool stop; + + /** @receiver_threads: threads that receive responses. */ + std::vector receiving_threads; + + /** + * @sending_thread: thread that sends requests. + */ + std::optional sending_thread; +}; + +/** + * udp_client::udp_client() - Constructor for udp_client objects. + * @id: Unique identifier for this client. + * @experiment: Name of experiment. + */ +udp_client::udp_client(int id, std::string& experiment) + : client(id, experiment) + , fd(-1) + , stop(false) + , receiving_threads() + , sending_thread() +{ + fd = socket(inet_family, SOCK_DGRAM, 0); + if (fd < 0) { + log(NORMAL, "FATAL: couldn't open UDP client socket: %s\n", + strerror(errno)); + fatal(); + } + + for (int i = 0; i < port_receivers; i++) + receiving_threads.emplace_back(&udp_client::receiver, this, i); + while (receivers_running < receiving_threads.size()) { + /* Wait for receivers to begin execution before starting + * the sender. + */ + } + sending_thread.emplace(&udp_client::sender, this); +} + +/** + * udp_client::~udp_client() - Destructor for udp_client objects. + */ +udp_client::~udp_client() +{ + stop = true; + shutdown(fd, SHUT_RDWR); + if (sending_thread) + sending_thread->join(); + for (std::thread& thread: receiving_threads) + thread.join(); + close(fd); + check_completion("udp"); +} + +/** + * udp_client::sender() - Invoked as the top-level method in a thread; + * invokes a pseudo-random stream of RPCs continuously. + */ +void udp_client::sender() +{ + char thread_name[50]; + char buffer[HOMA_MAX_MESSAGE_LENGTH]; + int pid = syscall(__NR_gettid); + + snprintf(thread_name, sizeof(thread_name), "C%d", id); + time_trace::thread_buffer thread_buffer(thread_name); + + uint64_t next_start = rdtsc(); + message_header *header = reinterpret_cast(buffer); + + while (1) { + uint64_t now; + int server; + int slot = get_rinfo(); + + while (1) { + if (stop) { + rinfos[slot].active = false; + return; + } + now = rdtsc(); + if ((now >= next_start) && + ((total_requests - total_responses) + < client_port_max)) + break; + } + + rinfos[slot].start_time = now; + server = server_dist(rand_gen); + header->length = length_dist(rand_gen); + if (header->length > HOMA_MAX_MESSAGE_LENGTH) + header->length = HOMA_MAX_MESSAGE_LENGTH; + if (header->length < (int)sizeof(message_header)) + header->length = sizeof(message_header); + rinfos[slot].request_length = header->length; + header->cid = server_conns[server]; + header->cid.client_port = id; + header->msg_id = slot; + header->freeze = freeze[header->cid.server]; + header->short_response = one_way; + header->response = 0; + tt("Sending UDP request, cid 0x%08x, id %u, length %d, " + "pid %d", header->cid, header->msg_id, + header->length, pid); + + ssize_t sent = sendto(fd, buffer, header->length, 0, + &server_addrs[server].sa, + sockaddr_size(&server_addrs[server].sa)); + if (sent < 0) { + log(NORMAL, "FATAL: error in UDP sendto: %s (request " + "length %d)\n", strerror(errno), + header->length); + fatal(); + } + requests[server]++; + total_requests++; + lag = now - next_start; + next_start += interval_dist(rand_gen) * cycles_per_second; + } +} + +/** + * udp_client::receiver() - Invoked as the top-level method in a thread + * that waits for UDP responses and logs statistics. + * @receiver_id: Id of this receiver. + */ +void udp_client::receiver(int receiver_id) +{ + char thread_name[50]; + char buffer[1000000]; + + snprintf(thread_name, sizeof(thread_name), "R%d.%d", id, receiver_id); + time_trace::thread_buffer thread_buffer(thread_name); + receivers_running++; + int pid = syscall(__NR_gettid); + + while (!stop) { + ssize_t length = recvfrom(fd, buffer, sizeof(buffer), + 0, NULL, NULL); + if (length < 0) { + if (stop) + return; + if ((errno == EAGAIN) || (errno == EINTR)) + continue; + log(NORMAL, "FATAL: UDP recvfrom failed in client: " + "%s\n", strerror(errno)); + fatal(); + } + if (length < (ssize_t)sizeof(message_header)) + continue; + uint64_t end_time = rdtsc(); + message_header *header = + reinterpret_cast(buffer); + record(end_time, header); + tt("Response for cid 0x%08x received by pid %d", + header->cid, pid); + } +} + /** * homa_info() - Use the HOMAIOCINFO ioctl to extract the status of a * Homa socket and print the information to the log. @@ -3173,6 +3514,10 @@ int client_cmd(std::vector &words) if (first_port == -1) first_port = 4000; clients.push_back(new homa_client(i, experiment)); + } else if (strcmp(protocol, "udp") == 0) { + if (first_port == -1) + first_port = 6000; + clients.push_back(new udp_client(i, experiment)); } else { if (first_port == -1) first_port = 5000; @@ -3454,6 +3799,14 @@ int server_cmd(std::vector &words) experiment); homa_servers.push_back(server); } + } else if (strcmp(protocol, "udp") == 0) { + if (first_port == -1) + first_port = 6000; + for (int i = 0; i < server_ports; i++) { + udp_server *server = new udp_server(first_port + i, + i, port_threads, experiment); + udp_servers.push_back(server); + } } else { if (first_port == -1) first_port = 5000; @@ -3492,6 +3845,9 @@ int stop_cmd(std::vector &words) for (tcp_server *server: tcp_servers) delete server; tcp_servers.clear(); + for (udp_server *server: udp_servers) + delete server; + udp_servers.clear(); last_per_server_rpcs.clear(); for (server_metrics *m: metrics) delete m; diff --git a/util/homa_test.cc b/util/homa_test.cc index 5546089f..3ca1b3e1 100644 --- a/util/homa_test.cc +++ b/util/homa_test.cc @@ -760,6 +760,97 @@ void test_tcp(char *server_name, int port) return; } +/** + * udp_ping() - Send a request on a UDP socket and wait for the + * corresponding response. + * @fd: File descriptor for a UDP socket. + * @dest: Destination address. + * @dest_len: Size of @dest. + * @request: Buffer containing the request message. + * @length: Length of the request message. + */ +void udp_ping(int fd, struct sockaddr *dest, socklen_t dest_len, + void *request, int length) +{ + char response[1000000]; + int *int_response = reinterpret_cast(response); + ssize_t sent, received; + + sent = sendto(fd, request, length, 0, dest, dest_len); + if (sent != length) { + printf("UDP sendto failed: %s\n", strerror(errno)); + exit(1); + } + received = recvfrom(fd, response, sizeof(response), 0, NULL, NULL); + if (received < 0) { + printf("UDP recvfrom failed: %s\n", strerror(errno)); + exit(1); + } + if (received < (ssize_t)(2 * sizeof(int))) + return; + if (received != int_response[1]) + printf("Expected %d bytes in UDP response, got %ld\n", + int_response[1], received); +} + +/** + * test_udp() - Measure round-trip time for an RPC sent via a UDP socket. + * @server_name: Name of the server machine. + * @port: Server port to connect to. + */ +void test_udp(char *server_name, int port) +{ + struct addrinfo hints; + struct addrinfo *matching_addresses; + struct sockaddr *dest; + socklen_t dest_len; + int status, i; + int buffer[250000]; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = inet_family; + hints.ai_socktype = SOCK_DGRAM; + status = getaddrinfo(server_name, "80", &hints, &matching_addresses); + if (status != 0) { + printf("Couldn't look up address for %s: %s\n", + server_name, gai_strerror(status)); + exit(1); + } + dest = matching_addresses->ai_addr; + ((struct sockaddr_in *) dest)->sin_port = htons(port); + dest_len = matching_addresses->ai_addrlen; + + int fd = socket(inet_family, SOCK_DGRAM, 0); + if (fd == -1) { + printf("Couldn't open UDP socket: %s\n", strerror(errno)); + exit(1); + } + + /* Warm up. */ + buffer[0] = length; + buffer[1] = length; + seed_buffer(&buffer[2], sizeof32(buffer) - 2*sizeof32(int), seed); + for (i = 0; i < 10; i++) + udp_ping(fd, dest, dest_len, buffer, length); + + uint64_t times[count+1]; + for (i = 0; i < count; i++) { + times[i] = rdtsc(); + udp_ping(fd, dest, dest_len, buffer, length); + } + times[count] = rdtsc(); + freeaddrinfo(matching_addresses); + + for (i = 0; i < count; i++) { + times[i] = times[i+1] - times[i]; + } + print_dist(times, count); + printf("Bandwidth at median: %.1f MB/sec\n", + 2.0*((double) length)/(to_seconds(times[count/2])*1e06)); + close(fd); + return; +} + /** * test_tcpstream() - Measure throughput of a TCP socket using --length as * the size of the buffer for each write system call. @@ -1158,6 +1249,8 @@ int main(int argc, char** argv) test_tcpstream(host, port); } else if (strcmp(argv[next_arg], "tmp") == 0) { test_tmp(fd, count); + } else if (strcmp(argv[next_arg], "udp") == 0) { + test_udp(host, port); } else if (strcmp(argv[next_arg], "udpclose") == 0) { test_udpclose(); } else if (strcmp(argv[next_arg], "wmem") == 0) { diff --git a/util/server.cc b/util/server.cc index a87d753d..658529d3 100644 --- a/util/server.cc +++ b/util/server.cc @@ -307,6 +307,72 @@ void tcp_server(int port) } } +/** + * udp_server() - Opens a UDP socket and handles all requests arriving on + * that socket. Each request is a datagram whose first word is the total + * message length and second word is the desired response length. + * @port: Port number on which to listen. + */ +void udp_server(int port) +{ + int fd; + char buffer[1000000]; + sockaddr_in_union addr; + sockaddr_in_union source; + socklen_t source_len; + + fd = socket(inet_family, SOCK_DGRAM, 0); + if (fd < 0) { + printf("Couldn't open UDP socket: %s\n", strerror(errno)); + return; + } + memset(&addr, 0, sizeof(addr)); + addr.in4.sin_family = inet_family; + addr.in4.sin_port = htons(port); + if (bind(fd, &addr.sa, sizeof(addr)) != 0) { + printf("Couldn't bind UDP socket to port %d: %s\n", port, + strerror(errno)); + return; + } + if (verbose) + printf("Successfully bound to UDP port %d\n", port); + + while (1) { + int *int_buffer = reinterpret_cast(buffer); + ssize_t length; + int resp_length; + + source_len = sizeof(source); + length = recvfrom(fd, buffer, sizeof(buffer), 0, + &source.sa, &source_len); + if (length < 0) { + printf("UDP recvfrom failed: %s\n", strerror(errno)); + continue; + } + if (length < (ssize_t)(2 * sizeof(int))) { + if (verbose) + printf("UDP message too short (%ld bytes) " + "from %s\n", length, + print_address(&source)); + continue; + } + resp_length = int_buffer[1]; + if (verbose) + printf("Received UDP message from %s with %ld bytes, " + "response length %d\n", + print_address(&source), length, resp_length); + if (resp_length <= 0) + continue; + if (resp_length > (int)sizeof(buffer)) + resp_length = sizeof(buffer); + /* Echo the header back so the client can match responses. */ + if (sendto(fd, buffer, resp_length, 0, + &source.sa, source_len) < 0) { + printf("UDP sendto failed: %s\n", strerror(errno)); + } + } +} + int main(int argc, char** argv) { int next_arg; int num_ports = 1; @@ -356,5 +422,8 @@ int main(int argc, char** argv) { thread.detach(); } + std::thread udp_thread(udp_server, port); + udp_thread.detach(); + tcp_server(port); }