# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
UDP support for IOCP reactor
"""

from __future__ import annotations

import errno
import socket
import struct
import warnings
from typing import TYPE_CHECKING, Optional

from zope.interface import implementer

from twisted.internet import address, defer, error, interfaces
from twisted.internet._multicast import MulticastMixin
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.iocpreactor import abstract, iocpsupport as _iocp
from twisted.internet.iocpreactor.const import (
    ERROR_CONNECTION_REFUSED,
    ERROR_IO_PENDING,
    ERROR_PORT_UNREACHABLE,
)
from twisted.internet.iocpreactor.interfaces import IReadWriteHandle
from twisted.internet.protocol import AbstractDatagramProtocol
from twisted.python import log

if TYPE_CHECKING:
    from twisted.internet.iocpreactor.reactor import IOCPReactor


@implementer(
    IReadWriteHandle,
    interfaces.IListeningPort,
    interfaces.IUDPTransport,
    interfaces.ISystemHandle,
)
class Port(abstract.FileHandle):
    """
    UDP port, listening for packets.

    @ivar addressFamily: L{socket.AF_INET} or L{socket.AF_INET6}, depending on
        whether this port is listening on an IPv4 address or an IPv6 address.
    """

    reactor: IOCPReactor
    addressFamily = socket.AF_INET
    socketType = socket.SOCK_DGRAM
    dynamicReadBuffers = False

    # Actual port number being listened on, only set to a non-None
    # value when we are actually listening.
    _realPortNumber: Optional[int] = None

    def __init__(
        self,
        port: int,
        proto: AbstractDatagramProtocol,
        interface: str = "",
        maxPacketSize: int = 8192,
        reactor: IOCPReactor | None = None,
    ) -> None:
        """
        Initialize with a numeric port to listen on.
        """
        self.port = port
        self.protocol = proto
        self.readBufferSize = maxPacketSize
        self.interface = interface
        self.setLogStr()
        self._connectedAddr = None
        self._setAddressFamily()

        abstract.FileHandle.__init__(self, reactor)

        skt = socket.socket(self.addressFamily, self.socketType)
        addrLen = _iocp.maxAddrLen(skt.fileno())
        self.addressBuffer = bytearray(addrLen)
        # WSARecvFrom takes an int
        self.addressLengthBuffer = bytearray(struct.calcsize("i"))

    def _setAddressFamily(self):
        """
        Resolve address family for the socket.
        """
        if isIPv6Address(self.interface):
            self.addressFamily = socket.AF_INET6
        elif isIPAddress(self.interface):
            self.addressFamily = socket.AF_INET
        elif self.interface:
            raise error.InvalidAddressError(
                self.interface, "not an IPv4 or IPv6 address"
            )

    def __repr__(self) -> str:
        if self._realPortNumber is not None:
            return f"<{self.protocol.__class__} on {self._realPortNumber}>"
        else:
            return f"<{self.protocol.__class__} not connected>"

    def getHandle(self):
        """
        Return a socket object.
        """
        return self.socket

    def startListening(self):
        """
        Create and bind my socket, and begin listening on it.

        This is called on unserialization, and must be called after creating a
        server to begin listening on the specified port.
        """
        self._bindSocket()
        self._connectToProtocol()

    def createSocket(self) -> socket.socket:
        return self.reactor.createSocket(self.addressFamily, self.socketType)

    def _bindSocket(self):
        try:
            skt = self.createSocket()
            skt.bind((self.interface, self.port))
        except OSError as le:
            raise error.CannotListenError(self.interface, self.port, le)

        # Make sure that if we listened on port 0, we update that to
        # reflect what the OS actually assigned us.
        self._realPortNumber = skt.getsockname()[1]

        log.msg(
            "%s starting on %s"
            % (self._getLogPrefix(self.protocol), self._realPortNumber)
        )

        self.connected = True
        self.socket = skt
        self.getFileHandle = self.socket.fileno

    def _connectToProtocol(self):
        self.protocol.makeConnection(self)
        self.startReading()
        self.reactor.addActiveHandle(self)

    def cbRead(self, rc, data, evt):
        if self.reading:
            self.handleRead(rc, data, evt)
            self.doRead()

    def handleRead(self, rc, data, evt):
        if rc in (
            errno.WSAECONNREFUSED,
            errno.WSAECONNRESET,
            ERROR_CONNECTION_REFUSED,
            ERROR_PORT_UNREACHABLE,
        ):
            if self._connectedAddr:
                self.protocol.connectionRefused()
        elif rc:
            log.msg(
                "error in recvfrom -- %s (%s)"
                % (errno.errorcode.get(rc, "unknown error"), rc)
            )
        else:
            try:
                self.protocol.datagramReceived(
                    bytes(evt.buff[:data]), _iocp.makesockaddr(evt.addr_buff)
                )
            except BaseException:
                log.err()

    def doRead(self):
        evt = _iocp.Event(self.cbRead, self)

        evt.buff = buff = self._readBuffers[0]
        evt.addr_buff = addr_buff = self.addressBuffer
        evt.addr_len_buff = addr_len_buff = self.addressLengthBuffer
        rc, data = _iocp.recvfrom(
            self.getFileHandle(), buff, addr_buff, addr_len_buff, evt
        )

        if rc and rc != ERROR_IO_PENDING:
            # If the error was not 0 or IO_PENDING then that means recvfrom() hit a
            # failure condition. In this situation recvfrom() gives us our response
            # right away and we don't need to wait for Windows to call the callback
            # on our event. In fact, windows will not call it for us so we must call it
            # ourselves manually
            self.reactor.callLater(0, self.cbRead, rc, data, evt)

    def write(self, datagram, addr=None):
        """
        Write a datagram.

        @param addr: should be a tuple (ip, port), can be None in connected
        mode.
        """
        if self._connectedAddr:
            assert addr in (None, self._connectedAddr)
            try:
                return self.socket.send(datagram)
            except OSError as se:
                no = se.args[0]
                if no == errno.WSAEINTR:
                    return self.write(datagram)
                elif no == errno.WSAEMSGSIZE:
                    raise error.MessageLengthError("message too long")
                elif no in (
                    errno.WSAECONNREFUSED,
                    errno.WSAECONNRESET,
                    ERROR_CONNECTION_REFUSED,
                    ERROR_PORT_UNREACHABLE,
                ):
                    self.protocol.connectionRefused()
                else:
                    raise
        else:
            assert addr != None
            if (
                not isIPAddress(addr[0])
                and not isIPv6Address(addr[0])
                and addr[0] != "<broadcast>"
            ):
                raise error.InvalidAddressError(
                    addr[0], "write() only accepts IP addresses, not hostnames"
                )
            if isIPAddress(addr[0]) and self.addressFamily == socket.AF_INET6:
                raise error.InvalidAddressError(
                    addr[0], "IPv6 port write() called with IPv4 address"
                )
            if isIPv6Address(addr[0]) and self.addressFamily == socket.AF_INET:
                raise error.InvalidAddressError(
                    addr[0], "IPv4 port write() called with IPv6 address"
                )
            try:
                return self.socket.sendto(datagram, addr)
            except OSError as se:
                no = se.args[0]
                if no == errno.WSAEINTR:
                    return self.write(datagram, addr)
                elif no == errno.WSAEMSGSIZE:
                    raise error.MessageLengthError("message too long")
                elif no in (
                    errno.WSAECONNREFUSED,
                    errno.WSAECONNRESET,
                    ERROR_CONNECTION_REFUSED,
                    ERROR_PORT_UNREACHABLE,
                ):
                    # in non-connected UDP ECONNREFUSED is platform dependent,
                    # I think and the info is not necessarily useful.
                    # Nevertheless maybe we should call connectionRefused? XXX
                    return
                else:
                    raise

    def writeSequence(self, seq, addr):
        self.write(b"".join(seq), addr)

    def connect(self, host, port):
        """
        'Connect' to remote server.
        """
        if self._connectedAddr:
            raise RuntimeError(
                "already connected, reconnecting is not currently supported "
                "(talk to itamar if you want this)"
            )
        if not isIPAddress(host) and not isIPv6Address(host):
            raise error.InvalidAddressError(host, "not an IPv4 or IPv6 address.")
        self._connectedAddr = (host, port)
        self.socket.connect((host, port))

    def _loseConnection(self):
        self.stopReading()
        self.reactor.removeActiveHandle(self)
        if self.connected:  # actually means if we are *listening*
            self.reactor.callLater(0, self.connectionLost)

    def stopListening(self):
        if self.connected:
            result = self.d = defer.Deferred()
        else:
            result = None
        self._loseConnection()
        return result

    def loseConnection(self):
        warnings.warn(
            "Please use stopListening() to disconnect port",
            DeprecationWarning,
            stacklevel=2,
        )
        self.stopListening()

    def connectionLost(self, reason=None):
        """
        Cleans up my socket.
        """
        log.msg("(UDP Port %s Closed)" % self._realPortNumber)
        self._realPortNumber = None
        abstract.FileHandle.connectionLost(self, reason)
        self.protocol.doStop()
        self.socket.close()
        del self.socket
        del self.getFileHandle
        if hasattr(self, "d"):
            self.d.callback(None)
            del self.d

    def setLogStr(self):
        """
        Initialize the C{logstr} attribute to be used by C{logPrefix}.
        """
        logPrefix = self._getLogPrefix(self.protocol)
        self.logstr = "%s (UDP)" % logPrefix

    def logPrefix(self):
        """
        Returns the name of my class, to prefix log entries with.
        """
        return self.logstr

    def getHost(self):
        """
        Return the local address of the UDP connection

        @returns: the local address of the UDP connection
        @rtype: L{IPv4Address} or L{IPv6Address}
        """
        addr = self.socket.getsockname()
        if self.addressFamily == socket.AF_INET:
            return address.IPv4Address("UDP", *addr)
        elif self.addressFamily == socket.AF_INET6:
            return address.IPv6Address("UDP", *(addr[:2]))

    def setBroadcastAllowed(self, enabled):
        """
        Set whether this port may broadcast. This is disabled by default.

        @param enabled: Whether the port may broadcast.
        @type enabled: L{bool}
        """
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, enabled)

    def getBroadcastAllowed(self):
        """
        Checks if broadcast is currently allowed on this port.

        @return: Whether this port may broadcast.
        @rtype: L{bool}
        """
        return bool(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST))


@implementer(interfaces.IMulticastTransport)
class MulticastPort(MulticastMixin, Port):
    """
    UDP Port that supports multicasting.
    """

    def __init__(
        self,
        port: int,
        proto: AbstractDatagramProtocol,
        interface: str = "",
        maxPacketSize: int = 8192,
        reactor: IOCPReactor | None = None,
        listenMultiple: bool = False,
    ) -> None:
        Port.__init__(self, port, proto, interface, maxPacketSize, reactor)
        self.listenMultiple = listenMultiple

    def createSocket(self) -> socket.socket:
        skt = Port.createSocket(self)
        if self.listenMultiple:
            skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if hasattr(socket, "SO_REUSEPORT"):
                skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        return skt
