import Foundation
import Network

private let doPrint = false

final class StartTLSSocket: Socket {
    let (events, eventsContinuation) = AsyncStream.makeStream(of: SocketEvent.self, bufferingPolicy: .unbounded)

    private let queue: DispatchQueue
    private let bridgeSocketPath: URL
    private let host: String
    private let port: Int
    private let allowInsecure: Bool
    private var rawConnection: NWConnection?
    private var secConnection: NWConnection?
    private var bridge: NWListener?
    private var brgConnection: NWConnection?

    init(id: String, host: String, port: Int, allowInsecure: Bool) {
        self.host = host
        self.port = port
        self.allowInsecure = allowInsecure
        queue = DispatchQueue(label: "another.xmpp.network.queue_\(id)")

        bridgeSocketPath = URL(fileURLWithPath: "/tmp/tls_bridge_listener\(id).sock")
        try? FileManager.default.removeItem(at: bridgeSocketPath)

        // tcp options
        let tcpOptions = NWProtocolTCP.Options()
        tcpOptions.noDelay = true
        tcpOptions.connectionTimeout = 5
        tcpOptions.enableFastOpen = true
        tcpOptions.disableAckStretching = true
        let params = NWParameters(tls: nil, tcp: tcpOptions)
        params.serviceClass = .responsiveData

        rawConnection = NWConnection(host: .name(host, nil), port: .init(integerLiteral: UInt16(port)), using: .init(tls: nil, tcp: tcpOptions))
        rawConnection?.stateUpdateHandler = { [weak self] state in
            self?.logState("Raw connection \(state)")
            switch state {
            case .ready:
                self?.eventsContinuation.yield(.state(.connected))
                self?.rawRead()

            case .waiting(let error), .failed(let error):
                print(error.localizedDescription)
                self?.eventsContinuation.yield(.state(.disconnected(error)))

            default:
                break
            }
        }
    }

    deinit {
        eventsContinuation.finish()
        rawConnection?.cancel()
        brgConnection?.cancel()
        secConnection?.cancel()
        bridge?.cancel()
    }

    func connect() async {
        rawConnection?.start(queue: queue)
    }

    func send(_ data: Data) async {
        if secConnection != nil {
            secWrite(data: data)
        } else {
            rawWrite(data: data)
        }
    }

    func startTls() throws {
        try initBridgeAndSecureConnection()
    }
}

private extension StartTLSSocket {
    func rawRead() {
        rawConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
                return
            }
            guard let data else { return }
            self?.log("raw", data: data, read: true)
            if let brg = self?.brgConnection {
                brg.send(content: data, completion: .contentProcessed { [weak self] error in
                    if let err = error {
                        self?.eventsContinuation.yield(.state(.disconnected(err)))
                    }
                })
            } else {
                self?.eventsContinuation.yield(.dataReceived(data))
            }
            self?.rawRead()
        })
    }

    func rawWrite(data: Data) {
        log("raw", data: data, read: false)
        rawConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
            }
        })
    }

    func brgRead() {
        brgConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
                return
            }
            guard let data else { return }
            self?.rawWrite(data: data)
            self?.brgRead()
        })
    }

    func secRead() {
        secConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
                return
            }
            guard let data else { return }
            self?.log("sec", data: data, read: true)
            self?.eventsContinuation.yield(.dataReceived(data))
            self?.secRead()
        })
    }

    func secWrite(data: Data) {
        log("sec", data: data, read: false)
        secConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
            }
        })
    }

    func initBridgeAndSecureConnection() throws {
        let params = NWParameters()
        params.defaultProtocolStack.transportProtocol = NWProtocolTCP.Options()
        params.requiredLocalEndpoint = NWEndpoint.unix(path: bridgeSocketPath.path)
        params.allowLocalEndpointReuse = false
        bridge = try NWListener(using: params)
        bridge?.newConnectionLimit = 1

        // make bridge and connection
        bridge?.stateUpdateHandler = { [weak self] state in
            guard let self else { return }
            self.logState("Bridge \(state)")
            switch state {
            case .ready:
                // tcp options for secure connection
                let tcpOptions = NWProtocolTCP.Options()
                tcpOptions.noDelay = true
                tcpOptions.enableFastOpen = true
                tcpOptions.disableAckStretching = true

                // tls options for secure connection
                let tlsOptions = NWProtocolTLS.Options()
                sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv12)
                sec_protocol_options_set_max_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv13)
                // sec_protocol_options_set_peer_authentication_required(tlsOptions.securityProtocolOptions, false)
                if let domain = self.host.cString(using: .utf8) {
                    sec_protocol_options_set_tls_server_name(tlsOptions.securityProtocolOptions, domain)
                }
                sec_protocol_options_set_verify_block(tlsOptions.securityProtocolOptions, { _, sec_trust, sec_protocol_verify_complete in
                    if self.allowInsecure {
                        sec_protocol_verify_complete(true)
                    } else {
                        let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue()
                        var error: CFError?
                        if SecTrustEvaluateWithError(trust, &error) {
                            sec_protocol_verify_complete(true)
                        } else {
                            sec_protocol_verify_complete(false)
                        }
                    }
                }, queue)

                // secure connection
                let params = NWParameters(tls: tlsOptions, tcp: tcpOptions)
                params.serviceClass = .responsiveData
                self.secConnection = NWConnection(to: .unix(path: self.bridgeSocketPath.path), using: params)
                self.secConnection?.stateUpdateHandler = { [weak self] state in
                    self?.logState("Secure connection \(state)")
                    switch state {
                    case .ready:
                        self?.eventsContinuation.yield(.state(.startTlsReady))
                        self?.secRead()

                    case .waiting(let error), .failed(let error):
                        self?.eventsContinuation.yield(.state(.disconnected(error)))

                    default:
                        break
                    }
                }
                self.secConnection?.start(queue: queue)

            case .waiting(let error), .failed(let error):
                self.eventsContinuation.yield(.state(.disconnected(error)))

            default:
                break
            }
        }

        // bridge handler
        bridge?.newConnectionHandler = { [weak self] connection in
            guard let self else { return }
            self.brgConnection = connection
            self.brgConnection?.stateUpdateHandler = { state in
                self.logState("Bridge connection \(state)")
                switch state {
                case .ready:
                    self.brgRead()

                case .waiting(let error), .failed(let error):
                    self.eventsContinuation.yield(.state(.disconnected(error)))

                default:
                    break
                }
            }
            connection.start(queue: self.queue)
        }
        bridge?.start(queue: queue)
    }

    func log(_ socket: String, data: Data, read: Bool) {
        if socket == "raw", secConnection != nil { return }
        if doPrint {
            let direction = read ? "read-in" : "write-out"
            let str = String(bytes: data, encoding: .ascii) ?? "?"
            print("\nStartTLSSocket-\(socket) \(direction): \(str)\n")
        }
    }

    func logState(_ str: String) {
        if doPrint {
            print("Connection state: \(str)")
        }
    }
}