stormbrigade_sheriff/sbsheriff/Lib/site-packages/asyncpg/connect_utils.py

957 lines
29 KiB
Python

# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import collections
import enum
import functools
import getpass
import os
import pathlib
import platform
import re
import socket
import ssl as ssl_module
import stat
import struct
import sys
import time
import typing
import urllib.parse
import warnings
import inspect
from . import compat
from . import exceptions
from . import protocol
class SSLMode(enum.IntEnum):
disable = 0
allow = 1
prefer = 2
require = 3
verify_ca = 4
verify_full = 5
@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace('-', '_'))
_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'sslmode',
'direct_tls',
'connect_timeout',
'server_settings',
])
_ClientConfiguration = collections.namedtuple(
'ConnectionConfiguration',
[
'command_timeout',
'statement_cache_size',
'max_cached_statement_lifetime',
'max_cacheable_statement_size',
])
_system = platform.uname().system
if _system == 'Windows':
PGPASSFILE = 'pgpass.conf'
else:
PGPASSFILE = '.pgpass'
def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:
passtab = []
try:
if not passfile.exists():
return []
if not passfile.is_file():
warnings.warn(
'password file {!r} is not a plain file'.format(passfile))
return []
if _system != 'Windows':
if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
warnings.warn(
'password file {!r} has group or world access; '
'permissions should be u=rw (0600) or less'.format(
passfile))
return []
with passfile.open('rt') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
# Skip empty lines and comments.
continue
# Backslash escapes both itself and the colon,
# which is a record separator.
line = line.replace(R'\\', '\n')
passtab.append(tuple(
p.replace('\n', R'\\')
for p in re.split(r'(?<!\\):', line, maxsplit=4)
))
except IOError:
pass
return passtab
def _read_password_from_pgpass(
*, passfile: typing.Optional[pathlib.Path],
hosts: typing.List[str],
ports: typing.List[int],
database: str,
user: str):
"""Parse the pgpass file and return the matching password.
:return:
Password string, if found, ``None`` otherwise.
"""
passtab = _read_password_file(passfile)
if not passtab:
return None
for host, port in zip(hosts, ports):
if host.startswith('/'):
# Unix sockets get normalized into 'localhost'
host = 'localhost'
for phost, pport, pdatabase, puser, ppassword in passtab:
if phost != '*' and phost != host:
continue
if pport != '*' and pport != str(port):
continue
if pdatabase != '*' and pdatabase != database:
continue
if puser != '*' and puser != user:
continue
# Found a match.
return ppassword
return None
def _validate_port_spec(hosts, port):
if isinstance(port, list):
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
raise exceptions.InterfaceError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
else:
port = [port for _ in range(len(hosts))]
return port
def _parse_hostlist(hostlist, port, *, unquote=False):
if ',' in hostlist:
# A comma-separated list of host addresses.
hostspecs = hostlist.split(',')
else:
hostspecs = [hostlist]
hosts = []
hostlist_ports = []
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
default_port = [int(p) for p in portspec.split(',')]
else:
default_port = int(portspec)
else:
default_port = 5432
default_port = _validate_port_spec(hostspecs, default_port)
else:
port = _validate_port_spec(hostspecs, port)
for i, hostspec in enumerate(hostspecs):
if hostspec[0] == '/':
# Unix socket
addr = hostspec
hostspec_port = ''
elif hostspec[0] == '[':
# IPv6 address
m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
if m:
addr = m.group(1)
hostspec_port = m.group(2)
else:
raise ValueError(
'invalid IPv6 address in the connection URI: {!r}'.format(
hostspec
)
)
else:
# IPv4 address
addr, _, hostspec_port = hostspec.partition(':')
if unquote:
addr = urllib.parse.unquote(addr)
hosts.append(addr)
if not port:
if hostspec_port:
if unquote:
hostspec_port = urllib.parse.unquote(hostspec_port)
hostlist_ports.append(int(hostspec_port))
else:
hostlist_ports.append(default_port[i])
if not port:
port = hostlist_ports
return hosts, port
def _parse_tls_version(tls_version):
if tls_version.startswith('SSL'):
raise ValueError(
f"Unsupported TLS version: {tls_version}"
)
try:
return ssl_module.TLSVersion[tls_version.replace('.', '_')]
except KeyError:
raise ValueError(
f"No such TLS version: {tls_version}"
)
def _dot_postgresql_path(filename) -> pathlib.Path:
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, connect_timeout, server_settings):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
ssl_min_protocol_version = ssl_max_protocol_version = None
if dsn:
parsed = urllib.parse.urlparse(dsn)
if parsed.scheme not in {'postgresql', 'postgres'}:
raise ValueError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
if parsed.netloc:
if '@' in parsed.netloc:
dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@')
else:
dsn_hostspec = parsed.netloc
dsn_auth = ''
else:
dsn_auth = dsn_hostspec = ''
if dsn_auth:
dsn_user, _, dsn_password = dsn_auth.partition(':')
else:
dsn_user = dsn_password = ''
if not host and dsn_hostspec:
host, port = _parse_hostlist(dsn_hostspec, port, unquote=True)
if parsed.path and database is None:
dsn_database = parsed.path
if dsn_database.startswith('/'):
dsn_database = dsn_database[1:]
database = urllib.parse.unquote(dsn_database)
if user is None and dsn_user:
user = urllib.parse.unquote(dsn_user)
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
if 'port' in query:
val = query.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]
if 'host' in query:
val = query.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)
if 'dbname' in query:
val = query.pop('dbname')
if database is None:
database = val
if 'database' in query:
val = query.pop('database')
if database is None:
database = val
if 'user' in query:
val = query.pop('user')
if user is None:
user = val
if 'password' in query:
val = query.pop('password')
if password is None:
password = val
if 'passfile' in query:
val = query.pop('passfile')
if passfile is None:
passfile = val
if 'sslmode' in query:
val = query.pop('sslmode')
if ssl is None:
ssl = val
if 'sslcert' in query:
sslcert = query.pop('sslcert')
if 'sslkey' in query:
sslkey = query.pop('sslkey')
if 'sslrootcert' in query:
sslrootcert = query.pop('sslrootcert')
if 'sslcrl' in query:
sslcrl = query.pop('sslcrl')
if 'sslpassword' in query:
sslpassword = query.pop('sslpassword')
if 'ssl_min_protocol_version' in query:
ssl_min_protocol_version = query.pop(
'ssl_min_protocol_version'
)
if 'ssl_max_protocol_version' in query:
ssl_max_protocol_version = query.pop(
'ssl_max_protocol_version'
)
if query:
if server_settings is None:
server_settings = query
else:
server_settings = {**query, **server_settings}
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
host, port = _parse_hostlist(hostspec, port)
if not host:
auth_hosts = ['localhost']
if _system == 'Windows':
host = ['localhost']
else:
host = ['/run/postgresql', '/var/run/postgresql',
'/tmp', '/private/tmp', 'localhost']
if not isinstance(host, list):
host = [host]
if auth_hosts is None:
auth_hosts = host
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
port = [int(p) for p in portspec.split(',')]
else:
port = int(portspec)
else:
port = 5432
elif isinstance(port, (list, tuple)):
port = [int(p) for p in port]
else:
port = int(port)
port = _validate_port_spec(host, port)
if user is None:
user = os.getenv('PGUSER')
if not user:
user = getpass.getuser()
if password is None:
password = os.getenv('PGPASSWORD')
if database is None:
database = os.getenv('PGDATABASE')
if database is None:
database = user
if user is None:
raise exceptions.InterfaceError(
'could not determine user name to connect with')
if database is None:
raise exceptions.InterfaceError(
'could not determine database name to connect to')
if password is None:
if passfile is None:
passfile = os.getenv('PGPASSFILE')
if passfile is None:
homedir = compat.get_pg_home_directory()
if homedir:
passfile = homedir / PGPASSFILE
else:
passfile = None
else:
passfile = pathlib.Path(passfile)
if passfile is not None:
password = _read_password_from_pgpass(
hosts=auth_hosts, ports=port,
database=database, user=user,
passfile=passfile)
addrs = []
have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
# UNIX socket name
if '.s.PGSQL.' not in h:
h = os.path.join(h, '.s.PGSQL.{}'.format(p))
addrs.append(h)
else:
# TCP host/port
addrs.append((h, p))
have_tcp_addrs = True
if not addrs:
raise ValueError(
'could not determine the database address to connect to')
if ssl is None:
ssl = os.getenv('PGSSLMODE')
if ssl is None and have_tcp_addrs:
ssl = 'prefer'
if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.InterfaceError(
'`sslmode` parameter must be one of: {}'.format(modes))
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
if sslmode < SSLMode.allow:
ssl = False
else:
ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
ssl.check_hostname = sslmode >= SSLMode.verify_full
if sslmode < SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
else:
if sslrootcert is None:
sslrootcert = os.getenv('PGSSLROOTCERT')
if sslrootcert:
ssl.load_verify_locations(cafile=sslrootcert)
ssl.verify_mode = ssl_module.CERT_REQUIRED
else:
sslrootcert = _dot_postgresql_path('root.crt')
try:
ssl.load_verify_locations(cafile=sslrootcert)
except FileNotFoundError:
if sslmode > SSLMode.require:
raise ValueError(
f'root certificate file "{sslrootcert}" does '
f'not exist\nEither provide the file or '
f'change sslmode to disable server '
f'certificate verification.'
)
elif sslmode == SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
else:
assert False, 'unreachable'
else:
ssl.verify_mode = ssl_module.CERT_REQUIRED
if sslcrl is None:
sslcrl = os.getenv('PGSSLCRL')
if sslcrl:
ssl.load_verify_locations(cafile=sslcrl)
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
else:
sslcrl = _dot_postgresql_path('root.crl')
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
pass
else:
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
if sslkey is None:
sslkey = os.getenv('PGSSLKEY')
if not sslkey:
sslkey = _dot_postgresql_path('postgresql.key')
if not sslkey.exists():
sslkey = None
if not sslpassword:
sslpassword = ''
if sslcert is None:
sslcert = os.getenv('PGSSLCERT')
if sslcert:
ssl.load_cert_chain(
sslcert, keyfile=sslkey, password=lambda: sslpassword
)
else:
sslcert = _dot_postgresql_path('postgresql.crt')
try:
ssl.load_cert_chain(
sslcert, keyfile=sslkey, password=lambda: sslpassword
)
except FileNotFoundError:
pass
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
if hasattr(ssl, 'keylog_filename'):
keylogfile = os.environ.get('SSLKEYLOGFILE')
if keylogfile and not sys.flags.ignore_environment:
ssl.keylog_filename = keylogfile
if ssl_min_protocol_version is None:
ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
if ssl_min_protocol_version:
ssl.minimum_version = _parse_tls_version(
ssl_min_protocol_version
)
else:
ssl.minimum_version = _parse_tls_version('TLSv1.2')
if ssl_max_protocol_version is None:
ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
if ssl_max_protocol_version:
ssl.maximum_version = _parse_tls_version(
ssl_max_protocol_version
)
elif ssl is True:
ssl = ssl_module.create_default_context()
sslmode = SSLMode.verify_full
else:
sslmode = SSLMode.disable
if server_settings is not None and (
not isinstance(server_settings, dict) or
not all(isinstance(k, str) for k in server_settings) or
not all(isinstance(v, str) for v in server_settings.values())):
raise ValueError(
'server_settings is expected to be None or '
'a Dict[str, str]')
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings)
return addrs, params
def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
database, timeout, command_timeout,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
'statement_cache_size'}:
var_val = local_vars[var_name]
if var_val is None or isinstance(var_val, bool) or var_val < 0:
raise ValueError(
'{} is expected to be greater '
'or equal to 0, got {!r}'.format(var_name, var_val))
if command_timeout is not None:
try:
if isinstance(command_timeout, bool):
raise ValueError
command_timeout = float(command_timeout)
if command_timeout <= 0:
raise ValueError
except ValueError:
raise ValueError(
'invalid command_timeout value: '
'expected greater than 0 float (got {!r})'.format(
command_timeout)) from None
addrs, params = _parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings)
config = _ClientConfiguration(
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,)
return addrs, params, config
class TLSUpgradeProto(asyncio.Protocol):
def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
self.on_data = _create_future(loop)
self.host = host
self.port = port
self.ssl_context = ssl_context
self.ssl_is_advisory = ssl_is_advisory
def data_received(self, data):
if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
self.ssl_context.verify_mode == ssl_module.CERT_NONE and
data == b'N'):
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
# since the only way to get ssl_is_advisory is from
# sslmode=prefer. But be extra sure to disallow insecure
# connections when the ssl context asks for real security.
self.on_data.set_result(False)
else:
self.on_data.set_exception(
ConnectionError(
'PostgreSQL server at "{host}:{port}" '
'rejected SSL upgrade'.format(
host=self.host, port=self.port)))
def connection_lost(self, exc):
if not self.on_data.done():
if exc is None:
exc = ConnectionError('unexpected connection_lost() call')
self.on_data.set_exception(exc)
async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):
tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
ssl_context, ssl_is_advisory),
host, port)
tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
try:
do_ssl_upgrade = await pr.on_data
except (Exception, asyncio.CancelledError):
tr.close()
raise
if hasattr(loop, 'start_tls'):
if do_ssl_upgrade:
try:
new_tr = await loop.start_tls(
tr, pr, ssl_context, server_hostname=host)
except (Exception, asyncio.CancelledError):
tr.close()
raise
else:
new_tr = tr
pg_proto = protocol_factory()
pg_proto.is_ssl = do_ssl_upgrade
pg_proto.connection_made(new_tr)
new_tr.set_protocol(pg_proto)
return new_tr, pg_proto
else:
conn_factory = functools.partial(
loop.create_connection, protocol_factory)
if do_ssl_upgrade:
conn_factory = functools.partial(
conn_factory, ssl=ssl_context, server_hostname=host)
sock = _get_socket(tr)
sock = sock.dup()
_set_nodelay(sock)
tr.close()
try:
new_tr, pg_proto = await conn_factory(sock=sock)
pg_proto.is_ssl = do_ssl_upgrade
return new_tr, pg_proto
except (Exception, asyncio.CancelledError):
sock.close()
raise
async def _connect_addr(
*,
addr,
loop,
timeout,
params,
config,
connection_class,
record_class
):
assert loop is not None
if timeout <= 0:
raise asyncio.TimeoutError
params_input = params
if callable(params.password):
password = params.password()
if inspect.isawaitable(password):
password = await password
params = params._replace(password=password)
args = (addr, loop, config, connection_class, record_class, params_input)
# prepare the params (which attempt has ssl) for the 2 attempts
if params.sslmode == SSLMode.allow:
params_retry = params
params = params._replace(ssl=None)
elif params.sslmode == SSLMode.prefer:
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
return await __connect_addr(params, timeout, False, *args)
# first attempt
before = time.monotonic()
try:
return await __connect_addr(params, timeout, True, *args)
except _RetryConnectSignal:
pass
# second attempt
timeout -= time.monotonic() - before
if timeout <= 0:
raise asyncio.TimeoutError
else:
return await __connect_addr(params_retry, timeout, False, *args)
class _RetryConnectSignal(Exception):
pass
async def __connect_addr(
params,
timeout,
retry,
addr,
loop,
config,
connection_class,
record_class,
params_input,
):
connected = _create_future(loop)
proto_factory = lambda: protocol.Protocol(
addr, connected, params, record_class, loop)
if isinstance(addr, str):
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)
elif params.ssl and params.direct_tls:
# if ssl and direct_tls are given, skip STARTTLS and perform direct
# SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)
elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
connector = loop.create_connection(proto_factory, *addr)
connector = asyncio.ensure_future(connector)
before = time.monotonic()
tr, pr = await compat.wait_for(connector, timeout=timeout)
timeout -= time.monotonic() - before
try:
if timeout <= 0:
raise asyncio.TimeoutError
await compat.wait_for(connected, timeout=timeout)
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
):
tr.close()
# retry=True here is a redundant check because we don't want to
# accidentally raise the internal _RetryConnectSignal to the user
if retry and (
params.sslmode == SSLMode.allow and not pr.is_ssl or
params.sslmode == SSLMode.prefer and pr.is_ssl
):
# Trigger retry when:
# 1. First attempt with sslmode=allow, ssl=None failed
# 2. First attempt with sslmode=prefer, ssl=ctx failed while the
# server claimed to support SSL (returning "S" for SSLRequest)
# (likely because pg_hba.conf rejected the connection)
raise _RetryConnectSignal()
else:
# but will NOT retry if:
# 1. First attempt with sslmode=prefer failed but the server
# doesn't support SSL (returning 'N' for SSLRequest), because
# we already tried to connect without SSL thru ssl_is_advisory
# 2. Second attempt with sslmode=prefer, ssl=None failed
# 3. Second attempt with sslmode=allow, ssl=ctx failed
# 4. Any other sslmode
raise
except (Exception, asyncio.CancelledError):
tr.close()
raise
con = connection_class(pr, tr, loop, addr, config, params_input)
pr.set_connection(con)
return con
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
last_error = None
addr = None
for addr in addrs:
before = time.monotonic()
try:
return await _connect_addr(
addr=addr,
loop=loop,
timeout=timeout,
params=params,
config=config,
connection_class=connection_class,
record_class=record_class,
)
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
finally:
timeout -= time.monotonic() - before
raise last_error
async def _cancel(*, loop, addr, params: _ConnectionParameters,
backend_pid, backend_secret):
class CancelProto(asyncio.Protocol):
def __init__(self):
self.on_disconnect = _create_future(loop)
self.is_ssl = False
def connection_lost(self, exc):
if not self.on_disconnect.done():
self.on_disconnect.set_result(True)
if isinstance(addr, str):
tr, pr = await loop.create_unix_connection(CancelProto, addr)
else:
if params.ssl and params.sslmode != SSLMode.allow:
tr, pr = await _create_ssl_connection(
CancelProto,
*addr,
loop=loop,
ssl_context=params.ssl,
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
tr, pr = await loop.create_connection(
CancelProto, *addr)
_set_nodelay(_get_socket(tr))
# Pack a CancelRequest message
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
try:
tr.write(msg)
await pr.on_disconnect
finally:
tr.close()
def _get_socket(transport):
sock = transport.get_extra_info('socket')
if sock is None:
# Shouldn't happen with any asyncio-complaint event loop.
raise ConnectionError(
'could not get the socket for transport {!r}'.format(transport))
return sock
def _set_nodelay(sock):
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
def _create_future(loop):
try:
create_future = loop.create_future
except AttributeError:
return asyncio.Future(loop=loop)
else:
return create_future()