stormbrigade_sheriff/sbsheriff/Lib/site-packages/asyncpg/cluster.py

689 lines
23 KiB
Python

# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import os
import os.path
import platform
import re
import shutil
import socket
import subprocess
import sys
import tempfile
import textwrap
import time
import asyncpg
from asyncpg import serverversion
_system = platform.uname().system
if _system == 'Windows':
def platform_exe(name):
if name.endswith('.exe'):
return name
return name + '.exe'
else:
def platform_exe(name):
return name
def find_available_port():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('127.0.0.1', 0))
return sock.getsockname()[1]
except Exception:
return None
finally:
sock.close()
class ClusterError(Exception):
pass
class Cluster:
def __init__(self, data_dir, *, pg_config_path=None):
self._data_dir = data_dir
self._pg_config_path = pg_config_path
self._pg_bin_dir = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
self._pg_ctl = None
self._daemon_pid = None
self._daemon_process = None
self._connection_addr = None
self._connection_spec_override = None
def get_pg_version(self):
return self._pg_version
def is_managed(self):
return True
def get_data_dir(self):
return self._data_dir
def get_status(self):
if self._pg_ctl is None:
self._init_env()
process = subprocess.run(
[self._pg_ctl, 'status', '-D', self._data_dir],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
not os.listdir(self._data_dir)):
return 'not-initialized'
elif process.returncode == 3:
return 'stopped'
elif process.returncode == 0:
r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
if not r:
raise ClusterError(
'could not parse pg_ctl status output: {}'.format(
stdout.decode()))
self._daemon_pid = int(r.group(1))
return self._test_connection(timeout=0)
else:
raise ClusterError(
'pg_ctl status exited with status {:d}: {}'.format(
process.returncode, stderr))
async def connect(self, loop=None, **kwargs):
conn_info = self.get_connection_spec()
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
settings = dict(settings)
if 'encoding' not in settings:
settings['encoding'] = 'UTF-8'
if settings:
settings_args = ['--{}={}'.format(k, v)
for k, v in settings.items()]
extra_args = ['-o'] + [' '.join(settings_args)]
else:
extra_args = []
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_ctl init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
"""Start the cluster."""
status = self.get_status()
if status == 'running':
return
elif status == 'not-initialized':
raise ClusterError(
'cluster in {!r} has not been initialized'.format(
self._data_dir))
port = opts.pop('port', None)
if port == 'dynamic':
port = find_available_port()
extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
extra_args.append('--port={}'.format(port))
sockdir = server_settings.get('unix_socket_directories')
if sockdir is None:
sockdir = server_settings.get('unix_socket_directory')
if sockdir is None and _system != 'Windows':
sockdir = tempfile.gettempdir()
ssl_key = server_settings.get('ssl_key_file')
if ssl_key:
# Make sure server certificate key file has correct permissions.
keyfile = os.path.join(self._data_dir, 'srvkey.pem')
shutil.copy(ssl_key, keyfile)
os.chmod(keyfile, 0o600)
server_settings = server_settings.copy()
server_settings['ssl_key_file'] = keyfile
if sockdir is not None:
if self._pg_version < (9, 3):
sockdir_opt = 'unix_socket_directory'
else:
sockdir_opt = 'unix_socket_directories'
server_settings[sockdir_opt] = sockdir
for k, v in server_settings.items():
extra_args.extend(['-c', '{}={}'.format(k, v)])
if _system == 'Windows':
# On Windows we have to use pg_ctl as direct execution
# of postgres daemon under an Administrative account
# is not permitted and there is no easy way to drop
# privileges.
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
print(
'asyncpg.cluster: Running',
' '.join([
self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)
]),
file=sys.stderr,
)
else:
stdout = subprocess.DEVNULL
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
stdout=stdout, stderr=subprocess.STDOUT)
if process.returncode != 0:
if process.stderr:
stderr = ':\n{}'.format(process.stderr.decode())
else:
stderr = ''
raise ClusterError(
'pg_ctl start exited with status {:d}{}'.format(
process.returncode, stderr))
else:
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
else:
stdout = subprocess.DEVNULL
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=stdout, stderr=subprocess.STDOUT)
self._daemon_pid = self._daemon_process.pid
self._test_connection(timeout=wait)
def reload(self):
"""Reload server configuration."""
status = self.get_status()
if status != 'running':
raise ClusterError('cannot reload: cluster is not running')
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def stop(self, wait=60):
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
if (self._daemon_process is not None and
self._daemon_process.returncode is None):
self._daemon_process.kill()
def destroy(self):
status = self.get_status()
if status == 'stopped' or status == 'not-initialized':
shutil.rmtree(self._data_dir)
else:
raise ClusterError('cannot destroy {} cluster'.format(status))
def _get_connection_spec(self):
if self._connection_addr is None:
self._connection_addr = self._connection_addr_from_pidfile()
if self._connection_addr is not None:
if self._connection_spec_override:
args = self._connection_addr.copy()
args.update(self._connection_spec_override)
return args
else:
return self._connection_addr
def get_connection_spec(self):
status = self.get_status()
if status != 'running':
raise ClusterError('cluster is not running')
return self._get_connection_spec()
def override_connection_spec(self, **kwargs):
self._connection_spec_override = kwargs
def reset_wal(self, *, oid=None, xid=None):
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify WAL status: cluster is not initialized')
if status == 'running':
raise ClusterError(
'cannot modify WAL status: cluster is running')
opts = []
if oid is not None:
opts.extend(['-o', str(oid)])
if xid is not None:
opts.extend(['-x', str(xid)])
if not opts:
return
opts.append(self._data_dir)
try:
reset_wal = self._find_pg_binary('pg_resetwal')
except ClusterError:
reset_wal = self._find_pg_binary('pg_resetxlog')
process = subprocess.run(
[reset_wal] + opts,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_resetwal exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def reset_hba(self):
"""Remove all records from pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
try:
with open(pg_hba, 'w'):
pass
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
"""Add a record to pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
raise ValueError('invalid HBA record type: {!r}'.format(type))
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
record = '{} {} {}'.format(type, database, user)
if type != 'local':
if address is None:
raise ValueError(
'{!r} entry requires a valid address'.format(type))
else:
record += ' {}'.format(address)
record += ' {}'.format(auth_method)
if auth_options is not None:
record += ' ' + ' '.join(
'{}={}'.format(k, v) for k, v in auth_options)
try:
with open(pg_hba, 'a') as f:
print(record, file=f)
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def trust_local_connections(self):
self.reset_hba()
if _system != 'Windows':
self.add_hba_entry(type='local', database='all',
user='all', auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='all', user='all',
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='all', user='all',
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def trust_local_replication_by(self, user):
if _system != 'Windows':
self.add_hba_entry(type='local', database='replication',
user=user, auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='replication', user=user,
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='replication', user=user,
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def _init_env(self):
if not self._pg_bin_dir:
pg_config = self._find_pg_config(self._pg_config_path)
pg_config_data = self._run_pg_config(pg_config)
self._pg_bin_dir = pg_config_data.get('bindir')
if not self._pg_bin_dir:
raise ClusterError(
'pg_config output did not provide the BINDIR value')
self._pg_ctl = self._find_pg_binary('pg_ctl')
self._postgres = self._find_pg_binary('postgres')
self._pg_version = self._get_pg_version()
def _connection_addr_from_pidfile(self):
pidfile = os.path.join(self._data_dir, 'postmaster.pid')
try:
with open(pidfile, 'rt') as f:
piddata = f.read()
except FileNotFoundError:
return None
lines = piddata.splitlines()
if len(lines) < 6:
# A complete postgres pidfile is at least 6 lines
return None
pmpid = int(lines[0])
if self._daemon_pid and pmpid != self._daemon_pid:
# This might be an old pidfile left from previous postgres
# daemon run.
return None
portnum = lines[3]
sockdir = lines[4]
hostaddr = lines[5]
if sockdir:
if sockdir[0] != '/':
# Relative sockdir
sockdir = os.path.normpath(
os.path.join(self._data_dir, sockdir))
host_str = sockdir
else:
host_str = hostaddr
if host_str == '*':
host_str = 'localhost'
elif host_str == '0.0.0.0':
host_str = '127.0.0.1'
elif host_str == '::':
host_str = '::1'
return {
'host': host_str,
'port': portnum
}
def _test_connection(self, timeout=60):
self._connection_addr = None
loop = asyncio.new_event_loop()
try:
for i in range(timeout):
if self._connection_addr is None:
conn_spec = self._get_connection_spec()
if conn_spec is None:
time.sleep(1)
continue
try:
con = loop.run_until_complete(
asyncpg.connect(database='postgres',
user='postgres',
timeout=5, loop=loop,
**self._connection_addr))
except (OSError, asyncio.TimeoutError,
asyncpg.CannotConnectNowError,
asyncpg.PostgresConnectionError):
time.sleep(1)
continue
except asyncpg.PostgresError:
# Any other error other than ServerNotReadyError or
# ConnectionError is interpreted to indicate the server is
# up.
break
else:
loop.run_until_complete(con.close())
break
finally:
loop.close()
return 'running'
def _run_pg_config(self, pg_config_path):
process = subprocess.run(
pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError('pg_config exited with status {:d}: {}'.format(
process.returncode, stderr))
else:
config = {}
for line in stdout.splitlines():
k, eq, v = line.decode('utf-8').partition('=')
if eq:
config[k.strip().lower()] = v.strip()
return config
def _find_pg_config(self, pg_config_path):
if pg_config_path is None:
pg_install = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
if pg_install:
pg_config_path = platform_exe(
os.path.join(pg_install, 'pg_config'))
else:
pathenv = os.environ.get('PATH').split(os.pathsep)
for path in pathenv:
pg_config_path = platform_exe(
os.path.join(path, 'pg_config'))
if os.path.exists(pg_config_path):
break
else:
pg_config_path = None
if not pg_config_path:
raise ClusterError('could not find pg_config executable')
if not os.path.isfile(pg_config_path):
raise ClusterError('{!r} is not an executable'.format(
pg_config_path))
return pg_config_path
def _find_pg_binary(self, binary):
bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
if not os.path.isfile(bpath):
raise ClusterError(
'could not find {} executable: '.format(binary) +
'{!r} does not exist or is not a file'.format(bpath))
return bpath
def _get_pg_version(self):
process = subprocess.run(
[self._postgres, '--version'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError(
'postgres --version exited with status {:d}: {}'.format(
process.returncode, stderr))
version_string = stdout.decode('utf-8').strip(' \n')
prefix = 'postgres (PostgreSQL) '
if not version_string.startswith(prefix):
raise ClusterError(
'could not determine server version from {!r}'.format(
version_string))
version_string = version_string[len(prefix):]
return serverversion.split_server_version_string(version_string)
class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)
class HotStandbyCluster(TempCluster):
def __init__(self, *,
master, replication_user,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._master = master
self._repl_user = replication_user
super().__init__(
data_dir_suffix=data_dir_suffix,
data_dir_prefix=data_dir_prefix,
data_dir_parent=data_dir_parent,
pg_config_path=pg_config_path)
def _init_env(self):
super()._init_env()
self._pg_basebackup = self._find_pg_binary('pg_basebackup')
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
process = subprocess.run(
[self._pg_basebackup, '-h', self._master['host'],
'-p', self._master['port'], '-D', self._data_dir,
'-U', self._repl_user],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
if self._pg_version <= (11, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
primary_conninfo = 'host={host} port={port} user={user}'
""".format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user)))
else:
f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
f.close()
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
if self._pg_version >= (12, 0):
server_settings = server_settings.copy()
server_settings['primary_conninfo'] = (
'"host={host} port={port} user={user}"'.format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user,
)
)
super().start(wait=wait, server_settings=server_settings, **opts)
class RunningCluster(Cluster):
def __init__(self, **kwargs):
self.conn_spec = kwargs
def is_managed(self):
return False
def get_connection_spec(self):
return dict(self.conn_spec)
def get_status(self):
return 'running'
def init(self, **settings):
pass
def start(self, wait=60, **settings):
pass
def stop(self, wait=60):
pass
def destroy(self):
pass
def reset_hba(self):
raise ClusterError('cannot modify HBA records of unmanaged cluster')
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
raise ClusterError('cannot modify HBA records of unmanaged cluster')