mirror of
https://github.com/djohnlewis/stackdump
synced 2025-01-23 15:11:36 +00:00
306 lines
12 KiB
Python
306 lines
12 KiB
Python
|
from sqlobject import col
|
||
|
from sqlobject.dbconnection import DBAPI
|
||
|
from sqlobject.dberrors import *
|
||
|
|
||
|
class ErrorMessage(str):
|
||
|
def __new__(cls, e, append_msg=''):
|
||
|
obj = str.__new__(cls, e[1] + append_msg)
|
||
|
obj.code = int(e[0])
|
||
|
obj.module = e.__module__
|
||
|
obj.exception = e.__class__.__name__
|
||
|
return obj
|
||
|
|
||
|
class MySQLConnection(DBAPI):
|
||
|
|
||
|
supportTransactions = False
|
||
|
dbName = 'mysql'
|
||
|
schemes = [dbName]
|
||
|
|
||
|
def __init__(self, db, user, password='', host='localhost', port=0, **kw):
|
||
|
import MySQLdb, MySQLdb.constants.CR, MySQLdb.constants.ER
|
||
|
self.module = MySQLdb
|
||
|
self.host = host
|
||
|
self.port = port
|
||
|
self.db = db
|
||
|
self.user = user
|
||
|
self.password = password
|
||
|
self.kw = {}
|
||
|
for key in ("unix_socket", "init_command",
|
||
|
"read_default_file", "read_default_group", "conv"):
|
||
|
if key in kw:
|
||
|
self.kw[key] = kw.pop(key)
|
||
|
for key in ("connect_timeout", "compress", "named_pipe", "use_unicode",
|
||
|
"client_flag", "local_infile"):
|
||
|
if key in kw:
|
||
|
self.kw[key] = int(kw.pop(key))
|
||
|
for key in ("ssl_key", "ssl_cert", "ssl_ca", "ssl_capath"):
|
||
|
if key in kw:
|
||
|
if "ssl" not in self.kw:
|
||
|
self.kw["ssl"] = {}
|
||
|
self.kw["ssl"][key[4:]] = kw.pop(key)
|
||
|
if "charset" in kw:
|
||
|
self.dbEncoding = self.kw["charset"] = kw.pop("charset")
|
||
|
else:
|
||
|
self.dbEncoding = None
|
||
|
|
||
|
# MySQLdb < 1.2.1: only ascii
|
||
|
# MySQLdb = 1.2.1: only unicode
|
||
|
# MySQLdb > 1.2.1: both ascii and unicode
|
||
|
self.need_unicode = (self.module.version_info[:3] >= (1, 2, 1)) and (self.module.version_info[:3] < (1, 2, 2))
|
||
|
|
||
|
DBAPI.__init__(self, **kw)
|
||
|
|
||
|
@classmethod
|
||
|
def _connectionFromParams(cls, user, password, host, port, path, args):
|
||
|
return cls(db=path.strip('/'), user=user or '', password=password or '',
|
||
|
host=host or 'localhost', port=port or 0, **args)
|
||
|
|
||
|
def makeConnection(self):
|
||
|
dbEncoding = self.dbEncoding
|
||
|
if dbEncoding:
|
||
|
from MySQLdb.connections import Connection
|
||
|
if not hasattr(Connection, 'set_character_set'):
|
||
|
# monkeypatch pre MySQLdb 1.2.1
|
||
|
def character_set_name(self):
|
||
|
return dbEncoding + '_' + dbEncoding
|
||
|
Connection.character_set_name = character_set_name
|
||
|
try:
|
||
|
conn = self.module.connect(host=self.host, port=self.port,
|
||
|
db=self.db, user=self.user, passwd=self.password, **self.kw)
|
||
|
if self.module.version_info[:3] >= (1, 2, 2):
|
||
|
conn.ping(True) # Attempt to reconnect. This setting is persistent.
|
||
|
except self.module.OperationalError, e:
|
||
|
conninfo = "; used connection string: host=%(host)s, port=%(port)s, db=%(db)s, user=%(user)s" % self.__dict__
|
||
|
raise OperationalError(ErrorMessage(e, conninfo))
|
||
|
|
||
|
if hasattr(conn, 'autocommit'):
|
||
|
conn.autocommit(bool(self.autoCommit))
|
||
|
|
||
|
if dbEncoding:
|
||
|
if hasattr(conn, 'set_character_set'): # MySQLdb 1.2.1 and later
|
||
|
conn.set_character_set(dbEncoding)
|
||
|
else: # pre MySQLdb 1.2.1
|
||
|
# works along with monkeypatching code above
|
||
|
conn.query("SET NAMES %s" % dbEncoding)
|
||
|
|
||
|
return conn
|
||
|
|
||
|
def _setAutoCommit(self, conn, auto):
|
||
|
if hasattr(conn, 'autocommit'):
|
||
|
conn.autocommit(auto)
|
||
|
|
||
|
def _executeRetry(self, conn, cursor, query):
|
||
|
if self.need_unicode and not isinstance(query, unicode):
|
||
|
try:
|
||
|
query = unicode(query, self.dbEncoding)
|
||
|
except UnicodeError:
|
||
|
pass
|
||
|
|
||
|
# When a server connection is lost and a query is attempted, most of
|
||
|
# the time the query will raise a SERVER_LOST exception, then at the
|
||
|
# second attempt to execute it, the mysql lib will reconnect and
|
||
|
# succeed. However is a few cases, the first attempt raises the
|
||
|
# SERVER_GONE exception, the second attempt the SERVER_LOST exception
|
||
|
# and only the third succeeds. Thus the 3 in the loop count.
|
||
|
# If it doesn't reconnect even after 3 attempts, while the database is
|
||
|
# up and running, it is because a 5.0.3 (or newer) server is used
|
||
|
# which no longer permits autoreconnects by default. In that case a
|
||
|
# reconnect flag must be set when making the connection to indicate
|
||
|
# that autoreconnecting is desired. In MySQLdb 1.2.2 or newer this is
|
||
|
# done by calling ping(True) on the connection.
|
||
|
for count in range(3):
|
||
|
try:
|
||
|
return cursor.execute(query)
|
||
|
except self.module.OperationalError, e:
|
||
|
if e.args[0] in (self.module.constants.CR.SERVER_GONE_ERROR, self.module.constants.CR.SERVER_LOST):
|
||
|
if count == 2:
|
||
|
raise OperationalError(ErrorMessage(e))
|
||
|
if self.debug:
|
||
|
self.printDebug(conn, str(e), 'ERROR')
|
||
|
else:
|
||
|
raise OperationalError(ErrorMessage(e))
|
||
|
except self.module.IntegrityError, e:
|
||
|
msg = ErrorMessage(e)
|
||
|
if e.args[0] == self.module.constants.ER.DUP_ENTRY:
|
||
|
raise DuplicateEntryError(msg)
|
||
|
else:
|
||
|
raise IntegrityError(msg)
|
||
|
except self.module.InternalError, e:
|
||
|
raise InternalError(ErrorMessage(e))
|
||
|
except self.module.ProgrammingError, e:
|
||
|
raise ProgrammingError(ErrorMessage(e))
|
||
|
except self.module.DataError, e:
|
||
|
raise DataError(ErrorMessage(e))
|
||
|
except self.module.NotSupportedError, e:
|
||
|
raise NotSupportedError(ErrorMessage(e))
|
||
|
except self.module.DatabaseError, e:
|
||
|
raise DatabaseError(ErrorMessage(e))
|
||
|
except self.module.InterfaceError, e:
|
||
|
raise InterfaceError(ErrorMessage(e))
|
||
|
except self.module.Warning, e:
|
||
|
raise Warning(ErrorMessage(e))
|
||
|
except self.module.Error, e:
|
||
|
raise Error(ErrorMessage(e))
|
||
|
|
||
|
def _queryInsertID(self, conn, soInstance, id, names, values):
|
||
|
table = soInstance.sqlmeta.table
|
||
|
idName = soInstance.sqlmeta.idName
|
||
|
c = conn.cursor()
|
||
|
if id is not None:
|
||
|
names = [idName] + names
|
||
|
values = [id] + values
|
||
|
q = self._insertSQL(table, names, values)
|
||
|
if self.debug:
|
||
|
self.printDebug(conn, q, 'QueryIns')
|
||
|
self._executeRetry(conn, c, q)
|
||
|
if id is None:
|
||
|
try:
|
||
|
id = c.lastrowid
|
||
|
except AttributeError:
|
||
|
id = c.insert_id()
|
||
|
if self.debugOutput:
|
||
|
self.printDebug(conn, id, 'QueryIns', 'result')
|
||
|
return id
|
||
|
|
||
|
@classmethod
|
||
|
def _queryAddLimitOffset(cls, query, start, end):
|
||
|
if not start:
|
||
|
return "%s LIMIT %i" % (query, end)
|
||
|
if not end:
|
||
|
return "%s LIMIT %i, -1" % (query, start)
|
||
|
return "%s LIMIT %i, %i" % (query, start, end-start)
|
||
|
|
||
|
def createReferenceConstraint(self, soClass, col):
|
||
|
return col.mysqlCreateReferenceConstraint()
|
||
|
|
||
|
def createColumn(self, soClass, col):
|
||
|
return col.mysqlCreateSQL()
|
||
|
|
||
|
def createIndexSQL(self, soClass, index):
|
||
|
return index.mysqlCreateIndexSQL(soClass)
|
||
|
|
||
|
def createIDColumn(self, soClass):
|
||
|
if soClass.sqlmeta.idType == str:
|
||
|
return '%s TEXT PRIMARY KEY' % soClass.sqlmeta.idName
|
||
|
return '%s INT PRIMARY KEY AUTO_INCREMENT' % soClass.sqlmeta.idName
|
||
|
|
||
|
def joinSQLType(self, join):
|
||
|
return 'INT NOT NULL'
|
||
|
|
||
|
def tableExists(self, tableName):
|
||
|
try:
|
||
|
# Use DESCRIBE instead of SHOW TABLES because SHOW TABLES
|
||
|
# assumes there is a default database selected
|
||
|
# which is not always True (for an embedded application, e.g.)
|
||
|
self.query('DESCRIBE %s' % (tableName))
|
||
|
return True
|
||
|
except ProgrammingError, e:
|
||
|
if e[0].code == 1146: # ER_NO_SUCH_TABLE
|
||
|
return False
|
||
|
raise
|
||
|
|
||
|
def addColumn(self, tableName, column):
|
||
|
self.query('ALTER TABLE %s ADD COLUMN %s' %
|
||
|
(tableName,
|
||
|
column.mysqlCreateSQL()))
|
||
|
|
||
|
def delColumn(self, sqlmeta, column):
|
||
|
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
|
||
|
|
||
|
def columnsFromSchema(self, tableName, soClass):
|
||
|
colData = self.queryAll("SHOW COLUMNS FROM %s"
|
||
|
% tableName)
|
||
|
results = []
|
||
|
for field, t, nullAllowed, key, default, extra in colData:
|
||
|
if field == soClass.sqlmeta.idName:
|
||
|
continue
|
||
|
colClass, kw = self.guessClass(t)
|
||
|
if self.kw.get('use_unicode') and colClass is col.StringCol:
|
||
|
colClass = col.UnicodeCol
|
||
|
if self.dbEncoding: kw['dbEncoding'] = self.dbEncoding
|
||
|
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
|
||
|
kw['dbName'] = field
|
||
|
|
||
|
# Since MySQL 5.0, 'NO' is returned in the NULL column (SQLObject expected '')
|
||
|
kw['notNone'] = (nullAllowed.upper() != 'YES' and True or False)
|
||
|
|
||
|
if default and t.startswith('int'):
|
||
|
kw['default'] = int(default)
|
||
|
elif default and t.startswith('float'):
|
||
|
kw['default'] = float(default)
|
||
|
elif default == 'CURRENT_TIMESTAMP' and t == 'timestamp':
|
||
|
kw['default'] = None
|
||
|
elif default and colClass is col.BoolCol:
|
||
|
kw['default'] = int(default) and True or False
|
||
|
else:
|
||
|
kw['default'] = default
|
||
|
# @@ skip key...
|
||
|
# @@ skip extra...
|
||
|
results.append(colClass(**kw))
|
||
|
return results
|
||
|
|
||
|
def guessClass(self, t):
|
||
|
if t.startswith('int'):
|
||
|
return col.IntCol, {}
|
||
|
elif t.startswith('enum'):
|
||
|
values = []
|
||
|
for i in t[5:-1].split(','): # take the enum() off and split
|
||
|
values.append(i[1:-1]) # remove the surrounding \'
|
||
|
return col.EnumCol, {'enumValues': values}
|
||
|
elif t.startswith('double'):
|
||
|
return col.FloatCol, {}
|
||
|
elif t.startswith('varchar'):
|
||
|
colType = col.StringCol
|
||
|
if self.kw.get('use_unicode', False):
|
||
|
colType = col.UnicodeCol
|
||
|
if t.endswith('binary'):
|
||
|
return colType, {'length': int(t[8:-8]),
|
||
|
'char_binary': True}
|
||
|
else:
|
||
|
return colType, {'length': int(t[8:-1])}
|
||
|
elif t.startswith('char'):
|
||
|
if t.endswith('binary'):
|
||
|
return col.StringCol, {'length': int(t[5:-8]),
|
||
|
'varchar': False,
|
||
|
'char_binary': True}
|
||
|
else:
|
||
|
return col.StringCol, {'length': int(t[5:-1]),
|
||
|
'varchar': False}
|
||
|
elif t.startswith('datetime'):
|
||
|
return col.DateTimeCol, {}
|
||
|
elif t.startswith('date'):
|
||
|
return col.DateCol, {}
|
||
|
elif t.startswith('time'):
|
||
|
return col.TimeCol, {}
|
||
|
elif t.startswith('timestamp'):
|
||
|
return col.TimestampCol, {}
|
||
|
elif t.startswith('bool'):
|
||
|
return col.BoolCol, {}
|
||
|
elif t.startswith('tinyblob'):
|
||
|
return col.BLOBCol, {"length": 2**8-1}
|
||
|
elif t.startswith('tinytext'):
|
||
|
return col.StringCol, {"length": 2**8-1, "varchar": True}
|
||
|
elif t.startswith('blob'):
|
||
|
return col.BLOBCol, {"length": 2**16-1}
|
||
|
elif t.startswith('text'):
|
||
|
return col.StringCol, {"length": 2**16-1, "varchar": True}
|
||
|
elif t.startswith('mediumblob'):
|
||
|
return col.BLOBCol, {"length": 2**24-1}
|
||
|
elif t.startswith('mediumtext'):
|
||
|
return col.StringCol, {"length": 2**24-1, "varchar": True}
|
||
|
elif t.startswith('longblob'):
|
||
|
return col.BLOBCol, {"length": 2**32}
|
||
|
elif t.startswith('longtext'):
|
||
|
return col.StringCol, {"length": 2**32, "varchar": True}
|
||
|
else:
|
||
|
return col.Col, {}
|
||
|
|
||
|
def _createOrDropDatabase(self, op="CREATE"):
|
||
|
self.query('%s DATABASE %s' % (op, self.db))
|
||
|
|
||
|
def createEmptyDatabase(self):
|
||
|
self._createOrDropDatabase()
|
||
|
|
||
|
def dropDatabase(self):
|
||
|
self._createOrDropDatabase(op="DROP")
|