another.im-ios/AnotherXMPP/modules/connection/StartTLSSocket.swift

252 lines
9.2 KiB
Swift
Raw Normal View History

2024-06-19 15:06:39 +00:00
import Foundation
import Network
private let doPrint = false
final class StartTLSSocket: Socket {
let (events, eventsContinuation) = AsyncStream.makeStream(of: SocketEvent.self, bufferingPolicy: .unbounded)
private let queue: DispatchQueue
private let bridgeSocketPath: URL
private let host: String
private let port: Int
private let allowInsecure: Bool
private var rawConnection: NWConnection?
private var secConnection: NWConnection?
private var bridge: NWListener?
private var brgConnection: NWConnection?
init(id: String, host: String, port: Int, allowInsecure: Bool) {
self.host = host
self.port = port
self.allowInsecure = allowInsecure
queue = DispatchQueue(label: "another.xmpp.network.queue_\(id)")
bridgeSocketPath = URL(fileURLWithPath: "/tmp/tls_bridge_listener\(id).sock")
try? FileManager.default.removeItem(at: bridgeSocketPath)
// tcp options
let tcpOptions = NWProtocolTCP.Options()
tcpOptions.noDelay = true
tcpOptions.connectionTimeout = 5
tcpOptions.enableFastOpen = true
tcpOptions.disableAckStretching = true
let params = NWParameters(tls: nil, tcp: tcpOptions)
params.serviceClass = .responsiveData
rawConnection = NWConnection(host: .name(host, nil), port: .init(integerLiteral: UInt16(port)), using: .init(tls: nil, tcp: tcpOptions))
rawConnection?.stateUpdateHandler = { [weak self] state in
self?.logState("Raw connection \(state)")
switch state {
case .ready:
self?.eventsContinuation.yield(.state(.connected))
self?.rawRead()
case .waiting(let error), .failed(let error):
print(error.localizedDescription)
self?.eventsContinuation.yield(.state(.disconnected(error)))
default:
break
}
}
}
deinit {
eventsContinuation.finish()
rawConnection?.cancel()
brgConnection?.cancel()
secConnection?.cancel()
bridge?.cancel()
}
func connect() async {
rawConnection?.start(queue: queue)
}
func send(_ data: Data) async {
if secConnection != nil {
secWrite(data: data)
} else {
rawWrite(data: data)
}
}
func startTls() throws {
try initBridgeAndSecureConnection()
}
}
private extension StartTLSSocket {
func rawRead() {
rawConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
if let err = error {
self?.eventsContinuation.yield(.state(.disconnected(err)))
return
}
guard let data else { return }
self?.log("raw", data: data, read: true)
if let brg = self?.brgConnection {
brg.send(content: data, completion: .contentProcessed { [weak self] error in
if let err = error {
self?.eventsContinuation.yield(.state(.disconnected(err)))
}
})
} else {
self?.eventsContinuation.yield(.dataReceived(data))
}
self?.rawRead()
})
}
func rawWrite(data: Data) {
log("raw", data: data, read: false)
rawConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
if let err = error {
self?.eventsContinuation.yield(.state(.disconnected(err)))
}
})
}
func brgRead() {
brgConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
if let err = error {
self?.eventsContinuation.yield(.state(.disconnected(err)))
return
}
guard let data else { return }
self?.rawWrite(data: data)
self?.brgRead()
})
}
func secRead() {
secConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
if let err = error {
self?.eventsContinuation.yield(.state(.disconnected(err)))
return
}
guard let data else { return }
self?.log("sec", data: data, read: true)
self?.eventsContinuation.yield(.dataReceived(data))
self?.secRead()
})
}
func secWrite(data: Data) {
log("sec", data: data, read: false)
secConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
if let err = error {
self?.eventsContinuation.yield(.state(.disconnected(err)))
}
})
}
func initBridgeAndSecureConnection() throws {
let params = NWParameters()
params.defaultProtocolStack.transportProtocol = NWProtocolTCP.Options()
params.requiredLocalEndpoint = NWEndpoint.unix(path: bridgeSocketPath.path)
params.allowLocalEndpointReuse = false
bridge = try NWListener(using: params)
bridge?.newConnectionLimit = 1
// make bridge and connection
bridge?.stateUpdateHandler = { [weak self] state in
guard let self else { return }
self.logState("Bridge \(state)")
switch state {
case .ready:
// tcp options for secure connection
let tcpOptions = NWProtocolTCP.Options()
tcpOptions.noDelay = true
tcpOptions.enableFastOpen = true
tcpOptions.disableAckStretching = true
// tls options for secure connection
let tlsOptions = NWProtocolTLS.Options()
sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv12)
sec_protocol_options_set_max_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv13)
// sec_protocol_options_set_peer_authentication_required(tlsOptions.securityProtocolOptions, false)
if let domain = self.host.cString(using: .utf8) {
sec_protocol_options_set_tls_server_name(tlsOptions.securityProtocolOptions, domain)
}
sec_protocol_options_set_verify_block(tlsOptions.securityProtocolOptions, { _, sec_trust, sec_protocol_verify_complete in
if self.allowInsecure {
sec_protocol_verify_complete(true)
} else {
let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue()
var error: CFError?
if SecTrustEvaluateWithError(trust, &error) {
sec_protocol_verify_complete(true)
} else {
sec_protocol_verify_complete(false)
}
}
}, queue)
// secure connection
let params = NWParameters(tls: tlsOptions, tcp: tcpOptions)
params.serviceClass = .responsiveData
self.secConnection = NWConnection(to: .unix(path: self.bridgeSocketPath.path), using: params)
self.secConnection?.stateUpdateHandler = { [weak self] state in
self?.logState("Secure connection \(state)")
switch state {
case .ready:
self?.eventsContinuation.yield(.state(.startTlsReady))
self?.secRead()
case .waiting(let error), .failed(let error):
self?.eventsContinuation.yield(.state(.disconnected(error)))
default:
break
}
}
self.secConnection?.start(queue: queue)
case .waiting(let error), .failed(let error):
self.eventsContinuation.yield(.state(.disconnected(error)))
default:
break
}
}
// bridge handler
bridge?.newConnectionHandler = { [weak self] connection in
guard let self else { return }
self.brgConnection = connection
self.brgConnection?.stateUpdateHandler = { state in
self.logState("Bridge connection \(state)")
switch state {
case .ready:
self.brgRead()
case .waiting(let error), .failed(let error):
self.eventsContinuation.yield(.state(.disconnected(error)))
default:
break
}
}
connection.start(queue: self.queue)
}
bridge?.start(queue: queue)
}
func log(_ socket: String, data: Data, read: Bool) {
if socket == "raw", secConnection != nil { return }
if doPrint {
let direction = read ? "read-in" : "write-out"
let str = String(bytes: data, encoding: .ascii) ?? "?"
print("\nStartTLSSocket-\(socket) \(direction): \(str)\n")
}
}
func logState(_ str: String) {
if doPrint {
print("Connection state: \(str)")
}
}
}