1
0
mirror of https://github.com/djohnlewis/stackdump synced 2026-04-02 16:49:17 +00:00

Initial commit. Still building up the env and some parsing code.

This commit is contained in:
Samuel Lai
2011-09-11 14:29:39 +10:00
commit af2eafeccd
301 changed files with 82327 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
"""
SQLObject 1.1.3
"""
from __version__ import version, version_info
from col import *
from index import *
from joins import *
from main import *
from sqlbuilder import AND, OR, NOT, IN, LIKE, RLIKE, DESC, CONTAINSSTRING, const, func
from styles import *
from dbconnection import connectionForURI
import dberrors

View File

@@ -0,0 +1,8 @@
version = '1.1.3'
major = 1
minor = 1
micro = 3
release_level = 'final'
serial = 0
version_info = (major, minor, micro, release_level, serial)

View File

@@ -0,0 +1,133 @@
"""
Bound attributes are attributes that are bound to a specific class and
a specific name. In SQLObject a typical example is a column object,
which knows its name and class.
A bound attribute should define a method ``__addtoclass__(added_class,
name)`` (attributes without this method will simply be treated as
normal). The return value is ignored; if the attribute wishes to
change the value in the class, it must call ``setattr(added_class,
name, new_value)``.
BoundAttribute is a class that facilitates lazy attribute creation.
``bind_attributes(cls, new_attrs)`` is a function that looks for
attributes with this special method. ``new_attrs`` is a dictionary,
as typically passed into ``__classinit__`` with declarative (calling
``bind_attributes`` in ``__classinit__`` would be typical).
Note if you do this that attributes defined in a superclass will not
be rebound in subclasses. If you want to rebind attributes in
subclasses, use ``bind_attributes_local``, which adds a
``__bound_attributes__`` variable to your class to track these active
attributes.
"""
__all__ = ['BoundAttribute', 'BoundFactory', 'bind_attributes',
'bind_attributes_local']
import declarative
import events
class BoundAttribute(declarative.Declarative):
"""
This is a declarative class that passes all the values given to it
to another object. So you can pass it arguments (via
__init__/__call__) or give it the equivalent of keyword arguments
through subclassing. Then a bound object will be added in its
place.
To hook this other object in, override ``make_object(added_class,
name, **attrs)`` and maybe ``set_object(added_class, name,
**attrs)`` (the default implementation of ``set_object``
just resets the attribute to whatever ``make_object`` returned).
Also see ``BoundFactory``.
"""
_private_variables = (
'_private_variables',
'_all_attributes',
'__classinit__',
'__addtoclass__',
'_add_attrs',
'set_object',
'make_object',
'clone_in_subclass',
)
_all_attrs = ()
clone_for_subclass = True
def __classinit__(cls, new_attrs):
declarative.Declarative.__classinit__(cls, new_attrs)
cls._all_attrs = cls._add_attrs(cls, new_attrs)
def __instanceinit__(self, new_attrs):
declarative.Declarative.__instanceinit__(self, new_attrs)
self.__dict__['_all_attrs'] = self._add_attrs(self, new_attrs)
@staticmethod
def _add_attrs(this_object, new_attrs):
private = this_object._private_variables
all_attrs = list(this_object._all_attrs)
for key in new_attrs.keys():
if key.startswith('_') or key in private:
continue
if key not in all_attrs:
all_attrs.append(key)
return tuple(all_attrs)
@declarative.classinstancemethod
def __addtoclass__(self, cls, added_class, attr_name):
me = self or cls
attrs = {}
for name in me._all_attrs:
attrs[name] = getattr(me, name)
attrs['added_class'] = added_class
attrs['attr_name'] = attr_name
obj = me.make_object(**attrs)
if self.clone_for_subclass:
def on_rebind(new_class_name, bases, new_attrs,
post_funcs, early_funcs):
def rebind(new_class):
me.set_object(
new_class, attr_name,
me.make_object(**attrs))
post_funcs.append(rebind)
events.listen(receiver=on_rebind, soClass=added_class,
signal=events.ClassCreateSignal, weak=False)
me.set_object(added_class, attr_name, obj)
@classmethod
def set_object(cls, added_class, attr_name, obj):
setattr(added_class, attr_name, obj)
@classmethod
def make_object(cls, added_class, attr_name, *args, **attrs):
raise NotImplementedError
def __setattr__(self, name, value):
self.__dict__['_all_attrs'] = self._add_attrs(self, {name: value})
self.__dict__[name] = value
class BoundFactory(BoundAttribute):
"""
This will bind the attribute to whatever is given by
``factory_class``. This factory should be a callable with the
signature ``factory_class(added_class, attr_name, *args, **kw)``.
The factory will be reinvoked (and the attribute rebound) for
every subclassing.
"""
factory_class = None
_private_variables = (
BoundAttribute._private_variables + ('factory_class',))
def make_object(cls, added_class, attr_name, *args, **kw):
return cls.factory_class(added_class, attr_name, *args, **kw)

View File

@@ -0,0 +1,376 @@
"""
This implements the instance caching in SQLObject. Caching is
relatively aggressive. All objects are retained so long as they are
in memory, by keeping weak references to objects. We also keep other
objects in a cache that doesn't allow them to be garbage collected
(unless caching is turned off).
"""
import threading
from weakref import ref
from time import time as now
class CacheFactory(object):
"""
CacheFactory caches object creation. Each object should be
referenced by a single hashable ID (note tuples of hashable
values are also hashable).
"""
def __init__(self, cullFrequency=100, cullFraction=2,
cache=True):
"""
Every cullFrequency times that an item is retrieved from
this cache, the cull method is called.
The cull method then expires an arbitrary fraction of
the cached objects. The idea is at no time will the cache
be entirely emptied, placing a potentially high load at that
moment, but everything object will have its time to go
eventually. The fraction is given as an integer, and one
in that many objects are expired (i.e., the default is 1/2
of objects are expired).
By setting cache to False, items won't be cached.
However, in all cases a weak reference is kept to created
objects, and if the object hasn't been garbage collected
it will be returned.
"""
self.cullFrequency = cullFrequency
self.cullCount = 0
self.cullOffset = 0
self.cullFraction = cullFraction
self.doCache = cache
if self.doCache:
self.cache = {}
self.expiredCache = {}
self.lock = threading.Lock()
def tryGet(self, id):
"""
This returns None, or the object in cache.
"""
value = self.expiredCache.get(id)
if value:
# it's actually a weakref:
return value()
if not self.doCache:
return None
return self.cache.get(id)
def get(self, id):
"""
This method can cause deadlocks! tryGet is safer
This returns the object found in cache, or None. If None,
then the cache will remain locked! This is so that the
calling function can create the object in a threadsafe manner
before releasing the lock. You should use this like (note
that ``cache`` is actually a CacheSet object in this
example)::
obj = cache.get(some_id, my_class)
if obj is None:
try:
obj = create_object(some_id)
cache.put(some_id, my_class, obj)
finally:
cache.finishPut(cls)
This method checks both the main cache (which retains
references) and the 'expired' cache, which retains only weak
references.
"""
if self.doCache:
if self.cullCount > self.cullFrequency:
# Two threads could hit the cull in a row, but
# that's not so bad. At least by setting cullCount
# back to zero right away we avoid this. The cull
# method has a lock, so it's threadsafe.
self.cullCount = 0
self.cull()
else:
self.cullCount = self.cullCount + 1
try:
return self.cache[id]
except KeyError:
pass
self.lock.acquire()
try:
val = self.cache[id]
except KeyError:
pass
else:
self.lock.release()
return val
try:
val = self.expiredCache[id]()
except KeyError:
return None
else:
del self.expiredCache[id]
if val is None:
return None
self.cache[id] = val
self.lock.release()
return val
else:
try:
val = self.expiredCache[id]()
if val is not None:
return val
except KeyError:
pass
self.lock.acquire()
try:
val = self.expiredCache[id]()
except KeyError:
return None
else:
if val is None:
del self.expiredCache[id]
return None
self.lock.release()
return val
def put(self, id, obj):
"""
Puts an object into the cache. Should only be called after
.get(), so that duplicate objects don't end up in the cache.
"""
if self.doCache:
self.cache[id] = obj
else:
self.expiredCache[id] = ref(obj)
def finishPut(self):
"""
Releases the lock that is retained when .get() is called and
returns None.
"""
self.lock.release()
def created(self, id, obj):
"""
Inserts and object into the cache. Should be used when no one
else knows about the object yet, so there cannot be any object
already in the cache. After a database INSERT is an example
of this situation.
"""
if self.doCache:
if self.cullCount > self.cullFrequency:
# Two threads could hit the cull in a row, but
# that's not so bad. At least by setting cullCount
# back to zero right away we avoid this. The cull
# method has a lock, so it's threadsafe.
self.cullCount = 0
self.cull()
else:
self.cullCount = self.cullCount + 1
self.cache[id] = obj
else:
self.expiredCache[id] = ref(obj)
def cull(self):
"""Runs through the cache and expires objects
E.g., if ``cullFraction`` is 3, then every third object is moved to
the 'expired' (aka weakref) cache.
"""
self.lock.acquire()
try:
#remove dead references from the expired cache
keys = self.expiredCache.keys()
for key in keys:
if self.expiredCache[key]() is None:
self.expiredCache.pop(key, None)
keys = self.cache.keys()
for i in xrange(self.cullOffset, len(keys), self.cullFraction):
id = keys[i]
# create a weakref, then remove from the cache
obj = ref(self.cache[id])
del self.cache[id]
#the object may have been gc'd when removed from the cache
#above, no need to place in expiredCache
if obj() is not None:
self.expiredCache[id] = obj
# This offset tries to balance out which objects we
# expire, so no object will just hang out in the cache
# forever.
self.cullOffset = (self.cullOffset + 1) % self.cullFraction
finally:
self.lock.release()
def clear(self):
"""
Removes everything from the cache. Warning! This can cause
duplicate objects in memory.
"""
if self.doCache:
self.cache.clear()
self.expiredCache.clear()
def expire(self, id):
"""
Expires a single object. Typically called after a delete.
Doesn't even keep a weakref. (@@: bad name?)
"""
if not self.doCache:
return
self.lock.acquire()
try:
if id in self.cache:
del self.cache[id]
if id in self.expiredCache:
del self.expiredCache[id]
finally:
self.lock.release()
def expireAll(self):
"""
Expires all objects, moving them all into the expired/weakref
cache.
"""
if not self.doCache:
return
self.lock.acquire()
try:
for key, value in self.cache.items():
self.expiredCache[key] = ref(value)
self.cache = {}
finally:
self.lock.release()
def allIDs(self):
"""
Returns the IDs of all objects in the cache.
"""
if self.doCache:
all = self.cache.keys()
else:
all = []
for id, value in self.expiredCache.items():
if value():
all.append(id)
return all
def getAll(self):
"""
Return all the objects in the cache.
"""
if self.doCache:
all = self.cache.values()
else:
all = []
for value in self.expiredCache.values():
if value():
all.append(value())
return all
class CacheSet(object):
"""
A CacheSet is used to collect and maintain a series of caches. In
SQLObject, there is one CacheSet per connection, and one Cache
in the CacheSet for each class, since IDs are not unique across
classes. It contains methods similar to Cache, but that take
a ``cls`` argument.
"""
def __init__(self, *args, **kw):
self.caches = {}
self.args = args
self.kw = kw
def get(self, id, cls):
try:
return self.caches[cls.__name__].get(id)
except KeyError:
self.caches[cls.__name__] = CacheFactory(*self.args, **self.kw)
return self.caches[cls.__name__].get(id)
def put(self, id, cls, obj):
self.caches[cls.__name__].put(id, obj)
def finishPut(self, cls):
self.caches[cls.__name__].finishPut()
def created(self, id, cls, obj):
try:
self.caches[cls.__name__].created(id, obj)
except KeyError:
self.caches[cls.__name__] = CacheFactory(*self.args, **self.kw)
self.caches[cls.__name__].created(id, obj)
def expire(self, id, cls):
try:
self.caches[cls.__name__].expire(id)
except KeyError:
pass
def clear(self, cls=None):
if cls is None:
for cache in self.caches.values():
cache.clear()
elif cls.__name__ in self.caches:
self.caches[cls.__name__].clear()
def tryGet(self, id, cls):
return self.tryGetByName(id, cls.__name__)
def tryGetByName(self, id, clsname):
try:
return self.caches[clsname].tryGet(id)
except KeyError:
return None
def allIDs(self, cls):
try:
self.caches[cls.__name__].allIDs()
except KeyError:
return []
def allSubCaches(self):
return self.caches.values()
def allSubCachesByClassNames(self):
return self.caches
def weakrefAll(self, cls=None):
"""
Move all objects in the cls (or if not given, then in all
classes) to the weakref dictionary, where they can be
collected.
"""
if cls is None:
for cache in self.caches.values():
cache.expireAll()
elif cls.__name__ in self.caches:
self.caches[cls.__name__].expireAll()
def getAll(self, cls=None):
"""
Returns all instances in the cache for the given class or all
classes.
"""
if cls is None:
results = []
for cache in self.caches.values():
results.extend(cache.getAll())
return results
elif cls.__name__ in self.caches:
return self.caches[cls.__name__].getAll()
else:
return []

View File

@@ -0,0 +1,135 @@
"""
classresolver.py
2 February 2004, Ian Bicking <ianb@colorstudy.com>
Resolves strings to classes, and runs callbacks when referenced
classes are created.
Classes are referred to only by name, not by module. So that
identically-named classes can coexist, classes are put into individual
registries, which are keyed on strings (names). These registries are
created on demand.
Use like::
>>> import classregistry
>>> registry = classregistry.registry('MyModules')
>>> def afterMyClassExists(cls):
... print 'Class finally exists:', cls
>>> registry.addClassCallback('MyClass', afterMyClassExists)
>>> class MyClass:
... pass
>>> registry.addClass(MyClass)
Class finally exists: MyClass
"""
class ClassRegistry(object):
"""
We'll be dealing with classes that reference each other, so
class C1 may reference C2 (in a join), while C2 references
C1 right back. Since classes are created in an order, there
will be a point when C1 exists but C2 doesn't. So we deal
with classes by name, and after each class is created we
try to fix up any references by replacing the names with
actual classes.
Here we keep a dictionaries of class names to classes -- note
that the classes might be spread among different modules, so
since we pile them together names need to be globally unique,
to just module unique.
Like needSet below, the container dictionary is keyed by the
class registry.
"""
def __init__(self, name):
self.name = name
self.classes = {}
self.callbacks = {}
self.genericCallbacks = []
def addClassCallback(self, className, callback, *args, **kw):
"""
Whenever a name is substituted for the class, you can register
a callback that will be called when the needed class is
created. If it's already been created, the callback will be
called immediately.
"""
if className in self.classes:
callback(self.classes[className], *args, **kw)
else:
self.callbacks.setdefault(className, []).append((callback, args, kw))
def addCallback(self, callback, *args, **kw):
"""
This callback is called for all classes, not just specific
ones (like addClassCallback).
"""
self.genericCallbacks.append((callback, args, kw))
for cls in self.classes.values():
callback(cls, *args, **kw)
def addClass(self, cls):
"""
Everytime a class is created, we add it to the registry, so
that other classes can find it by name. We also call any
callbacks that are waiting for the class.
"""
if cls.__name__ in self.classes:
import sys
other = self.classes[cls.__name__]
raise ValueError(
"class %s is already in the registry (other class is "
"%r, from the module %s in %s; attempted new class is "
"%r, from the module %s in %s)"
% (cls.__name__,
other, other.__module__,
getattr(sys.modules.get(other.__module__),
'__file__', '(unknown)'),
cls, cls.__module__,
getattr(sys.modules.get(cls.__module__),
'__file__', '(unknown)')))
self.classes[cls.__name__] = cls
if cls.__name__ in self.callbacks:
for callback, args, kw in self.callbacks[cls.__name__]:
callback(cls, *args, **kw)
del self.callbacks[cls.__name__]
for callback, args, kw in self.genericCallbacks:
callback(cls, *args, **kw)
def getClass(self, className):
try:
return self.classes[className]
except KeyError:
all = self.classes.keys()
all.sort()
raise KeyError(
"No class %s found in the registry %s (these classes "
"exist: %s)"
% (className, self.name or '[default]', ', '.join(all)))
def allClasses(self):
return self.classes.values()
class _MasterRegistry(object):
"""
This singleton holds all the class registries. There can be
multiple registries to hold different unrelated sets of classes
that reside in the same process. These registries are named with
strings, and are created on demand. The MasterRegistry module
global holds the singleton.
"""
def __init__(self):
self.registries = {}
def registry(self, item):
if item not in self.registries:
self.registries[item] = ClassRegistry(item)
return self.registries[item]
MasterRegistry = _MasterRegistry()
registry = MasterRegistry.registry
def findClass(name, class_registry=None):
return registry(class_registry).getClass(name)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
"""
This module is used by py.test to configure testing for this
application.
"""
# Override some options (doesn't override command line):
verbose = 0
exitfirst = True
import py
import os
import sqlobject
try:
import pkg_resources
except ImportError: # Python 2.2
pass
else:
pkg_resources.require('SQLObject')
connectionShortcuts = {
'mysql': 'mysql://test@localhost/test',
'dbm': 'dbm:///data',
'postgres': 'postgres:///test',
'postgresql': 'postgres:///test',
'rdbhost': 'rdhbost://role:authcode@www.rdbhost.com/',
'pygresql': 'pygresql://localhost/test',
'sqlite': 'sqlite:/:memory:',
'sybase': 'sybase://test:test123@sybase/test?autoCommit=0',
'firebird': 'firebird://sysdba:masterkey@localhost/var/lib/firebird/data/test.gdb',
'mssql': 'mssql://sa:@127.0.0.1/test'
}
Option = py.test.config.Option
option = py.test.config.addoptions(
"SQLObject options",
Option('-D', '--Database',
action="store", dest="Database", default='sqlite',
help="The database to run the tests under (default sqlite). "
"Can also use an alias from: %s"
% (', '.join(connectionShortcuts.keys()))),
Option('-S', '--SQL',
action="store_true", dest="show_sql", default=False,
help="Show SQL from statements (when capturing stdout the "
"SQL is only displayed when a test fails)"),
Option('-O', '--SQL-output',
action="store_true", dest="show_sql_output", default=False,
help="Show output from SQL statements (when capturing "
"stdout the output is only displayed when a test fails)"),
Option('-E', '--events',
action="store_true", dest="debug_events", default=False,
help="Debug events (print information about events as they are "
"sent)"),
)
class SQLObjectClass(py.test.collect.Class):
def run(self):
if (isinstance(self.obj, type)
and issubclass(self.obj, sqlobject.SQLObject)):
return []
return super(SQLObjectClass, self).run()
Class = SQLObjectClass
def setup_tests():
if option.debug_events:
from sqlobject import events
events.debug_events()

View File

@@ -0,0 +1,63 @@
"""
Constraints
"""
class BadValue(ValueError):
def __init__(self, desc, obj, col, value, *args):
self.desc = desc
self.col = col
# I want these objects to be garbage-collectable, so
# I just keep their repr:
self.obj = repr(obj)
self.value = repr(value)
fullDesc = "%s.%s %s (you gave: %s)" \
% (obj, col.name, desc, value)
ValueError.__init__(self, fullDesc, *args)
def isString(obj, col, value):
if not isinstance(value, str):
raise BadValue("only allows strings", obj, col, value)
def notNull(obj, col, value):
if value is None:
raise BadValue("is defined NOT NULL", obj, col, value)
def isInt(obj, col, value):
if not isinstance(value, (int, long)):
raise BadValue("only allows integers", obj, col, value)
def isFloat(obj, col, value):
if not isinstance(value, (int, long, float)):
raise BadValue("only allows floating point numbers", obj, col, value)
def isBool(obj, col, value):
if not isinstance(value, bool):
raise BadValue("only allows booleans", obj, col, value)
class InList:
def __init__(self, l):
self.list = l
def __call__(self, obj, col, value):
if value not in self.list:
raise BadValue("accepts only values in %s" % repr(self.list),
obj, col, value)
class MaxLength:
def __init__(self, length):
self.length = length
def __call__(self, obj, col, value):
try:
length = len(value)
except TypeError:
raise BadValue("object does not have a length",
obj, col, value)
if length > self.length:
raise BadValue("must be shorter in length than %s"
% self.length,
obj, col, value)

View File

@@ -0,0 +1,215 @@
import sys
from array import array
# Jython doesn't have the buffer sequence type (bug #1521).
# using this workaround instead.
try:
buffer
except NameError, e:
buffer = str
try:
import mx.DateTime.ISO
origISOStr = mx.DateTime.ISO.strGMT
from mx.DateTime import DateTimeType, DateTimeDeltaType
except ImportError:
try:
import DateTime.ISO
origISOStr = DateTime.ISO.strGMT
from DateTime import DateTimeType, DateTimeDeltaType
except ImportError:
origISOStr = None
DateTimeType = None
DateTimeDeltaType = None
import time
import datetime
try:
import Sybase
NumericType=Sybase.NumericType
except ImportError:
NumericType = None
from decimal import Decimal
from types import ClassType, InstanceType, NoneType
########################################
## Quoting
########################################
sqlStringReplace = [
("'", "''"),
('\\', '\\\\'),
('\000', '\\0'),
('\b', '\\b'),
('\n', '\\n'),
('\r', '\\r'),
('\t', '\\t'),
]
def isoStr(val):
"""
Gets rid of time zone information
(@@: should we convert to GMT?)
"""
val = origISOStr(val)
if val.find('+') == -1:
return val
else:
return val[:val.find('+')]
class ConverterRegistry:
def __init__(self):
self.basic = {}
self.klass = {}
def registerConverter(self, typ, func):
if type(typ) is ClassType:
self.klass[typ] = func
else:
self.basic[typ] = func
def lookupConverter(self, value, default=None):
if type(value) is InstanceType:
# lookup on klasses dict
return self.klass.get(value.__class__, default)
return self.basic.get(type(value), default)
converters = ConverterRegistry()
registerConverter = converters.registerConverter
lookupConverter = converters.lookupConverter
def StringLikeConverter(value, db):
if isinstance(value, array):
try:
value = value.tounicode()
except ValueError:
value = value.tostring()
elif isinstance(value, buffer):
value = str(value)
if db in ('mysql', 'postgres', 'rdbhost'):
for orig, repl in sqlStringReplace:
value = value.replace(orig, repl)
elif db in ('sqlite', 'firebird', 'sybase', 'maxdb', 'mssql'):
value = value.replace("'", "''")
else:
assert 0, "Database %s unknown" % db
return "'%s'" % value
registerConverter(str, StringLikeConverter)
registerConverter(unicode, StringLikeConverter)
registerConverter(array, StringLikeConverter)
registerConverter(buffer, StringLikeConverter)
def IntConverter(value, db):
return repr(int(value))
registerConverter(int, IntConverter)
def LongConverter(value, db):
return str(value)
registerConverter(long, LongConverter)
if NumericType:
registerConverter(NumericType, IntConverter)
def BoolConverter(value, db):
if db in ('postgres', 'rdbhost'):
if value:
return "'t'"
else:
return "'f'"
else:
if value:
return '1'
else:
return '0'
registerConverter(bool, BoolConverter)
def FloatConverter(value, db):
return repr(value)
registerConverter(float, FloatConverter)
if DateTimeType:
def DateTimeConverter(value, db):
return "'%s'" % isoStr(value)
registerConverter(DateTimeType, DateTimeConverter)
def TimeConverter(value, db):
return "'%s'" % value.strftime("%T")
registerConverter(DateTimeDeltaType, TimeConverter)
def NoneConverter(value, db):
return "NULL"
registerConverter(NoneType, NoneConverter)
def SequenceConverter(value, db):
return "(%s)" % ", ".join([sqlrepr(v, db) for v in value])
registerConverter(tuple, SequenceConverter)
registerConverter(list, SequenceConverter)
registerConverter(dict, SequenceConverter)
registerConverter(set, SequenceConverter)
registerConverter(frozenset, SequenceConverter)
if sys.version_info[:3] < (2, 6, 0): # Module sets was deprecated in Python 2.6
from sets import Set, ImmutableSet
registerConverter(Set, SequenceConverter)
registerConverter(ImmutableSet, SequenceConverter)
if hasattr(time, 'struct_time'):
def StructTimeConverter(value, db):
return time.strftime("'%Y-%m-%d %H:%M:%S'", value)
registerConverter(time.struct_time, StructTimeConverter)
def DateTimeConverter(value, db):
return "'%04d-%02d-%02d %02d:%02d:%02d'" % (
value.year, value.month, value.day,
value.hour, value.minute, value.second)
registerConverter(datetime.datetime, DateTimeConverter)
def DateConverter(value, db):
return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
registerConverter(datetime.date, DateConverter)
def TimeConverter(value, db):
return "'%02d:%02d:%02d'" % (value.hour, value.minute, value.second)
registerConverter(datetime.time, TimeConverter)
def DecimalConverter(value, db):
# See http://mail.python.org/pipermail/python-dev/2008-March/078189.html
return str(value.to_eng_string()) # Convert to str to work around a bug in Python 2.5.2
registerConverter(Decimal, DecimalConverter)
def TimedeltaConverter(value, db):
return """INTERVAL '%d days %d seconds'""" % \
(value.days, value.seconds)
registerConverter(datetime.timedelta, TimedeltaConverter)
def sqlrepr(obj, db=None):
try:
reprFunc = obj.__sqlrepr__
except AttributeError:
converter = lookupConverter(obj)
if converter is None:
raise ValueError, "Unknown SQL builtin type: %s for %s" % \
(type(obj), repr(obj))
return converter(obj, db)
else:
return reprFunc(db)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
"""dberrors: database exception classes for SQLObject.
These classes are dictated by the DB API v2.0:
http://www.python.org/topics/database/DatabaseAPI-2.0.html
"""
class Error(StandardError): pass
class Warning(StandardError): pass
class InterfaceError(Error): pass
class DatabaseError(Error): pass
class InternalError(DatabaseError): pass
class OperationalError(DatabaseError): pass
class ProgrammingError(DatabaseError): pass
class IntegrityError(DatabaseError): pass
class DataError(DatabaseError): pass
class NotSupportedError(DatabaseError): pass
class DuplicateEntryError(IntegrityError): pass

