Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Sources/DNSServer/DNSHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
// limitations under the License.
//===----------------------------------------------------------------------===//

/// Context for a DNS request.
public struct DNSRequestContext: Sendable {
/// The source IP address for the datagram, if available.
public let remoteIPAddress: String?

/// The source port for the datagram, if available.
public let remotePort: Int?

public init(remoteIPAddress: String? = nil, remotePort: Int? = nil) {
self.remoteIPAddress = remoteIPAddress
self.remotePort = remotePort
}
}

/// Protocol for implementing custom DNS handlers.
public protocol DNSHandler {
/// Attempt to answer a DNS query
Expand All @@ -22,4 +36,19 @@ public protocol DNSHandler {
/// - Returns: The response message for the query, or nil if the request
/// is not within the scope of the handler.
func answer(query: Message) async throws -> Message?

/// Attempt to answer a DNS query with request context.
/// - Parameters:
/// - query: the query message
/// - context: request context such as the source address.
/// - Throws: a server failure occurred during the query
/// - Returns: The response message for the query, or nil if the request
/// is not within the scope of the handler.
func answer(query: Message, context: DNSRequestContext) async throws -> Message?
}

extension DNSHandler {
public func answer(query: Message, context: DNSRequestContext) async throws -> Message? {
try await answer(query: query)
}
}
16 changes: 13 additions & 3 deletions Sources/DNSServer/DNSServer+Handle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,25 @@ extension DNSServer {
self.log?.debug("processing query: \(query.questions)")

self.log?.debug("awaiting processing")
var response =
try await handler.answer(query: query)
?? Message(
let context = DNSRequestContext(
remoteIPAddress: packet.remoteAddress.ipAddress,
remotePort: packet.remoteAddress.port
)
var response: Message
if let handlerResponse = try await handler.answer(query: query, context: context) {
response = handlerResponse
} else if respondWhenUnhandled {
response = Message(
id: query.id,
type: .response,
returnCode: .notImplemented,
questions: query.questions,
answers: []
)
} else {
self.log?.debug("dropping unhandled DNS query")
return
}

// Only set NXDOMAIN if handler didn't explicitly set noError (NODATA response).
// This preserves NODATA responses for AAAA queries when A record exists,
Expand Down
3 changes: 3 additions & 0 deletions Sources/DNSServer/DNSServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ import NIOPosix
/// - port: The port for the server to listen.
public struct DNSServer {
public var handler: DNSHandler
let respondWhenUnhandled: Bool
let log: Logger?

public init(
handler: DNSHandler,
respondWhenUnhandled: Bool = true,
log: Logger? = nil
) {
self.handler = handler
self.respondWhenUnhandled = respondWhenUnhandled
self.log = log
}

Expand Down
10 changes: 10 additions & 0 deletions Sources/DNSServer/Handlers/CompositeResolver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,14 @@ public struct CompositeResolver: DNSHandler {

return nil
}

public func answer(query: Message, context: DNSRequestContext) async throws -> Message? {
for handler in self.handlers {
if let response = try await handler.answer(query: query, context: context) {
return response
}
}

return nil
}
}
142 changes: 142 additions & 0 deletions Sources/DNSServer/Handlers/ForwardingResolver.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//===----------------------------------------------------------------------===//
// Copyright © 2026 Apple Inc. and the container project authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===----------------------------------------------------------------------===//

import Foundation
import Logging
import NIOCore
import NIOPosix

/// Handler that forwards DNS queries to upstream DNS servers.
public struct ForwardingResolver: DNSHandler {
public struct Upstream: Sendable, Equatable {
public let host: String
public let port: Int

public init(host: String, port: Int = 53) {
self.host = host
self.port = port
}
}

private let upstreams: [Upstream]
private let timeoutNanoseconds: UInt64
private let log: Logger?

public init(
upstreams: [Upstream],
timeoutNanoseconds: UInt64 = 2_000_000_000,
log: Logger? = nil
) {
self.upstreams = upstreams
self.timeoutNanoseconds = timeoutNanoseconds
self.log = log
}

public func answer(query: Message) async throws -> Message? {
for upstream in upstreams {
do {
if let response = try await forward(query: query, to: upstream) {
return response
}
} catch {
log?.debug(
"DNS upstream query failed",
metadata: [
"host": "\(upstream.host)",
"port": "\(upstream.port)",
"error": "\(error)",
])
}
}
return nil
}

public static func systemUpstreams(resolvConfPath: String = "/etc/resolv.conf") -> [Upstream] {
guard let contents = try? String(contentsOfFile: resolvConfPath, encoding: .utf8) else {
return [Upstream(host: "1.1.1.1")]
}

let upstreams = contents
.split(whereSeparator: \.isNewline)
.compactMap { line -> Upstream? in
let stripped = line.split(separator: "#", maxSplits: 1, omittingEmptySubsequences: false)[0]
let parts = stripped.split(whereSeparator: \.isWhitespace)
guard parts.count >= 2, parts[0] == "nameserver" else {
return nil
}
let host = String(parts[1])
guard isUsableUpstreamHost(host) else {
return nil
}
return Upstream(host: host)
}

return upstreams.isEmpty ? [Upstream(host: "1.1.1.1")] : upstreams
}

private static func isUsableUpstreamHost(_ host: String) -> Bool {
guard !host.hasPrefix("127."), host != "::1", host != "0.0.0.0", host != "::" else {
return false
}
return (try? SocketAddress(ipAddress: host, port: 53)) != nil
}

private func forward(query: Message, to upstream: Upstream) async throws -> Message? {
let queryData = try query.serialize()
let upstreamAddress = try SocketAddress(ipAddress: upstream.host, port: upstream.port)
let bindHost = upstream.host.contains(":") ? "::" : "0.0.0.0"
let channel = try await DatagramBootstrap(group: NIOSingletons.posixEventLoopGroup)
.bind(host: bindHost, port: 0)
.flatMapThrowing { channel in
try NIOAsyncChannel(
wrappingChannelSynchronously: channel,
configuration: NIOAsyncChannel.Configuration(
inboundType: AddressedEnvelope<ByteBuffer>.self,
outboundType: AddressedEnvelope<ByteBuffer>.self
)
)
}
.get()

return try await channel.executeThenClose { inbound, outbound in
try await outbound.write(AddressedEnvelope(remoteAddress: upstreamAddress, data: ByteBuffer(bytes: queryData)))
let timeoutNanoseconds = self.timeoutNanoseconds

return try await withThrowingTaskGroup(of: Message?.self) { group in
group.addTask {
for try await var packet in inbound {
var data = Data()
while packet.data.readableBytes > 0 {
if let chunk = packet.data.readBytes(length: packet.data.readableBytes) {
data.append(contentsOf: chunk)
}
}
return try Message(deserialize: data)
}
return nil
}
group.addTask {
try await Task.sleep(nanoseconds: timeoutNanoseconds)
return nil
}

let result = try await group.next() ?? nil
group.cancelAll()
return result
}
}
}
}
54 changes: 43 additions & 11 deletions Sources/DNSServer/Records/Message.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ public struct Message: Sendable {
/// Additional resource records.
public var additional: [ResourceRecord]

private var rawAnswerCount: UInt16?
private var rawAuthorityCount: UInt16?
private var rawAdditionalCount: UInt16?
private var rawResourceRecords: Data?

/// Creates a new DNS message.
public init(
id: UInt16 = 0,
Expand Down Expand Up @@ -101,6 +106,10 @@ public struct Message: Sendable {
self.answers = answers
self.authorities = authorities
self.additional = additional
self.rawAnswerCount = nil
self.rawAuthorityCount = nil
self.rawAdditionalCount = nil
self.rawResourceRecords = nil
}

/// Deserialize a DNS message from raw data.
Expand Down Expand Up @@ -153,15 +162,13 @@ public struct Message: Sendable {
guard let (newOffset, rawNsCount) = buffer.copyOut(as: UInt16.self, offset: offset) else {
throw DNSBindError.unmarshalFailure(type: "Message", field: "nscount")
}
// nsCount not used for now, but we need to read past it
_ = UInt16(bigEndian: rawNsCount)
let nsCount = UInt16(bigEndian: rawNsCount)
offset = newOffset

guard let (newOffset, rawArCount) = buffer.copyOut(as: UInt16.self, offset: offset) else {
throw DNSBindError.unmarshalFailure(type: "Message", field: "arcount")
}
// arCount not used for now, but we need to read past it
_ = UInt16(bigEndian: rawArCount)
let arCount = UInt16(bigEndian: rawArCount)
offset = newOffset

// Read questions
Expand All @@ -172,13 +179,16 @@ public struct Message: Sendable {
self.questions.append(question)
}

// Read answers (simplified - skip for now as we only need to parse queries)
// Resource-record parsing is intentionally minimal. Preserve the raw
// section so forwarded DNS responses can be serialized without dropping
// upstream answers.
self.answers = []
self.authorities = []
self.additional = []

// Skip answer parsing for now - we primarily receive queries and send responses
_ = anCount
self.rawAnswerCount = anCount
self.rawAuthorityCount = nsCount
self.rawAdditionalCount = arCount
self.rawResourceRecords = offset < buffer.count ? Data(buffer[offset...]) : nil
}

/// Serialize this message to raw data.
Expand All @@ -196,6 +206,9 @@ public struct Message: Sendable {
let rdataSize = answer.type == .host ? 4 : 16
bufferSize += (try DNSName(labels: n.isEmpty ? [] : n.split(separator: ".", omittingEmptySubsequences: false).map(String.init))).size + 10 + rdataSize
}
if shouldSerializeRawResourceRecords, let rawResourceRecords {
bufferSize += rawResourceRecords.count
}

var buffer = [UInt8](repeating: 0, count: bufferSize)
var offset = 0
Expand Down Expand Up @@ -240,17 +253,21 @@ public struct Message: Sendable {
}
offset = newOffset

guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(answers.count).bigEndian, offset: offset) else {
let answerCount = shouldSerializeRawResourceRecords ? (rawAnswerCount ?? 0) : UInt16(answers.count)
let authorityCount = shouldSerializeRawResourceRecords ? (rawAuthorityCount ?? 0) : UInt16(authorities.count)
let additionalCount = shouldSerializeRawResourceRecords ? (rawAdditionalCount ?? 0) : UInt16(additional.count)

guard let newOffset = buffer.copyIn(as: UInt16.self, value: answerCount.bigEndian, offset: offset) else {
throw DNSBindError.marshalFailure(type: "Message", field: "ancount")
}
offset = newOffset

guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(authorities.count).bigEndian, offset: offset) else {
guard let newOffset = buffer.copyIn(as: UInt16.self, value: authorityCount.bigEndian, offset: offset) else {
throw DNSBindError.marshalFailure(type: "Message", field: "nscount")
}
offset = newOffset

guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(additional.count).bigEndian, offset: offset) else {
guard let newOffset = buffer.copyIn(as: UInt16.self, value: additionalCount.bigEndian, offset: offset) else {
throw DNSBindError.marshalFailure(type: "Message", field: "arcount")
}
offset = newOffset
Expand All @@ -260,6 +277,17 @@ public struct Message: Sendable {
offset = try question.appendBuffer(&buffer, offset: offset)
}

if shouldSerializeRawResourceRecords, let rawResourceRecords {
guard let newOffset = buffer.copyIn(buffer: Array(rawResourceRecords), offset: offset) else {
throw DNSBindError.marshalFailure(type: "Message", field: "resourceRecords")
}
offset = newOffset
guard offset == bufferSize else {
throw DNSBindError.unexpectedOffset(type: "Message", expected: bufferSize, actual: offset)
}
return Data(buffer[0..<offset])
}

// Write answers
for answer in answers {
offset = try answer.appendBuffer(&buffer, offset: offset)
Expand All @@ -280,4 +308,8 @@ public struct Message: Sendable {
}
return Data(buffer[0..<offset])
}

private var shouldSerializeRawResourceRecords: Bool {
rawResourceRecords != nil && answers.isEmpty && authorities.isEmpty && additional.isEmpty
}
}
10 changes: 9 additions & 1 deletion Sources/Services/Network/Server/AttachmentAllocator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ actor AttachmentAllocator {

/// Allocate a network address for a host.
func allocate(hostname: String) async throws -> UInt32 {
let hostname = Self.normalized(hostname: hostname)
// Client is responsible for ensuring two containers don't use same hostname, so provide existing IP if hostname exists
if let index = hostnames[hostname] {
return index
Expand All @@ -44,6 +45,7 @@ actor AttachmentAllocator {
/// Free an allocated network address by hostname.
@discardableResult
func deallocate(hostname: String) async throws -> UInt32? {
let hostname = Self.normalized(hostname: hostname)
guard let index = hostnames.removeValue(forKey: hostname) else {
return nil
}
Expand All @@ -54,6 +56,12 @@ actor AttachmentAllocator {

/// Retrieve the allocator index for a hostname.
func lookup(hostname: String) async throws -> UInt32? {
hostnames[hostname]
let hostname = Self.normalized(hostname: hostname)
return hostnames[hostname]
}

private static func normalized(hostname: String) -> String {
let hostname = hostname.hasSuffix(".") ? String(hostname.dropLast()) : hostname
return hostname.lowercased()
}
}
Loading