summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/base_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r--Lib/asyncio/base_events.py102
1 files changed, 82 insertions, 20 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index e722cf2..94eb308 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -157,47 +157,106 @@ def _run_until_complete_cb(fut):
class Server(events.AbstractServer):
- def __init__(self, loop, sockets):
+ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
+ ssl_handshake_timeout):
self._loop = loop
- self.sockets = sockets
+ self._sockets = sockets
self._active_count = 0
self._waiters = []
+ self._protocol_factory = protocol_factory
+ self._backlog = backlog
+ self._ssl_context = ssl_context
+ self._ssl_handshake_timeout = ssl_handshake_timeout
+ self._serving = False
+ self._serving_forever_fut = None
def __repr__(self):
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
def _attach(self):
- assert self.sockets is not None
+ assert self._sockets is not None
self._active_count += 1
def _detach(self):
assert self._active_count > 0
self._active_count -= 1
- if self._active_count == 0 and self.sockets is None:
+ if self._active_count == 0 and self._sockets is None:
self._wakeup()
+ def _wakeup(self):
+ waiters = self._waiters
+ self._waiters = None
+ for waiter in waiters:
+ if not waiter.done():
+ waiter.set_result(waiter)
+
+ def _start_serving(self):
+ if self._serving:
+ return
+ self._serving = True
+ for sock in self._sockets:
+ sock.listen(self._backlog)
+ self._loop._start_serving(
+ self._protocol_factory, sock, self._ssl_context,
+ self, self._backlog, self._ssl_handshake_timeout)
+
+ def get_loop(self):
+ return self._loop
+
+ def is_serving(self):
+ return self._serving
+
+ @property
+ def sockets(self):
+ if self._sockets is None:
+ return []
+ return list(self._sockets)
+
def close(self):
- sockets = self.sockets
+ sockets = self._sockets
if sockets is None:
return
- self.sockets = None
+ self._sockets = None
+
for sock in sockets:
self._loop._stop_serving(sock)
+
+ self._serving = False
+
+ if (self._serving_forever_fut is not None and
+ not self._serving_forever_fut.done()):
+ self._serving_forever_fut.cancel()
+ self._serving_forever_fut = None
+
if self._active_count == 0:
self._wakeup()
- def get_loop(self):
- return self._loop
+ async def start_serving(self):
+ self._start_serving()
- def _wakeup(self):
- waiters = self._waiters
- self._waiters = None
- for waiter in waiters:
- if not waiter.done():
- waiter.set_result(waiter)
+ async def serve_forever(self):
+ if self._serving_forever_fut is not None:
+ raise RuntimeError(
+ f'server {self!r} is already being awaited on serve_forever()')
+ if self._sockets is None:
+ raise RuntimeError(f'server {self!r} is closed')
+
+ self._start_serving()
+ self._serving_forever_fut = self._loop.create_future()
+
+ try:
+ await self._serving_forever_fut
+ except futures.CancelledError:
+ try:
+ self.close()
+ await self.wait_closed()
+ finally:
+ raise
+ finally:
+ self._serving_forever_fut = None
async def wait_closed(self):
- if self.sockets is None or self._waiters is None:
+ if self._sockets is None or self._waiters is None:
return
waiter = self._loop.create_future()
self._waiters.append(waiter)
@@ -1059,7 +1118,8 @@ class BaseEventLoop(events.AbstractEventLoop):
ssl=None,
reuse_address=None,
reuse_port=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ start_serving=True):
"""Create a TCP server.
The host parameter can be a string, in that case the TCP server is
@@ -1149,12 +1209,14 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
sockets = [sock]
- server = Server(self, sockets)
for sock in sockets:
- sock.listen(backlog)
sock.setblocking(False)
- self._start_serving(protocol_factory, sock, ssl, server, backlog,
- ssl_handshake_timeout)
+
+ server = Server(self, sockets, protocol_factory,
+ ssl, backlog, ssl_handshake_timeout)
+ if start_serving:
+ server._start_serving()
+
if self._debug:
logger.info("%r is serving", server)
return server