View File

@@ -0,0 +1,204 @@
"""
Declarative objects.
Declarative objects have a simple protocol: you can use classes in
lieu of instances and they are equivalent, and any keyword arguments
you give to the constructor will override those instance variables.
(So if a class is received, we'll simply instantiate an instance with
no arguments).
You can provide a variable __unpackargs__ (a list of strings), and if
the constructor is called with non-keyword arguments they will be
interpreted as the given keyword arguments.
If __unpackargs__ is ('*', name), then all the arguments will be put
in a variable by that name.
You can define a __classinit__(cls, new_attrs) method, which will be
called when the class is created (including subclasses). Note: you
can't use super() in __classinit__ because the class isn't bound to a
name. As an analog to __classinit__, Declarative adds
__instanceinit__ which is called with the same argument (new_attrs).
This is like __init__, but after __unpackargs__ and other factors have
been taken into account.
If __mutableattributes__ is defined as a sequence of strings, these
attributes will not be shared between superclasses and their
subclasses. E.g., if you have a class variable that contains a list
and you append to that list, changes to subclasses will effect
superclasses unless you add the attribute here.
Also defines classinstancemethod, which acts as either a class method
or an instance method depending on where it is called.
"""
import copy
import events
import itertools
counter = itertools.count()
__all__ = ('classinstancemethod', 'DeclarativeMeta', 'Declarative')
class classinstancemethod(object):
"""
Acts like a class method when called from a class, like an
instance method when called by an instance. The method should
take two arguments, 'self' and 'cls'; one of these will be None
depending on how the method was called.
"""
def __init__(self, func):
self.func = func
def __get__(self, obj, type=None):
return _methodwrapper(self.func, obj=obj, type=type)
class _methodwrapper(object):
def __init__(self, func, obj, type):
self.func = func
self.obj = obj
self.type = type
def __call__(self, *args, **kw):
assert not 'self' in kw and not 'cls' in kw, (
"You cannot use 'self' or 'cls' arguments to a "
"classinstancemethod")
return self.func(*((self.obj, self.type) + args), **kw)
def __repr__(self):
if self.obj is None:
return ('<bound class method %s.%s>'
% (self.type.__name__, self.func.func_name))
else:
return ('<bound method %s.%s of %r>'
% (self.type.__name__, self.func.func_name, self.obj))
class DeclarativeMeta(type):
def __new__(meta, class_name, bases, new_attrs):
post_funcs = []
early_funcs = []
events.send(events.ClassCreateSignal,
bases[0], class_name, bases, new_attrs,
post_funcs, early_funcs)
cls = type.__new__(meta, class_name, bases, new_attrs)
for func in early_funcs:
func(cls)
if '__classinit__' in new_attrs:
cls.__classinit__ = staticmethod(cls.__classinit__.im_func)
cls.__classinit__(cls, new_attrs)
for func in post_funcs:
func(cls)
return cls
class Declarative(object):
__unpackargs__ = ()
__mutableattributes__ = ()
__metaclass__ = DeclarativeMeta
__restrict_attributes__ = None
def __classinit__(cls, new_attrs):
cls.declarative_count = counter.next()
for name in cls.__mutableattributes__:
if name not in new_attrs:
setattr(cls, copy.copy(getattr(cls, name)))
def __instanceinit__(self, new_attrs):
if self.__restrict_attributes__ is not None:
for name in new_attrs:
if name not in self.__restrict_attributes__:
raise TypeError(
'%s() got an unexpected keyword argument %r'
% (self.__class__.__name__, name))
for name, value in new_attrs.items():
setattr(self, name, value)
if 'declarative_count' not in new_attrs:
self.declarative_count = counter.next()
def __init__(self, *args, **kw):
if self.__unpackargs__ and self.__unpackargs__[0] == '*':
assert len(self.__unpackargs__) == 2, \
"When using __unpackargs__ = ('*', varname), you must only provide a single variable name (you gave %r)" % self.__unpackargs__
name = self.__unpackargs__[1]
if name in kw:
raise TypeError(
"keyword parameter '%s' was given by position and name"
% name)
kw[name] = args
else:
if len(args) > len(self.__unpackargs__):
raise TypeError(
'%s() takes at most %i arguments (%i given)'
% (self.__class__.__name__,
len(self.__unpackargs__),
len(args)))
for name, arg in zip(self.__unpackargs__, args):
if name in kw:
raise TypeError(
"keyword parameter '%s' was given by position and name"
% name)
kw[name] = arg
if '__alsocopy' in kw:
for name, value in kw['__alsocopy'].items():
if name not in kw:
if name in self.__mutableattributes__:
value = copy.copy(value)
kw[name] = value
del kw['__alsocopy']
self.__instanceinit__(kw)
def __call__(self, *args, **kw):
kw['__alsocopy'] = self.__dict__
return self.__class__(*args, **kw)
@classinstancemethod
def singleton(self, cls):
if self:
return self
name = '_%s__singleton' % cls.__name__
if not hasattr(cls, name):
setattr(cls, name, cls(declarative_count=cls.declarative_count))
return getattr(cls, name)
@classinstancemethod
def __repr__(self, cls):
if self:
name = '%s object' % self.__class__.__name__
v = self.__dict__.copy()
else:
name = '%s class' % cls.__name__
v = cls.__dict__.copy()
if 'declarative_count' in v:
name = '%s %i' % (name, v['declarative_count'])
del v['declarative_count']
# @@: simplifying repr:
#v = {}
names = v.keys()
args = []
for n in self._repr_vars(names):
args.append('%s=%r' % (n, v[n]))
if not args:
return '<%s>' % name
else:
return '<%s %s>' % (name, ' '.join(args))
@staticmethod
def _repr_vars(dictNames):
names = [n for n in dictNames
if not n.startswith('_')
and n != 'declarative_count']
names.sort()
return names
def setup_attributes(cls, new_attrs):
for name, value in new_attrs.items():
if hasattr(value, '__addtoclass__'):
value.__addtoclass__(cls, name)

View File

@@ -0,0 +1,316 @@
import sys
import types
from sqlobject.include.pydispatch import dispatcher
from weakref import ref
subclassClones = {}
def listen(receiver, soClass, signal, alsoSubclasses=True, weak=True):
"""
Listen for the given ``signal`` on the SQLObject subclass
``soClass``, calling ``receiver()`` when ``send(soClass, signal,
...)`` is called.
If ``alsoSubclasses`` is true, receiver will also be called when
an event is fired on any subclass.
"""
dispatcher.connect(receiver, signal=signal, sender=soClass, weak=weak)
weakReceiver = ref(receiver)
subclassClones.setdefault(soClass, []).append((weakReceiver, signal))
# We export this function:
send = dispatcher.send
class Signal(object):
"""
Base event for all SQLObject events.
In general the sender for these methods is the class, not the
instance.
"""
class ClassCreateSignal(Signal):
"""
Signal raised after class creation. The sender is the superclass
(in case of multiple superclasses, the first superclass). The
arguments are ``(new_class_name, bases, new_attrs, post_funcs,
early_funcs)``. ``new_attrs`` is a dictionary and may be modified
(but ``new_class_name`` and ``bases`` are immutable).
``post_funcs`` is an initially-empty list that can have callbacks
appended to it.
Note: at the time this event is called, the new class has not yet
been created. The functions in ``post_funcs`` will be called
after the class is created, with the single arguments of
``(new_class)``. Also, ``early_funcs`` will be called at the
soonest possible time after class creation (``post_funcs`` is
called after the class's ``__classinit__``).
"""
def _makeSubclassConnections(new_class_name, bases, new_attrs,
post_funcs, early_funcs):
early_funcs.insert(0, _makeSubclassConnectionsPost)
def _makeSubclassConnectionsPost(new_class):
for cls in new_class.__bases__:
for weakReceiver, signal in subclassClones.get(cls, []):
receiver = weakReceiver()
if not receiver:
continue
listen(receiver, new_class, signal)
dispatcher.connect(_makeSubclassConnections, signal=ClassCreateSignal)
# @@: Should there be a class reload event? This would allow modules
# to be reloaded, possibly. Or it could even be folded into
# ClassCreateSignal, since anything that listens to that needs to pay
# attention to reloads (or else it is probably buggy).
class RowCreateSignal(Signal):
"""
Called before an instance is created, with the class as the
sender. Called with the arguments ``(instance, kwargs, post_funcs)``.
There may be a ``connection`` argument. ``kwargs``may be usefully
modified. ``post_funcs`` is a list of callbacks, intended to have
functions appended to it, and are called with the arguments
``(new_instance)``.
Note: this is not called when an instance is created from an
existing database row.
"""
class RowCreatedSignal(Signal):
"""
Called after an instance is created, with the class as the
sender. Called with the arguments ``(instance, kwargs, post_funcs)``.
There may be a ``connection`` argument. ``kwargs``may be usefully
modified. ``post_funcs`` is a list of callbacks, intended to have
functions appended to it, and are called with the arguments
``(new_instance)``.
Note: this is not called when an instance is created from an
existing database row.
"""
# @@: An event for getting a row? But for each row, when doing a
# select? For .sync, .syncUpdate, .expire?
class RowDestroySignal(Signal):
"""
Called before an instance is deleted. Sender is the instance's
class. Arguments are ``(instance, post_funcs)``.
``post_funcs`` is a list of callbacks, intended to have
functions appended to it, and are called with arguments ``(instance)``.
If any of the post_funcs raises an exception, the deletion is only
affected if this will prevent a commit.
You cannot cancel the delete, but you can raise an exception (which will
probably cancel the delete, but also cause an uncaught exception if not
expected).
Note: this is not called when an instance is destroyed through
garbage collection.
@@: Should this allow ``instance`` to be a primary key, so that a
row can be deleted without first fetching it?
"""
class RowDestroyedSignal(Signal):
"""
Called after an instance is deleted. Sender is the instance's
class. Arguments are ``(instance)``.
This is called before the post_funcs of RowDestroySignal
Note: this is not called when an instance is destroyed through
garbage collection.
"""
class RowUpdateSignal(Signal):
"""
Called when an instance is updated through a call to ``.set()``
(or a column attribute assignment). The arguments are
``(instance, kwargs)``. ``kwargs`` can be modified. This is run
*before* the instance is updated; if you want to look at the
current values, simply look at ``instance``.
"""
class RowUpdatedSignal(Signal):
"""
Called when an instance is updated through a call to ``.set()``
(or a column attribute assignment). The arguments are
``(instance, post_funcs)``. ``post_funcs`` is a list of callbacks,
intended to have functions appended to it, and are called with the
arguments ``(new_instance)``. This is run *after* the instance is
updated; Works better with lazyUpdate = True.
"""
class AddColumnSignal(Signal):
"""
Called when a column is added to a class, with arguments ``(cls,
connection, column_name, column_definition, changeSchema,
post_funcs)``. This is called *after* the column has been added,
and is called for each column after class creation.
post_funcs are called with ``(cls, so_column_obj)``
"""
class DeleteColumnSignal(Signal):
"""
Called when a column is removed from a class, with the arguments
``(cls, connection, column_name, so_column_obj, post_funcs)``.
Like ``AddColumnSignal`` this is called after the action has been
performed, and is called for subclassing (when a column is
implicitly removed by setting it to ``None``).
post_funcs are called with ``(cls, so_column_obj)``
"""
# @@: Signals for indexes and joins? These are mostly event consumers,
# though.
class CreateTableSignal(Signal):
"""
Called when a table is created. If ``ifNotExists==True`` and the
table exists, this event is not called.
Called with ``(cls, connection, extra_sql, post_funcs)``.
``extra_sql`` is a list (which can be appended to) of extra SQL
statements to be run after the table is created. ``post_funcs``
functions are called with ``(cls, connection)`` after the table
has been created. Those functions are *not* called simply when
constructing the SQL.
"""
class DropTableSignal(Signal):
"""
Called when a table is dropped. If ``ifExists==True`` and the
table doesn't exist, this event is not called.
Called with ``(cls, connection, extra_sql, post_funcs)``.
``post_funcs`` functions are called with ``(cls, connection)``
after the table has been dropped.
"""
############################################################
## Event Debugging
############################################################
def summarize_events_by_sender(sender=None, output=None, indent=0):
"""
Prints out a summary of the senders and listeners in the system,
for debugging purposes.
"""
if output is None:
output = sys.stdout
if sender is None:
send_list = [
(deref(dispatcher.senders.get(sid)), listeners)
for sid, listeners in dispatcher.connections.items()
if deref(dispatcher.senders.get(sid))]
for sender, listeners in sorted_items(send_list):
real_sender = deref(sender)
if not real_sender:
continue
header = 'Sender: %r' % real_sender
print >> output, (' '*indent) + header
print >> output, (' '*indent) + '='*len(header)
summarize_events_by_sender(real_sender, output=output, indent=indent+2)
else:
for signal, receivers in sorted_items(dispatcher.connections.get(id(sender), [])):
receivers = [deref(r) for r in receivers if deref(r)]
header = 'Signal: %s (%i receivers)' % (sort_name(signal),
len(receivers))
print >> output, (' '*indent) + header
print >> output, (' '*indent) + '-'*len(header)
for receiver in sorted(receivers, key=sort_name):
print >> output, (' '*indent) + ' ' + nice_repr(receiver)
def deref(value):
if isinstance(value, dispatcher.WEAKREF_TYPES):
return value()
else:
return value
def sorted_items(a_dict):
if isinstance(a_dict, dict):
a_dict = a_dict.items()
return sorted(a_dict, key=lambda t: sort_name(t[0]))
def sort_name(value):
if isinstance(value, type):
return value.__name__
elif isinstance(value, types.FunctionType):
return value.func_name
else:
return str(value)
_real_dispatcher_send = dispatcher.send
_real_dispatcher_sendExact = dispatcher.sendExact
_real_dispatcher_disconnect = dispatcher.disconnect
_real_dispatcher_connect = dispatcher.connect
_debug_enabled = False
def debug_events():
global _debug_enabled, send
if _debug_enabled:
return
_debug_enabled = True
dispatcher.send = send = _debug_send
dispatcher.sendExact = _debug_sendExact
dispatcher.disconnect = _debug_disconnect
dispatcher.connect = _debug_connect
def _debug_send(signal=dispatcher.Any, sender=dispatcher.Anonymous,
*arguments, **named):
print "send %s from %s: %s" % (
nice_repr(signal), nice_repr(sender), fmt_args(*arguments, **named))
return _real_dispatcher_send(signal, sender, *arguments, **named)
def _debug_sendExact(signal=dispatcher.Any, sender=dispatcher.Anonymous,
*arguments, **named):
print "sendExact %s from %s: %s" % (
nice_repr(signal), nice_repr(sender), fmt_args(*arguments, **name))
return _real_dispatcher_sendExact(signal, sender, *arguments, **named)
def _debug_connect(receiver, signal=dispatcher.Any, sender=dispatcher.Any,
weak=True):
print "connect %s to %s signal %s" % (
nice_repr(receiver), nice_repr(signal), nice_repr(sender))
return _real_dispatcher_connect(receiver, signal, sender, weak)
def _debug_disconnect(receiver, signal=dispatcher.Any, sender=dispatcher.Any,
weak=True):
print "disconnecting %s from %s signal %s" % (
nice_repr(receiver), nice_repr(signal), nice_repr(sender))
return disconnect(receiver, signal, sender, weak)
def fmt_args(*arguments, **name):
args = [repr(a) for a in arguments]
args.extend([
'%s=%r' % (n, v) for n, v in sorted(name.items())])
return ', '.join(args)
def nice_repr(v):
"""
Like repr(), but nicer for debugging here.
"""
if isinstance(v, (types.ClassType, type)):
return v.__module__ + '.' + v.__name__
elif isinstance(v, types.FunctionType):
if '__name__' in v.func_globals:
if getattr(sys.modules[v.func_globals['__name__']],
v.func_name, None) is v:
return '%s.%s' % (v.func_globals['__name__'], v.func_name)
return repr(v)
elif isinstance(v, types.MethodType):
return '%s.%s of %s' % (
nice_repr(v.im_class), v.im_func.func_name,
nice_repr(v.im_self))
else:
return repr(v)
__all__ = ['listen', 'send']
for name, value in globals().items():
if isinstance(value, type) and issubclass(value, Signal):
__all__.append(name)

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import firebirdconnection
return firebirdconnection.FirebirdConnection
registerConnection(['firebird', 'interbase'], builder)

View File

@@ -0,0 +1,232 @@
import re
import os
from sqlobject.dbconnection import DBAPI
from sqlobject import col
class FirebirdConnection(DBAPI):
supportTransactions = False
dbName = 'firebird'
schemes = [dbName]
limit_re = re.compile('^\s*(select )(.*)', re.IGNORECASE)
def __init__(self, host, port, db, user='sysdba',
password='masterkey', autoCommit=1,
dialect=None, role=None, charset=None, **kw):
import kinterbasdb
self.module = kinterbasdb
self.host = host
self.port = port
self.db = db
self.user = user
self.password = password
if dialect:
self.dialect = int(dialect)
else:
self.dialect = None
self.role = role
self.charset = charset
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, auth, password, host, port, path, args):
if not password:
password = 'masterkey'
if not auth:
auth='sysdba'
# check for alias using
if (path[0] == '/') and path[-3:].lower() not in ('fdb', 'gdb'):
path = path[1:]
path = path.replace('/', os.sep)
return cls(host, port, db=path, user=auth, password=password, **args)
def _runWithConnection(self, meth, *args):
if not self.autoCommit:
return DBAPI._runWithConnection(self, meth, args)
conn = self.getConnection()
# @@: Horrible auto-commit implementation. Just horrible!
try:
conn.begin()
except self.module.ProgrammingError:
pass
try:
val = meth(conn, *args)
try:
conn.commit()
except self.module.ProgrammingError:
pass
finally:
self.releaseConnection(conn)
return val
def _setAutoCommit(self, conn, auto):
# Only _runWithConnection does "autocommit", so we don't
# need to worry about that.
pass
def makeConnection(self):
extra = {}
if self.dialect:
extra['dialect'] = self.dialect
return self.module.connect(
host=self.host,
database=self.db,
user=self.user,
password=self.password,
role=self.role,
charset=self.charset,
**extra
)
def _queryInsertID(self, conn, soInstance, id, names, values):
"""Firebird uses 'generators' to create new ids for a table.
The users needs to create a generator named GEN_<tablename>
for each table this method to work."""
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
sequenceName = soInstance.sqlmeta.idSequence or \
'GEN_%s' % table
c = conn.cursor()
if id is None:
c.execute('SELECT gen_id(%s,1) FROM rdb$database'
% sequenceName)
id = c.fetchone()[0]
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
c.execute(q)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
"""Firebird slaps the limit and offset (actually 'first' and
'skip', respectively) statement right after the select."""
if not start:
limit_str = "SELECT FIRST %i" % end
if not end:
limit_str = "SELECT SKIP %i" % start
else:
limit_str = "SELECT FIRST %i SKIP %i" % (end-start, start)
match = cls.limit_re.match(query)
if match and len(match.groups()) == 2:
return ' '.join([limit_str, match.group(2)])
else:
return query
def createTable(self, soClass):
self.query('CREATE TABLE %s (\n%s\n)' % \
(soClass.sqlmeta.table, self.createColumns(soClass)))
self.query("CREATE GENERATOR GEN_%s" % soClass.sqlmeta.table)
return []
def createReferenceConstraint(self, soClass, col):
return None
def createColumn(self, soClass, col):
return col.firebirdCreateSQL()
def createIDColumn(self, soClass):
key_type = {int: "INT", str: "TEXT"}[soClass.sqlmeta.idType]
return '%s %s NOT NULL PRIMARY KEY' % (soClass.sqlmeta.idName, key_type)
def createIndexSQL(self, soClass, index):
return index.firebirdCreateIndexSQL(soClass)
def joinSQLType(self, join):
return 'INT NOT NULL'
def tableExists(self, tableName):
# there's something in the database by this name...let's
# assume it's a table. By default, fb 1.0 stores EVERYTHING
# it cares about in uppercase.
result = self.queryOne("SELECT COUNT(rdb$relation_name) FROM rdb$relations WHERE rdb$relation_name = '%s'"
% tableName.upper())
return result[0]
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD %s' %
(tableName,
column.firebirdCreateSQL()))
def dropTable(self, tableName, cascade=False):
self.query("DROP TABLE %s" % tableName)
self.query("DROP GENERATOR GEN_%s" % tableName)
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP %s' % (sqlmeta.table, column.dbName))
def columnsFromSchema(self, tableName, soClass):
"""
Look at the given table and create Col instances (or
subclasses of Col) for the fields it finds in that table.
"""
fieldqry = """\
SELECT rf.RDB$FIELD_NAME as field,
t.RDB$TYPE_NAME as t,
f.RDB$FIELD_LENGTH as flength,
f.RDB$FIELD_SCALE as fscale,
rf.RDB$NULL_FLAG as nullAllowed,
coalesce(rf.RDB$DEFAULT_SOURCE, f.rdb$default_source) as thedefault,
f.RDB$FIELD_SUB_TYPE as blobtype
FROM RDB$RELATION_FIELDS rf
INNER JOIN RDB$FIELDS f ON rf.RDB$FIELD_SOURCE = f.RDB$FIELD_NAME
INNER JOIN RDB$TYPES t ON f.RDB$FIELD_TYPE = t.RDB$TYPE
WHERE rf.RDB$RELATION_NAME = '%s'
AND t.RDB$FIELD_NAME = 'RDB$FIELD_TYPE'"""
colData = self.queryAll(fieldqry % tableName.upper())
results = []
for field, t, flength, fscale, nullAllowed, thedefault, blobType in colData:
field = field.strip().lower()
if thedefault:
thedefault = thedefault.split(' ')[1]
if thedefault.startswith("'") and thedefault.endswith("'"):
thedefault = thedefault[1:-1]
idName = str(soClass.sqlmeta.idName or 'id').upper()
if field.upper() == idName:
continue
colClass, kw = self.guessClass(t, flength, fscale)
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field).strip()
kw['dbName'] = field
kw['notNone'] = not nullAllowed
kw['default'] = thedefault
results.append(colClass(**kw))
return results
_intTypes=['INT64', 'SHORT','LONG']
_dateTypes=['DATE','TIME','TIMESTAMP']
def guessClass(self, t, flength, fscale=None):
"""
An internal method that tries to figure out what Col subclass
is appropriate given whatever introspective information is
available -- both very database-specific.
"""
if t in self._intTypes:
return col.IntCol, {}
elif t == 'VARYING':
return col.StringCol, {'length': flength}
elif t == 'TEXT':
return col.StringCol, {'length': flength,
'varchar': False}
elif t in self._dateTypes:
return col.DateTimeCol, {}
else:
return col.Col, {}
def createEmptyDatabase(self):
self.module.create_database("CREATE DATABASE '%s' user '%s' password '%s'" % \
(self.db, self.user, self.password))
def dropDatabase(self):
self.module.drop_database()

