mirror of
https://github.com/djohnlewis/stackdump
synced 2025-01-22 22:51:36 +00:00
307 lines
11 KiB
Python
307 lines
11 KiB
Python
|
from sqlobject.dbconnection import DBAPI
|
||
|
from sqlobject import col
|
||
|
import re
|
||
|
|
||
|
class MSSQLConnection(DBAPI):
|
||
|
|
||
|
supportTransactions = True
|
||
|
dbName = 'mssql'
|
||
|
schemes = [dbName]
|
||
|
|
||
|
limit_re = re.compile('^\s*(select )(.*)', re.IGNORECASE)
|
||
|
|
||
|
def __init__(self, db, user, password='', host='localhost', port=None,
|
||
|
autoCommit=0, **kw):
|
||
|
drivers = kw.pop('driver', None) or 'adodb,pymssql'
|
||
|
for driver in drivers.split(','):
|
||
|
driver = driver.strip()
|
||
|
if not driver:
|
||
|
continue
|
||
|
try:
|
||
|
if driver in ('adodb', 'adodbapi'):
|
||
|
import adodbapi as sqlmodule
|
||
|
elif driver == 'pymssql':
|
||
|
import pymssql as sqlmodule
|
||
|
else:
|
||
|
raise ValueError('Unknown MSSQL driver "%s", expected adodb or pymssql' % driver)
|
||
|
except ImportError:
|
||
|
pass
|
||
|
else:
|
||
|
break
|
||
|
else:
|
||
|
raise ImportError('Cannot find an MSSQL driver, tried %s' % drivers)
|
||
|
self.module = sqlmodule
|
||
|
|
||
|
if sqlmodule.__name__ == 'adodbapi':
|
||
|
self.dbconnection = sqlmodule.connect
|
||
|
# ADO uses unicode only (AFAIK)
|
||
|
self.usingUnicodeStrings = True
|
||
|
|
||
|
# Need to use SQLNCLI provider for SQL Server Express Edition
|
||
|
if kw.get("ncli"):
|
||
|
conn_str = "Provider=SQLNCLI;"
|
||
|
else:
|
||
|
conn_str = "Provider=SQLOLEDB;"
|
||
|
|
||
|
conn_str += "Data Source=%s;Initial Catalog=%s;"
|
||
|
|
||
|
# MSDE does not allow SQL server login
|
||
|
if kw.get("sspi"):
|
||
|
conn_str += "Integrated Security=SSPI;Persist Security Info=False"
|
||
|
self.make_conn_str = lambda keys: [conn_str % (keys.host, keys.db)]
|
||
|
else:
|
||
|
conn_str += "User Id=%s;Password=%s"
|
||
|
self.make_conn_str = lambda keys: [conn_str % (keys.host, keys.db, keys.user, keys.password)]
|
||
|
|
||
|
kw.pop("sspi", None)
|
||
|
kw.pop("ncli", None)
|
||
|
|
||
|
else: # pymssql
|
||
|
self.dbconnection = sqlmodule.connect
|
||
|
sqlmodule.Binary = lambda st: str(st)
|
||
|
# don't know whether pymssql uses unicode
|
||
|
self.usingUnicodeStrings = False
|
||
|
self.make_conn_str = lambda keys: \
|
||
|
["", keys.user, keys.password, keys.host, keys.db]
|
||
|
|
||
|
self.autoCommit=int(autoCommit)
|
||
|
self.host = host
|
||
|
self.port = port
|
||
|
self.db = db
|
||
|
self.user = user
|
||
|
self.password = password
|
||
|
self.password = password
|
||
|
self._can_use_max_types = None
|
||
|
DBAPI.__init__(self, **kw)
|
||
|
|
||
|
@classmethod
|
||
|
def _connectionFromParams(cls, user, password, host, port, path, args):
|
||
|
path = path.strip('/')
|
||
|
return cls(user=user, password=password,
|
||
|
host=host or 'localhost', port=port, db=path, **args)
|
||
|
|
||
|
def insert_id(self, conn):
|
||
|
"""
|
||
|
insert_id method.
|
||
|
"""
|
||
|
c = conn.cursor()
|
||
|
# converting the identity to an int is ugly, but it gets returned
|
||
|
# as a decimal otherwise :S
|
||
|
c.execute('SELECT CONVERT(INT, @@IDENTITY)')
|
||
|
return c.fetchone()[0]
|
||
|
|
||
|
def makeConnection(self):
|
||
|
con = self.dbconnection( *self.make_conn_str(self) )
|
||
|
cur = con.cursor()
|
||
|
cur.execute('SET ANSI_NULLS ON')
|
||
|
cur.execute("SELECT CAST('12345.21' AS DECIMAL(10, 2))")
|
||
|
self.decimalSeparator = str(cur.fetchone()[0])[-3]
|
||
|
cur.close()
|
||
|
return con
|
||
|
|
||
|
HAS_IDENTITY = """
|
||
|
SELECT col.name, col.status, obj.name
|
||
|
FROM syscolumns col
|
||
|
JOIN sysobjects obj
|
||
|
ON obj.id = col.id
|
||
|
WHERE obj.name = '%s'
|
||
|
and col.autoval is not null
|
||
|
|
||
|
"""
|
||
|
def _hasIdentity(self, conn, table):
|
||
|
query = self.HAS_IDENTITY % table
|
||
|
c = conn.cursor()
|
||
|
c.execute(query)
|
||
|
r = c.fetchone()
|
||
|
return r is not None
|
||
|
|
||
|
def _queryInsertID(self, conn, soInstance, id, names, values):
|
||
|
"""
|
||
|
Insert the Initial with names and values, using id.
|
||
|
"""
|
||
|
table = soInstance.sqlmeta.table
|
||
|
idName = soInstance.sqlmeta.idName
|
||
|
c = conn.cursor()
|
||
|
has_identity = self._hasIdentity(conn, table)
|
||
|
if id is not None:
|
||
|
names = [idName] + names
|
||
|
values = [id] + values
|
||
|
elif has_identity and idName in names:
|
||
|
try:
|
||
|
i = names.index( idName )
|
||
|
if i:
|
||
|
del names[i]
|
||
|
del values[i]
|
||
|
except ValueError:
|
||
|
pass
|
||
|
|
||
|
if has_identity:
|
||
|
if id is not None:
|
||
|
c.execute('SET IDENTITY_INSERT %s ON' % table)
|
||
|
else:
|
||
|
c.execute('SET IDENTITY_INSERT %s OFF' % table)
|
||
|
|
||
|
q = self._insertSQL(table, names, values)
|
||
|
if self.debug:
|
||
|
print 'QueryIns: %s' % q
|
||
|
c.execute(q)
|
||
|
if has_identity:
|
||
|
c.execute('SET IDENTITY_INSERT %s OFF' % table)
|
||
|
|
||
|
if id is None:
|
||
|
id = self.insert_id(conn)
|
||
|
if self.debugOutput:
|
||
|
self.printDebug(conn, id, 'QueryIns', 'result')
|
||
|
return id
|
||
|
|
||
|
@classmethod
|
||
|
def _queryAddLimitOffset(cls, query, start, end):
|
||
|
if end and not start:
|
||
|
limit_str = "SELECT TOP %i" % end
|
||
|
|
||
|
match = cls.limit_re.match(query)
|
||
|
if match and len(match.groups()) == 2:
|
||
|
return ' '.join([limit_str, match.group(2)])
|
||
|
else:
|
||
|
return query
|
||
|
|
||
|
def createReferenceConstraint(self, soClass, col):
|
||
|
return col.mssqlCreateReferenceConstraint()
|
||
|
|
||
|
def createColumn(self, soClass, col):
|
||
|
return col.mssqlCreateSQL(self)
|
||
|
|
||
|
def createIDColumn(self, soClass):
|
||
|
key_type = {int: "INT", str: "TEXT"}[soClass.sqlmeta.idType]
|
||
|
return '%s %s IDENTITY UNIQUE' % (soClass.sqlmeta.idName, key_type)
|
||
|
|
||
|
def createIndexSQL(self, soClass, index):
|
||
|
return index.mssqlCreateIndexSQL(soClass)
|
||
|
|
||
|
def joinSQLType(self, join):
|
||
|
return 'INT NOT NULL'
|
||
|
|
||
|
SHOW_TABLES="SELECT name FROM sysobjects WHERE type='U'"
|
||
|
def tableExists(self, tableName):
|
||
|
for (table,) in self.queryAll(self.SHOW_TABLES):
|
||
|
if table.lower() == tableName.lower():
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def addColumn(self, tableName, column):
|
||
|
self.query('ALTER TABLE %s ADD %s' %
|
||
|
(tableName,
|
||
|
column.mssqlCreateSQL(self)))
|
||
|
|
||
|
def delColumn(self, sqlmeta, column):
|
||
|
self.query('ALTER TABLE %s DROP COLUMN %s' % (tableName.table, column.dbName))
|
||
|
|
||
|
# precision and scale is gotten from column table so that we can create
|
||
|
# decimal columns if needed
|
||
|
SHOW_COLUMNS = """
|
||
|
select
|
||
|
name,
|
||
|
length,
|
||
|
( select name
|
||
|
from systypes
|
||
|
where cast(xusertype as int)= cast(sc.xtype as int)
|
||
|
) datatype,
|
||
|
prec,
|
||
|
scale,
|
||
|
isnullable,
|
||
|
cdefault,
|
||
|
m.text default_text,
|
||
|
isnull(len(autoval),0) is_identity
|
||
|
from syscolumns sc
|
||
|
LEFT OUTER JOIN syscomments m on sc.cdefault = m.id
|
||
|
AND m.colid = 1
|
||
|
where
|
||
|
sc.id in (select id
|
||
|
from sysobjects
|
||
|
where name = '%s')
|
||
|
order by
|
||
|
colorder"""
|
||
|
|
||
|
def columnsFromSchema(self, tableName, soClass):
|
||
|
colData = self.queryAll(self.SHOW_COLUMNS
|
||
|
% tableName)
|
||
|
results = []
|
||
|
for field, size, t, precision, scale, nullAllowed, default, defaultText, is_identity in colData:
|
||
|
if field == soClass.sqlmeta.idName:
|
||
|
continue
|
||
|
# precision is needed for decimal columns
|
||
|
colClass, kw = self.guessClass(t, size, precision, scale)
|
||
|
kw['name'] = str(soClass.sqlmeta.style.dbColumnToPythonAttr(field))
|
||
|
kw['dbName'] = str(field)
|
||
|
kw['notNone'] = not nullAllowed
|
||
|
if (defaultText):
|
||
|
# Strip ( and )
|
||
|
defaultText = defaultText[1:-1]
|
||
|
if defaultText[0] == "'":
|
||
|
defaultText = defaultText[1:-1]
|
||
|
else:
|
||
|
if t == "int" : defaultText = int(defaultText)
|
||
|
if t == "float" : defaultText = float(defaultText)
|
||
|
if t == "numeric": defaultText = float(defaultText)
|
||
|
# TODO need to access the "column" to_python method here--but the object doesn't exists yet
|
||
|
|
||
|
# @@ skip key...
|
||
|
kw['default'] = defaultText
|
||
|
|
||
|
results.append(colClass(**kw))
|
||
|
return results
|
||
|
|
||
|
def _setAutoCommit(self, conn, auto):
|
||
|
#raise Exception(repr(auto))
|
||
|
return
|
||
|
#conn.auto_commit = auto
|
||
|
option = "ON"
|
||
|
if auto == 0:
|
||
|
option = "OFF"
|
||
|
c = conn.cursor()
|
||
|
c.execute("SET AUTOCOMMIT " + option)
|
||
|
conn.setconnectoption(SQL.AUTOCOMMIT, option)
|
||
|
|
||
|
# precision and scale is needed for decimal columns
|
||
|
def guessClass(self, t, size, precision, scale):
|
||
|
"""
|
||
|
Here we take raw values coming out of syscolumns and map to SQLObject class types.
|
||
|
"""
|
||
|
if t.startswith('int'):
|
||
|
return col.IntCol, {}
|
||
|
elif t.startswith('varchar'):
|
||
|
if self.usingUnicodeStrings:
|
||
|
return col.UnicodeCol, {'length': size}
|
||
|
return col.StringCol, {'length': size}
|
||
|
elif t.startswith('char'):
|
||
|
if self.usingUnicodeStrings:
|
||
|
return col.UnicodeCol, {'length': size,
|
||
|
'varchar': False}
|
||
|
return col.StringCol, {'length': size,
|
||
|
'varchar': False}
|
||
|
elif t.startswith('datetime'):
|
||
|
return col.DateTimeCol, {}
|
||
|
elif t.startswith('decimal'):
|
||
|
return col.DecimalCol, {'size': precision, # be careful for awkward naming
|
||
|
'precision': scale}
|
||
|
else:
|
||
|
return col.Col, {}
|
||
|
|
||
|
def server_version(self):
|
||
|
try:
|
||
|
server_version = self.queryAll("SELECT SERVERPROPERTY('productversion')")[0][0]
|
||
|
server_version = server_version.split('.')[0]
|
||
|
server_version = int(server_version)
|
||
|
except:
|
||
|
server_version = None # unknown
|
||
|
self.server_version = server_version # cache it forever
|
||
|
return server_version
|
||
|
|
||
|
def can_use_max_types(self):
|
||
|
if self._can_use_max_types is not None:
|
||
|
return self._can_use_max_types
|
||
|
server_version = self.server_version()
|
||
|
self._can_use_max_types = can_use_max_types = \
|
||
|
(server_version is not None) and (server_version >= 9)
|
||
|
return can_use_max_types
|