"""
Test suite for socketserver.
"""

import contextlib
import imp
import os
import select
import signal
import socket
import tempfile
import unittest
import socketserver

import test.support
from test.support import reap_children, reap_threads, verbose
try:
    import threading
except ImportError:
    threading = None

test.support.requires("network")

TEST_STR = b"hello world\n"
HOST = test.support.HOST

HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"

def signal_alarm(n):
    """Call signal.alarm when it exists (i.e. not on Windows)."""
    if hasattr(signal, 'alarm'):
        signal.alarm(n)

def receive(sock, n, timeout=20):
    r, w, x = select.select([sock], [], [], timeout)
    if sock in r:
        return sock.recv(n)
    else:
        raise RuntimeError("timed out on %r" % (sock,))

if HAVE_UNIX_SOCKETS:
    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
                                  socketserver.UnixStreamServer):
        pass

    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
                                    socketserver.UnixDatagramServer):
        pass


@contextlib.contextmanager
def simple_subprocess(testcase):
    pid = os.fork()
    if pid == 0:
        # Don't throw an exception; it would be caught by the test harness.
        os._exit(72)
    yield None
    pid2, status = os.waitpid(pid, 0)
    testcase.assertEqual(pid2, pid)
    testcase.assertEqual(72 << 8, status)