View File

@@ -0,0 +1 @@
#

View File

@@ -0,0 +1,62 @@
__all__ = ['HashCol']
import sqlobject.col
class DbHash:
""" Presents a comparison object for hashes, allowing plain text to be
automagically compared with the base content. """
def __init__( self, hash, hashMethod ):
self.hash = hash
self.hashMethod = hashMethod
def __cmp__( self, other ):
if other is None:
if self.hash is None:
return 0
return True
if not isinstance( other, basestring ):
raise TypeError( "A hash may only be compared with a string, or None." )
return cmp( self.hashMethod( other ), self.hash )
def __repr__( self ):
return "<DbHash>"
class HashValidator( sqlobject.col.StringValidator ):
""" Provides formal SQLObject validation services for the HashCol. """
def to_python( self, value, state ):
""" Passes out a hash object. """
if value is None:
return None
return DbHash( hash = value, hashMethod = self.hashMethod )
def from_python( self, value, state ):
""" Store the given value as a MD5 hash, or None if specified. """
if value is None:
return None
return self.hashMethod( value )
class SOHashCol( sqlobject.col.SOStringCol ):
""" The internal HashCol definition. By default, enforces a md5 digest. """
def __init__( self, **kw ):
if 'hashMethod' not in kw:
from md5 import md5
self.hashMethod = lambda v: md5( v ).hexdigest()
if 'length' not in kw:
kw['length'] = 32
else:
self.hashMethod = kw['hashMethod']
del kw['hashMethod']
super( sqlobject.col.SOStringCol, self ).__init__( **kw )
def createValidators( self ):
return [HashValidator( name=self.name, hashMethod=self.hashMethod )] + \
super( SOHashCol, self ).createValidators()
class HashCol( sqlobject.col.StringCol ):
""" End-user HashCol class. May be instantiated with 'hashMethod', a function
which returns the string hash of any other string (i.e. basestring). """
baseClass = SOHashCol

View File

@@ -0,0 +1,9 @@
This is from PyDispatcher <http://pydispatcher.sf.net>
It was moved here because installation of PyDispatcher conflicts with
RuleDispatch (they both use the dispatch top-level module), and I
thought it would be easier to just put it here. Also, PyDispatcher is
small and stable and doesn't need updating often.
If the name conflict is resolved in the future, this package can go
away.

View File

@@ -0,0 +1,6 @@
"""Multi-consumer multi-producer dispatching mechanism
"""
__version__ = "1.0.0"
__author__ = "Patrick K. O'Brien"
__license__ = "BSD-style, see license.txt for details"

View File

@@ -0,0 +1,497 @@
"""Multiple-producer-multiple-consumer signal-dispatching
dispatcher is the core of the PyDispatcher system,
providing the primary API and the core logic for the
system.
Module attributes of note:
Any -- Singleton used to signal either "Any Sender" or
"Any Signal". See documentation of the _Any class.
Anonymous -- Singleton used to signal "Anonymous Sender"
See documentation of the _Anonymous class.
Internal attributes:
WEAKREF_TYPES -- tuple of types/classes which represent
weak references to receivers, and thus must be de-
referenced on retrieval to retrieve the callable
object
connections -- { senderkey (id) : { signal : [receivers...]}}
senders -- { senderkey (id) : weakref(sender) }
used for cleaning up sender references on sender
deletion
sendersBack -- { receiverkey (id) : [senderkey (id)...] }
used for cleaning up receiver references on receiver
deletion, (considerably speeds up the cleanup process
vs. the original code.)
"""
from __future__ import generators
import types, weakref
import saferef, robustapply, errors
__author__ = "Patrick K. O'Brien <pobrien@orbtech.com>"
__cvsid__ = "$Id: dispatcher.py,v 1.9 2005/09/17 04:55:57 mcfletch Exp $"
__version__ = "$Revision: 1.9 $"[11:-2]
try:
True
except NameError:
True = 1==1
False = 1==0
class _Parameter:
"""Used to represent default parameter values."""
def __repr__(self):
return self.__class__.__name__
class _Any(_Parameter):
"""Singleton used to signal either "Any Sender" or "Any Signal"
The Any object can be used with connect, disconnect,
send, or sendExact to signal that the parameter given
Any should react to all senders/signals, not just
a particular sender/signal.
"""
Any = _Any()
class _Anonymous(_Parameter):
"""Singleton used to signal "Anonymous Sender"
The Anonymous object is used to signal that the sender
of a message is not specified (as distinct from being
"any sender"). Registering callbacks for Anonymous
will only receive messages sent without senders. Sending
with anonymous will only send messages to those receivers
registered for Any or Anonymous.
Note:
The default sender for connect is Any, while the
default sender for send is Anonymous. This has
the effect that if you do not specify any senders
in either function then all messages are routed
as though there was a single sender (Anonymous)
being used everywhere.
"""
Anonymous = _Anonymous()
WEAKREF_TYPES = (weakref.ReferenceType, saferef.BoundMethodWeakref)
connections = {}
senders = {}
sendersBack = {}
def connect(receiver, signal=Any, sender=Any, weak=True):
"""Connect receiver to sender for signal
receiver -- a callable Python object which is to receive
messages/signals/events. Receivers must be hashable
objects.
if weak is True, then receiver must be weak-referencable
(more precisely saferef.safeRef() must be able to create
a reference to the receiver).
Receivers are fairly flexible in their specification,
as the machinery in the robustApply module takes care
of most of the details regarding figuring out appropriate
subsets of the sent arguments to apply to a given
receiver.
Note:
if receiver is itself a weak reference (a callable),
it will be de-referenced by the system's machinery,
so *generally* weak references are not suitable as
receivers, though some use might be found for the
facility whereby a higher-level library passes in
pre-weakrefed receiver references.
signal -- the signal to which the receiver should respond
if Any, receiver will receive any signal from the
indicated sender (which might also be Any, but is not
necessarily Any).
Otherwise must be a hashable Python object other than
None (DispatcherError raised on None).
sender -- the sender to which the receiver should respond
if Any, receiver will receive the indicated signals
from any sender.
if Anonymous, receiver will only receive indicated
signals from send/sendExact which do not specify a
sender, or specify Anonymous explicitly as the sender.
Otherwise can be any python object.
weak -- whether to use weak references to the receiver
By default, the module will attempt to use weak
references to the receiver objects. If this parameter
is false, then strong references will be used.
returns None, may raise DispatcherTypeError
"""
if signal is None:
raise errors.DispatcherTypeError(
'Signal cannot be None (receiver=%r sender=%r)'%( receiver,sender)
)
if weak:
receiver = saferef.safeRef(receiver, onDelete=_removeReceiver)
senderkey = id(sender)
if senderkey in connections:
signals = connections[senderkey]
else:
connections[senderkey] = signals = {}
# Keep track of senders for cleanup.
# Is Anonymous something we want to clean up?
if sender not in (None, Anonymous, Any):
def remove(object, senderkey=senderkey):
_removeSender(senderkey=senderkey)
# Skip objects that can not be weakly referenced, which means
# they won't be automatically cleaned up, but that's too bad.
try:
weakSender = weakref.ref(sender, remove)
senders[senderkey] = weakSender
except:
pass
receiverID = id(receiver)
# get current set, remove any current references to
# this receiver in the set, including back-references
if signal in signals:
receivers = signals[signal]
_removeOldBackRefs(senderkey, signal, receiver, receivers)
else:
receivers = signals[signal] = []
try:
current = sendersBack.get( receiverID )
if current is None:
sendersBack[ receiverID ] = current = []
if senderkey not in current:
current.append(senderkey)
except:
pass
receivers.append(receiver)
def disconnect(receiver, signal=Any, sender=Any, weak=True):
"""Disconnect receiver from sender for signal
receiver -- the registered receiver to disconnect
signal -- the registered signal to disconnect
sender -- the registered sender to disconnect
weak -- the weakref state to disconnect
disconnect reverses the process of connect,
the semantics for the individual elements are
logically equivalent to a tuple of
(receiver, signal, sender, weak) used as a key
to be deleted from the internal routing tables.
(The actual process is slightly more complex
but the semantics are basically the same).
Note:
Using disconnect is not required to cleanup
routing when an object is deleted, the framework
will remove routes for deleted objects
automatically. It's only necessary to disconnect
if you want to stop routing to a live object.
returns None, may raise DispatcherTypeError or
DispatcherKeyError
"""
if signal is None:
raise errors.DispatcherTypeError(
'Signal cannot be None (receiver=%r sender=%r)'%( receiver,sender)
)
if weak: receiver = saferef.safeRef(receiver)
senderkey = id(sender)
try:
signals = connections[senderkey]
receivers = signals[signal]
except KeyError:
raise errors.DispatcherKeyError(
"""No receivers found for signal %r from sender %r""" %(
signal,
sender
)
)
try:
# also removes from receivers
_removeOldBackRefs(senderkey, signal, receiver, receivers)
except ValueError:
raise errors.DispatcherKeyError(
"""No connection to receiver %s for signal %s from sender %s""" %(
receiver,
signal,
sender
)
)
_cleanupConnections(senderkey, signal)
def getReceivers( sender = Any, signal = Any ):
"""Get list of receivers from global tables
This utility function allows you to retrieve the
raw list of receivers from the connections table
for the given sender and signal pair.
Note:
there is no guarantee that this is the actual list
stored in the connections table, so the value
should be treated as a simple iterable/truth value
rather than, for instance a list to which you
might append new records.
Normally you would use liveReceivers( getReceivers( ...))
to retrieve the actual receiver objects as an iterable
object.
"""
try:
return connections[id(sender)][signal]
except KeyError:
return []
def liveReceivers(receivers):
"""Filter sequence of receivers to get resolved, live receivers
This is a generator which will iterate over
the passed sequence, checking for weak references
and resolving them, then returning all live
receivers.
"""
for receiver in receivers:
if isinstance( receiver, WEAKREF_TYPES):
# Dereference the weak reference.
receiver = receiver()
if receiver is not None:
yield receiver
else:
yield receiver
def getAllReceivers( sender = Any, signal = Any ):
"""Get list of all receivers from global tables
This gets all receivers which should receive
the given signal from sender, each receiver should
be produced only once by the resulting generator
"""
receivers = {}
for set in (
# Get receivers that receive *this* signal from *this* sender.
getReceivers( sender, signal ),
# Add receivers that receive *any* signal from *this* sender.
getReceivers( sender, Any ),
# Add receivers that receive *this* signal from *any* sender.
getReceivers( Any, signal ),
# Add receivers that receive *any* signal from *any* sender.
getReceivers( Any, Any ),
):
for receiver in set:
if receiver: # filter out dead instance-method weakrefs
try:
if not receiver in receivers:
receivers[receiver] = 1
yield receiver
except TypeError:
# dead weakrefs raise TypeError on hash...
pass
def send(signal=Any, sender=Anonymous, *arguments, **named):
"""Send signal from sender to all connected receivers.
signal -- (hashable) signal value, see connect for details
sender -- the sender of the signal
if Any, only receivers registered for Any will receive
the message.
if Anonymous, only receivers registered to receive
messages from Anonymous or Any will receive the message
Otherwise can be any python object (normally one
registered with a connect if you actually want
something to occur).
arguments -- positional arguments which will be passed to
*all* receivers. Note that this may raise TypeErrors
if the receivers do not allow the particular arguments.
Note also that arguments are applied before named
arguments, so they should be used with care.
named -- named arguments which will be filtered according
to the parameters of the receivers to only provide those
acceptable to the receiver.
Return a list of tuple pairs [(receiver, response), ... ]
if any receiver raises an error, the error propagates back
through send, terminating the dispatch loop, so it is quite
possible to not have all receivers called if a raises an
error.
"""
# Call each receiver with whatever arguments it can accept.
# Return a list of tuple pairs [(receiver, response), ... ].
responses = []
for receiver in liveReceivers(getAllReceivers(sender, signal)):
response = robustapply.robustApply(
receiver,
signal=signal,
sender=sender,
*arguments,
**named
)
responses.append((receiver, response))
return responses
def sendExact( signal=Any, sender=Anonymous, *arguments, **named ):
"""Send signal only to those receivers registered for exact message
sendExact allows for avoiding Any/Anonymous registered
handlers, sending only to those receivers explicitly
registered for a particular signal on a particular
sender.
"""
responses = []
for receiver in liveReceivers(getReceivers(sender, signal)):
response = robustapply.robustApply(
receiver,
signal=signal,
sender=sender,
*arguments,
**named
)
responses.append((receiver, response))
return responses
def _removeReceiver(receiver):
"""Remove receiver from connections."""
if not sendersBack:
# During module cleanup the mapping will be replaced with None
return False
backKey = id(receiver)
for senderkey in sendersBack.get(backKey,()):
try:
signals = connections[senderkey].keys()
except KeyError,err:
pass
else:
for signal in signals:
try:
receivers = connections[senderkey][signal]
except KeyError:
pass
else:
try:
receivers.remove( receiver )
except Exception, err:
pass
_cleanupConnections(senderkey, signal)
try:
del sendersBack[ backKey ]
except KeyError:
pass
def _cleanupConnections(senderkey, signal):
"""Delete any empty signals for senderkey. Delete senderkey if empty."""
try:
receivers = connections[senderkey][signal]
except:
pass
else:
if not receivers:
# No more connected receivers. Therefore, remove the signal.
try:
signals = connections[senderkey]
except KeyError:
pass
else:
del signals[signal]
if not signals:
# No more signal connections. Therefore, remove the sender.
_removeSender(senderkey)
def _removeSender(senderkey):
"""Remove senderkey from connections."""
_removeBackrefs(senderkey)
try:
del connections[senderkey]
except KeyError:
pass
# Senderkey will only be in senders dictionary if sender
# could be weakly referenced.
try:
del senders[senderkey]
except:
pass
def _removeBackrefs( senderkey):
"""Remove all back-references to this senderkey"""
try:
signals = connections[senderkey]
except KeyError:
signals = None
else:
items = signals.items()
def allReceivers( ):
for signal,set in items:
for item in set:
yield item
for receiver in allReceivers():
_killBackref( receiver, senderkey )
def _removeOldBackRefs(senderkey, signal, receiver, receivers):
"""Kill old sendersBack references from receiver
This guards against multiple registration of the same
receiver for a given signal and sender leaking memory
as old back reference records build up.
Also removes old receiver instance from receivers
"""
try:
index = receivers.index(receiver)
# need to scan back references here and remove senderkey
except ValueError:
return False
else:
oldReceiver = receivers[index]
del receivers[index]
found = 0
signals = connections.get(signal)
if signals is not None:
for sig,recs in connections.get(signal,{}).iteritems():
if sig != signal:
for rec in recs:
if rec is oldReceiver:
found = 1
break
if not found:
_killBackref( oldReceiver, senderkey )
return True
return False
def _killBackref( receiver, senderkey ):
"""Do the actual removal of back reference from receiver to senderkey"""
receiverkey = id(receiver)
set = sendersBack.get( receiverkey, () )
while senderkey in set:
try:
set.remove( senderkey )
except:
break
if not set:
try:
del sendersBack[ receiverkey ]
except KeyError:
pass
return True

View File

@@ -0,0 +1,10 @@
"""Error types for dispatcher mechanism
"""
class DispatcherError(Exception):
"""Base class for all Dispatcher errors"""
class DispatcherKeyError(KeyError, DispatcherError):
"""Error raised when unknown (sender,signal) set specified"""
class DispatcherTypeError(TypeError, DispatcherError):
"""Error raised when inappropriate signal-type specified (None)"""

View File

@@ -0,0 +1,57 @@
"""Module implementing error-catching version of send (sendRobust)"""
from dispatcher import Any, Anonymous, liveReceivers, getAllReceivers
from robustapply import robustApply
def sendRobust(
signal=Any,
sender=Anonymous,
*arguments, **named
):
"""Send signal from sender to all connected receivers catching errors
signal -- (hashable) signal value, see connect for details
sender -- the sender of the signal
if Any, only receivers registered for Any will receive
the message.
if Anonymous, only receivers registered to receive
messages from Anonymous or Any will receive the message
Otherwise can be any python object (normally one
registered with a connect if you actually want
something to occur).
arguments -- positional arguments which will be passed to
*all* receivers. Note that this may raise TypeErrors
if the receivers do not allow the particular arguments.
Note also that arguments are applied before named
arguments, so they should be used with care.
named -- named arguments which will be filtered according
to the parameters of the receivers to only provide those
acceptable to the receiver.
Return a list of tuple pairs [(receiver, response), ... ]
if any receiver raises an error (specifically any subclass of Exception),
the error instance is returned as the result for that receiver.
"""
# Call each receiver with whatever arguments it can accept.
# Return a list of tuple pairs [(receiver, response), ... ].
responses = []
for receiver in liveReceivers(getAllReceivers(sender, signal)):
try:
response = robustApply(
receiver,
signal=signal,
sender=sender,
*arguments,
**named
)
except Exception, err:
responses.append((receiver, err))
else:
responses.append((receiver, response))
return responses

View File

@@ -0,0 +1,49 @@
"""Robust apply mechanism
Provides a function "call", which can sort out
what arguments a given callable object can take,
and subset the given arguments to match only
those which are acceptable.
"""
def function( receiver ):
"""Get function-like callable object for given receiver
returns (function_or_method, codeObject, fromMethod)
If fromMethod is true, then the callable already
has its first argument bound
"""
if hasattr(receiver, '__call__'):
# receiver is a class instance; assume it is callable.
# Reassign receiver to the actual method that will be called.
if hasattr( receiver.__call__, 'im_func') or hasattr( receiver.__call__, 'im_code'):
receiver = receiver.__call__
if hasattr( receiver, 'im_func' ):
# an instance-method...
return receiver, receiver.im_func.func_code, 1
elif not hasattr( receiver, 'func_code'):
raise ValueError('unknown reciever type %s %s'%(receiver, type(receiver)))
return receiver, receiver.func_code, 0
def robustApply(receiver, *arguments, **named):
"""Call receiver with arguments and an appropriate subset of named
"""
receiver, codeObject, startIndex = function( receiver )
acceptable = codeObject.co_varnames[startIndex+len(arguments):codeObject.co_argcount]
for name in codeObject.co_varnames[startIndex:startIndex+len(arguments)]:
if name in named:
raise TypeError(
"""Argument %r specified both positionally and as a keyword for calling %r"""% (
name, receiver,
)
)
if not (codeObject.co_flags & 8):
# fc does not have a **kwds type parameter, therefore
# remove unacceptable arguments.
for arg in named.keys():
if arg not in acceptable:
del named[arg]
return receiver(*arguments, **named)

View File

@@ -0,0 +1,165 @@
"""Refactored "safe reference" from dispatcher.py"""
import weakref, traceback
def safeRef(target, onDelete = None):
"""Return a *safe* weak reference to a callable target
target -- the object to be weakly referenced, if it's a
bound method reference, will create a BoundMethodWeakref,
otherwise creates a simple weakref.
onDelete -- if provided, will have a hard reference stored
to the callable to be called after the safe reference
goes out of scope with the reference object, (either a
weakref or a BoundMethodWeakref) as argument.
"""
if hasattr(target, 'im_self'):
if target.im_self is not None:
# Turn a bound method into a BoundMethodWeakref instance.
# Keep track of these instances for lookup by disconnect().
assert hasattr(target, 'im_func'), """safeRef target %r has im_self, but no im_func, don't know how to create reference"""%( target,)
reference = BoundMethodWeakref(
target=target,
onDelete=onDelete
)
return reference
if callable(onDelete):
return weakref.ref(target, onDelete)
else:
return weakref.ref( target )
class BoundMethodWeakref(object):
"""'Safe' and reusable weak references to instance methods
BoundMethodWeakref objects provide a mechanism for
referencing a bound method without requiring that the
method object itself (which is normally a transient
object) is kept alive. Instead, the BoundMethodWeakref
object keeps weak references to both the object and the
function which together define the instance method.
Attributes:
key -- the identity key for the reference, calculated
by the class's calculateKey method applied to the
target instance method
deletionMethods -- sequence of callable objects taking
single argument, a reference to this object which
will be called when *either* the target object or
target function is garbage collected (i.e. when
this object becomes invalid). These are specified
as the onDelete parameters of safeRef calls.
weakSelf -- weak reference to the target object
weakFunc -- weak reference to the target function
Class Attributes:
_allInstances -- class attribute pointing to all live
BoundMethodWeakref objects indexed by the class's
calculateKey(target) method applied to the target
objects. This weak value dictionary is used to
short-circuit creation so that multiple references
to the same (object, function) pair produce the
same BoundMethodWeakref instance.
"""
_allInstances = weakref.WeakValueDictionary()
def __new__( cls, target, onDelete=None, *arguments,**named ):
"""Create new instance or return current instance
Basically this method of construction allows us to
short-circuit creation of references to already-
referenced instance methods. The key corresponding
to the target is calculated, and if there is already
an existing reference, that is returned, with its
deletionMethods attribute updated. Otherwise the
new instance is created and registered in the table
of already-referenced methods.
"""
key = cls.calculateKey(target)
current =cls._allInstances.get(key)
if current is not None:
current.deletionMethods.append( onDelete)
return current
else:
base = super( BoundMethodWeakref, cls).__new__( cls )
cls._allInstances[key] = base
base.__init__( target, onDelete, *arguments,**named)
return base
def __init__(self, target, onDelete=None):
"""Return a weak-reference-like instance for a bound method
target -- the instance-method target for the weak
reference, must have im_self and im_func attributes
and be reconstructable via:
target.im_func.__get__( target.im_self )
which is true of built-in instance methods.
onDelete -- optional callback which will be called
when this weak reference ceases to be valid
(i.e. either the object or the function is garbage
collected). Should take a single argument,
which will be passed a pointer to this object.
"""
def remove(weak, self=self):
"""Set self.isDead to true when method or instance is destroyed"""
methods = self.deletionMethods[:]
del self.deletionMethods[:]
try:
del self.__class__._allInstances[ self.key ]
except KeyError:
pass
for function in methods:
try:
if callable( function ):
function( self )
except Exception, e:
try:
traceback.print_exc()
except AttributeError, err:
print '''Exception during saferef %s cleanup function %s: %s'''%(
self, function, e
)
self.deletionMethods = [onDelete]
self.key = self.calculateKey( target )
self.weakSelf = weakref.ref(target.im_self, remove)
self.weakFunc = weakref.ref(target.im_func, remove)
self.selfName = str(target.im_self)
self.funcName = str(target.im_func.__name__)
def calculateKey( cls, target ):
"""Calculate the reference key for this reference
Currently this is a two-tuple of the id()'s of the
target object and the target function respectively.
"""
return (id(target.im_self),id(target.im_func))
calculateKey = classmethod( calculateKey )
def __str__(self):
"""Give a friendly representation of the object"""
return """%s( %s.%s )"""%(
self.__class__.__name__,
self.selfName,
self.funcName,
)
__repr__ = __str__
def __nonzero__( self ):
"""Whether we are still a valid reference"""
return self() is not None
def __cmp__( self, other ):
"""Compare with another reference"""
if not isinstance (other,self.__class__):
return cmp( self.__class__, type(other) )
return cmp( self.key, other.key)
def __call__(self):
"""Return a strong reference to the bound method
If the target cannot be retrieved, then will
return None, otherwise returns a bound instance
method for our object and function.
Note:
You may call this method any number of times,
as it does not invalidate the reference.
"""
target = self.weakSelf()
if target is not None:
function = self.weakFunc()
if function is not None:
return function.__get__(target)
return None

