# -*- encoding: utf-8 -*-
"""
Basic support for communicating with the database server.
This is currently very postgres specific. If we really wanted to
support some other database, this would need massive refactoring.
"""
#c Copyright 2008-2023, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL. See the
#c COPYING file in the source distribution.
import contextlib
import functools
import os
import random
import re
import threading
import warnings
import weakref
import numpy
from gavo import utils
from gavo.base import config
debug = "GAVO_SQL_DEBUG" in os.environ
import psycopg2
import psycopg2.extensions
import psycopg2.pool
from psycopg2.extras import DictCursor #noflake: exported name
[docs]class Error(utils.Error):
pass
NUMERIC_TYPES = frozenset(["smallint", "integer", "bigint", "real",
"double precision"])
ORDERED_TYPES = frozenset(["timestamp", "text", "unicode"]) | NUMERIC_TYPES
_PG_TIME_UNITS = {
"ms": 0.0001,
"s": 1.,
"": 1.,
"min": 60.,
"h": 3600.,
"d": 86400.,}
[docs]class SqlSetAdapter(object):
"""is an adapter that formats python sequences as SQL sets.
-- as opposed to psycopg2's apparent default of building arrays
out of them.
"""
def __init__(self, seq):
self._seq = seq
[docs] def prepare(self, conn):
pass
[docs] def getquoted(self):
qobjs = []
for o in self._seq:
if isinstance(o, str):
qobjs.append(psycopg2.extensions.adapt(str(o)).getquoted())
else:
qobjs.append(psycopg2.extensions.adapt(o).getquoted())
return b'(%s)'%(b", ".join(qobjs))
__str__ = getquoted
[docs]class SqlArrayAdapter(object):
"""An adapter that formats python lists as SQL arrays
This makes, in the shameful tradition of VOTable, empty arrays equal to
NULL.
"""
def __init__(self, seq):
self._seq = seq
for item in seq:
if item is not None:
self.itemType = type(item)
break
else:
self.itemType = None
[docs] def prepare(self, conn):
pass
def _addCastIfNecessary(self, serializedList):
"""adds a typecast to serializedList if it needs one.
This is when all entries in serializedList are NULL; so, we're fine
anyway if the first element is non-NULL; if it's not, we try to
guess.
serializedList is changed in place, the method returns nothing.
"""
if not serializedList or serializedList[0]!=b"NULL":
return
if isinstance(self._seq, utils.floatlist):
serializedList[0] = b"NULL::REAL"
elif isinstance(self._seq, utils.intlist):
serializedList[0] = b"NULL::INTEGER"
[docs] def getquoted(self):
if len(self._seq)==0:
return b'NULL'
if self.itemType and issubclass(self.itemType, str):
# I need to be a bit verbose here because psycopg's default
# encoding still is latin-1, and it seems there's no better
# way to force it to utf-8 than this:
qobjs = []
for o in self._seq:
item = psycopg2.extensions.adapt(o)
item.encoding = "utf-8"
qobjs.append(item.getquoted())
else:
qobjs = [psycopg2.extensions.adapt(o).getquoted()
for o in self._seq]
self._addCastIfNecessary(qobjs)
return b'ARRAY[ %s ]'%(b", ".join(qobjs))
__str__ = getquoted
[docs]class FloatableAdapter(object):
"""An adapter for things that do "float", in particular numpy.float*
"""
def __init__(self, val):
self.val = float(val)
[docs] def prepare(self, conn):
pass
[docs] def getquoted(self):
if self.val!=self.val:
return b"'nan'::real"
else:
return repr(self.val).encode("ascii")
__str__ = getquoted
[docs]class IntableAdapter(object):
"""An adapter for things that do "int", in particular numpy.int*
"""
def __init__(self, val):
self.val = int(val)
[docs] def prepare(self, conn):
pass
[docs] def getquoted(self):
return str(self.val).encode("ascii")
__str__ = getquoted
[docs]class NULLAdapter(object):
"""An adapter for things that should end up as NULL in the DB.
"""
def __init__(self, val):
# val doesn't matter, we're making it NULL anyway
pass
[docs] def prepare(self, conn):
pass
[docs] def getquoted(self):
return b"NULL"
__str__ = getquoted
psycopg2.extensions.register_adapter(list, SqlArrayAdapter)
psycopg2.extensions.register_adapter(numpy.ndarray, SqlArrayAdapter)
psycopg2.extensions.register_adapter(tuple, SqlSetAdapter)
psycopg2.extensions.register_adapter(set, SqlSetAdapter)
psycopg2.extensions.register_adapter(frozenset, SqlSetAdapter)
for numpyType, adapter in [
("float32", FloatableAdapter),
("float64", FloatableAdapter),
("float96", FloatableAdapter),
("int8", IntableAdapter),
("int16", IntableAdapter),
("int32", IntableAdapter),
("int64", IntableAdapter),]:
try:
psycopg2.extensions.register_adapter(
getattr(numpy, numpyType), adapter)
except AttributeError: # pragma: no cover
# types not present on the python end we don't need to adapt
pass
# Override psycopg2's mapping of numeric to decimal, because our
# serialisers (votable, fits, json) don't really work with decimal.
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
psycopg2.extensions.DECIMAL.values,
"numeric_float",
lambda value, cursor: float(value) if value is not None else None))
from gavo.utils import pyfits
psycopg2.extensions.register_adapter(pyfits.Undefined, NULLAdapter)
from psycopg2 import (OperationalError, #noflake: exported names
DatabaseError, IntegrityError, ProgrammingError,
InterfaceError, DataError, InternalError)
from psycopg2.extensions import QueryCanceledError #noflake: exported name
from psycopg2 import Error as DBError
[docs]class DebugCursor(psycopg2.extensions.cursor): # pragma: no cover
[docs] def execute(self, sql, args=None):
print("Executing %s %s"%(id(self.connection), sql))
psycopg2.extensions.cursor.execute(self, sql, args)
print("Finished %s %s"%(id(self.connection), self.query.decode("utf-8")))
return self.rowcount
[docs] def executemany(self, sql, args=[]):
print("Executing many", sql)
print(("%d args, first one:\n%s"%(len(args), args[0])))
res = psycopg2.extensions.cursor.executemany(self, sql, args)
print("Finished many", self.query.decode("utf-8"))
return res
[docs]class GAVOConnection(psycopg2.extensions.connection):
"""A psycopg2 connection with some additional methods.
This derivation is also done so we can attach the getDBConnection
arguments to the connection; it is used when recovering from
a database restart.
"""
# extensionFunctions is filled senseEnvironment (and contains names of
# postgres extension functions that might modify our behaviour)
extensionFunctions = []
[docs] @classmethod
def senseEnvironment(cls, conn):
"""configures us depending on what is in the database.
The argument needs to be a connection to the database we will
connect to. In practice, _initPsycopg calls this once during
DaCHS startup.
"""
cls.extensionFunctions = frozenset(r[0] for r in
conn.query(
"SELECT proname FROM pg_proc WHERE"
" proname in ('epoch_prop', 'q3c_ang2ipix',"
" 'smoc_union', 'healpix_nest')"))
[docs] def getParameter(self, key, cursor=None):
"""returns the value of the postgres parameter key.
This returns unprocessed values, probably almost always as strings.
Caveat emptor.
The main purpose of this function is to help the parameters connection
manager, so users shouldn't really mess with it.
"""
cursor = cursor or self.cursor()
if not re.match("[A-Za-z_]+", key):
raise ValueError("Invalid settings key: %s"%key)
cursor.execute("SHOW %s"%key)
return list(cursor)[0][0]
[docs] @contextlib.contextmanager
def parameters(self, settings, cursor=None):
"""executes a block with a certain set of parameters on a connection,
resetting them to their original value again afterwards.
Of course, this only works as expected if you're not sharing your
connections to widely.
This rolls back the connection by itself on database errors; we couldn't
reset the parameters otherwise.
"""
cursor = cursor or self.cursor()
resetTo = self.configure(settings, cursor)
try:
yield
except Exception as ex:
try:
if isinstance(ex, psycopg2.Error):
self.rollback()
self.configure(resetTo)
except psycopg2.Error:
# we believe the connection was already closed and don't bother
pass
raise
self.configure(resetTo, cursor)
[docs] def queryToDicts(self, query, args={}, timeout=None, caseFixer=None):
"""iterates over dictionary rows for query.
This is a thin wrapper around query(yieldDicts=True) provided
for convenience and backwards compatibility.
"""
return self.query(query, args, timeout, True, caseFixer)
[docs] def query(self, query, args={}, timeout=None,
yieldDicts=False, caseFixer=None):
"""iterates over result tuples for query.
This is mainly for ad-hoc queries needing little metadata.
You can pass yieldDicts=True to get dictionaries instead of tuples.
The dictionary keys are determined by what the database says the
column titles are; thus, it's usually lower-cased variants of what's
in the select-list. To fix this, you can pass in a caseFixer dict
that gives a properly cased version of lowercase names.
Timeout is in seconds.
Warning: this is an iterator, so unless you iterate over the result,
the query will not get executed. Hence, for non-select statements
you will generally have to use conn.execute.
"""
cursor = self.cursor()
params = []
if timeout is not None:
params.append(("statement_timeout", "%s ms"%int(timeout*1000)))
try:
with self.parameters(params, cursor):
cursor.execute(query, args)
if yieldDicts:
keys = [cd[0] for cd in cursor.description]
if caseFixer:
keys = [caseFixer.get(key, key) for key in keys]
for row in cursor:
yield dict(list(zip(keys, row)))
else:
for row in cursor:
yield row
finally:
cursor.close()
[docs] def execute(self, query, args={}):
"""executes query in a cursor.
This returns the rowcount of the cursor used.
"""
cursor = self.cursor()
try:
cursor.execute(query, args)
return cursor.rowcount
finally:
cursor.close()
[docs] @contextlib.contextmanager
def savepoint(self):
"""sets up a section protected by a savepoint that will be released
after use.
If an exception happens in the controlled section, the connection
will be rolled back to the savepoint.
"""
savepointName = "auto_%s"%(random.randint(0, 2147483647))
self.execute("SAVEPOINT %s"%savepointName)
try:
yield
except:
self.execute("ROLLBACK TO SAVEPOINT %s"%savepointName)
raise
finally:
self.execute("RELEASE SAVEPOINT %s"%savepointName)
[docs]class DebugConnection(GAVOConnection): # pragma: no cover
[docs] def cursor(self, *args, **kwargs):
kwargs["cursor_factory"] = DebugCursor
return psycopg2.extensions.connection.cursor(self, *args, **kwargs)
[docs] def commit(self):
print("Commit %s"%id(self))
return GAVOConnection.commit(self)
[docs] def rollback(self):
print("Rollback %s"%id(self))
return GAVOConnection.rollback(self)
[docs] def getPID(self):
cursor = self.cursor()
cursor.execute("SELECT pg_backend_pid()")
pid = list(cursor)[0][0]
cursor.close()
return pid
[docs]def getDBConnection(profile, debug=debug, autocommitted=False):
"""returns an enhanced database connection through profile.
You will typically rather use the context managers for the standard
profiles (``getTableConnection`` and friends). Use this function if
you want to keep your connection out of connection pools or if you want
to use non-standard profiles.
profile will usually be a string naming a profile defined in
``GAVO_ROOT/etc``.
"""
if isinstance(profile, str):
profile = config.getDBProfile(profile)
if debug: # pragma: no cover
conn = psycopg2.connect(connection_factory=DebugConnection,
**profile.getArgs())
print("NEW CONN using %s (%s)"%(profile.name, conn.getPID()), id(conn))
def closer():
print("CONNECTION CLOSE", id(conn))
return DebugConnection.close(conn)
conn.close = closer
else:
try:
conn = psycopg2.connect(connection_factory=GAVOConnection,
**profile.getArgs())
except OperationalError as msg:
raise utils.ReportableError("Cannot connect to the database server."
" The database library reported:\n\n%s"%str(msg),
hint="This usually means you must adapt either the access profiles"
" in $GAVO_DIR/etc or your database config (in particular,"
" pg_hba.conf).")
if autocommitted:
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
conn.set_client_encoding("UTF8")
conn._getDBConnectionArgs = {
"profile": profile,
"debug": debug,
"autocommitted": autocommitted}
return conn
def _parseTableName(tableName, schema=None):
"""returns schema, unqualified table name for the arguments.
schema=None selects the default schema (public for postgresql).
If tableName is qualified (i.e. schema.table), the schema given
in the name overrides the schema argument.
We do not support delimited identifiers for tables in DaCHS. Hence,
this will raise a ValueError if anything that wouldn't work as
an SQL regular identifier (except we don't filter for reserved
words yet, which is an implementation detail that might change).
"""
parts = tableName.split(".")
if len(parts)>2:
raise ValueError("%s is not a SQL regular identifier"%repr(tableName))
for p in parts:
if not utils.identifierPattern.match(p):
raise ValueError("%s is not a SQL regular identifier"%repr(tableName))
if len(parts)==1:
name = parts[0]
else:
schema, name = parts
if schema is None:
schema = "public"
return schema.lower(), name.lower()
def _parseBannerString(bannerString):
"""returns digits from a postgres server banner.
This hardcodes the response given by postgres 8 and raises a ValueError
if the expected format is not found.
"""
mat = re.match(r"PostgreSQL ([\d.]*)", bannerString)
if not mat:
raise ValueError("Cannot make out the Postgres server version from %s"%
repr(bannerString))
return tuple(int(s) for s in mat.group(1).split("."))
[docs]def getPgVersion(digits=2):
"""returns the version number of the postgres server executing
untrusted (ADQL) queries.
This is relatively expensive, as it will actually ask the server.
"""
with getUntrustedConn() as conn:
bannerString = list(conn.query("SELECT version()"))[0][0]
return _parseBannerString(bannerString)
[docs]class PostgresQueryMixin(object):
"""is a mixin containing various useful queries that are postgres specific.
This mixin expects a parent that mixes is QuerierMixin (that, for now,
also mixes in PostgresQueryMixin, so you won't need to mix this in).
"""
[docs] def getPrimaryIndexName(self, tableName):
"""returns the name of the index corresponding to the primary key on
(the unqualified) tableName.
"""
return ("%s_pkey"%tableName).lower()
[docs] def schemaExists(self, schema):
"""returns True if the named schema exists in the database.
"""
matches = list(self.connection.query("SELECT nspname FROM"
" pg_namespace WHERE LOWER(nspname)=%(schemaName)s", {
'schemaName': schema.lower(),
}))
return len(matches)!=0
[docs] def hasIndex(self, tableName, indexName, schema=None):
"""returns True if table tablename has and index called indexName.
See _parseTableName on the meaning of the arguments.
"""
schema, tableName = _parseTableName(tableName, schema)
res = list(self.connection.query("SELECT indexname FROM"
" pg_indexes WHERE schemaname=lower(%(schema)s) AND"
" tablename=lower(%(tableName)s) AND"
" indexname=lower(%(indexName)s)", locals()))
return len(list(res))>0
def _getColIndices(self, relOID, colNames):
"""returns a sorted tuple of column indices of colNames in the relation
relOID.
This really is a helper for foreignKeyExists.
"""
colNames = set(n.lower() for n in colNames)
res = [r[0] for r in
self.connection.query("SELECT attnum FROM pg_attribute WHERE"
" attrelid=%(relOID)s and attname IN %(colNames)s",
locals())]
res.sort()
return res
[docs] def getForeignKeyName(self, srcTableName, destTableName, srcColNames,
destColNames, schema=None):
"""returns True if there's a foreign key constraint on srcTable's
srcColNames using destTableName's destColNames.
Warning: names in XColNames that are not column names in the respective
tables are ignored.
This raises a ValueError if the foreign keys do not exist.
"""
try:
srcOID = self.getOIDForTable(srcTableName, schema)
srcColInds = self._getColIndices( #noflake: used in locals()
srcOID, srcColNames)
destOID = self.getOIDForTable(destTableName, schema)
destColInds = self._getColIndices( #noflake: used in locals()
destOID, destColNames)
except Error: # Some of the items related probably don't exist
return False
res = list(self.connection.query("""SELECT conname FROM pg_constraint WHERE
contype='f'
AND conrelid=%(srcOID)s
AND confrelid=%(destOID)s
AND conkey=%(srcColInds)s::SMALLINT[]
AND confkey=%(destColInds)s::SMALLINT[]""", locals()))
if len(res)==1:
return res[0][0]
else:
raise ValueError("Non-existing or ambiguous foreign key")
[docs] def foreignKeyExists(self, srcTableName, destTableName, srcColNames,
destColNames, schema=None):
try:
_ = self.getForeignKeyName( #noflake: ignored value
srcTableName, destTableName,
srcColNames, destColNames,
schema)
return True
except ValueError:
return False
@functools.lru_cache()
def _resolveTypeCode(self, oid):
"""returns a textual description for a type oid as returned
by cursor.description.
These descriptions are *not* DDL-ready. There's the
*** postgres specific ***
"""
res = list(self.connection.query(
"select typname from pg_type where oid=%(oid)s", {"oid": oid}))
return res[0][0]
[docs] def getColumnsFromDB(self, tableName):
"""returns a sequence of (name, type) pairs of the columns this
table has in the database.
If the table is not on disk, this will raise a NotFoundError.
*** psycopg2 specific ***
"""
# _parseTableName bombs out on non-regular identifiers, hence
# foiling a possible SQL injection
_parseTableName(tableName)
cursor = self.connection.cursor()
try:
cursor.execute("select * from %s limit 0"%tableName)
return [(col.name, self._resolveTypeCode(col.type_code)) for col in
cursor.description]
finally:
cursor.close()
[docs] def getRowEstimate(self, tableName):
"""returns the size of the table in rows as estimated by the query
planner.
This will raise a KeyError with tableName if the table isn't known
to postgres.
"""
res = list(self.connection.query(
"SELECT reltuples FROM pg_class WHERE oid = %(tableName)s::regclass",
locals()))
# this is guaranteed to return something because of the ::regclass
# cast that will fail for non-existing tables.
return int(res[0][0])
[docs] def roleExists(self, role):
"""returns True if there role is known to the database.
"""
matches = list(self.connection.query(
"SELECT usesysid FROM pg_user WHERE usename=%(role)s",
locals()))
return len(matches)!=0
[docs] def getOIDForTable(self, tableName, schema=None):
"""returns the current oid of tableName.
tableName may be schema qualified. If it is not, public is assumed.
"""
schema, tableName = _parseTableName(tableName, schema)
return list(self.connection.query(
"SELECT %(tableName)s::regclass::bigint",
{"tableName": f"{schema}.{tableName}"}))[0][0]
[docs] def getTableType(self, tableName, schema=None):
"""returns the type of the relation relationName.
If relationName does not exist, None is returned. Otherwise, it's
what is in the information schema for the table, which for postgres
currently is one of BASE TABLE, VIEW, FOREIGN TABLE, MATERIALIZED VIEW,
or LOCAL TEMPORARY.
The DaCHS-idiomatic way to see if a relation exists is
getTableType() is not None.
You can pass in schema-qualified relation names, or the relation name
and the schema separately.
*** postgres specific ***
"""
schema, tableName = _parseTableName(tableName, schema)
res = list(
self.connection.query("""SELECT table_name, table_type FROM
information_schema.tables WHERE (
table_schema=%(schemaName)s
OR table_type='LOCAL TEMPORARY')
AND table_name=%(tableName)s""", {
'tableName': tableName.lower(),
'schemaName': schema.lower()}))
if not res:
# materialised views are not yet in information_schema.tables,
# so we try again with a special postgres case.
if list(self.connection.query(
"select table_name from information_schema.tables"
" where table_name='pg_matviews'")):
res = list(
self.connection.query(
"""SELECT matviewname, 'MATERIALIZED VIEW' AS table_type
FROM pg_matviews
WHERE
schemaname=%(schemaName)s
AND matviewname=%(tableName)s""", {
'tableName': tableName.lower(),
'schemaName': schema.lower()}))
if not res:
return None
assert len(res)==1
return res[0][1]
[docs] def dropTable(self, tableName, cascade=False):
"""drops a table or view named by tableName.
This does not raise an error if no such relation exists.
*** postgres specific ***
"""
tableType = self.getTableType(tableName)
if tableType is None:
return
dropQualification = {
"VIEW": "VIEW",
"MATERIALIZED VIEW": "MATERIALIZED VIEW",
"FOREIGN TABLE": "FOREIGN TABLE",
"BASE TABLE": "TABLE",
"LOCAL TEMPORARY": "TABLE"}[tableType]
self.connection.execute("DROP %s %s %s"%(
dropQualification,
tableName,
"CASCADE" if cascade else ""))
[docs] def getSchemaPrivileges(self, schema):
"""returns (owner, readRoles, allRoles) for schema's ACL.
"""
res = list(self.connection.query("SELECT nspacl FROM pg_namespace WHERE"
" nspname=%(schema)s", locals()))
return self.parsePGACL(res[0][0])
[docs] def getTablePrivileges(self, schema, tableName):
"""returns (owner, readRoles, allRoles) for the relation tableName
and the schema.
*** postgres specific ***
"""
res = list(self.connection.query("SELECT relacl FROM pg_class WHERE"
" lower(relname)=lower(%(tableName)s) AND"
" relnamespace=(SELECT oid FROM pg_namespace WHERE nspname=%(schema)s)",
locals()))
try:
return self.parsePGACL(res[0][0])
except IndexError: # Table doesn't exist, so no privileges
return {}
_privTable = {
"arwdRx": "ALL",
"arwdDxt": "ALL",
"arwdRxt": "ALL",
"arwdxt": "ALL",
"r": "SELECT",
"UC": "ALL",
"U": "USAGE",
}
[docs] def parsePGACL(self, acl):
"""returns a dict roleName->acl for acl in postgres'
ACL serialization.
"""
if acl is None:
return {}
res = []
for acs in re.match("{(.*)}", acl).group(1).split(","):
if acs!='': # empty ACLs don't match the RE, so catch them here
role, privs, granter = re.match("([^=]*)=([^/]*)/(.*)", acs).groups()
res.append((role, self._privTable.get(privs, "READ")))
return dict(res)
[docs] def getACLFromRes(self, thingWithPrivileges):
"""returns a dict of (role, ACL) as it is defined in thingWithPrivileges.
thingWithPrivileges is something mixing in rscdef.common.PrivilegesMixin.
(or has readProfiles and allProfiles attributes containing
sequences of profile names).
"""
res = []
if hasattr(thingWithPrivileges, "schema"): # it's an RD
readRight = "USAGE"
else:
readRight = "SELECT"
for profile in thingWithPrivileges.readProfiles:
res.append((config.getDBProfile(profile).roleName, readRight))
for profile in thingWithPrivileges.allProfiles:
res.append((config.getDBProfile(profile).roleName, "ALL"))
return dict(res)
[docs]class StandardQueryMixin(object):
"""is a mixin containing various useful queries that should work
against all SQL systems.
This mixin expects a parent that mixes is QuerierMixin (that, for now,
also mixes in StandardQueryMixin, so you won't need to mix this in).
The parent also needs to mix in something like PostgresQueryMixin (I
might want to define an interface there once I'd like to support
other databases).
"""
[docs] def setSchemaPrivileges(self, rd):
"""sets the privileges defined on rd to its schema.
This function will never touch the public schema.
"""
schema = rd.schema.lower()
if schema=="public":
return
self._updatePrivileges("SCHEMA %s"%schema,
self.getSchemaPrivileges(schema), self.getACLFromRes(rd))
[docs] def setTablePrivileges(self, tableDef):
"""sets the privileges defined in tableDef for that table through
querier.
"""
self._updatePrivileges(tableDef.getQName(),
self.getTablePrivileges(tableDef.rd.schema, tableDef.id),
self.getACLFromRes(tableDef))
def _updatePrivileges(self, objectName, foundPrivs, shouldPrivs):
"""is a helper for set[Table|Schema]Privileges.
Requests for granting privileges not known to the database are
ignored, but a log entry is generated.
"""
for role in set(foundPrivs)-set(shouldPrivs):
if role:
self.connection.execute("REVOKE ALL PRIVILEGES ON %s FROM %s"%(
objectName, role))
for role in set(shouldPrivs)-set(foundPrivs):
if role:
if self.roleExists(role):
self.connection.execute(
"GRANT %s ON %s TO %s"%(shouldPrivs[role], objectName, role))
else:
utils.sendUIEvent("Warning",
"Request to grant privileges to non-existing"
" database user %s dropped"%role)
for role in set(shouldPrivs)&set(foundPrivs):
if role:
if shouldPrivs[role]!=foundPrivs[role]:
self.connection.execute("REVOKE ALL PRIVILEGES ON %s FROM %s"%(
objectName, role))
self.connection.execute("GRANT %s ON %s TO %s"%(shouldPrivs[role], objectName,
role))
[docs]class QuerierMixin(PostgresQueryMixin, StandardQueryMixin):
"""is a mixin for "queriers", i.e., objects that maintain a db connection.
The mixin assumes an attribute connection from the parent.
"""
defaultProfile = None
# _reconnecting is used in query
_reconnecting = False
[docs] def enableAutocommit(self):
self.connection.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
def _queryReconnecting(self, query, data, timeout):
"""helps query in case of disconnections.
"""
self.connection = getDBConnection(
**self.connection._getDBConnectionArgs)
self._reconnecting = True
res = self.connection.query(query, data, timeout)
self._reconnection = False
return res
[docs] def query(self, query, data={}, timeout=None):
"""wraps conn.query adding logic to re-establish lost connections.
Don't use this method any more in new code. It contains wicked
logic to tell DDL statements (that run without anyone pulling
the results) from actual selects. That's a bad API. Also note
that the timeout is ignored for DDL statements.
We'll drop this some time in 2023.
Use either connection.query or connection.execute in new code.
"""
warnings.warn("You are using querier.query (or perhaps table.query)."
" This has terrible semantics; use querier.connection.query"
" for statements returning rows and .execute for DDL statements.",
category=FutureWarning)
if self.connection is None:
raise utils.ReportableError(
"SimpleQuerier connection is None.",
hint="This usually is because an AdhocQuerier's query method"
" was used outside of a with block.")
try:
if query[:5].lower() in ["selec", "with "]:
return self.connection.query(query, data, timeout)
else:
# it's DDL that we execute directly, ignoring the timeout
self.connection.execute(query, data)
except DBError as ex:
if isinstance(ex, OperationalError) and self.connection.fileno()==-1:
if not self._reconnecting:
return self._queryReconnecting(query, data, timeout)
raise
[docs] def queryToDicts(self, *args, **kwargs):
"""wraps conn.queryToDicts for backwards compatilitiy.
"""
return self.connection.queryToDicts(*args, **kwargs)
[docs]class UnmanagedQuerier(QuerierMixin):
"""A simple interface to querying the database through a connection
managed by someone else.
This is typically used as in::
with base.getTableConn() as conn:
q = UnmanagedQuerier(conn)
...
This contains numerous methods abstracting DB functionality a bit.
Documented ones include:
* schemaExists(schema)
* getColumnsFromDB(tableName)
* getTableType(tableName) -- this will return None for non-existing tables,
which is DaCHS' official way to determine table existence.
"""
def __init__(self, connection):
self.connection = connection
[docs]class AdhocQuerier(QuerierMixin):
"""A simple interface to querying the database through pooled
connections.
These are constructed using the connection getters (getTableConn (default),
getAdminConn) and then serve as context managers, handing back the connection
as you exit the controlled block.
Since they operate through pooled connections, no transaction
management takes place. These are typically for read-only things.
You can use the query method and everything that's in the QuerierMixin.
"""
def __init__(self, connectionManager=None):
if connectionManager is None:
self.connectionManager = getTableConn
else:
self.connectionManager = connectionManager
self.connection = None
def __enter__(self):
self._cm = self.connectionManager()
self.connection = self._cm.__enter__()
return self
def __exit__(self, *args):
self.connection = None
return self._cm.__exit__(*args)
[docs]class NonBlockingQuery:
"""a query run in a pseudo-nonblocking way.
While psycopg2 can do proper async, that doesn't play well with
about everything else DaCHS is doing so far. So, here's a quick
way to allow long-running queries that users can still interrupt.
The ugly secret is that it's based on threads.
This should not be used within the server. We might want to port
the async taprunner (which runs outside of the server) to using this,
though.
To use it, construct it with conn, query and perhaps args and
use it as a context manager.
Wait for its result attribute to become non-None; this will then
be either a list of result rows or an Exception (which will also be
raised when exiting the context manager).
To abort a running query, call abort().
"""
def __init__(self, conn, query, args={}):
self.conn, self.query, self.args = conn, query, args
self.backendPID = list(self.conn.query("SELECT pg_backend_pid()"))[0]
# will be set only from the thread
self.result = None
def __enter__(self):
self.thread = threading.Thread(target=self._runQuery)
self.thread.setDaemon(True)
self.thread.start()
return self
def __exit__(self, *excInfo):
self.cleanup(1)
if excInfo==(None, None, None) and isinstance(self.result, Exception):
# this probably is a memory leak, which is one of the reasons
# this shouldn't be used in the server without more thought
raise self.result
return False
def _runQuery(self):
try:
self.result = list(self.conn.query(self.query, self.args))
except QueryCanceledError:
# assume this happened on user request
pass
except Exception as ex:
self.result = ex
[docs] def abort(self):
"""aborts the current query and reaps the thread.
"""
self.conn.cancel()
self.cleanup(1)
[docs] def cleanup(self, timeout=None):
"""tries to reap the thread (i.e., join it).
If the thread hasn't terminated within timeout seconds, a
sqlsupport.Error is raised.
"""
self.thread.join(timeout=timeout)
if self.thread.is_alive():
raise Error("Could not join NonBlockingQuery")
[docs]class CustomConnectionPool(psycopg2.pool.ThreadedConnectionPool):
"""A threaded connection pool that returns connections made via
profileName.
"""
# we keep weak references to pools we've created so we can invalidate
# them all on a server restart to avoid having stale connections
# around.
knownPools = []
def __init__(self, minconn, maxconn, profileName, autocommitted=True):
# make sure no additional arguments come in, since we don't
# support them.
self.profileName = profileName
self.autocommitted = autocommitted
self.stale = False
psycopg2.pool.ThreadedConnectionPool.__init__(
self, minconn, maxconn)
self.knownPools.append(weakref.ref(self))
[docs] @classmethod
def serverRestarted(cls):
utils.sendUIEvent("Warning", "Suspecting a database restart."
" Discarding old connection pools, asking to create new ones.")
for pool in cls.knownPools:
try:
pool().stale = True
except AttributeError:
# already gone
pass
# we risk a race condition here; this is used rarely enough that this
# shouldn't matter.
cls.knownPools = []
def _connect(self, key=None):
"""creates a new connection with our selected profile and assigns it to
key if not None.
This is an implementation detail of psycopg2's connection
pools.
"""
conn = getDBConnection(self.profileName)
if self.autocommitted:
try:
conn.set_session(autocommit=True, readonly=True)
except ProgrammingError:
utils.sendUIEvent("Warning", "Uncommitted transaction escaped; please"
" investigate and fix")
conn.commit()
if key is not None:
self._used[key] = conn
self._rused[id(conn)] = key
else:
self._pool.append(conn)
return conn
def _cleanupAfterDBError(ex, conn, pool, poolLock):
"""removes conn from pool after an error occurred.
This is a helper for getConnFromPool below.
"""
if isinstance(ex, OperationalError) and ex.pgcode is None:
# this is probably a db server restart. Invalidate all connections
# immediately.
with poolLock:
if pool:
pool[0].serverRestarted()
# Make sure the connection is closed; something bad happened
# in it, so we don't want to re-use it
try:
pool[0].putconn(conn, close=True)
except InterfaceError:
# Connection already closed
pass
except Exception as msg:
utils.sendUIEvent("Error",
"Disaster: %s while force-closing connection"%msg)
def _makeConnectionManager(profileName, autocommitted=True, singleton=False):
"""returns a context manager for a connection pool for profileName
connections.
With singleton=True, only one connection will be created rather than
[db]poolSize ones.
"""
pool = []
poolLock = threading.Lock()
def makePool():
if singleton:
minConn = 1
else:
minConn = config.get("db", "poolSize")
with poolLock:
pool.append(CustomConnectionPool(
minConn,
# I don't think there's any point in maxConn at all the
# way psycopg pools are done right now, so all I care about
# here is that it won't get in our way.
200,
profileName,
autocommitted))
def getConnFromPool():
# we delay pool creation since these functions are built during
# sqlsupport import. We probably don't have profiles ready
# at that point.
if not pool:
makePool()
if pool[0].stale:
pool[0].closeall()
pool.pop()
makePool()
conn = pool[0].getconn()
try:
yield conn
except Exception as ex:
# controlled block bombed out, do error handling
_cleanupAfterDBError(ex, conn, pool, poolLock)
raise
else:
# no exception raised, commit if not autocommitted
if not autocommitted:
conn.commit()
try:
pool[0].putconn(conn, close=conn.closed)
except InterfaceError:
# Connection already closed
pass
return contextlib.contextmanager(getConnFromPool)
getUntrustedConn = _makeConnectionManager("untrustedquery")
getTableConn = _makeConnectionManager("trustedquery")
getAdminConn = _makeConnectionManager("admin", singleton=True)
getWritableUntrustedConn = _makeConnectionManager("untrustedquery",
autocommitted=False, singleton=True)
getWritableTableConn = _makeConnectionManager("trustedquery",
autocommitted=False, singleton=True)
getWritableAdminConn = _makeConnectionManager("admin",
autocommitted=False, singleton=True)
[docs]def initPsycopg():
"""does any DaCHS-specific database setup necessary.
This is executed on sqlsupport import unless we are in initdachs (or
setting up the testbed); see the foot of this module for how this is done.
This needs to call the GAVOConnection.senseEnvironment.
"""
conn = psycopg2.connect(connection_factory=GAVOConnection,
**config.getDBProfile("feed").getArgs())
try:
try:
from gavo.utils import pgsphere
pgsphere.preparePgSphere(conn)
except: # prama: no cover
warnings.warn("pgsphere missing -- ADQL, pg-SIAP, and SSA will not work")
GAVOConnection.senseEnvironment(conn)
finally:
conn.close()
if "GAVO_INIT_RUNNING" not in os.environ:
initPsycopg()