mirror of
https://github.com/djohnlewis/stackdump
synced 2025-12-08 08:53:25 +00:00
Initial commit. Still building up the env and some parsing code.
This commit is contained in:
359
python/packages/sqlobject/postgres/pgconnection.py
Normal file
359
python/packages/sqlobject/postgres/pgconnection.py
Normal file
@@ -0,0 +1,359 @@
|
||||
from sqlobject.dbconnection import DBAPI
|
||||
import re
|
||||
from sqlobject import col
|
||||
from sqlobject import sqlbuilder
|
||||
from sqlobject.converters import registerConverter
|
||||
|
||||
class PostgresConnection(DBAPI):
|
||||
|
||||
supportTransactions = True
|
||||
dbName = 'postgres'
|
||||
schemes = [dbName, 'postgresql']
|
||||
|
||||
def __init__(self, dsn=None, host=None, port=None, db=None,
|
||||
user=None, password=None, **kw):
|
||||
drivers = kw.pop('driver', None) or 'psycopg'
|
||||
for driver in drivers.split(','):
|
||||
driver = driver.strip()
|
||||
if not driver:
|
||||
continue
|
||||
try:
|
||||
if driver == 'psycopg2':
|
||||
import psycopg2 as psycopg
|
||||
elif driver == 'psycopg1':
|
||||
import psycopg
|
||||
elif driver == 'psycopg':
|
||||
try:
|
||||
import psycopg2 as psycopg
|
||||
except ImportError:
|
||||
import psycopg
|
||||
elif driver == 'pygresql':
|
||||
import pgdb
|
||||
self.module = pgdb
|
||||
else:
|
||||
raise ValueError('Unknown PostgreSQL driver "%s", expected psycopg2, psycopg1 or pygresql' % driver)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise ImportError('Cannot find a PostgreSQL driver, tried %s' % drivers)
|
||||
if driver.startswith('psycopg'):
|
||||
self.module = psycopg
|
||||
# Register a converter for psycopg Binary type.
|
||||
registerConverter(type(psycopg.Binary('')),
|
||||
PsycoBinaryConverter)
|
||||
|
||||
self.user = user
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.password = password
|
||||
self.dsn_dict = dsn_dict = {}
|
||||
if host:
|
||||
dsn_dict["host"] = host
|
||||
if port:
|
||||
if driver == 'pygresql':
|
||||
dsn_dict["host"] = "%s:%d" % (host, port)
|
||||
else:
|
||||
if psycopg.__version__.split('.')[0] == '1':
|
||||
dsn_dict["port"] = str(port)
|
||||
else:
|
||||
dsn_dict["port"] = port
|
||||
if db:
|
||||
dsn_dict["database"] = db
|
||||
if user:
|
||||
dsn_dict["user"] = user
|
||||
if password:
|
||||
dsn_dict["password"] = password
|
||||
sslmode = kw.pop("sslmode", None)
|
||||
if sslmode:
|
||||
dsn_dict["sslmode"] = sslmode
|
||||
self.use_dsn = dsn is not None
|
||||
if dsn is None:
|
||||
if driver == 'pygresql':
|
||||
dsn = ''
|
||||
if host:
|
||||
dsn += host
|
||||
dsn += ':'
|
||||
if db:
|
||||
dsn += db
|
||||
dsn += ':'
|
||||
if user:
|
||||
dsn += user
|
||||
dsn += ':'
|
||||
if password:
|
||||
dsn += password
|
||||
else:
|
||||
dsn = []
|
||||
if db:
|
||||
dsn.append('dbname=%s' % db)
|
||||
if user:
|
||||
dsn.append('user=%s' % user)
|
||||
if password:
|
||||
dsn.append('password=%s' % password)
|
||||
if host:
|
||||
dsn.append('host=%s' % host)
|
||||
if port:
|
||||
dsn.append('port=%d' % port)
|
||||
if sslmode:
|
||||
dsn.append('sslmode=%s' % sslmode)
|
||||
dsn = ' '.join(dsn)
|
||||
self.driver = driver
|
||||
self.dsn = dsn
|
||||
self.unicodeCols = kw.pop('unicodeCols', False)
|
||||
self.schema = kw.pop('schema', None)
|
||||
self.dbEncoding = kw.pop("charset", None)
|
||||
DBAPI.__init__(self, **kw)
|
||||
|
||||
@classmethod
|
||||
def _connectionFromParams(cls, user, password, host, port, path, args):
|
||||
path = path.strip('/')
|
||||
if (host is None) and path.count('/'): # Non-default unix socket
|
||||
path_parts = path.split('/')
|
||||
host = '/' + '/'.join(path_parts[:-1])
|
||||
path = path_parts[-1]
|
||||
return cls(host=host, port=port, db=path, user=user, password=password, **args)
|
||||
|
||||
def _setAutoCommit(self, conn, auto):
|
||||
# psycopg2 does not have an autocommit method.
|
||||
if hasattr(conn, 'autocommit'):
|
||||
conn.autocommit(auto)
|
||||
|
||||
def makeConnection(self):
|
||||
try:
|
||||
if self.use_dsn:
|
||||
conn = self.module.connect(self.dsn)
|
||||
else:
|
||||
conn = self.module.connect(**self.dsn_dict)
|
||||
except self.module.OperationalError, e:
|
||||
raise self.module.OperationalError("%s; used connection string %r" % (e, self.dsn))
|
||||
if self.autoCommit:
|
||||
# psycopg2 does not have an autocommit method.
|
||||
if hasattr(conn, 'autocommit'):
|
||||
conn.autocommit(1)
|
||||
c = conn.cursor()
|
||||
if self.schema:
|
||||
c.execute("SET search_path TO " + self.schema)
|
||||
dbEncoding = self.dbEncoding
|
||||
if dbEncoding:
|
||||
c.execute("SET client_encoding TO '%s'" % dbEncoding)
|
||||
return conn
|
||||
|
||||
def _queryInsertID(self, conn, soInstance, id, names, values):
|
||||
table = soInstance.sqlmeta.table
|
||||
idName = soInstance.sqlmeta.idName
|
||||
sequenceName = soInstance.sqlmeta.idSequence or \
|
||||
'%s_%s_seq' % (table, idName)
|
||||
c = conn.cursor()
|
||||
if id is None:
|
||||
c.execute("SELECT NEXTVAL('%s')" % sequenceName)
|
||||
id = c.fetchone()[0]
|
||||
names = [idName] + names
|
||||
values = [id] + values
|
||||
q = self._insertSQL(table, names, values)
|
||||
if self.debug:
|
||||
self.printDebug(conn, q, 'QueryIns')
|
||||
c.execute(q)
|
||||
if self.debugOutput:
|
||||
self.printDebug(conn, id, 'QueryIns', 'result')
|
||||
return id
|
||||
|
||||
@classmethod
|
||||
def _queryAddLimitOffset(cls, query, start, end):
|
||||
if not start:
|
||||
return "%s LIMIT %i" % (query, end)
|
||||
if not end:
|
||||
return "%s OFFSET %i" % (query, start)
|
||||
return "%s LIMIT %i OFFSET %i" % (query, end-start, start)
|
||||
|
||||
def createColumn(self, soClass, col):
|
||||
return col.postgresCreateSQL()
|
||||
|
||||
def createReferenceConstraint(self, soClass, col):
|
||||
return col.postgresCreateReferenceConstraint()
|
||||
|
||||
def createIndexSQL(self, soClass, index):
|
||||
return index.postgresCreateIndexSQL(soClass)
|
||||
|
||||
def createIDColumn(self, soClass):
|
||||
key_type = {int: "SERIAL", str: "TEXT"}[soClass.sqlmeta.idType]
|
||||
return '%s %s PRIMARY KEY' % (soClass.sqlmeta.idName, key_type)
|
||||
|
||||
def dropTable(self, tableName, cascade=False):
|
||||
self.query("DROP TABLE %s %s" % (tableName,
|
||||
cascade and 'CASCADE' or ''))
|
||||
|
||||
def joinSQLType(self, join):
|
||||
return 'INT NOT NULL'
|
||||
|
||||
def tableExists(self, tableName):
|
||||
result = self.queryOne("SELECT COUNT(relname) FROM pg_class WHERE relname = %s"
|
||||
% self.sqlrepr(tableName))
|
||||
return result[0]
|
||||
|
||||
def addColumn(self, tableName, column):
|
||||
self.query('ALTER TABLE %s ADD COLUMN %s' %
|
||||
(tableName,
|
||||
column.postgresCreateSQL()))
|
||||
|
||||
def delColumn(self, sqlmeta, column):
|
||||
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
|
||||
|
||||
def columnsFromSchema(self, tableName, soClass):
|
||||
|
||||
keyQuery = """
|
||||
SELECT pg_catalog.pg_get_constraintdef(oid) as condef
|
||||
FROM pg_catalog.pg_constraint r
|
||||
WHERE r.conrelid = %s::regclass AND r.contype = 'f'"""
|
||||
|
||||
colQuery = """
|
||||
SELECT a.attname,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod), a.attnotnull,
|
||||
(SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
|
||||
WHERE d.adrelid=a.attrelid AND d.adnum = a.attnum)
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attrelid =%s::regclass
|
||||
AND a.attnum > 0 AND NOT a.attisdropped
|
||||
ORDER BY a.attnum"""
|
||||
|
||||
primaryKeyQuery = """
|
||||
SELECT pg_index.indisprimary,
|
||||
pg_catalog.pg_get_indexdef(pg_index.indexrelid)
|
||||
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
|
||||
pg_catalog.pg_index AS pg_index
|
||||
WHERE c.relname = %s
|
||||
AND c.oid = pg_index.indrelid
|
||||
AND pg_index.indexrelid = c2.oid
|
||||
AND pg_index.indisprimary
|
||||
"""
|
||||
|
||||
keyData = self.queryAll(keyQuery % self.sqlrepr(tableName))
|
||||
keyRE = re.compile(r"\((.+)\) REFERENCES (.+)\(")
|
||||
keymap = {}
|
||||
|
||||
for (condef,) in keyData:
|
||||
match = keyRE.search(condef)
|
||||
if match:
|
||||
field, reftable = match.groups()
|
||||
keymap[field] = reftable.capitalize()
|
||||
|
||||
primaryData = self.queryAll(primaryKeyQuery % self.sqlrepr(tableName))
|
||||
primaryRE = re.compile(r'CREATE .*? USING .* \((.+?)\)')
|
||||
primaryKey = None
|
||||
for isPrimary, indexDef in primaryData:
|
||||
match = primaryRE.search(indexDef)
|
||||
assert match, "Unparseable contraint definition: %r" % indexDef
|
||||
assert primaryKey is None, "Already found primary key (%r), then found: %r" % (primaryKey, indexDef)
|
||||
primaryKey = match.group(1)
|
||||
assert primaryKey, "No primary key found in table %r" % tableName
|
||||
if primaryKey.startswith('"'):
|
||||
assert primaryKey.endswith('"')
|
||||
primaryKey = primaryKey[1:-1]
|
||||
|
||||
colData = self.queryAll(colQuery % self.sqlrepr(tableName))
|
||||
results = []
|
||||
if self.unicodeCols:
|
||||
client_encoding = self.queryOne("SHOW client_encoding")[0]
|
||||
for field, t, notnull, defaultstr in colData:
|
||||
if field == primaryKey:
|
||||
continue
|
||||
if field in keymap:
|
||||
colClass = col.ForeignKey
|
||||
kw = {'foreignKey': soClass.sqlmeta.style.dbTableToPythonClass(keymap[field])}
|
||||
name = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
|
||||
if name.endswith('ID'):
|
||||
name = name[:-2]
|
||||
kw['name'] = name
|
||||
else:
|
||||
colClass, kw = self.guessClass(t)
|
||||
if self.unicodeCols and colClass is col.StringCol:
|
||||
colClass = col.UnicodeCol
|
||||
kw['dbEncoding'] = client_encoding
|
||||
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
|
||||
kw['dbName'] = field
|
||||
kw['notNone'] = notnull
|
||||
if defaultstr is not None:
|
||||
kw['default'] = self.defaultFromSchema(colClass, defaultstr)
|
||||
elif not notnull:
|
||||
kw['default'] = None
|
||||
results.append(colClass(**kw))
|
||||
return results
|
||||
|
||||
def guessClass(self, t):
|
||||
if t.count('point'): # poINT before INT
|
||||
return col.StringCol, {}
|
||||
elif t.count('int'):
|
||||
return col.IntCol, {}
|
||||
elif t.count('varying') or t.count('varchar'):
|
||||
if '(' in t:
|
||||
return col.StringCol, {'length': int(t[t.index('(')+1:-1])}
|
||||
else: # varchar without length in Postgres means any length
|
||||
return col.StringCol, {}
|
||||
elif t.startswith('character('):
|
||||
return col.StringCol, {'length': int(t[t.index('(')+1:-1]),
|
||||
'varchar': False}
|
||||
elif t.count('float') or t.count('real') or t.count('double'):
|
||||
return col.FloatCol, {}
|
||||
elif t == 'text':
|
||||
return col.StringCol, {}
|
||||
elif t.startswith('timestamp'):
|
||||
return col.DateTimeCol, {}
|
||||
elif t.startswith('datetime'):
|
||||
return col.DateTimeCol, {}
|
||||
elif t.startswith('date'):
|
||||
return col.DateCol, {}
|
||||
elif t.startswith('bool'):
|
||||
return col.BoolCol, {}
|
||||
elif t.startswith('bytea'):
|
||||
return col.BLOBCol, {}
|
||||
else:
|
||||
return col.Col, {}
|
||||
|
||||
def defaultFromSchema(self, colClass, defaultstr):
|
||||
"""
|
||||
If the default can be converted to a python constant, convert it.
|
||||
Otherwise return is as a sqlbuilder constant.
|
||||
"""
|
||||
if colClass == col.BoolCol:
|
||||
if defaultstr == 'false':
|
||||
return False
|
||||
elif defaultstr == 'true':
|
||||
return True
|
||||
return getattr(sqlbuilder.const, defaultstr)
|
||||
|
||||
def _createOrDropDatabase(self, op="CREATE"):
|
||||
# We have to connect to *some* database, so we'll connect to
|
||||
# template1, which is a common open database.
|
||||
# @@: This doesn't use self.use_dsn or self.dsn_dict
|
||||
if self.driver == 'pygresql':
|
||||
dsn = '%s:template1:%s:%s' % (
|
||||
self.host or '', self.user or '', self.password or '')
|
||||
else:
|
||||
dsn = 'dbname=template1'
|
||||
if self.user:
|
||||
dsn += ' user=%s' % self.user
|
||||
if self.password:
|
||||
dsn += ' password=%s' % self.password
|
||||
if self.host:
|
||||
dsn += ' host=%s' % self.host
|
||||
conn = self.module.connect(dsn)
|
||||
cur = conn.cursor()
|
||||
# We must close the transaction with a commit so that
|
||||
# the CREATE DATABASE can work (which can't be in a transaction):
|
||||
cur.execute('COMMIT')
|
||||
cur.execute('%s DATABASE %s' % (op, self.db))
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
def createEmptyDatabase(self):
|
||||
self._createOrDropDatabase()
|
||||
|
||||
def dropDatabase(self):
|
||||
self._createOrDropDatabase(op="DROP")
|
||||
|
||||
|
||||
# Converter for psycopg Binary type.
|
||||
def PsycoBinaryConverter(value, db):
|
||||
assert db == 'postgres'
|
||||
return str(value)
|
||||
Reference in New Issue
Block a user