1
0
mirror of https://github.com/djohnlewis/stackdump synced 2025-01-23 15:11:36 +00:00
stackdump/python/packages/sqlobject/mssql/mssqlconnection.py

307 lines
11 KiB
Python

from sqlobject.dbconnection import DBAPI
from sqlobject import col
import re
class MSSQLConnection(DBAPI):
supportTransactions = True
dbName = 'mssql'
schemes = [dbName]
limit_re = re.compile('^\s*(select )(.*)', re.IGNORECASE)
def __init__(self, db, user, password='', host='localhost', port=None,
autoCommit=0, **kw):
drivers = kw.pop('driver', None) or 'adodb,pymssql'
for driver in drivers.split(','):
driver = driver.strip()
if not driver:
continue
try:
if driver in ('adodb', 'adodbapi'):
import adodbapi as sqlmodule
elif driver == 'pymssql':
import pymssql as sqlmodule
else:
raise ValueError('Unknown MSSQL driver "%s", expected adodb or pymssql' % driver)
except ImportError:
pass
else:
break
else:
raise ImportError('Cannot find an MSSQL driver, tried %s' % drivers)
self.module = sqlmodule
if sqlmodule.__name__ == 'adodbapi':
self.dbconnection = sqlmodule.connect
# ADO uses unicode only (AFAIK)
self.usingUnicodeStrings = True
# Need to use SQLNCLI provider for SQL Server Express Edition
if kw.get("ncli"):
conn_str = "Provider=SQLNCLI;"
else:
conn_str = "Provider=SQLOLEDB;"
conn_str += "Data Source=%s;Initial Catalog=%s;"
# MSDE does not allow SQL server login
if kw.get("sspi"):
conn_str += "Integrated Security=SSPI;Persist Security Info=False"
self.make_conn_str = lambda keys: [conn_str % (keys.host, keys.db)]
else:
conn_str += "User Id=%s;Password=%s"
self.make_conn_str = lambda keys: [conn_str % (keys.host, keys.db, keys.user, keys.password)]
kw.pop("sspi", None)
kw.pop("ncli", None)
else: # pymssql
self.dbconnection = sqlmodule.connect
sqlmodule.Binary = lambda st: str(st)
# don't know whether pymssql uses unicode
self.usingUnicodeStrings = False
self.make_conn_str = lambda keys: \
["", keys.user, keys.password, keys.host, keys.db]
self.autoCommit=int(autoCommit)
self.host = host
self.port = port
self.db = db
self.user = user
self.password = password
self.password = password
self._can_use_max_types = None
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
path = path.strip('/')
return cls(user=user, password=password,
host=host or 'localhost', port=port, db=path, **args)
def insert_id(self, conn):
"""
insert_id method.
"""
c = conn.cursor()
# converting the identity to an int is ugly, but it gets returned
# as a decimal otherwise :S
c.execute('SELECT CONVERT(INT, @@IDENTITY)')
return c.fetchone()[0]
def makeConnection(self):
con = self.dbconnection( *self.make_conn_str(self) )
cur = con.cursor()
cur.execute('SET ANSI_NULLS ON')
cur.execute("SELECT CAST('12345.21' AS DECIMAL(10, 2))")
self.decimalSeparator = str(cur.fetchone()[0])[-3]
cur.close()
return con
HAS_IDENTITY = """
SELECT col.name, col.status, obj.name
FROM syscolumns col
JOIN sysobjects obj
ON obj.id = col.id
WHERE obj.name = '%s'
and col.autoval is not null
"""
def _hasIdentity(self, conn, table):
query = self.HAS_IDENTITY % table
c = conn.cursor()
c.execute(query)
r = c.fetchone()
return r is not None
def _queryInsertID(self, conn, soInstance, id, names, values):
"""
Insert the Initial with names and values, using id.
"""
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
has_identity = self._hasIdentity(conn, table)
if id is not None:
names = [idName] + names
values = [id] + values
elif has_identity and idName in names:
try:
i = names.index( idName )
if i:
del names[i]
del values[i]
except ValueError:
pass
if has_identity:
if id is not None:
c.execute('SET IDENTITY_INSERT %s ON' % table)
else:
c.execute('SET IDENTITY_INSERT %s OFF' % table)
q = self._insertSQL(table, names, values)
if self.debug:
print 'QueryIns: %s' % q
c.execute(q)
if has_identity:
c.execute('SET IDENTITY_INSERT %s OFF' % table)
if id is None:
id = self.insert_id(conn)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if end and not start:
limit_str = "SELECT TOP %i" % end
match = cls.limit_re.match(query)
if match and len(match.groups()) == 2:
return ' '.join([limit_str, match.group(2)])
else:
return query
def createReferenceConstraint(self, soClass, col):
return col.mssqlCreateReferenceConstraint()
def createColumn(self, soClass, col):
return col.mssqlCreateSQL(self)
def createIDColumn(self, soClass):
key_type = {int: "INT", str: "TEXT"}[soClass.sqlmeta.idType]
return '%s %s IDENTITY UNIQUE' % (soClass.sqlmeta.idName, key_type)
def createIndexSQL(self, soClass, index):
return index.mssqlCreateIndexSQL(soClass)
def joinSQLType(self, join):
return 'INT NOT NULL'
SHOW_TABLES="SELECT name FROM sysobjects WHERE type='U'"
def tableExists(self, tableName):
for (table,) in self.queryAll(self.SHOW_TABLES):
if table.lower() == tableName.lower():
return True
return False
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD %s' %
(tableName,
column.mssqlCreateSQL(self)))
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (tableName.table, column.dbName))
# precision and scale is gotten from column table so that we can create
# decimal columns if needed
SHOW_COLUMNS = """
select
name,
length,
( select name
from systypes
where cast(xusertype as int)= cast(sc.xtype as int)
) datatype,
prec,
scale,
isnullable,
cdefault,
m.text default_text,
isnull(len(autoval),0) is_identity
from syscolumns sc
LEFT OUTER JOIN syscomments m on sc.cdefault = m.id
AND m.colid = 1
where
sc.id in (select id
from sysobjects
where name = '%s')
order by
colorder"""
def columnsFromSchema(self, tableName, soClass):
colData = self.queryAll(self.SHOW_COLUMNS
% tableName)
results = []
for field, size, t, precision, scale, nullAllowed, default, defaultText, is_identity in colData:
if field == soClass.sqlmeta.idName:
continue
# precision is needed for decimal columns
colClass, kw = self.guessClass(t, size, precision, scale)
kw['name'] = str(soClass.sqlmeta.style.dbColumnToPythonAttr(field))
kw['dbName'] = str(field)
kw['notNone'] = not nullAllowed
if (defaultText):
# Strip ( and )
defaultText = defaultText[1:-1]
if defaultText[0] == "'":
defaultText = defaultText[1:-1]
else:
if t == "int" : defaultText = int(defaultText)
if t == "float" : defaultText = float(defaultText)
if t == "numeric": defaultText = float(defaultText)
# TODO need to access the "column" to_python method here--but the object doesn't exists yet
# @@ skip key...
kw['default'] = defaultText
results.append(colClass(**kw))
return results
def _setAutoCommit(self, conn, auto):
#raise Exception(repr(auto))
return
#conn.auto_commit = auto
option = "ON"
if auto == 0:
option = "OFF"
c = conn.cursor()
c.execute("SET AUTOCOMMIT " + option)
conn.setconnectoption(SQL.AUTOCOMMIT, option)
# precision and scale is needed for decimal columns
def guessClass(self, t, size, precision, scale):
"""
Here we take raw values coming out of syscolumns and map to SQLObject class types.
"""
if t.startswith('int'):
return col.IntCol, {}
elif t.startswith('varchar'):
if self.usingUnicodeStrings:
return col.UnicodeCol, {'length': size}
return col.StringCol, {'length': size}
elif t.startswith('char'):
if self.usingUnicodeStrings:
return col.UnicodeCol, {'length': size,
'varchar': False}
return col.StringCol, {'length': size,
'varchar': False}
elif t.startswith('datetime'):
return col.DateTimeCol, {}
elif t.startswith('decimal'):
return col.DecimalCol, {'size': precision, # be careful for awkward naming
'precision': scale}
else:
return col.Col, {}
def server_version(self):
try:
server_version = self.queryAll("SELECT SERVERPROPERTY('productversion')")[0][0]
server_version = server_version.split('.')[0]
server_version = int(server_version)
except:
server_version = None # unknown
self.server_version = server_version # cache it forever
return server_version
def can_use_max_types(self):
if self._can_use_max_types is not None:
return self._can_use_max_types
server_version = self.server_version()
self._can_use_max_types = can_use_max_types = \
(server_version is not None) and (server_version >= 9)
return can_use_max_types