186 lines
6.4 KiB
Swift
186 lines
6.4 KiB
Swift
// RFC - 6120: chapter 3
|
|
import Combine
|
|
import dnssd
|
|
import Foundation
|
|
|
|
// MARK: Public
|
|
enum SRVResolverError: Error {
|
|
case srvReferenceError
|
|
case srvSocketError
|
|
case srvTimeout
|
|
case srvUnableToComplete
|
|
case srvProcessError
|
|
}
|
|
|
|
final class SRVResolverModule: XmppModule {
|
|
let id = "SRV resolver module"
|
|
|
|
func reduce(oldState: ClientState, with event: Event) -> ClientState {
|
|
var newState = oldState
|
|
switch event {
|
|
case .domainResolved(let records):
|
|
newState.srvRecords = records
|
|
newState.srvRecordIndex = -1 // will be increased on each connection attempt
|
|
|
|
default:
|
|
break
|
|
}
|
|
return newState
|
|
}
|
|
|
|
func process(state: ClientState, with event: Event) async -> Event? {
|
|
switch event {
|
|
case .resolveDomain:
|
|
let domain = state.jid.domainPart
|
|
do {
|
|
let records = try await SRVResolver.resolve(domain: domain)
|
|
return .domainResolved(records)
|
|
} catch let err {
|
|
// swiftlint:disable:next force_cast
|
|
return .domainResolvingError(err as! SRVResolverError)
|
|
}
|
|
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: Private
|
|
private enum SRVResolver {
|
|
static func resolve(domain: String) async throws -> [SRVRecord] {
|
|
// request for non-tls
|
|
async let req1 = try withCheckedThrowingContinuation { continuation in
|
|
SRVRequest(target: "_xmpp-client._tcp." + domain) { result in
|
|
switch result {
|
|
case .success(let records):
|
|
continuation.resume(returning: records)
|
|
|
|
case .failure(let error):
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
.runQuery()
|
|
}
|
|
|
|
// request for tls
|
|
async let req2 = try withCheckedThrowingContinuation { continuation in
|
|
SRVRequest(target: "_xmpps-client._tcp." + domain) { result in
|
|
switch result {
|
|
case .success(let records):
|
|
continuation.resume(returning: records)
|
|
|
|
case .failure(let error):
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
.runQuery()
|
|
}
|
|
|
|
// sort by priority
|
|
let (records1, records2) = try (await req1, await req2)
|
|
var result = (records1 + records2).sorted(by: { $0.priority < $1.priority })
|
|
|
|
// for fallback according to RFC 6120 section 3.2.2
|
|
if result.isEmpty {
|
|
result.append(.init(fallbackTarget: domain))
|
|
}
|
|
|
|
//
|
|
return result
|
|
}
|
|
}
|
|
|
|
private typealias SRVRequestCompletion = (Result<[SRVRecord], SRVResolverError>) -> Void
|
|
|
|
private class SRVRequest {
|
|
private let queue = DispatchQueue(label: "srv.resolving")
|
|
private var dispatchSourceRead: DispatchSourceRead?
|
|
private var timeoutTimer: DispatchSourceTimer?
|
|
private var serviceRef: DNSServiceRef?
|
|
private var socket: dnssd_sock_t = -1
|
|
private let timeout = TimeInterval(5)
|
|
private let target: String
|
|
|
|
var records = [SRVRecord]()
|
|
var completion: SRVRequestCompletion
|
|
|
|
init(target: String, completion: @escaping SRVRequestCompletion) {
|
|
self.target = target
|
|
self.completion = completion
|
|
}
|
|
|
|
func runQuery() {
|
|
let result = DNSServiceQueryRecord(
|
|
&serviceRef,
|
|
kDNSServiceFlagsReturnIntermediates,
|
|
UInt32(kDNSServiceInterfaceIndexAny),
|
|
target.cString(using: .utf8),
|
|
UInt16(kDNSServiceType_SRV),
|
|
UInt16(kDNSServiceClass_IN), { _, flags, _, _, _, _, _, rdLen, rdata, _, context in
|
|
guard let context = context else {
|
|
return
|
|
}
|
|
let request: SRVRequest = Mem.bridge(context)
|
|
|
|
if
|
|
let data = rdata?.assumingMemoryBound(to: UInt8.self),
|
|
let record = SRVRecord(data: Data(bytes: data, count: Int(rdLen)))
|
|
// swiftlint:disable:next opening_brace
|
|
{
|
|
request.records.append(record)
|
|
}
|
|
if flags & kDNSServiceFlagsMoreComing == 0 {
|
|
request.timeoutTimer?.cancel()
|
|
request.dispatchSourceRead?.cancel()
|
|
request.completion(.success(request.records))
|
|
}
|
|
},
|
|
Mem.bridge(self)
|
|
)
|
|
switch result {
|
|
case DNSServiceErrorType(kDNSServiceErr_NoError):
|
|
guard let sdRef = serviceRef else {
|
|
timeoutTimer?.cancel()
|
|
dispatchSourceRead?.cancel()
|
|
completion(.failure(.srvReferenceError))
|
|
return
|
|
}
|
|
socket = DNSServiceRefSockFD(serviceRef)
|
|
guard socket != -1 else {
|
|
timeoutTimer?.cancel()
|
|
dispatchSourceRead?.cancel()
|
|
completion(.failure(.srvSocketError))
|
|
return
|
|
}
|
|
dispatchSourceRead = DispatchSource.makeReadSource(fileDescriptor: socket, queue: queue)
|
|
dispatchSourceRead?.setEventHandler { [weak self] in
|
|
let res = DNSServiceProcessResult(sdRef)
|
|
if res != kDNSServiceErr_NoError {
|
|
self?.timeoutTimer?.cancel()
|
|
self?.dispatchSourceRead?.cancel()
|
|
self?.completion(.failure(.srvProcessError))
|
|
}
|
|
}
|
|
dispatchSourceRead?.setCancelHandler {
|
|
DNSServiceRefDeallocate(self.serviceRef)
|
|
}
|
|
dispatchSourceRead?.resume()
|
|
timeoutTimer = DispatchSource.makeTimerSource(flags: [], queue: queue)
|
|
timeoutTimer?.setEventHandler { [weak self] in
|
|
self?.timeoutTimer?.cancel()
|
|
self?.dispatchSourceRead?.cancel()
|
|
self?.completion(.failure(.srvTimeout))
|
|
}
|
|
let deadline = DispatchTime(uptimeNanoseconds: DispatchTime.now().uptimeNanoseconds + UInt64(timeout * Double(NSEC_PER_SEC)))
|
|
timeoutTimer?.schedule(deadline: deadline, repeating: .infinity, leeway: DispatchTimeInterval.never)
|
|
timeoutTimer?.resume()
|
|
|
|
default:
|
|
timeoutTimer?.cancel()
|
|
dispatchSourceRead?.cancel()
|
|
completion(.failure(.srvUnableToComplete))
|
|
}
|
|
}
|
|
}
|