307 lines
9.6 KiB
Python
307 lines
9.6 KiB
Python
# Copyright (C) 2016-present the asyncpg authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of asyncpg and is released under
|
|
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
import asyncio
|
|
import socket
|
|
import threading
|
|
import typing
|
|
|
|
from asyncpg import cluster
|
|
|
|
|
|
class StopServer(Exception):
|
|
pass
|
|
|
|
|
|
class TCPFuzzingProxy:
|
|
def __init__(self, *, listening_addr: str='127.0.0.1',
|
|
listening_port: typing.Optional[int]=None,
|
|
backend_host: str, backend_port: int,
|
|
settings: typing.Optional[dict]=None) -> None:
|
|
self.listening_addr = listening_addr
|
|
self.listening_port = listening_port
|
|
self.backend_host = backend_host
|
|
self.backend_port = backend_port
|
|
self.settings = settings or {}
|
|
self.loop = None
|
|
self.connectivity = None
|
|
self.connectivity_loss = None
|
|
self.stop_event = None
|
|
self.connections = {}
|
|
self.sock = None
|
|
self.listen_task = None
|
|
|
|
async def _wait(self, work):
|
|
work_task = asyncio.ensure_future(work)
|
|
stop_event_task = asyncio.ensure_future(self.stop_event.wait())
|
|
|
|
try:
|
|
await asyncio.wait(
|
|
[work_task, stop_event_task],
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
if self.stop_event.is_set():
|
|
raise StopServer()
|
|
else:
|
|
return work_task.result()
|
|
finally:
|
|
if not work_task.done():
|
|
work_task.cancel()
|
|
if not stop_event_task.done():
|
|
stop_event_task.cancel()
|
|
|
|
def start(self):
|
|
started = threading.Event()
|
|
self.thread = threading.Thread(
|
|
target=self._start_thread, args=(started,))
|
|
self.thread.start()
|
|
if not started.wait(timeout=2):
|
|
raise RuntimeError('fuzzer proxy failed to start')
|
|
|
|
def stop(self):
|
|
self.loop.call_soon_threadsafe(self._stop)
|
|
self.thread.join()
|
|
|
|
def _stop(self):
|
|
self.stop_event.set()
|
|
|
|
def _start_thread(self, started_event):
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
self.connectivity = asyncio.Event()
|
|
self.connectivity.set()
|
|
self.connectivity_loss = asyncio.Event()
|
|
self.stop_event = asyncio.Event()
|
|
|
|
if self.listening_port is None:
|
|
self.listening_port = cluster.find_available_port()
|
|
|
|
self.sock = socket.socket()
|
|
self.sock.bind((self.listening_addr, self.listening_port))
|
|
self.sock.listen(50)
|
|
self.sock.setblocking(False)
|
|
|
|
try:
|
|
self.loop.run_until_complete(self._main(started_event))
|
|
finally:
|
|
self.loop.close()
|
|
|
|
async def _main(self, started_event):
|
|
self.listen_task = asyncio.ensure_future(self.listen())
|
|
# Notify the main thread that we are ready to go.
|
|
started_event.set()
|
|
try:
|
|
await self.listen_task
|
|
finally:
|
|
for c in list(self.connections):
|
|
c.close()
|
|
await asyncio.sleep(0.01)
|
|
if hasattr(self.loop, 'remove_reader'):
|
|
self.loop.remove_reader(self.sock.fileno())
|
|
self.sock.close()
|
|
|
|
async def listen(self):
|
|
while True:
|
|
try:
|
|
client_sock, _ = await self._wait(
|
|
self.loop.sock_accept(self.sock))
|
|
|
|
backend_sock = socket.socket()
|
|
backend_sock.setblocking(False)
|
|
|
|
await self._wait(self.loop.sock_connect(
|
|
backend_sock, (self.backend_host, self.backend_port)))
|
|
except StopServer:
|
|
break
|
|
|
|
conn = Connection(client_sock, backend_sock, self)
|
|
conn_task = self.loop.create_task(conn.handle())
|
|
self.connections[conn] = conn_task
|
|
|
|
def trigger_connectivity_loss(self):
|
|
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
|
|
|
|
def _trigger_connectivity_loss(self):
|
|
self.connectivity.clear()
|
|
self.connectivity_loss.set()
|
|
|
|
def restore_connectivity(self):
|
|
self.loop.call_soon_threadsafe(self._restore_connectivity)
|
|
|
|
def _restore_connectivity(self):
|
|
self.connectivity.set()
|
|
self.connectivity_loss.clear()
|
|
|
|
def reset(self):
|
|
self.restore_connectivity()
|
|
|
|
def _close_connection(self, connection):
|
|
conn_task = self.connections.pop(connection, None)
|
|
if conn_task is not None:
|
|
conn_task.cancel()
|
|
|
|
def close_all_connections(self):
|
|
for conn in list(self.connections):
|
|
self.loop.call_soon_threadsafe(self._close_connection, conn)
|
|
|
|
|
|
class Connection:
|
|
def __init__(self, client_sock, backend_sock, proxy):
|
|
self.client_sock = client_sock
|
|
self.backend_sock = backend_sock
|
|
self.proxy = proxy
|
|
self.loop = proxy.loop
|
|
self.connectivity = proxy.connectivity
|
|
self.connectivity_loss = proxy.connectivity_loss
|
|
self.proxy_to_backend_task = None
|
|
self.proxy_from_backend_task = None
|
|
self.is_closed = False
|
|
|
|
def close(self):
|
|
if self.is_closed:
|
|
return
|
|
|
|
self.is_closed = True
|
|
|
|
if self.proxy_to_backend_task is not None:
|
|
self.proxy_to_backend_task.cancel()
|
|
self.proxy_to_backend_task = None
|
|
|
|
if self.proxy_from_backend_task is not None:
|
|
self.proxy_from_backend_task.cancel()
|
|
self.proxy_from_backend_task = None
|
|
|
|
self.proxy._close_connection(self)
|
|
|
|
async def handle(self):
|
|
self.proxy_to_backend_task = asyncio.ensure_future(
|
|
self.proxy_to_backend())
|
|
|
|
self.proxy_from_backend_task = asyncio.ensure_future(
|
|
self.proxy_from_backend())
|
|
|
|
try:
|
|
await asyncio.wait(
|
|
[self.proxy_to_backend_task, self.proxy_from_backend_task],
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
finally:
|
|
if self.proxy_to_backend_task is not None:
|
|
self.proxy_to_backend_task.cancel()
|
|
|
|
if self.proxy_from_backend_task is not None:
|
|
self.proxy_from_backend_task.cancel()
|
|
|
|
# Asyncio fails to properly remove the readers and writers
|
|
# when the task doing recv() or send() is cancelled, so
|
|
# we must remove the readers and writers manually before
|
|
# closing the sockets.
|
|
self.loop.remove_reader(self.client_sock.fileno())
|
|
self.loop.remove_writer(self.client_sock.fileno())
|
|
self.loop.remove_reader(self.backend_sock.fileno())
|
|
self.loop.remove_writer(self.backend_sock.fileno())
|
|
|
|
self.client_sock.close()
|
|
self.backend_sock.close()
|
|
|
|
async def _read(self, sock, n):
|
|
read_task = asyncio.ensure_future(
|
|
self.loop.sock_recv(sock, n))
|
|
conn_event_task = asyncio.ensure_future(
|
|
self.connectivity_loss.wait())
|
|
|
|
try:
|
|
await asyncio.wait(
|
|
[read_task, conn_event_task],
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
if self.connectivity_loss.is_set():
|
|
return None
|
|
else:
|
|
return read_task.result()
|
|
finally:
|
|
if not self.loop.is_closed():
|
|
if not read_task.done():
|
|
read_task.cancel()
|
|
if not conn_event_task.done():
|
|
conn_event_task.cancel()
|
|
|
|
async def _write(self, sock, data):
|
|
write_task = asyncio.ensure_future(
|
|
self.loop.sock_sendall(sock, data))
|
|
conn_event_task = asyncio.ensure_future(
|
|
self.connectivity_loss.wait())
|
|
|
|
try:
|
|
await asyncio.wait(
|
|
[write_task, conn_event_task],
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
if self.connectivity_loss.is_set():
|
|
return None
|
|
else:
|
|
return write_task.result()
|
|
finally:
|
|
if not self.loop.is_closed():
|
|
if not write_task.done():
|
|
write_task.cancel()
|
|
if not conn_event_task.done():
|
|
conn_event_task.cancel()
|
|
|
|
async def proxy_to_backend(self):
|
|
buf = None
|
|
|
|
try:
|
|
while True:
|
|
await self.connectivity.wait()
|
|
if buf is not None:
|
|
data = buf
|
|
buf = None
|
|
else:
|
|
data = await self._read(self.client_sock, 4096)
|
|
if data == b'':
|
|
break
|
|
if self.connectivity_loss.is_set():
|
|
if data:
|
|
buf = data
|
|
continue
|
|
await self._write(self.backend_sock, data)
|
|
|
|
except ConnectionError:
|
|
pass
|
|
|
|
finally:
|
|
if not self.loop.is_closed():
|
|
self.loop.call_soon(self.close)
|
|
|
|
async def proxy_from_backend(self):
|
|
buf = None
|
|
|
|
try:
|
|
while True:
|
|
await self.connectivity.wait()
|
|
if buf is not None:
|
|
data = buf
|
|
buf = None
|
|
else:
|
|
data = await self._read(self.backend_sock, 4096)
|
|
if data == b'':
|
|
break
|
|
if self.connectivity_loss.is_set():
|
|
if data:
|
|
buf = data
|
|
continue
|
|
await self._write(self.client_sock, data)
|
|
|
|
except ConnectionError:
|
|
pass
|
|
|
|
finally:
|
|
if not self.loop.is_closed():
|
|
self.loop.call_soon(self.close)
|