View File

@@ -0,0 +1,180 @@
from itertools import count
from types import *
from converters import sqlrepr
creationOrder = count()
class SODatabaseIndex(object):
def __init__(self,
soClass,
name,
columns,
creationOrder,
unique=False):
self.soClass = soClass
self.name = name
self.descriptions = self.convertColumns(columns)
self.creationOrder = creationOrder
self.unique = unique
def get(self, *args, **kw):
if not self.unique:
raise AttributeError, (
"'%s' object has no attribute 'get' (index is not unique)" % self.name)
connection = kw.pop('connection', None)
if args and kw:
raise TypeError, "You cannot mix named and unnamed arguments"
columns = [d['column'] for d in self.descriptions
if 'column' in d]
if kw and len(kw) != len(columns) or args and len(args) != len(columns):
raise TypeError, ("get() takes exactly %d argument and an optional "
"named argument 'connection' (%d given)" % (
len(columns), len(args)+len(kw)))
if args:
kw = {}
for i in range(len(args)):
if columns[i].foreignName is not None:
kw[columns[i].foreignName] = args[i]
else:
kw[columns[i].name] = args[i]
return self.soClass.selectBy(connection=connection, **kw).getOne()
def convertColumns(self, columns):
"""
Converts all the columns to dictionary descriptors;
dereferences string column names.
"""
new = []
for desc in columns:
if not isinstance(desc, dict):
desc = {'column': desc}
if 'expression' in desc:
assert 'column' not in desc, (
'You cannot provide both an expression and a column '
'(for %s in index %s in %s)' %
(desc, self.name, self.soClass))
assert 'length' not in desc, (
'length does not apply to expressions (for %s in '
'index %s in %s)' %
(desc, self.name, self.soClass))
new.append(desc)
continue
columnName = desc['column']
if not isinstance(columnName, str):
columnName = columnName.name
colDict = self.soClass.sqlmeta.columns
if columnName not in colDict:
for possible in colDict.values():
if possible.origName == columnName:
column = possible
break
else:
# None found
raise ValueError, "The column by the name %r was not found in the class %r" % (columnName, self.soClass)
else:
column = colDict[columnName]
desc['column'] = column
new.append(desc)
return new
def getExpression(self, desc, db):
if isinstance(desc['expression'], str):
return desc['expression']
else:
return sqlrepr(desc['expression'], db)
def sqliteCreateIndexSQL(self, soClass):
if self.unique:
uniqueOrIndex = 'UNIQUE INDEX'
else:
uniqueOrIndex = 'INDEX'
spec = []
for desc in self.descriptions:
if 'expression' in desc:
spec.append(self.getExpression(desc, 'sqlite'))
else:
spec.append(desc['column'].dbName)
ret = 'CREATE %s %s_%s ON %s (%s)' % \
(uniqueOrIndex,
self.soClass.sqlmeta.table,
self.name,
self.soClass.sqlmeta.table,
', '.join(spec))
return ret
postgresCreateIndexSQL = maxdbCreateIndexSQL = mssqlCreateIndexSQL = sybaseCreateIndexSQL = firebirdCreateIndexSQL = sqliteCreateIndexSQL
def mysqlCreateIndexSQL(self, soClass):
if self.unique:
uniqueOrIndex = 'UNIQUE'
else:
uniqueOrIndex = 'INDEX'
spec = []
for desc in self.descriptions:
if 'expression' in desc:
spec.append(self.getExpression(desc, 'mysql'))
elif 'length' in desc:
spec.append('%s(%d)' % (desc['column'].dbName, desc['length']))
else:
spec.append(desc['column'].dbName)
return 'ALTER TABLE %s ADD %s %s (%s)' % \
(soClass.sqlmeta.table, uniqueOrIndex,
self.name,
', '.join(spec))
class DatabaseIndex(object):
"""
This takes a variable number of parameters, each of which is a
column for indexing. Each column may be a column object or the
string name of the column (*not* the database name). You may also
use dictionaries, to further customize the indexing of the column.
The dictionary may have certain keys:
'column':
The column object or string identifier.
'length':
MySQL will only index the first N characters if this is
given. For other databases this is ignored.
'expression':
You can create an index based on an expression, e.g.,
'lower(column)'. This can either be a string or a sqlbuilder
expression.
Further keys may be added to the column specs in the future.
The class also take the keyword argument `unique`; if true then
a UNIQUE index is created.
"""
baseClass = SODatabaseIndex
def __init__(self, *columns, **kw):
kw['columns'] = columns
self.kw = kw
self.creationOrder = creationOrder.next()
def setName(self, value):
assert self.kw.get('name') is None, "You cannot change a name after it has already been set (from %s to %s)" % (self.kw['name'], value)
self.kw['name'] = value
def _get_name(self):
return self.kw['name']
def _set_name(self, value):
self.setName(value)
name = property(_get_name, _set_name)
def withClass(self, soClass):
return self.baseClass(soClass=soClass,
creationOrder=self.creationOrder, **self.kw)
def __repr__(self):
return '<%s %s %s>' % (
self.__class__.__name__,
hex(abs(id(self)))[2:],
self.kw)
__all__ = ['DatabaseIndex']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,91 @@
from sqlobject import sqlbuilder
from sqlobject.classregistry import findClass
from sqlobject.dbconnection import Iteration
class InheritableIteration(Iteration):
# Default array size for cursor.fetchmany()
defaultArraySize = 10000
def __init__(self, dbconn, rawconn, select, keepConnection=False):
super(InheritableIteration, self).__init__(dbconn, rawconn, select, keepConnection)
self.lazyColumns = select.ops.get('lazyColumns', False)
self.cursor.arraysize = self.defaultArraySize
self._results = []
# Find the index of the childName column
childNameIdx = None
columns = select.sourceClass.sqlmeta.columnList
for i, column in enumerate(columns):
if column.name == "childName":
childNameIdx = i
break
self._childNameIdx = childNameIdx
def next(self):
if not self._results:
self._results = list(self.cursor.fetchmany())
if not self.lazyColumns: self.fetchChildren()
if not self._results:
self._cleanup()
raise StopIteration
result = self._results[0]
del self._results[0]
if self.lazyColumns:
obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
return obj
else:
id = result[0]
if id in self._childrenResults:
childResults = self._childrenResults[id]
del self._childrenResults[id]
else:
childResults = None
obj = self.select.sourceClass.get(id, selectResults=result[1:],
childResults=childResults, connection=self.dbconn)
return obj
def fetchChildren(self):
"""Prefetch childrens' data
Fetch childrens' data for every subclass in one big .select()
to avoid .get() fetching it one by one.
"""
self._childrenResults = {}
if self._childNameIdx is None:
return
childIdsNames = {}
childNameIdx = self._childNameIdx
for result in self._results:
childName = result[childNameIdx+1]
if childName:
ids = childIdsNames.get(childName)
if ids is None:
ids = childIdsNames[childName] = []
ids.append(result[0])
dbconn = self.dbconn
rawconn = self.rawconn
cursor = rawconn.cursor()
registry = self.select.sourceClass.sqlmeta.registry
for childName, ids in childIdsNames.items():
klass = findClass(childName, registry)
if len(ids) == 1:
select = klass.select(klass.q.id == ids[0],
childUpdate=True, connection=dbconn)
else:
select = klass.select(sqlbuilder.IN(klass.q.id, ids),
childUpdate=True, connection=dbconn)
query = dbconn.queryForSelect(select)
if dbconn.debug:
dbconn.printDebug(rawconn, query, 'Select children of the class %s' % childName)
self.dbconn._executeRetry(rawconn, cursor, query)
for result in cursor.fetchall():
# Inheritance child classes may have no own columns
# (that makes sense when child class has a join
# that does not apply to parent class objects).
# In such cases result[1:] gives an empty tuple
# which is interpreted as "no results fetched" in .get().
# So .get() issues another query which is absolutely
# meaningless (like "SELECT NULL FROM child WHERE id=1").
# In order to avoid this, we replace empty results
# with non-empty tuple. Extra values in selectResults
# are Ok - they will be ignored by ._SO_selectInit().
self._childrenResults[result[0]] = result[1:] or (None,)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import sqliteconnection
return sqliteconnection.SQLiteConnection
registerConnection(['jython_sqlite'], builder)

View File

@@ -0,0 +1,379 @@
import base64
import os
import thread
import urllib
from sqlobject.dbconnection import DBAPI, Boolean
from sqlobject import col, sqlbuilder
from sqlobject.dberrors import *
sqlite2_Binary = None
class ErrorMessage(str):
def __new__(cls, e):
obj = str.__new__(cls, e[0])
obj.code = None
obj.module = e.__module__
obj.exception = e.__class__.__name__
return obj
class SQLiteConnection(DBAPI):
supportTransactions = True
dbName = 'sqlite'
schemes = [dbName]
def __init__(self, filename, autoCommit=1, **kw):
from com.ziclix.python.sql import zxJDBC as sqlite
self.module = sqlite
self.using_sqlite2 = True
self.filename = filename # full path to sqlite-db-file
self._memory = filename == ':memory:'
if self._memory and not self.using_sqlite2:
raise ValueError("You must use sqlite2 to use in-memory databases")
# connection options
opts = {}
if self.using_sqlite2:
if autoCommit:
opts["isolation_level"] = None
if 'factory' in kw:
factory = kw.pop('factory')
if isinstance(factory, str):
factory = globals()[factory]
opts['factory'] = factory(sqlite)
else:
opts['autocommit'] = Boolean(autoCommit)
if 'encoding' in kw:
opts['encoding'] = kw.pop('encoding')
if 'mode' in kw:
opts['mode'] = int(kw.pop('mode'), 0)
if 'timeout' in kw:
if self.using_sqlite2:
opts['timeout'] = float(kw.pop('timeout'))
else:
opts['timeout'] = int(float(kw.pop('timeout')) * 1000)
if 'check_same_thread' in kw:
opts["check_same_thread"] = Boolean(kw.pop('check_same_thread'))
# use only one connection for sqlite - supports multiple)
# cursors per connection
self._connOptions = opts
self.use_table_info = Boolean(kw.pop("use_table_info", True))
DBAPI.__init__(self, **kw)
self._threadPool = {}
self._threadOrigination = {}
if self._memory:
self._memoryConn = self.module.connect('jdbc:sqlite:%s' % self.filename, None, None, 'org.sqlite.JDBC')
# Convert text data from SQLite to str, not unicode -
# SQLObject converts it to unicode itself.
#self._memoryConn.text_factory = str
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
assert host is None and port is None, (
"SQLite can only be used locally (with a URI like "
"sqlite:/file or sqlite:///file, not sqlite://%s%s)" %
(host, port and ':%r' % port or ''))
assert user is None and password is None, (
"You may not provide usernames or passwords for SQLite "
"databases")
if path == "/:memory:":
path = ":memory:"
return cls(filename=path, **args)
def oldUri(self):
path = self.filename
if path == ":memory:":
path = "/:memory:"
else:
path = "//" + path
return 'sqlite:%s' % path
def uri(self):
path = self.filename
if path == ":memory:":
path = "/:memory:"
else:
if path.startswith('/'):
path = "//" + path
else:
path = "///" + path
path = urllib.quote(path)
return 'sqlite:%s' % path
def getConnection(self):
# SQLite can't share connections between threads, and so can't
# pool connections. Since we are isolating threads here, we
# don't have to worry about locking as much.
if self._memory:
conn = self.makeConnection()
self._connectionNumbers[id(conn)] = self._connectionCount
self._connectionCount += 1
return conn
threadid = thread.get_ident()
if (self._pool is not None
and threadid in self._threadPool):
conn = self._threadPool[threadid]
del self._threadPool[threadid]
if conn in self._pool:
self._pool.remove(conn)
else:
conn = self.makeConnection()
if self._pool is not None:
self._threadOrigination[id(conn)] = threadid
self._connectionNumbers[id(conn)] = self._connectionCount
self._connectionCount += 1
if self.debug:
s = 'ACQUIRE'
if self._pool is not None:
s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
self.printDebug(conn, s, 'Pool')
return conn
def releaseConnection(self, conn, explicit=False):
if self._memory:
return
threadid = self._threadOrigination.get(id(conn))
DBAPI.releaseConnection(self, conn, explicit=explicit)
if (self._pool is not None and threadid
and threadid not in self._threadPool):
self._threadPool[threadid] = conn
else:
if self._pool and conn in self._pool:
self._pool.remove(conn)
conn.close()
def _setAutoCommit(self, conn, auto):
if self.using_sqlite2:
if auto:
conn.isolation_level = None
else:
conn.isolation_level = ""
else:
conn.autocommit = auto
def _setIsolationLevel(self, conn, level):
if not self.using_sqlite2:
return
conn.isolation_level = level
def makeConnection(self):
if self._memory:
return self._memoryConn
# TODO: self._connOptions is ignored because it causes errors
conn = self.module.connect('jdbc:sqlite:%s' % self.filename, '', '', 'org.sqlite.JDBC')
# TODO: zxjdbc.connect does not have a text_factory property
#conn.text_factory = str # Convert text data to str, not unicode
return conn
def _executeRetry(self, conn, cursor, query):
if self.debug:
self.printDebug(conn, query, 'QueryR')
try:
return cursor.execute(query)
except self.module.OperationalError, e:
raise OperationalError(ErrorMessage(e))
except self.module.IntegrityError, e:
msg = ErrorMessage(e)
if msg.startswith('column') and msg.endswith('not unique'):
raise DuplicateEntryError(msg)
else:
raise IntegrityError(msg)
except self.module.InternalError, e:
raise InternalError(ErrorMessage(e))
except self.module.ProgrammingError, e:
raise ProgrammingError(ErrorMessage(e))
except self.module.DataError, e:
raise DataError(ErrorMessage(e))
except self.module.NotSupportedError, e:
raise NotSupportedError(ErrorMessage(e))
except self.module.DatabaseError, e:
raise DatabaseError(ErrorMessage(e))
except self.module.InterfaceError, e:
raise InterfaceError(ErrorMessage(e))
except self.module.Warning, e:
raise Warning(ErrorMessage(e))
except self.module.Error, e:
raise Error(ErrorMessage(e))
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
if id is not None:
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
self._executeRetry(conn, c, q)
# lastrowid is a DB-API extension from "PEP 0249":
if id is None:
if c.lastrowid:
id = c.lastrowid
else:
# the Java SQLite JDBC driver doesn't seem to have implemented
# the lastrowid extension, so we have to do this manually.
# Also getMetaData().getGeneratedKeys() is inaccessible.
# TODO: make this a prepared statement?
self._executeRetry(conn, c, 'select last_insert_rowid()')
id = c.fetchone()[0]
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
def _insertSQL(self, table, names, values):
if not names:
assert not values
# INSERT INTO table () VALUES () isn't allowed in
# SQLite (though it is in other databases)
return ("INSERT INTO %s VALUES (NULL)" % table)
else:
return DBAPI._insertSQL(self, table, names, values)
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if not start:
return "%s LIMIT %i" % (query, end)
if not end:
return "%s LIMIT 0 OFFSET %i" % (query, start)
return "%s LIMIT %i OFFSET %i" % (query, end-start, start)
def createColumn(self, soClass, col):
return col.sqliteCreateSQL()
def createReferenceConstraint(self, soClass, col):
return None
def createIDColumn(self, soClass):
return self._createIDColumn(soClass.sqlmeta)
def _createIDColumn(self, sqlmeta):
if sqlmeta.idType == str:
return '%s TEXT PRIMARY KEY' % sqlmeta.idName
return '%s INTEGER PRIMARY KEY AUTOINCREMENT' % sqlmeta.idName
def joinSQLType(self, join):
return 'INT NOT NULL'
def tableExists(self, tableName):
result = self.queryOne("SELECT tbl_name FROM sqlite_master WHERE type='table' AND tbl_name = '%s'" % tableName)
# turn it into a boolean:
return not not result
def createIndexSQL(self, soClass, index):
return index.sqliteCreateIndexSQL(soClass)
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD COLUMN %s' %
(tableName,
column.sqliteCreateSQL()))
self.query('VACUUM %s' % tableName)
def delColumn(self, sqlmeta, column):
self.recreateTableWithoutColumn(sqlmeta, column)
def recreateTableWithoutColumn(self, sqlmeta, column):
new_name = sqlmeta.table + '_ORIGINAL'
self.query('ALTER TABLE %s RENAME TO %s' % (sqlmeta.table, new_name))
cols = [self._createIDColumn(sqlmeta)] \
+ [self.createColumn(None, col)
for col in sqlmeta.columnList if col.name != column.name]
cols = ",\n".join([" %s" % c for c in cols])
self.query('CREATE TABLE %s (\n%s\n)' % (sqlmeta.table, cols))
all_columns = ', '.join([sqlmeta.idName] + [col.dbName for col in sqlmeta.columnList])
self.query('INSERT INTO %s (%s) SELECT %s FROM %s' % (
sqlmeta.table, all_columns, all_columns, new_name))
self.query('DROP TABLE %s' % new_name)
def columnsFromSchema(self, tableName, soClass):
if self.use_table_info:
return self._columnsFromSchemaTableInfo(tableName, soClass)
else:
return self._columnsFromSchemaParse(tableName, soClass)
def _columnsFromSchemaTableInfo(self, tableName, soClass):
colData = self.queryAll("PRAGMA table_info(%s)" % tableName)
results = []
for index, field, t, nullAllowed, default, key in colData:
if field == soClass.sqlmeta.idName:
continue
colClass, kw = self.guessClass(t)
if default == 'NULL':
nullAllowed = True
default = None
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
kw['notNone'] = not nullAllowed
kw['default'] = default
# @@ skip key...
# @@ skip extra...
results.append(colClass(**kw))
return results
def _columnsFromSchemaParse(self, tableName, soClass):
colData = self.queryOne("SELECT sql FROM sqlite_master WHERE type='table' AND name='%s'"
% tableName)
if not colData:
raise ValueError('The table %s was not found in the database. Load failed.' % tableName)
colData = colData[0].split('(', 1)[1].strip()[:-2]
while True:
start = colData.find('(')
if start == -1: break
end = colData.find(')', start)
if end == -1: break
colData = colData[:start] + colData[end+1:]
results = []
for colDesc in colData.split(','):
parts = colDesc.strip().split(' ', 2)
field = parts[0].strip()
# skip comments
if field.startswith('--'):
continue
# get rid of enclosing quotes
if field[0] == field[-1] == '"':
field = field[1:-1]
if field == getattr(soClass.sqlmeta, 'idName', 'id'):
continue
colClass, kw = self.guessClass(parts[1].strip())
if len(parts) == 2:
index_info = ''
else:
index_info = parts[2].strip().upper()
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
import re
nullble = re.search(r'(\b\S*)\sNULL', index_info)
default = re.search(r"DEFAULT\s((?:\d[\dA-FX.]*)|(?:'[^']*')|(?:#[^#]*#))", index_info)
kw['notNone'] = nullble and nullble.group(1) == 'NOT'
kw['default'] = default and default.group(1)
# @@ skip key...
# @@ skip extra...
results.append(colClass(**kw))
return results
def guessClass(self, t):
t = t.upper()
if t.find('INT') >= 0:
return col.IntCol, {}
elif t.find('TEXT') >= 0 or t.find('CHAR') >= 0 or t.find('CLOB') >= 0:
return col.StringCol, {'length': 2**32-1}
elif t.find('BLOB') >= 0:
return col.BLOBCol, {"length": 2**32-1}
elif t.find('REAL') >= 0 or t.find('FLOAT') >= 0:
return col.FloatCol, {}
elif t.find('DECIMAL') >= 0:
return col.DecimalCol, {'size': None, 'precision': None}
elif t.find('BOOL') >= 0:
return col.BoolCol, {}
else:
return col.Col, {}
def createEmptyDatabase(self):
if self._memory:
return
open(self.filename, 'w').close()
def dropDatabase(self):
if self._memory:
return
os.unlink(self.filename)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
#

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import maxdbconnection
return maxdbconnection.MaxdbConnection
registerConnection(['maxdb','sapdb'],builder)

View File