@unittest.skipUnless(threading, 'Threading required for this test.')
class SocketServerTest(unittest.TestCase):
    """Test all socket servers."""

    def setUp(self):
        signal_alarm(60)  # Kill deadlocks after 60 seconds.
        self.port_seed = 0
        self.test_files = []

    def tearDown(self):
        signal_alarm(0)  # Didn't deadlock.
        reap_children()

        for fn in self.test_files:
            try:
                os.remove(fn)
            except os.error:
                pass
        self.test_files[:] = []

    def pickaddr(self, proto):
        if proto == socket.AF_INET:
            return (HOST, 0)
        else:
            # XXX: We need a way to tell AF_UNIX to pick its own name
            # like AF_INET provides port==0.
            dir = None
            if os.name == 'os2':
                dir = '\socket'
            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
            if os.name == 'os2':
                # AF_UNIX socket names on OS/2 require a specific prefix
                # which can't include a drive letter and must also use
                # backslashes as directory separators
                if fn[1] == ':':
                    fn = fn[2:]
                if fn[0] in (os.sep, os.altsep):
                    fn = fn[1:]
                if os.sep == '/':
                    fn = fn.replace(os.sep, os.altsep)
                else:
                    fn = fn.replace(os.altsep, os.sep)
            self.test_files.append(fn)
            return fn

    def make_server(self, addr, svrcls, hdlrbase):
        class MyServer(svrcls):
            def handle_error(self, request, client_address):
                self.close_request(request)
                self.server_close()
                raise

        class MyHandler(hdlrbase):
            def handle(self):
                line = self.rfile.readline()
                self.wfile.write(line)

        if verbose: print("creating server")
        server = MyServer(addr, MyHandler)
        self.assertEqual(server.server_address, server.socket.getsockname())
        return server

    @reap_threads
    def run_server(self, svrcls, hdlrbase, testfunc):
        server = self.make_server(self.pickaddr(svrcls.address_family),
                                  svrcls, hdlrbase)
        # We had the OS pick a port, so pull the real address out of
        # the server.
        addr = server.server_address
        if verbose:
            print("ADDR =", addr)
            print("CLASS =", svrcls)

        t = threading.Thread(
            name='%s serving' % svrcls,
            target=server.serve_forever,
            # Short poll interval to make the test finish quickly.
            # Time between requests is short enough that we won't wake
            # up spuriously too many times.
            kwargs={'poll_interval':0.01})
        t.daemon = True  # In case this function raises.
        t.start()
        if verbose: print("server running")
        for i in range(3):
            if verbose: print("test client", i)
            testfunc(svrcls.address_family, addr)
        if verbose: print("waiting for server")
        server.shutdown()
        t.join()
        server.server_close()
        if verbose: print("done")

    def stream_examine(self, proto, addr):
        s = socket.socket(proto, socket.SOCK_STREAM)
        s.connect(addr)
        s.sendall(TEST_STR)
        buf = data = receive(s, 100)
        while data and b'\n' not in buf:
            data = receive(s, 100)
            buf += data
        self.assertEqual(buf, TEST_STR)
        s.close()

    def dgram_examine(self, proto, addr):
        s = socket.socket(proto, socket.SOCK_DGRAM)
        s.sendto(TEST_STR, addr)
        buf = data = receive(s, 100)
        while data and b'\n' not in buf:
            data = receive(s, 100)
            buf += data
        self.assertEqual(buf, TEST_STR)
        s.close()

    def test_TCPServer(self):
        self.run_server(socketserver.TCPServer,
                        socketserver.StreamRequestHandler,
                        self.stream_examine)

    def test_ThreadingTCPServer(self):
        self.run_server(socketserver.ThreadingTCPServer,
                        socketserver.StreamRequestHandler,
                        self.stream_examine)

    if HAVE_FORKING:
        def test_ForkingTCPServer(self):
            with simple_subprocess(self):
                self.run_server(socketserver.ForkingTCPServer,
                                socketserver.StreamRequestHandler,
                                self.stream_examine)

    if HAVE_UNIX_SOCKETS:
        def test_UnixStreamServer(self):
            self.run_server(socketserver.UnixStreamServer,
                            socketserver.StreamRequestHandler,
                            self.stream_examine)

        def test_ThreadingUnixStreamServer(self):
            self.run_server(socketserver.ThreadingUnixStreamServer,
                            socketserver.StreamRequestHandler,
                            self.stream_examine)

        if HAVE_FORKING:
            def test_ForkingUnixStreamServer(self):
                with simple_subprocess(self):
                    self.run_server(ForkingUnixStreamServer,
                                    socketserver.StreamRequestHandler,
                                    self.stream_examine)

    def test_UDPServer(self):
        self.run_server(socketserver.UDPServer,
                        socketserver.DatagramRequestHandler,
                        self.dgram_examine)

    def test_ThreadingUDPServer(self):
        self.run_server(socketserver.ThreadingUDPServer,
                        socketserver.DatagramRequestHandler,
                        self.dgram_examine)

    if HAVE_FORKING:
        def test_ForkingUDPServer(self):
            with simple_subprocess(self):
                self.run_server(socketserver.ForkingUDPServer,
                                socketserver.DatagramRequestHandler,
                                self.dgram_examine)

    # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
    # client address so this cannot work:

    # if HAVE_UNIX_SOCKETS:
    #     def test_UnixDatagramServer(self):
    #         self.run_server(socketserver.UnixDatagramServer,
    #                         socketserver.DatagramRequestHandler,
    #                         self.dgram_examine)
    #
    #     def test_ThreadingUnixDatagramServer(self):
    #         self.run_server(socketserver.ThreadingUnixDatagramServer,
    #                         socketserver.DatagramRequestHandler,
    #                         self.dgram_examine)
    #
    #     if HAVE_FORKING:
    #         def test_ForkingUnixDatagramServer(self):
    #             self.run_server(socketserver.ForkingUnixDatagramServer,
    #                             socketserver.DatagramRequestHandler,
    #                             self.dgram_examine)

    @reap_threads
    def test_shutdown(self):
        # Issue #2302: shutdown() should always succeed in making an
        # other thread leave serve_forever().
        class MyServer(socketserver.TCPServer):
            pass

        class MyHandler(socketserver.StreamRequestHandler):
            pass

        threads = []
        for i in range(20):
            s = MyServer((HOST, 0), MyHandler)
            t = threading.Thread(
                name='MyServer serving',
                target=s.serve_forever,
                kwargs={'poll_interval':0.01})
            t.daemon = True  # In case this function raises.
            threads.append((t, s))
        for t, s in threads:
            t.start()
            s.shutdown()
        for t, s in threads:
            t.join()
            s.server_close()


def test_main():
    if imp.lock_held():
        # If the import lock is held, the threads will hang
        raise unittest.SkipTest("can't run when import lock is held")

    test.support.run_unittest(SocketServerTest)

if __name__ == "__main__":
    test_main()