// RFC - 6120
import Foundation

enum SessionState: Codable & Equatable {
    case waitingSRVRecords
    case tryingConnect
    case readyToStreamInit
    case streamActive
}

// TODO: add stream errors processing
final class SessionModule: XmppModule {
    let id = "Session module"

    private var reqId = ""

    func reduce(oldState: ClientState, with event: Event) -> ClientState {
        var newState = oldState
        switch event {
        case .startClientLogin(let jid, let credsId):
            newState.jid = jid
            newState.credentialsId = credsId

        case .domainResolved:
            newState.sessionState = .tryingConnect

        case .tryConnect:
            newState.srvRecordIndex += 1

        case .allRecordsUnreachable:
            newState.srvRecords = []
            newState.srvRecordIndex = -1
            newState.sessionState = .waitingSRVRecords

        case .socketConnected(let type):
            newState.socketType = type
            newState.sessionState = .readyToStreamInit
            if type == .directTls {
                newState.isSocketSecured = true
            }

        case .startTlsFailed:
            newState.sessionState = .tryingConnect
            newState.isSocketSecured = false

        case .startTlsDone:
            newState.sessionState = .readyToStreamInit
            newState.isSocketSecured = true

        case .authDone(let saslType, let args):
            switch saslType {
            case .sasl1:
                newState.sessionState = .readyToStreamInit

            case .sasl2:
                if let authId = args["authorization-identifier"], let newJid = try? JID(authId) {
                    newState.jid = newJid
                }
            }

        case .bindStreamDone(let jidStr):
            if let jid = try? JID(jidStr) {
                newState.jid = jid
                newState.isStreamBound = true
            }

        case .bindStreamError:
            newState.isStreamBound = false // TODO: implement good error handling

        default:
            break
        }
        return newState
    }

    func process(state: ClientState, with event: Event) async -> Event? {
        switch (event, state.sessionState) {
        case (.startClientLogin, .waitingSRVRecords):
            return .resolveDomain

        case (.domainResolved, .tryingConnect):
            return .tryConnect

        case (.socketError, .tryingConnect):
            return .tryConnect

        case (.socketConnected, .readyToStreamInit):
            return .startStream

        case (.startStream, .readyToStreamInit):
            let req = XMLElement(
                name: "stream:stream",
                xmlns: nil,
                attributes: [
                    "from": state.jid.bare,
                    "to": state.jid.domainPart,
                    "xml:lang": "en",
                    "version": "1.0",
                    "xmlns": "jabber:client",
                    "xmlns:stream": "http://etherx.jabber.org/streams"
                ],
                content: nil,
                nodes: [],
                woClose: true
            )

            return .xmlOutbound(req)

        case (.startTlsFailed, _):
            // try reconnect with another srv record if starttls failed
            return .tryConnect

        case (.startTlsDone, _):
            return .startStream

        case (.authDone, _):
            return .startStream

        // Stream binding
        case (.xmlInbound(let xml), _):
            if !state.isStreamBound, xml.name == "stream:features", xml.nodes.map({ $0.name }).contains("bind") {
                let reqXml = XMLElement(
                    name: "bind",
                    xmlns: "urn:ietf:params:xml:ns:xmpp-bind",
                    attributes: [:],
                    content: nil,
                    nodes: []
                )
                if let request = Stanza.iqSet(payload: reqXml), let id = request.id {
                    reqId = id
                    return .stanzaOutbound(request)
                } else {
                    return nil
                }
            } else {
                return nil
            }

        case (.stanzaInbound(let stanza), _):
            guard stanza.id == reqId else { return nil }
            switch stanza.type {
            case .iq(.result):
                let jid = stanza.wrapped
                    .nodes
                    .first(where: { $0.name == "bind" })?
                    .nodes
                    .first(where: { $0.name == "jid" })?
                    .content
                if let jid {
                    return .bindStreamDone(jid)
                } else {
                    return nil
                }

            default:
                return .bindStreamError // TODO: implement good error handling
            }

        case (.bindStreamDone, _):
            return .streamReady

        default:
            return nil
        }
    }
}