@@ -0,0 +1,303 @@
"""
Contributed by Edigram SAS, Paris France Tel:01 44 77 94 00
Ahmed MOHAMED ALI <ahmedmoali@yahoo.com> 27 April 2004
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
connection creation sample::
__connection__ = DBConnection.maxdbConnection(
host=hostname, database=dbname,
user=user_name, password=user_password, autoCommit=1, debug=1)
"""
from sqlobject.dbconnection import DBAPI
from sqlobject import col
class maxdbException(Exception):
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value)
class LowerBoundOfSliceIsNotSupported(maxdbException):
def __init__(self, value):
maxdbException.__init__(self, '')
class IncorrectIDStyleError(maxdbException) :
def __init__(self,value):
maxdbException.__init__(
self,
'This primary key name is not in the expected style, '
'please rename the column to %r or switch to another style'
% value)
class StyleMismatchError(maxdbException):
def __init__(self, value):
maxdbException.__init__(
self,
'The name %r is only permitted for primary key, change the '
'column name or switch to another style' % value)
class PrimaryKeyNotFounded(maxdbException):
def __init__(self, value):
maxdbException.__init__(
self,
"No primary key was defined on table %r" % value)
SAPDBMAX_ID_LENGTH=32
class MaxdbConnection(DBAPI):
supportTransactions = True
dbName = 'maxdb'
schemes = [dbName]
def __init__ (self, host='', port=None, user=None, password=None,
database=None, autoCommit=1, sqlmode='internal',
isolation=None, timeout=None, **kw):
from sapdb import dbapi
self.module = dbapi
self.host = host
self.port = port
self.user = user
self.password = password
self.db = database
self.autoCommit = autoCommit
self.sqlmode = sqlmode
self.isolation = isolation
self.timeout = timeout
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, auth, password, host, port, path, args):
path = path.replace('/', os.path.sep)
return cls(host, port, user=auth, password=password,
database=path, **args)
def _getConfigParams(self,sqlmode,auto):
autocommit='off'
if auto:
autocommit='on'
opt = {}
opt["autocommit"] = autocommit
opt["sqlmode"] = sqlmode
if self.isolation:
opt["isolation"]=self.isolation
if self.timeout :
opt["timeout"]=self.timeout
return opt
def _setAutoCommit(self, conn, auto):
conn.close()
conn.__init__(self.user, self.password, self.db, self.host,
**self._getConfigParams(self.sqlmode, auto))
def createSequenceName(self,table):
"""
sequence name are builded with the concatenation of the table
name with '_SEQ' word we truncate the name of the
sequence_name because sapdb identifier cannot exceed 32
characters so that the name of the sequence does not exceed 32
characters
"""
return '%s_SEQ'%(table[:SAPDBMAX_ID_LENGTH -4])
def makeConnection(self):
conn = self.module.Connection(
self.user, self.password, self.db, self.host,
**self._getConfigParams(self.sqlmode, self.autoCommit))
return conn
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
if id is None:
c.execute('SELECT %s.NEXTVAL FROM DUAL' % (self.createSequenceName(table)))
id = c.fetchone()[0]
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
c.execute(q)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def sqlAddLimit(cls,query,limit):
sql = query
sql = sql.replace("SELECT","SELECT ROWNO, ")
if sql.find('WHERE') != -1:
sql = sql + ' AND ' + limit
else:
sql = sql + 'WHERE ' + limit
return sql
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if start:
raise LowerBoundOfSliceIsNotSupported
limit = ' ROWNO <= %d ' % (end)
return cls.sqlAddLimit(query,limit)
def createTable(self, soClass):
#we create the table in a transaction because the addition of the
#table and the sequence must be atomic
#i tried to use the transaction class but i get a recursion limit error
#t=self.transaction()
# t.query('CREATE TABLE %s (\n%s\n)' % \
# (soClass.sqlmeta.table, self.createColumns(soClass)))
#
# t.query("CREATE SEQUENCE %s" % self.createSequenceName(soClass.sqlmeta.table))
# t.commit()
#so use transaction when the problem will be solved
self.query('CREATE TABLE %s (\n%s\n)' % \
(soClass.sqlmeta.table, self.createColumns(soClass)))
self.query("CREATE SEQUENCE %s"
% self.createSequenceName(soClass.sqlmeta.table))
return []
def createReferenceConstraint(self, soClass, col):
return col.maxdbCreateReferenceConstraint()
def createColumn(self, soClass, col):
return col.maxdbCreateSQL()
def createIDColumn(self, soClass):
key_type = {int: "INT", str: "TEXT"}[soClass.sqlmeta.idType]
return '%s %s PRIMARY KEY' % (soClass.sqlmeta.idName, key_type)
def createIndexSQL(self, soClass, index):
return index.maxdbCreateIndexSQL(soClass)
def dropTable(self, tableName,cascade=False):
#we drop the table in a transaction because the removal of the
#table and the sequence must be atomic
#i tried to use the transaction class but i get a recursion limit error
# try:
# t=self.transaction()
# t.query("DROP TABLE %s" % tableName)
# t.query("DROP SEQUENCE %s" % self.createSequenceName(tableName))
# t.commit()
# except:
# t.rollback()
#so use transaction when the problem will be solved
self.query("DROP TABLE %s" % tableName)
self.query("DROP SEQUENCE %s" % self.createSequenceName(tableName))
def joinSQLType(self, join):
return 'INT NOT NULL'
def tableExists(self, tableName):
for (table,) in self.queryAll("SELECT OBJECT_NAME FROM ALL_OBJECTS WHERE OBJECT_TYPE='TABLE'"):
if table.lower() == tableName.lower():
return True
return False
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD %s' %
(tableName,
column.maxdbCreateSQL()))
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
GET_COLUMNS = """
SELECT COLUMN_NAME, NULLABLE, DATA_DEFAULT, DATA_TYPE,
DATA_LENGTH, DATA_SCALE
FROM USER_TAB_COLUMNS WHERE TABLE_NAME=UPPER('%s')"""
GET_PK_AND_FK = """
SELECT constraint_cols.column_name, constraints.constraint_type,
refname,reftablename
FROM user_cons_columns constraint_cols
INNER JOIN user_constraints constraints
ON constraint_cols.constraint_name = constraints.constraint_name
LEFT OUTER JOIN show_foreign_key fk
ON constraint_cols.column_name = fk.columnname
WHERE constraints.table_name =UPPER('%s')"""
def columnsFromSchema(self, tableName, soClass):
colData = self.queryAll(self.GET_COLUMNS
% tableName)
results = []
keymap = {}
pkmap={}
fkData = self.queryAll(self.GET_PK_AND_FK% tableName)
for col, cons_type, refcol, reftable in fkData:
col_name= col.lower()
pkmap[col_name]=False
if cons_type == 'R':
keymap[col_name]=reftable.lower()
elif cons_type == 'P':
pkmap[col_name]=True
if len(pkmap) == 0:
raise PrimaryKeyNotFounded, tableName
for (field, nullAllowed, default, data_type, data_len,
data_scale) in colData:
# id is defined as primary key --> ok
# We let sqlobject raise error if the 'id' is used for another column
field_name = field.lower()
if (field_name == soClass.sqlmeta.idName) and pkmap[field_name]:
continue
colClass, kw = self.guessClass(data_type,data_len,data_scale)
kw['name'] = field_name
kw['dbName'] = field
if nullAllowed == 'Y' :
nullAllowed=False
else:
nullAllowed=True
kw['notNone'] = nullAllowed
if default is not None:
kw['default'] = default
if field_name in keymap:
kw['foreignKey'] = keymap[field_name]
results.append(colClass(**kw))
return results
_numericTypes=['INTEGER', 'INT','SMALLINT']
_dateTypes=['DATE','TIME','TIMESTAMP']
def guessClass(self, t, flength, fscale=None):
"""
An internal method that tries to figure out what Col subclass
is appropriate given whatever introspective information is
available -- both very database-specific.
"""
if t in self._numericTypes:
return col.IntCol, {}
# The type returned by the sapdb library for LONG is
# SapDB_LongReader To get the data call the read member with
# desired size (default =-1 means get all)
elif t.find('LONG') != -1:
return col.StringCol, {'length': flength,
'varchar': False}
elif t in self._dateTypes:
return col.DateTimeCol, {}
elif t == 'FIXED':
return CurrencyCol,{'size':flength,
'precision':fscale}
else:
return col.Col, {}

View File

@@ -0,0 +1,30 @@
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
Author: <ahmedmo@wanadoo.fr>
Edigram SA - Paris France
Tel:0144779400
SAPDBAPI installation
---------------------
The sapdb module can be downloaded from:
Win32
-------
ftp://ftp.sap.com/pub/sapdb/bin/win/sapdb-python-win32-7.4.03.31a.zip
Linux
------
ftp://ftp.sap.com/pub/sapdb/bin/linux/sapdb-python-linux-i386-7.4.03.31a.tgz
uncompress the archive and add the sapdb directory path to your PYTHONPATH.

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import mssqlconnection
return mssqlconnection.MSSQLConnection
registerConnection(['mssql'], builder)

View File

@@ -0,0 +1,306 @@
from sqlobject.dbconnection import DBAPI
from sqlobject import col
import re
class MSSQLConnection(DBAPI):
supportTransactions = True
dbName = 'mssql'
schemes = [dbName]
limit_re = re.compile('^\s*(select )(.*)', re.IGNORECASE)
def __init__(self, db, user, password='', host='localhost', port=None,
autoCommit=0, **kw):
drivers = kw.pop('driver', None) or 'adodb,pymssql'
for driver in drivers.split(','):
driver = driver.strip()
if not driver:
continue
try:
if driver in ('adodb', 'adodbapi'):
import adodbapi as sqlmodule
elif driver == 'pymssql':
import pymssql as sqlmodule
else:
raise ValueError('Unknown MSSQL driver "%s", expected adodb or pymssql' % driver)
except ImportError:
pass
else:
break
else:
raise ImportError('Cannot find an MSSQL driver, tried %s' % drivers)
self.module = sqlmodule
if sqlmodule.__name__ == 'adodbapi':
self.dbconnection = sqlmodule.connect
# ADO uses unicode only (AFAIK)
self.usingUnicodeStrings = True
# Need to use SQLNCLI provider for SQL Server Express Edition
if kw.get("ncli"):
conn_str = "Provider=SQLNCLI;"
else:
conn_str = "Provider=SQLOLEDB;"
conn_str += "Data Source=%s;Initial Catalog=%s;"
# MSDE does not allow SQL server login
if kw.get("sspi"):
conn_str += "Integrated Security=SSPI;Persist Security Info=False"
self.make_conn_str = lambda keys: [conn_str % (keys.host, keys.db)]
else:
conn_str += "User Id=%s;Password=%s"
self.make_conn_str = lambda keys: [conn_str % (keys.host, keys.db, keys.user, keys.password)]
kw.pop("sspi", None)
kw.pop("ncli", None)
else: # pymssql
self.dbconnection = sqlmodule.connect
sqlmodule.Binary = lambda st: str(st)
# don't know whether pymssql uses unicode
self.usingUnicodeStrings = False
self.make_conn_str = lambda keys: \
["", keys.user, keys.password, keys.host, keys.db]
self.autoCommit=int(autoCommit)
self.host = host
self.port = port
self.db = db
self.user = user
self.password = password
self.password = password
self._can_use_max_types = None
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
path = path.strip('/')
return cls(user=user, password=password,
host=host or 'localhost', port=port, db=path, **args)
def insert_id(self, conn):
"""
insert_id method.
"""
c = conn.cursor()
# converting the identity to an int is ugly, but it gets returned
# as a decimal otherwise :S
c.execute('SELECT CONVERT(INT, @@IDENTITY)')
return c.fetchone()[0]
def makeConnection(self):
con = self.dbconnection( *self.make_conn_str(self) )
cur = con.cursor()
cur.execute('SET ANSI_NULLS ON')
cur.execute("SELECT CAST('12345.21' AS DECIMAL(10, 2))")
self.decimalSeparator = str(cur.fetchone()[0])[-3]
cur.close()
return con
HAS_IDENTITY = """
SELECT col.name, col.status, obj.name
FROM syscolumns col
JOIN sysobjects obj
ON obj.id = col.id
WHERE obj.name = '%s'
and col.autoval is not null
"""
def _hasIdentity(self, conn, table):
query = self.HAS_IDENTITY % table
c = conn.cursor()
c.execute(query)
r = c.fetchone()
return r is not None
def _queryInsertID(self, conn, soInstance, id, names, values):
"""
Insert the Initial with names and values, using id.
"""
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
has_identity = self._hasIdentity(conn, table)
if id is not None:
names = [idName] + names
values = [id] + values
elif has_identity and idName in names:
try:
i = names.index( idName )
if i:
del names[i]
del values[i]
except ValueError:
pass
if has_identity:
if id is not None:
c.execute('SET IDENTITY_INSERT %s ON' % table)
else:
c.execute('SET IDENTITY_INSERT %s OFF' % table)
q = self._insertSQL(table, names, values)
if self.debug:
print 'QueryIns: %s' % q
c.execute(q)
if has_identity:
c.execute('SET IDENTITY_INSERT %s OFF' % table)
if id is None:
id = self.insert_id(conn)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if end and not start:
limit_str = "SELECT TOP %i" % end
match = cls.limit_re.match(query)
if match and len(match.groups()) == 2:
return ' '.join([limit_str, match.group(2)])
else:
return query
def createReferenceConstraint(self, soClass, col):
return col.mssqlCreateReferenceConstraint()
def createColumn(self, soClass, col):
return col.mssqlCreateSQL(self)
def createIDColumn(self, soClass):
key_type = {int: "INT", str: "TEXT"}[soClass.sqlmeta.idType]
return '%s %s IDENTITY UNIQUE' % (soClass.sqlmeta.idName, key_type)
def createIndexSQL(self, soClass, index):
return index.mssqlCreateIndexSQL(soClass)
def joinSQLType(self, join):
return 'INT NOT NULL'
SHOW_TABLES="SELECT name FROM sysobjects WHERE type='U'"
def tableExists(self, tableName):
for (table,) in self.queryAll(self.SHOW_TABLES):
if table.lower() == tableName.lower():
return True
return False
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD %s' %
(tableName,
column.mssqlCreateSQL(self)))
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (tableName.table, column.dbName))
# precision and scale is gotten from column table so that we can create
# decimal columns if needed
SHOW_COLUMNS = """
select
name,
length,
( select name
from systypes
where cast(xusertype as int)= cast(sc.xtype as int)
) datatype,
prec,
scale,
isnullable,
cdefault,
m.text default_text,
isnull(len(autoval),0) is_identity
from syscolumns sc
LEFT OUTER JOIN syscomments m on sc.cdefault = m.id
AND m.colid = 1
where
sc.id in (select id
from sysobjects
where name = '%s')
order by
colorder"""
def columnsFromSchema(self, tableName, soClass):
colData = self.queryAll(self.SHOW_COLUMNS
% tableName)
results = []
for field, size, t, precision, scale, nullAllowed, default, defaultText, is_identity in colData:
if field == soClass.sqlmeta.idName:
continue
# precision is needed for decimal columns
colClass, kw = self.guessClass(t, size, precision, scale)
kw['name'] = str(soClass.sqlmeta.style.dbColumnToPythonAttr(field))
kw['dbName'] = str(field)
kw['notNone'] = not nullAllowed
if (defaultText):
# Strip ( and )
defaultText = defaultText[1:-1]
if defaultText[0] == "'":
defaultText = defaultText[1:-1]
else:
if t == "int" : defaultText = int(defaultText)
if t == "float" : defaultText = float(defaultText)
if t == "numeric": defaultText = float(defaultText)
# TODO need to access the "column" to_python method here--but the object doesn't exists yet
# @@ skip key...
kw['default'] = defaultText
results.append(colClass(**kw))
return results
def _setAutoCommit(self, conn, auto):
#raise Exception(repr(auto))
return
#conn.auto_commit = auto
option = "ON"
if auto == 0:
option = "OFF"
c = conn.cursor()
c.execute("SET AUTOCOMMIT " + option)
conn.setconnectoption(SQL.AUTOCOMMIT, option)
# precision and scale is needed for decimal columns
def guessClass(self, t, size, precision, scale):
"""
Here we take raw values coming out of syscolumns and map to SQLObject class types.
"""
if t.startswith('int'):
return col.IntCol, {}
elif t.startswith('varchar'):
if self.usingUnicodeStrings:
return col.UnicodeCol, {'length': size}
return col.StringCol, {'length': size}
elif t.startswith('char'):
if self.usingUnicodeStrings:
return col.UnicodeCol, {'length': size,
'varchar': False}
return col.StringCol, {'length': size,
'varchar': False}
elif t.startswith('datetime'):
return col.DateTimeCol, {}
elif t.startswith('decimal'):
return col.DecimalCol, {'size': precision, # be careful for awkward naming
'precision': scale}
else:
return col.Col, {}
def server_version(self):
try:
server_version = self.queryAll("SELECT SERVERPROPERTY('productversion')")[0][0]
server_version = server_version.split('.')[0]
server_version = int(server_version)
except:
server_version = None # unknown
self.server_version = server_version # cache it forever
return server_version
def can_use_max_types(self):
if self._can_use_max_types is not None:
return self._can_use_max_types
server_version = self.server_version()
self._can_use_max_types = can_use_max_types = \
(server_version is not None) and (server_version >= 9)
return can_use_max_types

View File

@@ -0,0 +1,8 @@
from sqlobject.dbconnection import registerConnection
#import mysqltypes
def builder():
import mysqlconnection
return mysqlconnection.MySQLConnection
registerConnection(['mysql'], builder)

View File

@@ -0,0 +1,305 @@
from sqlobject import col
from sqlobject.dbconnection import DBAPI
from sqlobject.dberrors import *
class ErrorMessage(str):
def __new__(cls, e, append_msg=''):
obj = str.__new__(cls, e[1] + append_msg)
obj.code = int(e[0])
obj.module = e.__module__
obj.exception = e.__class__.__name__
return obj
class MySQLConnection(DBAPI):
supportTransactions = False
dbName = 'mysql'
schemes = [dbName]
def __init__(self, db, user, password='', host='localhost', port=0, **kw):
import MySQLdb, MySQLdb.constants.CR, MySQLdb.constants.ER
self.module = MySQLdb
self.host = host
self.port = port
self.db = db
self.user = user
self.password = password
self.kw = {}
for key in ("unix_socket", "init_command",
"read_default_file", "read_default_group", "conv"):
if key in kw:
self.kw[key] = kw.pop(key)
for key in ("connect_timeout", "compress", "named_pipe", "use_unicode",
"client_flag", "local_infile"):
if key in kw:
self.kw[key] = int(kw.pop(key))
for key in ("ssl_key", "ssl_cert", "ssl_ca", "ssl_capath"):
if key in kw:
if "ssl" not in self.kw:
self.kw["ssl"] = {}
self.kw["ssl"][key[4:]] = kw.pop(key)
if "charset" in kw:
self.dbEncoding = self.kw["charset"] = kw.pop("charset")
else:
self.dbEncoding = None
# MySQLdb < 1.2.1: only ascii
# MySQLdb = 1.2.1: only unicode
# MySQLdb > 1.2.1: both ascii and unicode
self.need_unicode = (self.module.version_info[:3] >= (1, 2, 1)) and (self.module.version_info[:3] < (1, 2, 2))
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
return cls(db=path.strip('/'), user=user or '', password=password or '',
host=host or 'localhost', port=port or 0, **args)
def makeConnection(self):
dbEncoding = self.dbEncoding
if dbEncoding:
from MySQLdb.connections import Connection
if not hasattr(Connection, 'set_character_set'):
# monkeypatch pre MySQLdb 1.2.1
def character_set_name(self):
return dbEncoding + '_' + dbEncoding
Connection.character_set_name = character_set_name
try:
conn = self.module.connect(host=self.host, port=self.port,
db=self.db, user=self.user, passwd=self.password, **self.kw)
if self.module.version_info[:3] >= (1, 2, 2):
conn.ping(True) # Attempt to reconnect. This setting is persistent.
except self.module.OperationalError, e:
conninfo = "; used connection string: host=%(host)s, port=%(port)s, db=%(db)s, user=%(user)s" % self.__dict__
raise OperationalError(ErrorMessage(e, conninfo))
if hasattr(conn, 'autocommit'):
conn.autocommit(bool(self.autoCommit))
if dbEncoding:
if hasattr(conn, 'set_character_set'): # MySQLdb 1.2.1 and later
conn.set_character_set(dbEncoding)
else: # pre MySQLdb 1.2.1
# works along with monkeypatching code above
conn.query("SET NAMES %s" % dbEncoding)
return conn
def _setAutoCommit(self, conn, auto):
if hasattr(conn, 'autocommit'):
conn.autocommit(auto)
def _executeRetry(self, conn, cursor, query):
if self.need_unicode and not isinstance(query, unicode):
try:
query = unicode(query, self.dbEncoding)
except UnicodeError:
pass
# When a server connection is lost and a query is attempted, most of
# the time the query will raise a SERVER_LOST exception, then at the
# second attempt to execute it, the mysql lib will reconnect and
# succeed. However is a few cases, the first attempt raises the
# SERVER_GONE exception, the second attempt the SERVER_LOST exception
# and only the third succeeds. Thus the 3 in the loop count.
# If it doesn't reconnect even after 3 attempts, while the database is
# up and running, it is because a 5.0.3 (or newer) server is used
# which no longer permits autoreconnects by default. In that case a
# reconnect flag must be set when making the connection to indicate
# that autoreconnecting is desired. In MySQLdb 1.2.2 or newer this is
# done by calling ping(True) on the connection.
for count in range(3):
try:
return cursor.execute(query)
except self.module.OperationalError, e:
if e.args[0] in (self.module.constants.CR.SERVER_GONE_ERROR, self.module.constants.CR.SERVER_LOST):
if count == 2:
raise OperationalError(ErrorMessage(e))
if self.debug:
self.printDebug(conn, str(e), 'ERROR')
else:
raise OperationalError(ErrorMessage(e))
except self.module.IntegrityError, e:
msg = ErrorMessage(e)
if e.args[0] == self.module.constants.ER.DUP_ENTRY:
raise DuplicateEntryError(msg)
else:
raise IntegrityError(msg)
except self.module.InternalError, e:
raise InternalError(ErrorMessage(e))
except self.module.ProgrammingError, e:
raise ProgrammingError(ErrorMessage(e))
except self.module.DataError, e:
raise DataError(ErrorMessage(e))
except self.module.NotSupportedError, e:
raise NotSupportedError(ErrorMessage(e))
except self.module.DatabaseError, e:
raise DatabaseError(ErrorMessage(e))
except self.module.InterfaceError, e:
raise InterfaceError(ErrorMessage(e))
except self.module.Warning, e:
raise Warning(ErrorMessage(e))
except self.module.Error, e:
raise Error(ErrorMessage(e))
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
if id is not None:
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
self._executeRetry(conn, c, q)
if id is None:
try:
id = c.lastrowid
except AttributeError:
id = c.insert_id()
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if not start:
return "%s LIMIT %i" % (query, end)
if not end:
return "%s LIMIT %i, -1" % (query, start)
return "%s LIMIT %i, %i" % (query, start, end-start)
def createReferenceConstraint(self, soClass, col):
return col.mysqlCreateReferenceConstraint()
def createColumn(self, soClass, col):
return col.mysqlCreateSQL()
def createIndexSQL(self, soClass, index):
return index.mysqlCreateIndexSQL(soClass)
def createIDColumn(self, soClass):
if soClass.sqlmeta.idType == str:
return '%s TEXT PRIMARY KEY' % soClass.sqlmeta.idName
return '%s INT PRIMARY KEY AUTO_INCREMENT' % soClass.sqlmeta.idName
def joinSQLType(self, join):
return 'INT NOT NULL'
def tableExists(self, tableName):
try:
# Use DESCRIBE instead of SHOW TABLES because SHOW TABLES
# assumes there is a default database selected
# which is not always True (for an embedded application, e.g.)
self.query('DESCRIBE %s' % (tableName))
return True
except ProgrammingError, e:
if e[0].code == 1146: # ER_NO_SUCH_TABLE
return False
raise
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD COLUMN %s' %
(tableName,
column.mysqlCreateSQL()))
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
def columnsFromSchema(self, tableName, soClass):
colData = self.queryAll("SHOW COLUMNS FROM %s"
% tableName)
results = []
for field, t, nullAllowed, key, default, extra in colData:
if field == soClass.sqlmeta.idName:
continue
colClass, kw = self.guessClass(t)
if self.kw.get('use_unicode') and colClass is col.StringCol:
colClass = col.UnicodeCol
if self.dbEncoding: kw['dbEncoding'] = self.dbEncoding
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
# Since MySQL 5.0, 'NO' is returned in the NULL column (SQLObject expected '')
kw['notNone'] = (nullAllowed.upper() != 'YES' and True or False)
if default and t.startswith('int'):
kw['default'] = int(default)
elif default and t.startswith('float'):
kw['default'] = float(default)
elif default == 'CURRENT_TIMESTAMP' and t == 'timestamp':
kw['default'] = None
elif default and colClass is col.BoolCol:
kw['default'] = int(default) and True or False
else:
kw['default'] = default
# @@ skip key...
# @@ skip extra...
results.append(colClass(**kw))
return results
def guessClass(self, t):
if t.startswith('int'):
return col.IntCol, {}
elif t.startswith('enum'):
values = []
for i in t[5:-1].split(','): # take the enum() off and split
values.append(i[1:-1]) # remove the surrounding \'
return col.EnumCol, {'enumValues': values}
elif t.startswith('double'):
return col.FloatCol, {}
elif t.startswith('varchar'):
colType = col.StringCol
if self.kw.get('use_unicode', False):
colType = col.UnicodeCol
if t.endswith('binary'):
return colType, {'length': int(t[8:-8]),
'char_binary': True}
else:
return colType, {'length': int(t[8:-1])}
elif t.startswith('char'):
if t.endswith('binary'):
return col.StringCol, {'length': int(t[5:-8]),
'varchar': False,
'char_binary': True}
else:
return col.StringCol, {'length': int(t[5:-1]),
'varchar': False}
elif t.startswith('datetime'):
return col.DateTimeCol, {}
elif t.startswith('date'):
return col.DateCol, {}
elif t.startswith('time'):
return col.TimeCol, {}
elif t.startswith('timestamp'):
return col.TimestampCol, {}
elif t.startswith('bool'):
return col.BoolCol, {}
elif t.startswith('tinyblob'):
return col.BLOBCol, {"length": 2**8-1}
elif t.startswith('tinytext'):
return col.StringCol, {"length": 2**8-1, "varchar": True}
elif t.startswith('blob'):
return col.BLOBCol, {"length": 2**16-1}
elif t.startswith('text'):
return col.StringCol, {"length": 2**16-1, "varchar": True}
elif t.startswith('mediumblob'):
return col.BLOBCol, {"length": 2**24-1}
elif t.startswith('mediumtext'):
return col.StringCol, {"length": 2**24-1, "varchar": True}
elif t.startswith('longblob'):
return col.BLOBCol, {"length": 2**32}
elif t.startswith('longtext'):
return col.StringCol, {"length": 2**32, "varchar": True}
else:
return col.Col, {}
def _createOrDropDatabase(self, op="CREATE"):
self.query('%s DATABASE %s' % (op, self.db))
def createEmptyDatabase(self):
self._createOrDropDatabase()
def dropDatabase(self):
self._createOrDropDatabase(op="DROP")

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import pgconnection
return pgconnection.PostgresConnection
registerConnection(['postgres', 'postgresql', 'psycopg'], builder)

