from sqlobject.dbconnection import DBAPI import re from sqlobject import col from sqlobject import sqlbuilder from sqlobject.converters import registerConverter class PostgresConnection(DBAPI): supportTransactions = True dbName = 'postgres' schemes = [dbName, 'postgresql'] def __init__(self, dsn=None, host=None, port=None, db=None, user=None, password=None, **kw): drivers = kw.pop('driver', None) or 'psycopg' for driver in drivers.split(','): driver = driver.strip() if not driver: continue try: if driver == 'psycopg2': import psycopg2 as psycopg elif driver == 'psycopg1': import psycopg elif driver == 'psycopg': try: import psycopg2 as psycopg except ImportError: import psycopg elif driver == 'pygresql': import pgdb self.module = pgdb else: raise ValueError('Unknown PostgreSQL driver "%s", expected psycopg2, psycopg1 or pygresql' % driver) except ImportError: pass else: break else: raise ImportError('Cannot find a PostgreSQL driver, tried %s' % drivers) if driver.startswith('psycopg'): self.module = psycopg # Register a converter for psycopg Binary type. registerConverter(type(psycopg.Binary('')), PsycoBinaryConverter) self.user = user self.host = host self.port = port self.db = db self.password = password self.dsn_dict = dsn_dict = {} if host: dsn_dict["host"] = host if port: if driver == 'pygresql': dsn_dict["host"] = "%s:%d" % (host, port) else: if psycopg.__version__.split('.')[0] == '1': dsn_dict["port"] = str(port) else: dsn_dict["port"] = port if db: dsn_dict["database"] = db if user: dsn_dict["user"] = user if password: dsn_dict["password"] = password sslmode = kw.pop("sslmode", None) if sslmode: dsn_dict["sslmode"] = sslmode self.use_dsn = dsn is not None if dsn is None: if driver == 'pygresql': dsn = '' if host: dsn += host dsn += ':' if db: dsn += db dsn += ':' if user: dsn += user dsn += ':' if password: dsn += password else: dsn = [] if db: dsn.append('dbname=%s' % db) if user: dsn.append('user=%s' % user) if password: dsn.append('password=%s' % password) if host: dsn.append('host=%s' % host) if port: dsn.append('port=%d' % port) if sslmode: dsn.append('sslmode=%s' % sslmode) dsn = ' '.join(dsn) self.driver = driver self.dsn = dsn self.unicodeCols = kw.pop('unicodeCols', False) self.schema = kw.pop('schema', None) self.dbEncoding = kw.pop("charset", None) DBAPI.__init__(self, **kw) @classmethod def _connectionFromParams(cls, user, password, host, port, path, args): path = path.strip('/') if (host is None) and path.count('/'): # Non-default unix socket path_parts = path.split('/') host = '/' + '/'.join(path_parts[:-1]) path = path_parts[-1] return cls(host=host, port=port, db=path, user=user, password=password, **args) def _setAutoCommit(self, conn, auto): # psycopg2 does not have an autocommit method. if hasattr(conn, 'autocommit'): conn.autocommit(auto) def makeConnection(self): try: if self.use_dsn: conn = self.module.connect(self.dsn) else: conn = self.module.connect(**self.dsn_dict) except self.module.OperationalError, e: raise self.module.OperationalError("%s; used connection string %r" % (e, self.dsn)) if self.autoCommit: # psycopg2 does not have an autocommit method. if hasattr(conn, 'autocommit'): conn.autocommit(1) c = conn.cursor() if self.schema: c.execute("SET search_path TO " + self.schema) dbEncoding = self.dbEncoding if dbEncoding: c.execute("SET client_encoding TO '%s'" % dbEncoding) return conn def _queryInsertID(self, conn, soInstance, id, names, values): table = soInstance.sqlmeta.table idName = soInstance.sqlmeta.idName sequenceName = soInstance.sqlmeta.idSequence or \ '%s_%s_seq' % (table, idName) c = conn.cursor() if id is None: c.execute("SELECT NEXTVAL('%s')" % sequenceName) id = c.fetchone()[0] names = [idName] + names values = [id] + values q = self._insertSQL(table, names, values) if self.debug: self.printDebug(conn, q, 'QueryIns') c.execute(q) if self.debugOutput: self.printDebug(conn, id, 'QueryIns', 'result') return id @classmethod def _queryAddLimitOffset(cls, query, start, end): if not start: return "%s LIMIT %i" % (query, end) if not end: return "%s OFFSET %i" % (query, start) return "%s LIMIT %i OFFSET %i" % (query, end-start, start) def createColumn(self, soClass, col): return col.postgresCreateSQL() def createReferenceConstraint(self, soClass, col): return col.postgresCreateReferenceConstraint() def createIndexSQL(self, soClass, index): return index.postgresCreateIndexSQL(soClass) def createIDColumn(self, soClass): key_type = {int: "SERIAL", str: "TEXT"}[soClass.sqlmeta.idType] return '%s %s PRIMARY KEY' % (soClass.sqlmeta.idName, key_type) def dropTable(self, tableName, cascade=False): self.query("DROP TABLE %s %s" % (tableName, cascade and 'CASCADE' or '')) def joinSQLType(self, join): return 'INT NOT NULL' def tableExists(self, tableName): result = self.queryOne("SELECT COUNT(relname) FROM pg_class WHERE relname = %s" % self.sqlrepr(tableName)) return result[0] def addColumn(self, tableName, column): self.query('ALTER TABLE %s ADD COLUMN %s' % (tableName, column.postgresCreateSQL())) def delColumn(self, sqlmeta, column): self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName)) def columnsFromSchema(self, tableName, soClass): keyQuery = """ SELECT pg_catalog.pg_get_constraintdef(oid) as condef FROM pg_catalog.pg_constraint r WHERE r.conrelid = %s::regclass AND r.contype = 'f'""" colQuery = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), a.attnotnull, (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d WHERE d.adrelid=a.attrelid AND d.adnum = a.attnum) FROM pg_catalog.pg_attribute a WHERE a.attrelid =%s::regclass AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum""" primaryKeyQuery = """ SELECT pg_index.indisprimary, pg_catalog.pg_get_indexdef(pg_index.indexrelid) FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, pg_catalog.pg_index AS pg_index WHERE c.relname = %s AND c.oid = pg_index.indrelid AND pg_index.indexrelid = c2.oid AND pg_index.indisprimary """ keyData = self.queryAll(keyQuery % self.sqlrepr(tableName)) keyRE = re.compile(r"\((.+)\) REFERENCES (.+)\(") keymap = {} for (condef,) in keyData: match = keyRE.search(condef) if match: field, reftable = match.groups() keymap[field] = reftable.capitalize() primaryData = self.queryAll(primaryKeyQuery % self.sqlrepr(tableName)) primaryRE = re.compile(r'CREATE .*? USING .* \((.+?)\)') primaryKey = None for isPrimary, indexDef in primaryData: match = primaryRE.search(indexDef) assert match, "Unparseable contraint definition: %r" % indexDef assert primaryKey is None, "Already found primary key (%r), then found: %r" % (primaryKey, indexDef) primaryKey = match.group(1) assert primaryKey, "No primary key found in table %r" % tableName if primaryKey.startswith('"'): assert primaryKey.endswith('"') primaryKey = primaryKey[1:-1] colData = self.queryAll(colQuery % self.sqlrepr(tableName)) results = [] if self.unicodeCols: client_encoding = self.queryOne("SHOW client_encoding")[0] for field, t, notnull, defaultstr in colData: if field == primaryKey: continue if field in keymap: colClass = col.ForeignKey kw = {'foreignKey': soClass.sqlmeta.style.dbTableToPythonClass(keymap[field])} name = soClass.sqlmeta.style.dbColumnToPythonAttr(field) if name.endswith('ID'): name = name[:-2] kw['name'] = name else: colClass, kw = self.guessClass(t) if self.unicodeCols and colClass is col.StringCol: colClass = col.UnicodeCol kw['dbEncoding'] = client_encoding kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field) kw['dbName'] = field kw['notNone'] = notnull if defaultstr is not None: kw['default'] = self.defaultFromSchema(colClass, defaultstr) elif not notnull: kw['default'] = None results.append(colClass(**kw)) return results def guessClass(self, t): if t.count('point'): # poINT before INT return col.StringCol, {} elif t.count('int'): return col.IntCol, {} elif t.count('varying') or t.count('varchar'): if '(' in t: return col.StringCol, {'length': int(t[t.index('(')+1:-1])} else: # varchar without length in Postgres means any length return col.StringCol, {} elif t.startswith('character('): return col.StringCol, {'length': int(t[t.index('(')+1:-1]), 'varchar': False} elif t.count('float') or t.count('real') or t.count('double'): return col.FloatCol, {} elif t == 'text': return col.StringCol, {} elif t.startswith('timestamp'): return col.DateTimeCol, {} elif t.startswith('datetime'): return col.DateTimeCol, {} elif t.startswith('date'): return col.DateCol, {} elif t.startswith('bool'): return col.BoolCol, {} elif t.startswith('bytea'): return col.BLOBCol, {} else: return col.Col, {} def defaultFromSchema(self, colClass, defaultstr): """ If the default can be converted to a python constant, convert it. Otherwise return is as a sqlbuilder constant. """ if colClass == col.BoolCol: if defaultstr == 'false': return False elif defaultstr == 'true': return True return getattr(sqlbuilder.const, defaultstr) def _createOrDropDatabase(self, op="CREATE"): # We have to connect to *some* database, so we'll connect to # template1, which is a common open database. # @@: This doesn't use self.use_dsn or self.dsn_dict if self.driver == 'pygresql': dsn = '%s:template1:%s:%s' % ( self.host or '', self.user or '', self.password or '') else: dsn = 'dbname=template1' if self.user: dsn += ' user=%s' % self.user if self.password: dsn += ' password=%s' % self.password if self.host: dsn += ' host=%s' % self.host conn = self.module.connect(dsn) cur = conn.cursor() # We must close the transaction with a commit so that # the CREATE DATABASE can work (which can't be in a transaction): cur.execute('COMMIT') cur.execute('%s DATABASE %s' % (op, self.db)) cur.close() conn.close() def createEmptyDatabase(self): self._createOrDropDatabase() def dropDatabase(self): self._createOrDropDatabase(op="DROP") # Converter for psycopg Binary type. def PsycoBinaryConverter(value, db): assert db == 'postgres' return str(value)