View File

@@ -0,0 +1,359 @@
from sqlobject.dbconnection import DBAPI
import re
from sqlobject import col
from sqlobject import sqlbuilder
from sqlobject.converters import registerConverter
class PostgresConnection(DBAPI):
supportTransactions = True
dbName = 'postgres'
schemes = [dbName, 'postgresql']
def __init__(self, dsn=None, host=None, port=None, db=None,
user=None, password=None, **kw):
drivers = kw.pop('driver', None) or 'psycopg'
for driver in drivers.split(','):
driver = driver.strip()
if not driver:
continue
try:
if driver == 'psycopg2':
import psycopg2 as psycopg
elif driver == 'psycopg1':
import psycopg
elif driver == 'psycopg':
try:
import psycopg2 as psycopg
except ImportError:
import psycopg
elif driver == 'pygresql':
import pgdb
self.module = pgdb
else:
raise ValueError('Unknown PostgreSQL driver "%s", expected psycopg2, psycopg1 or pygresql' % driver)
except ImportError:
pass
else:
break
else:
raise ImportError('Cannot find a PostgreSQL driver, tried %s' % drivers)
if driver.startswith('psycopg'):
self.module = psycopg
# Register a converter for psycopg Binary type.
registerConverter(type(psycopg.Binary('')),
PsycoBinaryConverter)
self.user = user
self.host = host
self.port = port
self.db = db
self.password = password
self.dsn_dict = dsn_dict = {}
if host:
dsn_dict["host"] = host
if port:
if driver == 'pygresql':
dsn_dict["host"] = "%s:%d" % (host, port)
else:
if psycopg.__version__.split('.')[0] == '1':
dsn_dict["port"] = str(port)
else:
dsn_dict["port"] = port
if db:
dsn_dict["database"] = db
if user:
dsn_dict["user"] = user
if password:
dsn_dict["password"] = password
sslmode = kw.pop("sslmode", None)
if sslmode:
dsn_dict["sslmode"] = sslmode
self.use_dsn = dsn is not None
if dsn is None:
if driver == 'pygresql':
dsn = ''
if host:
dsn += host
dsn += ':'
if db:
dsn += db
dsn += ':'
if user:
dsn += user
dsn += ':'
if password:
dsn += password
else:
dsn = []
if db:
dsn.append('dbname=%s' % db)
if user:
dsn.append('user=%s' % user)
if password:
dsn.append('password=%s' % password)
if host:
dsn.append('host=%s' % host)
if port:
dsn.append('port=%d' % port)
if sslmode:
dsn.append('sslmode=%s' % sslmode)
dsn = ' '.join(dsn)
self.driver = driver
self.dsn = dsn
self.unicodeCols = kw.pop('unicodeCols', False)
self.schema = kw.pop('schema', None)
self.dbEncoding = kw.pop("charset", None)
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
path = path.strip('/')
if (host is None) and path.count('/'): # Non-default unix socket
path_parts = path.split('/')
host = '/' + '/'.join(path_parts[:-1])
path = path_parts[-1]
return cls(host=host, port=port, db=path, user=user, password=password, **args)
def _setAutoCommit(self, conn, auto):
# psycopg2 does not have an autocommit method.
if hasattr(conn, 'autocommit'):
conn.autocommit(auto)
def makeConnection(self):
try:
if self.use_dsn:
conn = self.module.connect(self.dsn)
else:
conn = self.module.connect(**self.dsn_dict)
except self.module.OperationalError, e:
raise self.module.OperationalError("%s; used connection string %r" % (e, self.dsn))
if self.autoCommit:
# psycopg2 does not have an autocommit method.
if hasattr(conn, 'autocommit'):
conn.autocommit(1)
c = conn.cursor()
if self.schema:
c.execute("SET search_path TO " + self.schema)
dbEncoding = self.dbEncoding
if dbEncoding:
c.execute("SET client_encoding TO '%s'" % dbEncoding)
return conn
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
sequenceName = soInstance.sqlmeta.idSequence or \
'%s_%s_seq' % (table, idName)
c = conn.cursor()
if id is None:
c.execute("SELECT NEXTVAL('%s')" % sequenceName)
id = c.fetchone()[0]
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
c.execute(q)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if not start:
return "%s LIMIT %i" % (query, end)
if not end:
return "%s OFFSET %i" % (query, start)
return "%s LIMIT %i OFFSET %i" % (query, end-start, start)
def createColumn(self, soClass, col):
return col.postgresCreateSQL()
def createReferenceConstraint(self, soClass, col):
return col.postgresCreateReferenceConstraint()
def createIndexSQL(self, soClass, index):
return index.postgresCreateIndexSQL(soClass)
def createIDColumn(self, soClass):
key_type = {int: "SERIAL", str: "TEXT"}[soClass.sqlmeta.idType]
return '%s %s PRIMARY KEY' % (soClass.sqlmeta.idName, key_type)
def dropTable(self, tableName, cascade=False):
self.query("DROP TABLE %s %s" % (tableName,
cascade and 'CASCADE' or ''))
def joinSQLType(self, join):
return 'INT NOT NULL'
def tableExists(self, tableName):
result = self.queryOne("SELECT COUNT(relname) FROM pg_class WHERE relname = %s"
% self.sqlrepr(tableName))
return result[0]
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD COLUMN %s' %
(tableName,
column.postgresCreateSQL()))
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
def columnsFromSchema(self, tableName, soClass):
keyQuery = """
SELECT pg_catalog.pg_get_constraintdef(oid) as condef
FROM pg_catalog.pg_constraint r
WHERE r.conrelid = %s::regclass AND r.contype = 'f'"""
colQuery = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod), a.attnotnull,
(SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
WHERE d.adrelid=a.attrelid AND d.adnum = a.attnum)
FROM pg_catalog.pg_attribute a
WHERE a.attrelid =%s::regclass
AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum"""
primaryKeyQuery = """
SELECT pg_index.indisprimary,
pg_catalog.pg_get_indexdef(pg_index.indexrelid)
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index AS pg_index
WHERE c.relname = %s
AND c.oid = pg_index.indrelid
AND pg_index.indexrelid = c2.oid
AND pg_index.indisprimary
"""
keyData = self.queryAll(keyQuery % self.sqlrepr(tableName))
keyRE = re.compile(r"\((.+)\) REFERENCES (.+)\(")
keymap = {}
for (condef,) in keyData:
match = keyRE.search(condef)
if match:
field, reftable = match.groups()
keymap[field] = reftable.capitalize()
primaryData = self.queryAll(primaryKeyQuery % self.sqlrepr(tableName))
primaryRE = re.compile(r'CREATE .*? USING .* \((.+?)\)')
primaryKey = None
for isPrimary, indexDef in primaryData:
match = primaryRE.search(indexDef)
assert match, "Unparseable contraint definition: %r" % indexDef
assert primaryKey is None, "Already found primary key (%r), then found: %r" % (primaryKey, indexDef)
primaryKey = match.group(1)
assert primaryKey, "No primary key found in table %r" % tableName
if primaryKey.startswith('"'):
assert primaryKey.endswith('"')
primaryKey = primaryKey[1:-1]
colData = self.queryAll(colQuery % self.sqlrepr(tableName))
results = []
if self.unicodeCols:
client_encoding = self.queryOne("SHOW client_encoding")[0]
for field, t, notnull, defaultstr in colData:
if field == primaryKey:
continue
if field in keymap:
colClass = col.ForeignKey
kw = {'foreignKey': soClass.sqlmeta.style.dbTableToPythonClass(keymap[field])}
name = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
if name.endswith('ID'):
name = name[:-2]
kw['name'] = name
else:
colClass, kw = self.guessClass(t)
if self.unicodeCols and colClass is col.StringCol:
colClass = col.UnicodeCol
kw['dbEncoding'] = client_encoding
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
kw['notNone'] = notnull
if defaultstr is not None:
kw['default'] = self.defaultFromSchema(colClass, defaultstr)
elif not notnull:
kw['default'] = None
results.append(colClass(**kw))
return results
def guessClass(self, t):
if t.count('point'): # poINT before INT
return col.StringCol, {}
elif t.count('int'):
return col.IntCol, {}
elif t.count('varying') or t.count('varchar'):
if '(' in t:
return col.StringCol, {'length': int(t[t.index('(')+1:-1])}
else: # varchar without length in Postgres means any length
return col.StringCol, {}
elif t.startswith('character('):
return col.StringCol, {'length': int(t[t.index('(')+1:-1]),
'varchar': False}
elif t.count('float') or t.count('real') or t.count('double'):
return col.FloatCol, {}
elif t == 'text':
return col.StringCol, {}
elif t.startswith('timestamp'):
return col.DateTimeCol, {}
elif t.startswith('datetime'):
return col.DateTimeCol, {}
elif t.startswith('date'):
return col.DateCol, {}
elif t.startswith('bool'):
return col.BoolCol, {}
elif t.startswith('bytea'):
return col.BLOBCol, {}
else:
return col.Col, {}
def defaultFromSchema(self, colClass, defaultstr):
"""
If the default can be converted to a python constant, convert it.
Otherwise return is as a sqlbuilder constant.
"""
if colClass == col.BoolCol:
if defaultstr == 'false':
return False
elif defaultstr == 'true':
return True
return getattr(sqlbuilder.const, defaultstr)
def _createOrDropDatabase(self, op="CREATE"):
# We have to connect to *some* database, so we'll connect to
# template1, which is a common open database.
# @@: This doesn't use self.use_dsn or self.dsn_dict
if self.driver == 'pygresql':
dsn = '%s:template1:%s:%s' % (
self.host or '', self.user or '', self.password or '')
else:
dsn = 'dbname=template1'
if self.user:
dsn += ' user=%s' % self.user
if self.password:
dsn += ' password=%s' % self.password
if self.host:
dsn += ' host=%s' % self.host
conn = self.module.connect(dsn)
cur = conn.cursor()
# We must close the transaction with a commit so that
# the CREATE DATABASE can work (which can't be in a transaction):
cur.execute('COMMIT')
cur.execute('%s DATABASE %s' % (op, self.db))
cur.close()
conn.close()
def createEmptyDatabase(self):
self._createOrDropDatabase()
def dropDatabase(self):
self._createOrDropDatabase(op="DROP")
# Converter for psycopg Binary type.
def PsycoBinaryConverter(value, db):
assert db == 'postgres'
return str(value)

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import rdbhostconnection
return rdbhostconnection.RdbhostConnection
registerConnection(['rdbhost'], builder)

View File

@@ -0,0 +1,78 @@
"""
This module written by David Keeney, 2009, 2010
Released under the LGPL for use with the SQLObject ORM library.
"""
import re
from sqlobject import col
from sqlobject import sqlbuilder
from sqlobject.converters import registerConverter
from sqlobject.dbconnection import DBAPI
from sqlobject.postgres.pgconnection import PostgresConnection
class RdbhostConnection(PostgresConnection):
supportTransactions = False
dbName = 'rdbhost'
schemes = [dbName]
def __init__(self, dsn=None, host=None, port=None, db=None,
user=None, password=None, unicodeCols=False, driver='rdbhost',
**kw):
drivers = driver
for driver in drivers.split(','):
driver = driver.strip()
if not driver:
continue
try:
if driver == 'rdbhost':
from rdbhdb import rdbhdb as rdb
# monkey patch % escaping into Cursor._execute
old_execute = getattr(rdb.Cursor, '_execute')
setattr(rdb.Cursor, '_old_execute', old_execute)
def _execute(self, query, *args):
assert not any([a for a in args])
query = query.replace('%', '%%')
self._old_execute(query, (), (), ())
setattr(rdb.Cursor, '_execute', _execute)
self.module = rdb
else:
raise ValueError('Unknown Rdbhost driver %s' % driver)
except ImportError:
pass
else:
break
else:
raise ImportError('Cannot find the Rdbhost driver')
self.user = user
self.host = host
self.port = port
self.db = db
self.password = password
self.dsn_dict = dsn_dict = {}
self.use_dsn = dsn is not None
if host:
dsn_dict["host"] = host
if user:
dsn_dict["role"] = user
if password:
dsn_dict["authcode"] = password
if dsn is None:
dsn = []
if db:
dsn.append('dbname=%s' % db)
if user:
dsn.append('user=%s' % user)
if password:
dsn.append('password=%s' % password)
if host:
dsn.append('host=%s' % host)
if port:
dsn.append('port=%d' % port)
dsn = ' '.join(dsn)
self.dsn = dsn
self.unicodeCols = unicodeCols
self.schema = kw.pop('schema', None)
self.dbEncoding = 'utf-8'
DBAPI.__init__(self, **kw)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import sqliteconnection
return sqliteconnection.SQLiteConnection
registerConnection(['sqlite'], builder)

View File

@@ -0,0 +1,397 @@
import base64
import os
import thread
import urllib
from sqlobject.dbconnection import DBAPI, Boolean
from sqlobject import col, sqlbuilder
from sqlobject.dberrors import *
sqlite2_Binary = None
class ErrorMessage(str):
def __new__(cls, e):
obj = str.__new__(cls, e[0])
obj.code = None
obj.module = e.__module__
obj.exception = e.__class__.__name__
return obj
class SQLiteConnection(DBAPI):
supportTransactions = True
dbName = 'sqlite'
schemes = [dbName]
def __init__(self, filename, autoCommit=1, **kw):
drivers = kw.pop('driver', None) or 'pysqlite2,sqlite3,sqlite'
for driver in drivers.split(','):
driver = driver.strip()
if not driver:
continue
try:
if driver in ('sqlite2', 'pysqlite2'):
from pysqlite2 import dbapi2 as sqlite
self.using_sqlite2 = True
elif driver == 'sqlite3':
import sqlite3 as sqlite
self.using_sqlite2 = True
elif driver in ('sqlite', 'sqlite1'):
import sqlite
self.using_sqlite2 = False
else:
raise ValueError('Unknown SQLite driver "%s", expected pysqlite2, sqlite3 or sqlite' % driver)
except ImportError:
pass
else:
break
else:
raise ImportError('Cannot find an SQLite driver, tried %s' % drivers)
if self.using_sqlite2:
sqlite.encode = base64.encodestring
sqlite.decode = base64.decodestring
self.module = sqlite
self.filename = filename # full path to sqlite-db-file
self._memory = filename == ':memory:'
if self._memory and not self.using_sqlite2:
raise ValueError("You must use sqlite2 to use in-memory databases")
# connection options
opts = {}
if self.using_sqlite2:
if autoCommit:
opts["isolation_level"] = None
global sqlite2_Binary
if sqlite2_Binary is None:
sqlite2_Binary = sqlite.Binary
sqlite.Binary = lambda s: sqlite2_Binary(sqlite.encode(s))
if 'factory' in kw:
factory = kw.pop('factory')
if isinstance(factory, str):
factory = globals()[factory]
opts['factory'] = factory(sqlite)
else:
opts['autocommit'] = Boolean(autoCommit)
if 'encoding' in kw:
opts['encoding'] = kw.pop('encoding')
if 'mode' in kw:
opts['mode'] = int(kw.pop('mode'), 0)
if 'timeout' in kw:
if self.using_sqlite2:
opts['timeout'] = float(kw.pop('timeout'))
else:
opts['timeout'] = int(float(kw.pop('timeout')) * 1000)
if 'check_same_thread' in kw:
opts["check_same_thread"] = Boolean(kw.pop('check_same_thread'))
# use only one connection for sqlite - supports multiple)
# cursors per connection
self._connOptions = opts
self.use_table_info = Boolean(kw.pop("use_table_info", True))
DBAPI.__init__(self, **kw)
self._threadPool = {}
self._threadOrigination = {}
if self._memory:
self._memoryConn = sqlite.connect(
self.filename, **self._connOptions)
# Convert text data from SQLite to str, not unicode -
# SQLObject converts it to unicode itself.
self._memoryConn.text_factory = str
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
assert host is None and port is None, (
"SQLite can only be used locally (with a URI like "
"sqlite:/file or sqlite:///file, not sqlite://%s%s)" %
(host, port and ':%r' % port or ''))
assert user is None and password is None, (
"You may not provide usernames or passwords for SQLite "
"databases")
if path == "/:memory:":
path = ":memory:"
return cls(filename=path, **args)
def oldUri(self):
path = self.filename
if path == ":memory:":
path = "/:memory:"
else:
path = "//" + path
return 'sqlite:%s' % path
def uri(self):
path = self.filename
if path == ":memory:":
path = "/:memory:"
else:
if path.startswith('/'):
path = "//" + path
else:
path = "///" + path
path = urllib.quote(path)
return 'sqlite:%s' % path
def getConnection(self):
# SQLite can't share connections between threads, and so can't
# pool connections. Since we are isolating threads here, we
# don't have to worry about locking as much.
if self._memory:
conn = self.makeConnection()
self._connectionNumbers[id(conn)] = self._connectionCount
self._connectionCount += 1
return conn
threadid = thread.get_ident()
if (self._pool is not None
and threadid in self._threadPool):
conn = self._threadPool[threadid]
del self._threadPool[threadid]
if conn in self._pool:
self._pool.remove(conn)
else:
conn = self.makeConnection()
if self._pool is not None:
self._threadOrigination[id(conn)] = threadid
self._connectionNumbers[id(conn)] = self._connectionCount
self._connectionCount += 1
if self.debug:
s = 'ACQUIRE'
if self._pool is not None:
s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
self.printDebug(conn, s, 'Pool')
return conn
def releaseConnection(self, conn, explicit=False):
if self._memory:
return
threadid = self._threadOrigination.get(id(conn))
DBAPI.releaseConnection(self, conn, explicit=explicit)
if (self._pool is not None and threadid
and threadid not in self._threadPool):
self._threadPool[threadid] = conn
else:
if self._pool and conn in self._pool:
self._pool.remove(conn)
conn.close()
def _setAutoCommit(self, conn, auto):
if self.using_sqlite2:
if auto:
conn.isolation_level = None
else:
conn.isolation_level = ""
else:
conn.autocommit = auto
def _setIsolationLevel(self, conn, level):
if not self.using_sqlite2:
return
conn.isolation_level = level
def makeConnection(self):
if self._memory:
return self._memoryConn
conn = self.module.connect(self.filename, **self._connOptions)
conn.text_factory = str # Convert text data to str, not unicode
return conn
def _executeRetry(self, conn, cursor, query):
if self.debug:
self.printDebug(conn, query, 'QueryR')
try:
return cursor.execute(query)
except self.module.OperationalError, e:
raise OperationalError(ErrorMessage(e))
except self.module.IntegrityError, e:
msg = ErrorMessage(e)
if msg.startswith('column') and msg.endswith('not unique'):
raise DuplicateEntryError(msg)
else:
raise IntegrityError(msg)
except self.module.InternalError, e:
raise InternalError(ErrorMessage(e))
except self.module.ProgrammingError, e:
raise ProgrammingError(ErrorMessage(e))
except self.module.DataError, e:
raise DataError(ErrorMessage(e))
except self.module.NotSupportedError, e:
raise NotSupportedError(ErrorMessage(e))
except self.module.DatabaseError, e:
raise DatabaseError(ErrorMessage(e))
except self.module.InterfaceError, e:
raise InterfaceError(ErrorMessage(e))
except self.module.Warning, e:
raise Warning(ErrorMessage(e))
except self.module.Error, e:
raise Error(ErrorMessage(e))
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
if id is not None:
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
self._executeRetry(conn, c, q)
# lastrowid is a DB-API extension from "PEP 0249":
if id is None:
id = int(c.lastrowid)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
def _insertSQL(self, table, names, values):
if not names:
assert not values
# INSERT INTO table () VALUES () isn't allowed in
# SQLite (though it is in other databases)
return ("INSERT INTO %s VALUES (NULL)" % table)
else:
return DBAPI._insertSQL(self, table, names, values)
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if not start:
return "%s LIMIT %i" % (query, end)
if not end:
return "%s LIMIT 0 OFFSET %i" % (query, start)
return "%s LIMIT %i OFFSET %i" % (query, end-start, start)
def createColumn(self, soClass, col):
return col.sqliteCreateSQL()
def createReferenceConstraint(self, soClass, col):
return None
def createIDColumn(self, soClass):
return self._createIDColumn(soClass.sqlmeta)
def _createIDColumn(self, sqlmeta):
if sqlmeta.idType == str:
return '%s TEXT PRIMARY KEY' % sqlmeta.idName
return '%s INTEGER PRIMARY KEY AUTOINCREMENT' % sqlmeta.idName
def joinSQLType(self, join):
return 'INT NOT NULL'
def tableExists(self, tableName):
result = self.queryOne("SELECT tbl_name FROM sqlite_master WHERE type='table' AND tbl_name = '%s'" % tableName)
# turn it into a boolean:
return not not result
def createIndexSQL(self, soClass, index):
return index.sqliteCreateIndexSQL(soClass)
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD COLUMN %s' %
(tableName,
column.sqliteCreateSQL()))
self.query('VACUUM %s' % tableName)
def delColumn(self, sqlmeta, column):
self.recreateTableWithoutColumn(sqlmeta, column)
def recreateTableWithoutColumn(self, sqlmeta, column):
new_name = sqlmeta.table + '_ORIGINAL'
self.query('ALTER TABLE %s RENAME TO %s' % (sqlmeta.table, new_name))
cols = [self._createIDColumn(sqlmeta)] \
+ [self.createColumn(None, col)
for col in sqlmeta.columnList if col.name != column.name]
cols = ",\n".join([" %s" % c for c in cols])
self.query('CREATE TABLE %s (\n%s\n)' % (sqlmeta.table, cols))
all_columns = ', '.join([sqlmeta.idName] + [col.dbName for col in sqlmeta.columnList])
self.query('INSERT INTO %s (%s) SELECT %s FROM %s' % (
sqlmeta.table, all_columns, all_columns, new_name))
self.query('DROP TABLE %s' % new_name)
def columnsFromSchema(self, tableName, soClass):
if self.use_table_info:
return self._columnsFromSchemaTableInfo(tableName, soClass)
else:
return self._columnsFromSchemaParse(tableName, soClass)
def _columnsFromSchemaTableInfo(self, tableName, soClass):
colData = self.queryAll("PRAGMA table_info(%s)" % tableName)
results = []
for index, field, t, nullAllowed, default, key in colData:
if field == soClass.sqlmeta.idName:
continue
colClass, kw = self.guessClass(t)
if default == 'NULL':
nullAllowed = True
default = None
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
kw['notNone'] = not nullAllowed
kw['default'] = default
# @@ skip key...
# @@ skip extra...
results.append(colClass(**kw))
return results
def _columnsFromSchemaParse(self, tableName, soClass):
colData = self.queryOne("SELECT sql FROM sqlite_master WHERE type='table' AND name='%s'"
% tableName)
if not colData:
raise ValueError('The table %s was not found in the database. Load failed.' % tableName)
colData = colData[0].split('(', 1)[1].strip()[:-2]
while True:
start = colData.find('(')
if start == -1: break
end = colData.find(')', start)
if end == -1: break
colData = colData[:start] + colData[end+1:]
results = []
for colDesc in colData.split(','):
parts = colDesc.strip().split(' ', 2)
field = parts[0].strip()
# skip comments
if field.startswith('--'):
continue
# get rid of enclosing quotes
if field[0] == field[-1] == '"':
field = field[1:-1]
if field == getattr(soClass.sqlmeta, 'idName', 'id'):
continue
colClass, kw = self.guessClass(parts[1].strip())
if len(parts) == 2:
index_info = ''
else:
index_info = parts[2].strip().upper()
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
import re
nullble = re.search(r'(\b\S*)\sNULL', index_info)
default = re.search(r"DEFAULT\s((?:\d[\dA-FX.]*)|(?:'[^']*')|(?:#[^#]*#))", index_info)
kw['notNone'] = nullble and nullble.group(1) == 'NOT'
kw['default'] = default and default.group(1)
# @@ skip key...
# @@ skip extra...
results.append(colClass(**kw))
return results
def guessClass(self, t):
t = t.upper()
if t.find('INT') >= 0:
return col.IntCol, {}
elif t.find('TEXT') >= 0 or t.find('CHAR') >= 0 or t.find('CLOB') >= 0:
return col.StringCol, {'length': 2**32-1}
elif t.find('BLOB') >= 0:
return col.BLOBCol, {"length": 2**32-1}
elif t.find('REAL') >= 0 or t.find('FLOAT') >= 0:
return col.FloatCol, {}
elif t.find('DECIMAL') >= 0:
return col.DecimalCol, {'size': None, 'precision': None}
elif t.find('BOOL') >= 0:
return col.BoolCol, {}
else:
return col.Col, {}
def createEmptyDatabase(self):
if self._memory:
return
open(self.filename, 'w').close()
def dropDatabase(self):
if self._memory:
return
os.unlink(self.filename)

View File

@@ -0,0 +1,346 @@
import dbconnection
import joins
import main
import sqlbuilder
__all__ = ['SelectResults']
class SelectResults(object):
IterationClass = dbconnection.Iteration
def __init__(self, sourceClass, clause, clauseTables=None,
**ops):
self.sourceClass = sourceClass
if clause is None or isinstance(clause, str) and clause == 'all':
clause = sqlbuilder.SQLTrueClause
if not isinstance(clause, sqlbuilder.SQLExpression):
clause = sqlbuilder.SQLConstant(clause)
self.clause = clause
self.ops = ops
if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
orderBy = ops['orderBy']
if isinstance(orderBy, (tuple, list)):
orderBy = map(self._mungeOrderBy, orderBy)
else:
orderBy = self._mungeOrderBy(orderBy)
ops['dbOrderBy'] = orderBy
if 'connection' in ops and ops['connection'] is None:
del ops['connection']
if ops.get('limit', None):
assert not ops.get('start', None) and not ops.get('end', None), \
"'limit' cannot be used with 'start' or 'end'"
ops["start"] = 0
ops["end"] = ops.pop("limit")
tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
if clauseTables:
for table in clauseTables:
tablesSet.add(table)
self.clauseTables = clauseTables
# Explicitly post-adding-in sqlmeta.table, sqlbuilder.Select will handle sqlrepr'ing and dupes
self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
def queryForSelect(self):
columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
query = sqlbuilder.Select(columns,
where=self.clause,
join=self.ops.get('join', sqlbuilder.NoDefault),
distinct=self.ops.get('distinct',False),
lazyColumns=self.ops.get('lazyColumns', False),
start=self.ops.get('start', 0),
end=self.ops.get('end', None),
orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
reversed=self.ops.get('reversed', False),
staticTables=self.tables,
forUpdate=self.ops.get('forUpdate', False))
return query
def __repr__(self):
return "<%s at %x>" % (self.__class__.__name__, id(self))
def _getConnection(self):
return self.ops.get('connection') or self.sourceClass._connection
def __str__(self):
conn = self._getConnection()
return conn.queryForSelect(self)
def _mungeOrderBy(self, orderBy):
if isinstance(orderBy, str) and orderBy.startswith('-'):
orderBy = orderBy[1:]
desc = True
else:
desc = False
if isinstance(orderBy, basestring):
if orderBy in self.sourceClass.sqlmeta.columns:
val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
if desc:
return sqlbuilder.DESC(val)
else:
return val
else:
orderBy = sqlbuilder.SQLConstant(orderBy)
if desc:
return sqlbuilder.DESC(orderBy)
else:
return orderBy
else:
return orderBy
def clone(self, **newOps):
ops = self.ops.copy()
ops.update(newOps)
return self.__class__(self.sourceClass, self.clause,
self.clauseTables, **ops)
def orderBy(self, orderBy):
return self.clone(orderBy=orderBy)
def connection(self, conn):
return self.clone(connection=conn)
def limit(self, limit):
return self[:limit]
def lazyColumns(self, value):
return self.clone(lazyColumns=value)
def reversed(self):
return self.clone(reversed=not self.ops.get('reversed', False))
def distinct(self):
return self.clone(distinct=True)
def newClause(self, new_clause):
return self.__class__(self.sourceClass, new_clause,
self.clauseTables, **self.ops)
def filter(self, filter_clause):
if filter_clause is None:
# None doesn't filter anything, it's just a no-op:
return self
clause = self.clause
if isinstance(clause, basestring):
clause = sqlbuilder.SQLConstant('(%s)' % self.clause)
return self.newClause(sqlbuilder.AND(clause, filter_clause))
def __getitem__(self, value):
if isinstance(value, slice):
assert not value.step, "Slices do not support steps"
if not value.start and not value.stop:
# No need to copy, I'm immutable
return self
# Negative indexes aren't handled (and everything we
# don't handle ourselves we just create a list to
# handle)
if (value.start and value.start < 0) \
or (value.stop and value.stop < 0):
if value.start:
if value.stop:
return list(self)[value.start:value.stop]
return list(self)[value.start:]
return list(self)[:value.stop]
if value.start:
assert value.start >= 0
start = self.ops.get('start', 0) + value.start
if value.stop is not None:
assert value.stop >= 0
if value.stop < value.start:
# an empty result:
end = start
else:
end = value.stop + self.ops.get('start', 0)
if self.ops.get('end', None) is not None and \
self.ops['end'] < end:
# truncated by previous slice:
end = self.ops['end']
else:
end = self.ops.get('end', None)
else:
start = self.ops.get('start', 0)
end = value.stop + start
if self.ops.get('end', None) is not None \
and self.ops['end'] < end:
end = self.ops['end']
return self.clone(start=start, end=end)
else:
if value < 0:
return list(iter(self))[value]
else:
start = self.ops.get('start', 0) + value
return list(self.clone(start=start, end=start+1))[0]
def __iter__(self):
# @@: This could be optimized, using a simpler algorithm
# since we don't have to worry about garbage collection,
# etc., like we do with .lazyIter()
return iter(list(self.lazyIter()))
def lazyIter(self):
"""
Returns an iterator that will lazily pull rows out of the
database and return SQLObject instances
"""
conn = self._getConnection()
return conn.iterSelect(self)
def accumulate(self, *expressions):
""" Use accumulate expression(s) to select result
using another SQL select through current
connection.
Return the accumulate result
"""
conn = self._getConnection()
exprs = []
for expr in expressions:
if not isinstance(expr, sqlbuilder.SQLExpression):
expr = sqlbuilder.SQLConstant(expr)
exprs.append(expr)
return conn.accumulateSelect(self, *exprs)
def count(self):
""" Counting elements of current select results """
assert not self.ops.get('start') and not self.ops.get('end'), \
"start/end/limit have no meaning with 'count'"
assert not (self.ops.get('distinct') and (self.ops.get('start')
or self.ops.get('end'))), \
"distinct-counting of sliced objects is not supported"
if self.ops.get('distinct'):
# Column must be specified, so we are using unique ID column.
# COUNT(DISTINCT column) is supported by MySQL and PostgreSQL,
# but not by SQLite. Perhaps more portable would be subquery:
# SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table)
count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
else:
count = self.accumulate('COUNT(*)')
if self.ops.get('start'):
count -= self.ops['start']
if self.ops.get('end'):
count = min(self.ops['end'] - self.ops.get('start', 0), count)
return count
def accumulateMany(self, *attributes):
""" Making the expressions for count/sum/min/max/avg
of a given select result attributes.
`attributes` must be a list/tuple of pairs (func_name, attribute);
`attribute` can be a column name (like 'a_column')
or a dot-q attribute (like Table.q.aColumn)
"""
expressions = []
conn = self._getConnection()
if self.ops.get('distinct'):
distinct = 'DISTINCT '
else:
distinct = ''
for func_name, attribute in attributes:
if not isinstance(attribute, str):
attribute = conn.sqlrepr(attribute)
expression = '%s(%s%s)' % (func_name, distinct, attribute)
expressions.append(expression)
return self.accumulate(*expressions)
def accumulateOne(self, func_name, attribute):
""" Making the sum/min/max/avg of a given select result attribute.
`attribute` can be a column name (like 'a_column')
or a dot-q attribute (like Table.q.aColumn)
"""
return self.accumulateMany((func_name, attribute))
def sum(self, attribute):
return self.accumulateOne("SUM", attribute)
def min(self, attribute):
return self.accumulateOne("MIN", attribute)
def avg(self, attribute):
return self.accumulateOne("AVG", attribute)
def max(self, attribute):
return self.accumulateOne("MAX", attribute)
def getOne(self, default=sqlbuilder.NoDefault):
"""
If a query is expected to only return a single value,
using ``.getOne()`` will return just that value.
If not results are found, ``SQLObjectNotFound`` will be
raised, unless you pass in a default value (like
``.getOne(None)``).
If more than one result is returned,
``SQLObjectIntegrityError`` will be raised.
"""
results = list(self)
if not results:
if default is sqlbuilder.NoDefault:
raise main.SQLObjectNotFound(
"No results matched the query for %s"
% self.sourceClass.__name__)
return default
if len(results) > 1:
raise main.SQLObjectIntegrityError(
"More than one result returned from query: %s"
% results)
return results[0]
def throughTo(self):
class _throughTo_getter(object):
def __init__(self, inst):
self.sresult = inst
def __getattr__(self, attr):
return self.sresult._throughTo(attr)
return _throughTo_getter(self)
throughTo = property(throughTo)
def _throughTo(self, attr):
otherClass = None
orderBy = sqlbuilder.NoDefault
ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
if ref and ref.foreignKey:
otherClass, clause = self._throughToFK(ref)
else:
join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
if join:
join = join[0]
orderBy = join.orderBy
if hasattr(join, 'otherColumn'):
otherClass, clause = self._throughToRelatedJoin(join)
else:
otherClass, clause = self._throughToMultipleJoin(join)
if not otherClass:
raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))
return otherClass.select(clause,
orderBy=orderBy,
connection=self._getConnection())
def _throughToFK(self, col):
otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
colName = col.name
query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
return otherClass, otherClass.q.id==getattr(query.q, colName)
def _throughToMultipleJoin(self, join):
otherClass = join.otherClass
colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
joinColumn = getattr(otherClass.q, colName)
return otherClass, joinColumn==query.q.id
def _throughToRelatedJoin(self, join):
otherClass = join.otherClass
intTable = sqlbuilder.Table(join.intermediateTable)
colName = join.joinColumn
query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
getattr(intTable, colName) == query.q.id)
return otherClass, clause

View File

@@ -0,0 +1,154 @@
import re
__all__ = ["Style", "MixedCaseUnderscoreStyle", "DefaultStyle",
"MixedCaseStyle"]
class Style(object):
"""
The base Style class, and also the simplest implementation. No
translation occurs -- column names and attribute names match,
as do class names and table names (when using auto class or
schema generation).
"""
def __init__(self, pythonAttrToDBColumn=None,
dbColumnToPythonAttr=None,
pythonClassToDBTable=None,
dbTableToPythonClass=None,
idForTable=None,
longID=False):
if pythonAttrToDBColumn:
self.pythonAttrToDBColumn = lambda a, s=self: pythonAttrToDBColumn(s, a)
if dbColumnToPythonAttr:
self.dbColumnToPythonAttr = lambda a, s=self: dbColumnToPythonAttr(s, a)
if pythonClassToDBTable:
self.pythonClassToDBTable = lambda a, s=self: pythonClassToDBTable(s, a)
if dbTableToPythonClass:
self.dbTableToPythonClass = lambda a, s=self: dbTableToPythonClass(s, a)
if idForTable:
self.idForTable = lambda a, s=self: idForTable(s, a)
self.longID = longID
def pythonAttrToDBColumn(self, attr):
return attr
def dbColumnToPythonAttr(self, col):
return col
def pythonClassToDBTable(self, className):
return className
def dbTableToPythonClass(self, table):
return table
def idForTable(self, table):
if self.longID:
return self.tableReference(table)
else:
return 'id'
def pythonClassToAttr(self, className):
return lowerword(className)
def instanceAttrToIDAttr(self, attr):
return attr + "ID"
def instanceIDAttrToAttr(self, attr):
return attr[:-2]
def tableReference(self, table):
return table + "_id"
class MixedCaseUnderscoreStyle(Style):
"""
This is the default style. Python attributes use mixedCase,
while database columns use underscore_separated.
"""
def pythonAttrToDBColumn(self, attr):
return mixedToUnder(attr)
def dbColumnToPythonAttr(self, col):
return underToMixed(col)
def pythonClassToDBTable(self, className):
return className[0].lower() \
+ mixedToUnder(className[1:])
def dbTableToPythonClass(self, table):
return table[0].upper() \
+ underToMixed(table[1:])
def pythonClassToDBTableReference(self, className):
return self.tableReference(self.pythonClassToDBTable(className))
def tableReference(self, table):
return table + "_id"
DefaultStyle = MixedCaseUnderscoreStyle
class MixedCaseStyle(Style):
"""
This style leaves columns as mixed-case, and uses long
ID names (like ProductID instead of simply id).
"""
def pythonAttrToDBColumn(self, attr):
return capword(attr)
def dbColumnToPythonAttr(self, col):
return lowerword(col)
def dbTableToPythonClass(self, table):
return capword(table)
def tableReference(self, table):
return table + "ID"
defaultStyle = DefaultStyle()
def getStyle(soClass, dbConnection=None):
if dbConnection is None:
if hasattr(soClass, '_connection'):
dbConnection = soClass._connection
if hasattr(soClass.sqlmeta, 'style') and soClass.sqlmeta.style:
return soClass.sqlmeta.style
elif dbConnection and dbConnection.style:
return dbConnection.style
else:
return defaultStyle
############################################################
## Text utilities
############################################################
_mixedToUnderRE = re.compile(r'[A-Z]+')
def mixedToUnder(s):
if s.endswith('ID'):
return mixedToUnder(s[:-2] + "_id")
trans = _mixedToUnderRE.sub(mixedToUnderSub, s)
if trans.startswith('_'):
trans = trans[1:]
return trans
def mixedToUnderSub(match):
m = match.group(0).lower()
if len(m) > 1:
return '_%s_%s' % (m[:-1], m[-1])
else:
return '_%s' % m
def capword(s):
return s[0].upper() + s[1:]
def lowerword(s):
return s[0].lower() + s[1:]
_underToMixedRE = re.compile('_.')
def underToMixed(name):
if name.endswith('_id'):
return underToMixed(name[:-3] + "ID")
return _underToMixedRE.sub(lambda m: m.group(0)[1].upper(),
name)

View File

@@ -0,0 +1,7 @@
from sqlobject.dbconnection import registerConnection
def builder():
import sybaseconnection
return sybaseconnection.SybaseConnection
registerConnection(['sybase'], builder)

View File

@@ -0,0 +1,168 @@
from sqlobject.dbconnection import DBAPI
from sqlobject import col
class SybaseConnection(DBAPI):
supportTransactions = False
dbName = 'sybase'
schemes = [dbName]
NumericType = None
def __init__(self, db, user, password='', host='localhost', port=None,
locking=1, **kw):
db = db.strip('/')
import Sybase
Sybase._ctx.debug = 0
if SybaseConnection.NumericType is None:
from Sybase import NumericType
SybaseConnection.NumericType = NumericType
from sqlobject.converters import registerConverter, IntConverter
registerConverter(NumericType, IntConverter)
self.module = Sybase
self.locking = int(locking)
self.host = host
self.port = port
self.db = db
self.user = user
self.password = password
autoCommit = kw.get('autoCommit')
if autoCommit:
autoCommmit = int(autoCommit)
else:
autoCommit = None
kw['autoCommit'] = autoCommit
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
return cls(user=user, password=password,
host=host or 'localhost', port=port, db=path, **args)
def insert_id(self, conn):
"""
Sybase adapter/cursor does not support the
insert_id method.
"""
c = conn.cursor()
c.execute('SELECT @@IDENTITY')
return c.fetchone()[0]
def makeConnection(self):
return self.module.connect(self.host, self.user, self.password,
database=self.db, auto_commit=self.autoCommit,
locking=self.locking)
HAS_IDENTITY = """
SELECT col.name, col.status, obj.name
FROM syscolumns col
JOIN sysobjects obj
ON obj.id = col.id
WHERE obj.name = '%s'
AND (col.status & 0x80) = 0x80
"""
def _hasIdentity(self, conn, table):
query = self.HAS_IDENTITY % table
c = conn.cursor()
c.execute(query)
r = c.fetchone()
return r is not None
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
if id is not None:
names = [idName] + names
values = [id] + values
has_identity = self._hasIdentity(conn, table)
identity_insert_on = False
if has_identity and (id is not None):
identity_insert_on = True
c.execute('SET IDENTITY_INSERT %s ON' % table)
q = self._insertSQL(table, names, values)
if self.debug:
print 'QueryIns: %s' % q
c.execute(q)
if has_identity and identity_insert_on:
c.execute('SET IDENTITY_INSERT %s OFF' % table)
if id is None:
id = self.insert_id(conn)
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
# XXX Sybase doesn't support OFFSET
if end:
return "SET ROWCOUNT %i %s SET ROWCOUNT 0" % (end, query)
return query
def createReferenceConstraint(self, soClass, col):
return None
def createColumn(self, soClass, col):
return col.sybaseCreateSQL()
def createIDColumn(self, soClass):
key_type = {int: "NUMERIC(18,0)", str: "TEXT"}[soClass.sqlmeta.idType]
return '%s %s IDENTITY UNIQUE' % (soClass.sqlmeta.idName, key_type)
def createIndexSQL(self, soClass, index):
return index.sybaseCreateIndexSQL(soClass)
def joinSQLType(self, join):
return 'NUMERIC(18,0) NOT NULL'
SHOW_TABLES="SELECT name FROM sysobjects WHERE type='U'"
def tableExists(self, tableName):
for (table,) in self.queryAll(self.SHOW_TABLES):
if table.lower() == tableName.lower():
return True
return False
def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD COLUMN %s' %
(tableName,
column.sybaseCreateSQL()))
def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
SHOW_COLUMNS=('SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT FROM INFORMATION_SCHEMA.COLUMNS '
'WHERE TABLE_NAME = \'%s\'')
def columnsFromSchema(self, tableName, soClass):
colData = self.queryAll(self.SHOW_COLUMNS
% tableName)
results = []
for field, t, nullAllowed, default in colData:
if field == soClass.sqlmeta.idName:
continue
colClass, kw = self.guessClass(t)
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
kw['notNone'] = not nullAllowed
kw['default'] = default
# @@ skip key...
# @@ skip extra...
kw['forceDBName'] = True
results.append(colClass(**kw))
return results
def _setAutoCommit(self, conn, auto):
conn.auto_commit = auto
def guessClass(self, t):
if t.startswith('int'):
return col.IntCol, {}
elif t.startswith('varchar'):
return col.StringCol, {'length': int(t[8:-1])}
elif t.startswith('char'):
return col.StringCol, {'length': int(t[5:-1]),
'varchar': False}
elif t.startswith('datetime'):
return col.DateTimeCol, {}
else:
return col.Col, {}

View File

@@ -0,0 +1 @@
#

View File

@@ -0,0 +1,196 @@
"""
Exports a SQLObject class (possibly annotated) to a CSV file.
"""
import os
import csv
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
import sqlobject
__all__ = ['export_csv', 'export_csv_zip']
def export_csv(soClass, select=None, writer=None, connection=None,
orderBy=None):
"""
Export the SQLObject class ``soClass`` to a CSV file.
``soClass`` can also be a SelectResult object, as returned by
``.select()``. If it is a class, all objects will be retrieved,
ordered by ``orderBy`` if given, or the ``.csvOrderBy`` attribute
if present (but csvOrderBy will only be applied when no select
result is given).
You can also pass in select results (or simply a list of
instances) in ``select`` -- if you have a list of objects (not a
SelectResult instance, as produced by ``.select()``) then you must
pass it in with ``select`` and pass the class in as the first
argument.
``writer`` is a ``csv.writer()`` object, or a file-like object.
If not given, the string of the file will be returned.
Uses ``connection`` as the data source, if given, otherwise the
default connection.
Columns can be annotated with ``.csvTitle`` attributes, which will
form the attributes of the columns, or 'title' (secondarily), or
if nothing then the column attribute name.
If a column has a ``.noCSV`` attribute which is true, then the
column will be suppressed.
Additionally a class can have an ``.extraCSVColumns`` attribute,
which should be a list of strings/tuples. If a tuple, it should
be like ``(attribute, title)``, otherwise it is the attribute,
which will also be the title. These will be appended to the end
of the CSV file; the attribute will be retrieved from instances.
Also a ``.csvColumnOrder`` attribute can be on the class, which is
the string names of attributes in the order they should be
presented.
"""
return_fileobj = None
if not writer:
return_fileobj = StringIO()
writer = csv.writer(return_fileobj)
elif not hasattr(writer, 'writerow'):
writer = csv.writer(writer)
if isinstance(soClass, sqlobject.SQLObject.SelectResultsClass):
assert select is None, (
"You cannot pass in a select argument (%r) and a SelectResult argument (%r) for soClass"
% (select, soClass))
select = soClass
soClass = select.sourceClass
elif select is None:
select = soClass.select()
if getattr(soClass, 'csvOrderBy', None):
select = select.orderBy(soClass.csvOrderBy)
if orderBy:
select = select.orderBy(orderBy)
if connection:
select = select.connection(connection)
_actually_export_csv(soClass, select, writer)
if return_fileobj:
# They didn't pass any writer or file object in, so we return
# the string result:
return return_fileobj.getvalue()
def _actually_export_csv(soClass, select, writer):
attributes, titles = _find_columns(soClass)
writer.writerow(titles)
for soInstance in select:
row = [getattr(soInstance, attr)
for attr in attributes]
writer.writerow(row)
def _find_columns(soClass):
order = []
attrs = {}
for col in soClass.sqlmeta.columnList:
if getattr(col, 'noCSV', False):
continue
order.append(col.name)
title = col.name
if hasattr(col, 'csvTitle'):
title = col.csvTitle
elif getattr(col, 'title', None) is not None:
title = col.title
attrs[col.name] = title
for attrDesc in getattr(soClass, 'extraCSVColumns', []):
if isinstance(attrDesc, (list, tuple)):
attr, title = attrDesc
else:
attr = title = attrDesc
order.append(attr)
attrs[attr] = title
if hasattr(soClass, 'csvColumnOrder'):
oldOrder = order
order = soClass.csvColumnOrder
for attr in order:
if attr not in oldOrder:
raise KeyError(
"Attribute %r in csvColumnOrder (on class %r) does not exist as a column or in .extraCSVColumns (I have: %r)"
% (attr, soClass, oldOrder))
oldOrder.remove(attr)
order.extend(oldOrder)
titles = [attrs[attr] for attr in order]
return order, titles
def export_csv_zip(soClasses, file=None, zip=None, filename_prefix='',
connection=None):
"""
Export several SQLObject classes into a .zip file. Each
item in the ``soClasses`` list may be a SQLObject class,
select result, or ``(soClass, select)`` tuple.
Each file in the zip will be named after the class name (with
``.csv`` appended), or using the filename in the ``.csvFilename``
attribute.
If ``file`` is given, the zip will be written to that. ``file``
may be a string (a filename) or a file-like object. If not given,
a string will be returnd.
If ``zip`` is given, then the files will be written to that zip
file.
All filenames will be prefixed with ``filename_prefix`` (which may
be a directory name, for instance).
"""
import zipfile
close_file_when_finished = False
close_zip_when_finished = True
return_when_finished = False
if file:
if isinstance(file, basestring):
close_when_finished = True
file = open(file, 'wb')
elif zip:
close_zip_when_finished = False
else:
return_when_finished = True
file = StringIO()
if not zip:
zip = zipfile.ZipFile(file, mode='w')
try:
_actually_export_classes(soClasses, zip, filename_prefix,
connection)
finally:
if close_zip_when_finished:
zip.close()
if close_file_when_finished:
file.close()
if return_when_finished:
return file.getvalue()
def _actually_export_classes(soClasses, zip, filename_prefix,
connection):
for classDesc in soClasses:
if isinstance(classDesc, (tuple, list)):
soClass, select = classDesc
elif isinstance(classDesc, sqlobject.SQLObject.SelectResultsClass):
select = classDesc
soClass = select.sourceClass
else:
soClass = classDesc
select = None
filename = getattr(soClass, 'csvFilename', soClass.__name__)
if not os.path.splitext(filename)[1]:
filename += '.csv'
filename = filename_prefix + filename
zip.writestr(filename,
export_csv(soClass, select, connection=connection))

View File

@@ -0,0 +1,349 @@
"""
Import from a CSV file or directory of files.
CSV files should have a header line that lists columns. Headers can
also be appended with ``:type`` to indicate the type of the field.
``escaped`` is the default, though it can be overridden by the importer.
Supported types:
``:python``:
A python expression, run through ``eval()``. This can be a
security risk, pass in ``allow_python=False`` if you don't want to
allow it.
``:int``:
Integer
``:float``:
Float
``:str``:
String
``:escaped``:
A string with backslash escapes (note that you don't put quotation
marks around the value)
``:base64``:
A base64-encoded string
``:date``:
ISO date, like YYYY-MM-DD; this can also be ``NOW+days`` or
``NOW-days``
``:datetime``:
ISO date/time like YYYY-MM-DDTHH:MM:SS (either T or a space can be
used to separate the time, and seconds are optional). This can
also be ``NOW+seconds`` or ``NOW-seconds``
``:bool``:
Converts true/false/yes/no/on/off/1/0 to boolean value
``:ref``:
This will be resolved to the ID of the object named in this column
(None if the column is empty). @@: Since there's no ordering,
there's no way to promise the object already exists.
You can also get back references to the objects if you have a special
``[name]`` column.
Any column named ``[comment]`` or with no name will be ignored.
In any column you can put ``[default]`` to exclude the value and use
whatever default the class wants. ``[null]`` will use NULL.
Lines that begin with ``[comment]`` are ignored.
"""
from datetime import datetime, date, timedelta
import os
import csv
import types
__all__ = ['load_csv_from_directory',
'load_csv',
'create_data']
DEFAULT_TYPE = 'escaped'
def create_data(data, class_getter, keyorder=None):
"""
Create the ``data``, which is the return value from
``load_csv()``. Classes will be resolved with the callable
``class_getter``; or if ``class_getter`` is a module then the
class names will be attributes of that.
Returns a dictionary of ``{object_name: object(s)}``, using the
names from the ``[name]`` columns (if there are any). If a name
is used multiple times, you get a list of objects, not a single
object.
If ``keyorder`` is given, then the keys will be retrieved in that
order. It can be a list/tuple of names, or a sorting function.
If not given and ``class_getter`` is a module and has a
``soClasses`` function, then that will be used for the order.
"""
objects = {}
classnames = data.keys()
if (not keyorder and isinstance(class_getter, types.ModuleType)
and hasattr(class_getter, 'soClasses')):
keyorder = [c.__name__ for c in class_getter.soClasses]
if not keyorder:
classnames.sort()
elif isinstance(keyorder, (list, tuple)):
all = classnames
classnames = [name for name in keyorder if name in classnames]
for name in all:
if name not in classnames:
classnames.append(name)
else:
classnames.sort(keyorder)
for classname in classnames:
items = data[classname]
if not items:
continue
if isinstance(class_getter, types.ModuleType):
soClass = getattr(class_getter, classname)
else:
soClass = class_getter(classname)
for item in items:
for key, value in item.items():
if isinstance(value, Reference):
resolved = objects.get(value.name)
if not resolved:
raise ValueError(
"Object reference to %r does not have target"
% value.name)
elif (isinstance(resolved, list)
and len(resolved) > 1):
raise ValueError(
"Object reference to %r is ambiguous (got %r)"
% (value.name, resolved))
item[key] = resolved.id
if '[name]' in item:
name = item.pop('[name]').strip()
else:
name = None
inst = soClass(**item)
if name:
if name in objects:
if isinstance(objects[name], list):
objects[name].append(inst)
else:
objects[name] = [objects[name], inst]
else:
objects[name] = inst
return objects
def load_csv_from_directory(directory,
allow_python=True, default_type=DEFAULT_TYPE,
allow_multiple_classes=True):
"""
Load the data from all the files in a directory. Filenames
indicate the class, with ``general.csv`` for data not associated
with a class. Return data just like ``load_csv`` does.
This might cause problems on case-insensitive filesystems.
"""
results = {}
for filename in os.listdir(directory):
base, ext = os.path.splitext(filename)
if ext.lower() != '.csv':
continue
f = open(os.path.join(directory, filename), 'rb')
csvreader = csv.reader(f)
data = load_csv(csvreader, allow_python=allow_python,
default_type=default_type,
default_class=base,
allow_multiple_classes=allow_multiple_classes)
f.close()
for classname, items in data.items():
results.setdefault(classname, []).extend(items)
return results
def load_csv(csvreader, allow_python=True, default_type=DEFAULT_TYPE,
default_class=None, allow_multiple_classes=True):
"""
Loads the CSV file, returning a list of dictionaries with types
coerced.
"""
current_class = default_class
current_headers = None
results = {}
for row in csvreader:
if not [cell for cell in row if cell.strip()]:
# empty row
continue
if row and row[0].strip() == 'CLASS:':
if not allow_multiple_classes:
raise ValueError(
"CLASS: line in CSV file, but multiple classes are not allowed in this file (line: %r)"
% row)
if not row[1:]:
raise ValueError(
"CLASS: in line in CSV file, with no class name in next column (line: %r)"
% row)
current_class = row[1]
current_headers = None
continue
if not current_class:
raise ValueError(
"No CLASS: line given, and there is no default class for this file (line: %r"
% row)
if current_headers is None:
current_headers = _parse_headers(row, default_type)
continue
if row[0] == '[comment]':
continue
# Pad row with empty strings:
row += ['']*(len(current_headers) - len(row))
row_converted = {}
for value, (name, coercer, args) in zip(row, current_headers):
if name is None:
# Comment
continue
if value == '[default]':
continue
if value == '[null]':
row_converted[name] = None
continue
args = (value,) + args
row_converted[name] = coercer(*args)
results.setdefault(current_class, []).append(row_converted)
return results
def _parse_headers(header_row, default_type):
headers = []
for name in header_row:
original_name = name
if ':' in name:
name, type = name.split(':', 1)
else:
type = default_type
if type == 'python' and not allow_python:
raise ValueError(
":python header given when python headers are not allowed (with header %r"
% original_name)
name = name.strip()
if name == '[comment]' or not name:
headers.append((None, None, None))
continue
type = type.strip().lower()
if '(' in type:
type, arg = type.split('(', 1)
if not arg.endswith(')'):
raise ValueError(
"Arguments (in ()'s) do not end with ): %r"
% original_name)
args = (arg[:-1],)
else:
args = ()
if name == '[name]':
type = 'str'
coercer, args = get_coercer(type)
headers.append((name, coercer, args))
return headers
_coercers = {}
def get_coercer(type):
if type not in _coercers:
raise ValueError(
"Coercion type %r not known (I know: %s)"
% (type, ', '.join(_coercers.keys())))
return _coercers[type]
def register_coercer(type, coercer, *args):
_coercers[type] = (coercer, args)
def identity(v):
return v
register_coercer('str', identity)
register_coercer('string', identity)
def decode_string(v, encoding):
return v.decode(encoding)
register_coercer('escaped', decode_string, 'string_escape')
register_coercer('strescaped', decode_string, 'string_escape')
register_coercer('base64', decode_string, 'base64')
register_coercer('int', int)
register_coercer('float', float)
def parse_python(v):
return eval(v, {}, {})
register_coercer('python', parse_python)
def parse_date(v):
v = v.strip()
if not v:
return None
if v.startswith('NOW-') or v.startswith('NOW+'):
days = int(v[3:])
now = date.today()
return now+timedelta(days)
else:
parsed = time.strptime(v, '%Y-%m-%d')
return date.fromtimestamp(time.mktime(parsed))
register_coercer('date', parse_date)
def parse_datetime(v):
v = v.strip()
if not v:
return None
if v.startswith('NOW-') or v.startswith('NOW+'):
seconds = int(v[3:])
now = datetime.now()
return now+timedelta(0, seconds)
else:
fmts = ['%Y-%m-%dT%H:%M:%S',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%dT%H:%M',
'%Y-%m-%d %H:%M']
for fmt in fmts[:-1]:
try:
parsed = time.strptime(v, fmt)
break
except ValueError:
pass
else:
parsed = time.strptime(v, fmts[-1])
return datetime.fromtimestamp(time.mktime(parsed))
register_coercer('datetime', parse_datetime)
class Reference(object):
def __init__(self, name):
self.name = name
def parse_ref(v):
if not v.strip():
return None
else:
return Reference(v)
register_coercer('ref', parse_ref)
def parse_bool(v):
v = v.strip().lower()
if v in ('y', 'yes', 't', 'true', 'on', '1'):
return True
elif v in ('n', 'no', 'f', 'false', 'off', '0'):
return False
raise ValueError(
"Value is not boolean-like: %r" % value)
register_coercer('bool', parse_bool)
register_coercer('boolean', parse_bool)

View File

@@ -0,0 +1,42 @@
import sys
import imp
def load_module(module_name):
mod = __import__(module_name)
components = module_name.split('.')
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
def load_module_from_name(filename, module_name):
if module_name in sys.modules:
return sys.modules[module_name]
init_filename = os.path.join(os.path.dirname(filename), '__init__.py')
if not os.path.exists(init_filename):
try:
f = open(init_filename, 'w')
except (OSError, IOError), e:
raise IOError(
'Cannot write __init__.py file into directory %s (%s)\n'
% (os.path.dirname(filename), e))
f.write('#\n')
f.close()
fp = None
if module_name in sys.modules:
return sys.modules[module_name]
if '.' in module_name:
parent_name = '.'.join(module_name.split('.')[:-1])
base_name = module_name.split('.')[-1]
parent = load_module_from_name(os.path.dirname(filename),
parent_name)
else:
base_name = module_name
fp = None
try:
fp, pathname, stuff = imp.find_module(
base_name, [os.path.dirname(filename)])
module = imp.load_module(module_name, fp, pathname, stuff)
finally:
if fp is not None:
fp.close()
return module

View File

@@ -0,0 +1,6 @@
try:
from threading import local
except ImportError:
# No threads, so "thread local" means process-global
class local(object):
pass

View File

@@ -0,0 +1,115 @@
from sqlobject import *
from datetime import datetime
class Version(SQLObject):
def restore(self):
values = self.sqlmeta.asDict()
del values['id']
del values['masterID']
del values['dateArchived']
for col in self.extraCols:
del values[col]
self.masterClass.get(self.masterID).set(**values)
def nextVersion(self):
version = self.select(AND(self.q.masterID == self.masterID, self.q.id > self.id), orderBy=self.q.id)
if version.count():
return version[0]
else:
return self.master
def getChangedFields(self):
next = self.nextVersion()
columns = self.masterClass.sqlmeta.columns
fields = []
for column in columns:
if column not in ["dateArchived", "id", "masterID"]:
if getattr(self, column) != getattr(next, column):
fields.append(column.title())
return fields
@classmethod
def select(cls, clause=None, *args, **kw):
if not getattr(cls, '_connection', None):
cls._connection = cls.masterClass._connection
return super(Version, cls).select(clause, *args, **kw)
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
else:
return getattr(self.master, attr)
def getColumns(columns, cls):
for column, defi in cls.sqlmeta.columnDefinitions.items():
if column.endswith("ID") and isinstance(defi, ForeignKey):
column = column[:-2]
#remove incompatible constraints
kwds = dict(defi._kw)
for kw in ["alternateID", "unique"]:
if kw in kwds: del kwds[kw]
columns[column] = defi.__class__(**kwds)
#ascend heirarchy
if cls.sqlmeta.parentClass:
getColumns(columns, cls.sqlmeta.parentClass)
class Versioning(object):
def __init__(self, extraCols = None):
if extraCols:
self.extraCols = extraCols
else:
self.extraCols = {}
pass
def __addtoclass__(self, soClass, name):
self.name = name
self.soClass = soClass
attrs = {'dateArchived': DateTimeCol(default=datetime.now),
'master': ForeignKey(self.soClass.__name__),
'masterClass' : self.soClass,
'extraCols' : self.extraCols
}
getColumns (attrs, self.soClass)
attrs.update(self.extraCols)
self.versionClass = type(self.soClass.__name__+'Versions',
(Version,),
attrs)
if '_connection' in self.soClass.__dict__:
self.versionClass._connection = self.soClass.__dict__['_connection']
events.listen(self.createTable,
soClass, events.CreateTableSignal)
events.listen(self.rowUpdate, soClass,
events.RowUpdateSignal)
def createVersionTable(self, cls, conn):
self.versionClass.createTable(ifNotExists=True, connection=conn)
def createTable(self, soClass, connection, extra_sql, post_funcs):
assert soClass is self.soClass
post_funcs.append(self.createVersionTable)
def rowUpdate(self, instance, kwargs):
if instance.childName and instance.childName != self.soClass.__name__:
return #if you want your child class versioned, version it.
values = instance.sqlmeta.asDict()
del values['id']
values['masterID'] = instance.id
self.versionClass(connection=instance._connection, **values)
def __get__(self, obj, type=None):
if obj is None:
return self
return self.versionClass.select(
self.versionClass.q.masterID==obj.id, connection=obj._connection)

View File

@@ -0,0 +1,134 @@
from sqlbuilder import *
from main import SQLObject, sqlmeta
import types, threading
####
class ViewSQLObjectField(SQLObjectField):
def __init__(self, alias, *arg):
SQLObjectField.__init__(self, *arg)
self.alias = alias
def __sqlrepr__(self, db):
return self.alias + "." + self.fieldName
def tablesUsedImmediate(self):
return [self.tableName]
class ViewSQLObjectTable(SQLObjectTable):
FieldClass = ViewSQLObjectField
def __getattr__(self, attr):
if attr == 'sqlmeta':
raise AttributeError
return SQLObjectTable.__getattr__(self, attr)
def _getattrFromID(self, attr):
return self.FieldClass(self.soClass.sqlmeta.alias, self.tableName, 'id', attr, self.soClass, None)
def _getattrFromColumn(self, column, attr):
return self.FieldClass(self.soClass.sqlmeta.alias, self.tableName, column.name, attr, self.soClass, column)
class ViewSQLObject(SQLObject):
"""
A SQLObject class that derives all it's values from other SQLObject classes.
Columns on subclasses should use SQLBuilder constructs for dbName,
and sqlmeta should specify:
* idName as a SQLBuilder construction
* clause as SQLBuilder clause for specifying join conditions or other restrictions
* table as an optional alternate name for the class alias
See test_views.py for simple examples.
"""
def __classinit__(cls, new_attrs):
SQLObject.__classinit__(cls, new_attrs)
# like is_base
if cls.__name__ != 'ViewSQLObject':
dbName = hasattr(cls,'_connection') and (cls._connection and cls._connection.dbName) or None
if getattr(cls.sqlmeta, 'table', None):
cls.sqlmeta.alias = cls.sqlmeta.table
else:
cls.sqlmeta.alias = cls.sqlmeta.style.pythonClassToDBTable(cls.__name__)
alias = cls.sqlmeta.alias
columns = [ColumnAS(cls.sqlmeta.idName, 'id')]
# {sqlrepr-key: [restriction, *aggregate-column]}
aggregates = {'':[None]}
inverseColumns = dict([(y,x) for x,y in cls.sqlmeta.columns.iteritems()])
for col in cls.sqlmeta.columnList:
n = inverseColumns[col]
ascol = ColumnAS(col.dbName, n)
if isAggregate(col.dbName):
restriction = getattr(col, 'aggregateClause',None)
if restriction:
restrictkey = sqlrepr(restriction, dbName)
aggregates[restrictkey] = aggregates.get(restrictkey, [restriction]) + [ascol]
else:
aggregates[''].append(ascol)
else:
columns.append(ascol)
metajoin = getattr(cls.sqlmeta, 'join', NoDefault)
clause = getattr(cls.sqlmeta, 'clause', NoDefault)
select = Select(columns,
distinct=True,
# @@ LDO check if this really mattered for performance
# @@ Postgres (and MySQL?) extension!
#distinctOn=cls.sqlmeta.idName,
join=metajoin,
clause=clause)
aggregates = aggregates.values()
#print cls.__name__, sqlrepr(aggregates, dbName)
if aggregates != [[None]]:
join = []
last_alias = "%s_base" % alias
last_id = "id"
last = Alias(select, last_alias)
columns = [ColumnAS(SQLConstant("%s.%s"%(last_alias,x.expr2)), x.expr2) for x in columns]
for i, agg in enumerate(aggregates):
restriction = agg[0]
if restriction is None:
restriction = clause
else:
restriction = AND(clause, restriction)
agg = agg[1:]
agg_alias = "%s_%s" % (alias, i)
agg_id = '%s_id'%agg_alias
if not last.q.alias.endswith('base'):
last = None
new_alias = Alias(Select([ColumnAS(cls.sqlmeta.idName, agg_id)]+agg,
groupBy=cls.sqlmeta.idName,
join=metajoin,
clause=restriction),
agg_alias)
agg_join = LEFTJOINOn(last,
new_alias,
"%s.%s = %s.%s" % (last_alias, last_id, agg_alias, agg_id))
join.append(agg_join)
for col in agg:
columns.append(ColumnAS(SQLConstant("%s.%s"%(agg_alias, col.expr2)),col.expr2))
last = new_alias
last_alias = agg_alias
last_id = agg_id
select = Select(columns,
join=join)
cls.sqlmeta.table = Alias(select, alias)
cls.q = ViewSQLObjectTable(cls)
for n,col in cls.sqlmeta.columns.iteritems():
col.dbName = n
def isAggregate(expr):
if isinstance(expr, SQLCall):
return True
if isinstance(expr, SQLOp):
return isAggregate(expr.expr1) or isAggregate(expr.expr2)
return False
######

View File

@@ -0,0 +1,97 @@
from paste.deploy.converters import asbool
from paste.wsgilib import catch_errors
from paste.util import import_string
import sqlobject
import threading
def make_middleware(app, global_conf, database=None, use_transaction=False,
hub=None):
"""
WSGI middleware that sets the connection for the request (using
the database URI or connection object) and the given hub (or
``sqlobject.sqlhub`` if not given).
If ``use_transaction`` is true, then the request will be run in a
transaction.
Applications can use the keys (which are all no-argument functions):
``sqlobject.get_connection()``:
Returns the connection object
``sqlobject.abort()``:
Aborts the transaction. Does not raise an error, but at the *end*
of the request there will be a rollback.
``sqlobject.begin()``:
Starts a transaction. First commits (or rolls back if aborted) if
this is run in a transaction.
``sqlobject.in_transaction()``:
Returns true or false, depending if we are currently in a
transaction.
"""
use_transaction = asbool(use_transaction)
if database is None:
database = global_conf.get('database')
if not database:
raise ValueError(
"You must provide a 'database' configuration value")
if isinstance(hub, basestring):
hub = import_string.eval_import(hub)
if not hub:
hub = sqlobject.sqlhub
if isinstance(database, basestring):
database = sqlobject.connectionForURI(database)
return SQLObjectMiddleware(app, database, use_transaction, hub)
class SQLObjectMiddleware(object):
def __init__(self, app, conn, use_transaction, hub):
self.app = app
self.conn = conn
self.use_transaction = use_transaction
self.hub = hub
def __call__(self, environ, start_response):
conn = [self.conn]
if self.use_transaction:
conn[0] = conn[0].transaction()
any_errors = []
use_transaction = [self.use_transaction]
self.hub.threadConnection = conn[0]
def abort():
assert use_transaction[0], (
"You cannot abort, because a transaction is not being used")
any_errors.append(None)
def begin():
if use_transaction[0]:
if any_errors:
conn[0].rollback()
else:
conn[0].commit()
any_errors[:] = []
use_transaction[0] = True
conn[0] = self.conn.transaction()
self.hub.threadConnection = conn[0]
def error(exc_info=None):
any_errors.append(None)
ok()
def ok():
if use_transaction[0]:
if any_errors:
conn[0].rollback()
else:
conn[0].commit(close=True)
self.hub.threadConnection = None
def in_transaction():
return use_transaction[0]
def get_connection():
return conn[0]
environ['sqlobject.get_connection'] = get_connection
environ['sqlobject.abort'] = abort
environ['sqlobject.begin'] = begin
environ['sqlobject.in_transaction'] = in_transaction
return catch_errors(self.app, environ, start_response,
error_callback=error, ok_callback=ok)