mirror of
https://github.com/djohnlewis/stackdump
synced 2025-01-22 22:51:36 +00:00
536 lines
18 KiB
Python
536 lines
18 KiB
Python
|
from itertools import count
|
||
|
import classregistry
|
||
|
import events
|
||
|
import styles
|
||
|
import sqlbuilder
|
||
|
|
||
|
__all__ = ['MultipleJoin', 'SQLMultipleJoin', 'RelatedJoin', 'SQLRelatedJoin',
|
||
|
'SingleJoin', 'ManyToMany', 'OneToMany']
|
||
|
|
||
|
creationOrder = count()
|
||
|
NoDefault = sqlbuilder.NoDefault
|
||
|
|
||
|
def getID(obj):
|
||
|
try:
|
||
|
return obj.id
|
||
|
except AttributeError:
|
||
|
return int(obj)
|
||
|
|
||
|
class Join(object):
|
||
|
|
||
|
def __init__(self, otherClass=None, **kw):
|
||
|
kw['otherClass'] = otherClass
|
||
|
self.kw = kw
|
||
|
self._joinMethodName = self.kw.pop('joinMethodName', None)
|
||
|
self.creationOrder = creationOrder.next()
|
||
|
|
||
|
def _set_joinMethodName(self, value):
|
||
|
assert self._joinMethodName == value or self._joinMethodName is None, "You have already given an explicit joinMethodName (%s), and you are now setting it to %s" % (self._joinMethodName, value)
|
||
|
self._joinMethodName = value
|
||
|
|
||
|
def _get_joinMethodName(self):
|
||
|
return self._joinMethodName
|
||
|
|
||
|
joinMethodName = property(_get_joinMethodName, _set_joinMethodName)
|
||
|
name = joinMethodName
|
||
|
|
||
|
def withClass(self, soClass):
|
||
|
if 'joinMethodName' in self.kw:
|
||
|
self._joinMethodName = self.kw['joinMethodName']
|
||
|
del self.kw['joinMethodName']
|
||
|
return self.baseClass(creationOrder=self.creationOrder,
|
||
|
soClass=soClass,
|
||
|
joinDef=self,
|
||
|
joinMethodName=self._joinMethodName,
|
||
|
**self.kw)
|
||
|
|
||
|
# A join is separate from a foreign key, i.e., it is
|
||
|
# many-to-many, or one-to-many where the *other* class
|
||
|
# has the foreign key.
|
||
|
class SOJoin(object):
|
||
|
|
||
|
def __init__(self,
|
||
|
creationOrder,
|
||
|
soClass=None,
|
||
|
otherClass=None,
|
||
|
joinColumn=None,
|
||
|
joinMethodName=None,
|
||
|
orderBy=NoDefault,
|
||
|
joinDef=None):
|
||
|
self.creationOrder = creationOrder
|
||
|
self.soClass = soClass
|
||
|
self.joinDef = joinDef
|
||
|
self.otherClassName = otherClass
|
||
|
classregistry.registry(soClass.sqlmeta.registry).addClassCallback(
|
||
|
otherClass, self._setOtherClass)
|
||
|
self.joinColumn = joinColumn
|
||
|
self.joinMethodName = joinMethodName
|
||
|
self._orderBy = orderBy
|
||
|
if not self.joinColumn:
|
||
|
# Here we set up the basic join, which is
|
||
|
# one-to-many, where the other class points to
|
||
|
# us.
|
||
|
self.joinColumn = styles.getStyle(
|
||
|
self.soClass).tableReference(self.soClass.sqlmeta.table)
|
||
|
|
||
|
def orderBy(self):
|
||
|
if self._orderBy is NoDefault:
|
||
|
self._orderBy = self.otherClass.sqlmeta.defaultOrder
|
||
|
return self._orderBy
|
||
|
orderBy = property(orderBy)
|
||
|
|
||
|
def _setOtherClass(self, cls):
|
||
|
self.otherClass = cls
|
||
|
|
||
|
def hasIntermediateTable(self):
|
||
|
return False
|
||
|
|
||
|
def _applyOrderBy(self, results, defaultSortClass):
|
||
|
if self.orderBy is not None:
|
||
|
results.sort(sorter(self.orderBy))
|
||
|
return results
|
||
|
|
||
|
def sorter(orderBy):
|
||
|
if isinstance(orderBy, (tuple, list)):
|
||
|
if len(orderBy) == 1:
|
||
|
orderBy = orderBy[0]
|
||
|
else:
|
||
|
fhead = sorter(orderBy[0])
|
||
|
frest = sorter(orderBy[1:])
|
||
|
return lambda a, b, fhead=fhead, frest=frest: fhead(a, b) or frest(a, b)
|
||
|
if isinstance(orderBy, sqlbuilder.DESC) \
|
||
|
and isinstance(orderBy.expr, sqlbuilder.SQLObjectField):
|
||
|
orderBy = '-' + orderBy.expr.original
|
||
|
elif isinstance(orderBy, sqlbuilder.SQLObjectField):
|
||
|
orderBy = orderBy.original
|
||
|
# @@: but we don't handle more complex expressions for orderings
|
||
|
if orderBy.startswith('-'):
|
||
|
orderBy = orderBy[1:]
|
||
|
reverse = True
|
||
|
else:
|
||
|
reverse = False
|
||
|
|
||
|
def cmper(a, b, attr=orderBy, rev=reverse):
|
||
|
a = getattr(a, attr)
|
||
|
b = getattr(b, attr)
|
||
|
if rev:
|
||
|
a, b = b, a
|
||
|
if a is None:
|
||
|
if b is None:
|
||
|
return 0
|
||
|
return -1
|
||
|
if b is None:
|
||
|
return 1
|
||
|
return cmp(a, b)
|
||
|
return cmper
|
||
|
|
||
|
# This is a one-to-many
|
||
|
class SOMultipleJoin(SOJoin):
|
||
|
|
||
|
def __init__(self, addRemoveName=None, **kw):
|
||
|
# addRemovePrefix is something like @@
|
||
|
SOJoin.__init__(self, **kw)
|
||
|
|
||
|
# Here we generate the method names
|
||
|
if not self.joinMethodName:
|
||
|
name = self.otherClassName[0].lower() + self.otherClassName[1:]
|
||
|
if name.endswith('s'):
|
||
|
name = name + "es"
|
||
|
else:
|
||
|
name = name + "s"
|
||
|
self.joinMethodName = name
|
||
|
if not addRemoveName:
|
||
|
self.addRemoveName = capitalize(self.otherClassName)
|
||
|
else:
|
||
|
self.addRemoveName = addRemoveName
|
||
|
|
||
|
def performJoin(self, inst):
|
||
|
ids = inst._connection._SO_selectJoin(
|
||
|
self.otherClass,
|
||
|
self.joinColumn,
|
||
|
inst.id)
|
||
|
if inst.sqlmeta._perConnection:
|
||
|
conn = inst._connection
|
||
|
else:
|
||
|
conn = None
|
||
|
return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass)
|
||
|
|
||
|
def _dbNameToPythonName(self):
|
||
|
for column in self.otherClass.sqlmeta.columns.values():
|
||
|
if column.dbName == self.joinColumn:
|
||
|
return column.name
|
||
|
return self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)
|
||
|
|
||
|
class MultipleJoin(Join):
|
||
|
baseClass = SOMultipleJoin
|
||
|
|
||
|
class SOSQLMultipleJoin(SOMultipleJoin):
|
||
|
|
||
|
def performJoin(self, inst):
|
||
|
if inst.sqlmeta._perConnection:
|
||
|
conn = inst._connection
|
||
|
else:
|
||
|
conn = None
|
||
|
pythonColumn = self._dbNameToPythonName()
|
||
|
results = self.otherClass.select(getattr(self.otherClass.q, pythonColumn) == inst.id, connection=conn)
|
||
|
return results.orderBy(self.orderBy)
|
||
|
|
||
|
class SQLMultipleJoin(Join):
|
||
|
baseClass = SOSQLMultipleJoin
|
||
|
|
||
|
# This is a many-to-many join, with an intermediary table
|
||
|
class SORelatedJoin(SOMultipleJoin):
|
||
|
|
||
|
def __init__(self,
|
||
|
otherColumn=None,
|
||
|
intermediateTable=None,
|
||
|
createRelatedTable=True,
|
||
|
**kw):
|
||
|
self.intermediateTable = intermediateTable
|
||
|
self.otherColumn = otherColumn
|
||
|
self.createRelatedTable = createRelatedTable
|
||
|
SOMultipleJoin.__init__(self, **kw)
|
||
|
classregistry.registry(
|
||
|
self.soClass.sqlmeta.registry).addClassCallback(
|
||
|
self.otherClassName, self._setOtherRelatedClass)
|
||
|
|
||
|
def _setOtherRelatedClass(self, otherClass):
|
||
|
if not self.intermediateTable:
|
||
|
names = [self.soClass.sqlmeta.table,
|
||
|
otherClass.sqlmeta.table]
|
||
|
names.sort()
|
||
|
self.intermediateTable = '%s_%s' % (names[0], names[1])
|
||
|
if not self.otherColumn:
|
||
|
self.otherColumn = self.soClass.sqlmeta.style.tableReference(
|
||
|
otherClass.sqlmeta.table)
|
||
|
|
||
|
|
||
|
def hasIntermediateTable(self):
|
||
|
return True
|
||
|
|
||
|
def performJoin(self, inst):
|
||
|
ids = inst._connection._SO_intermediateJoin(
|
||
|
self.intermediateTable,
|
||
|
self.otherColumn,
|
||
|
self.joinColumn,
|
||
|
inst.id)
|
||
|
if inst.sqlmeta._perConnection:
|
||
|
conn = inst._connection
|
||
|
else:
|
||
|
conn = None
|
||
|
return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass)
|
||
|
|
||
|
def remove(self, inst, other):
|
||
|
inst._connection._SO_intermediateDelete(
|
||
|
self.intermediateTable,
|
||
|
self.joinColumn,
|
||
|
getID(inst),
|
||
|
self.otherColumn,
|
||
|
getID(other))
|
||
|
|
||
|
def add(self, inst, other):
|
||
|
inst._connection._SO_intermediateInsert(
|
||
|
self.intermediateTable,
|
||
|
self.joinColumn,
|
||
|
getID(inst),
|
||
|
self.otherColumn,
|
||
|
getID(other))
|
||
|
|
||
|
class RelatedJoin(MultipleJoin):
|
||
|
baseClass = SORelatedJoin
|
||
|
|
||
|
# helper classes to SQLRelatedJoin
|
||
|
class OtherTableToJoin(sqlbuilder.SQLExpression):
|
||
|
def __init__(self, otherTable, otherIdName, interTable, joinColumn):
|
||
|
self.otherTable = otherTable
|
||
|
self.otherIdName = otherIdName
|
||
|
self.interTable = interTable
|
||
|
self.joinColumn = joinColumn
|
||
|
|
||
|
def tablesUsedImmediate(self):
|
||
|
return [self.otherTable, self.interTable]
|
||
|
|
||
|
def __sqlrepr__(self, db):
|
||
|
return '%s.%s = %s.%s' % (self.otherTable, self.otherIdName, self.interTable, self.joinColumn)
|
||
|
|
||
|
class JoinToTable(sqlbuilder.SQLExpression):
|
||
|
def __init__(self, table, idName, interTable, joinColumn):
|
||
|
self.table = table
|
||
|
self.idName = idName
|
||
|
self.interTable = interTable
|
||
|
self.joinColumn = joinColumn
|
||
|
|
||
|
def tablesUsedImmediate(self):
|
||
|
return [self.table, self.interTable]
|
||
|
|
||
|
def __sqlrepr__(self, db):
|
||
|
return '%s.%s = %s.%s' % (self.interTable, self.joinColumn, self.table, self.idName)
|
||
|
|
||
|
class TableToId(sqlbuilder.SQLExpression):
|
||
|
def __init__(self, table, idName, idValue):
|
||
|
self.table = table
|
||
|
self.idName = idName
|
||
|
self.idValue = idValue
|
||
|
|
||
|
def tablesUsedImmediate(self):
|
||
|
return [self.table]
|
||
|
|
||
|
def __sqlrepr__(self, db):
|
||
|
return '%s.%s = %s' % (self.table, self.idName, self.idValue)
|
||
|
|
||
|
class SOSQLRelatedJoin(SORelatedJoin):
|
||
|
def performJoin(self, inst):
|
||
|
if inst.sqlmeta._perConnection:
|
||
|
conn = inst._connection
|
||
|
else:
|
||
|
conn = None
|
||
|
results = self.otherClass.select(sqlbuilder.AND(
|
||
|
OtherTableToJoin(
|
||
|
self.otherClass.sqlmeta.table, self.otherClass.sqlmeta.idName,
|
||
|
self.intermediateTable, self.otherColumn
|
||
|
),
|
||
|
JoinToTable(
|
||
|
self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName,
|
||
|
self.intermediateTable, self.joinColumn
|
||
|
),
|
||
|
TableToId(self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName, inst.id),
|
||
|
), clauseTables=(self.soClass.sqlmeta.table, self.otherClass.sqlmeta.table, self.intermediateTable),
|
||
|
connection=conn)
|
||
|
return results.orderBy(self.orderBy)
|
||
|
|
||
|
class SQLRelatedJoin(RelatedJoin):
|
||
|
baseClass = SOSQLRelatedJoin
|
||
|
|
||
|
def capitalize(name):
|
||
|
return name[0].capitalize() + name[1:]
|
||
|
|
||
|
class SOSingleJoin(SOMultipleJoin):
|
||
|
|
||
|
def __init__(self, **kw):
|
||
|
self.makeDefault = kw.pop('makeDefault', False)
|
||
|
SOMultipleJoin.__init__(self, **kw)
|
||
|
|
||
|
def performJoin(self, inst):
|
||
|
if inst.sqlmeta._perConnection:
|
||
|
conn = inst._connection
|
||
|
else:
|
||
|
conn = None
|
||
|
pythonColumn = self._dbNameToPythonName()
|
||
|
results = self.otherClass.select(
|
||
|
getattr(self.otherClass.q, pythonColumn) == inst.id,
|
||
|
connection=conn
|
||
|
)
|
||
|
if results.count() == 0:
|
||
|
if not self.makeDefault:
|
||
|
return None
|
||
|
else:
|
||
|
kw = {self.soClass.sqlmeta.style.instanceIDAttrToAttr(pythonColumn): inst}
|
||
|
return self.otherClass(**kw) # instanciating the otherClass with all
|
||
|
else:
|
||
|
return results[0]
|
||
|
|
||
|
class SingleJoin(Join):
|
||
|
baseClass = SOSingleJoin
|
||
|
|
||
|
|
||
|
|
||
|
import boundattributes
|
||
|
|
||
|
class SOManyToMany(object):
|
||
|
|
||
|
def __init__(self, soClass, name, join,
|
||
|
intermediateTable, joinColumn, otherColumn,
|
||
|
createJoinTable, **attrs):
|
||
|
self.name = name
|
||
|
self.intermediateTable = intermediateTable
|
||
|
self.joinColumn = joinColumn
|
||
|
self.otherColumn = otherColumn
|
||
|
self.createJoinTable = createJoinTable
|
||
|
self.soClass = self.otherClass = None
|
||
|
for name, value in attrs.items():
|
||
|
setattr(self, name, value)
|
||
|
classregistry.registry(
|
||
|
soClass.sqlmeta.registry).addClassCallback(
|
||
|
join, self._setOtherClass)
|
||
|
classregistry.registry(
|
||
|
soClass.sqlmeta.registry).addClassCallback(
|
||
|
soClass.__name__, self._setThisClass)
|
||
|
|
||
|
def _setThisClass(self, soClass):
|
||
|
self.soClass = soClass
|
||
|
if self.soClass and self.otherClass:
|
||
|
self._finishSet()
|
||
|
|
||
|
def _setOtherClass(self, otherClass):
|
||
|
self.otherClass = otherClass
|
||
|
if self.soClass and self.otherClass:
|
||
|
self._finishSet()
|
||
|
|
||
|
def _finishSet(self):
|
||
|
if self.intermediateTable is None:
|
||
|
names = [self.soClass.sqlmeta.table,
|
||
|
self.otherClass.sqlmeta.table]
|
||
|
names.sort()
|
||
|
self.intermediateTable = '%s_%s' % (names[0], names[1])
|
||
|
if not self.otherColumn:
|
||
|
self.otherColumn = self.soClass.sqlmeta.style.tableReference(
|
||
|
self.otherClass.sqlmeta.table)
|
||
|
if not self.joinColumn:
|
||
|
self.joinColumn = styles.getStyle(
|
||
|
self.soClass).tableReference(self.soClass.sqlmeta.table)
|
||
|
events.listen(self.event_CreateTableSignal,
|
||
|
self.soClass, events.CreateTableSignal)
|
||
|
events.listen(self.event_CreateTableSignal,
|
||
|
self.otherClass, events.CreateTableSignal)
|
||
|
self.clause = (
|
||
|
(self.otherClass.q.id ==
|
||
|
sqlbuilder.Field(self.intermediateTable, self.otherColumn))
|
||
|
& (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
|
||
|
== self.soClass.q.id))
|
||
|
|
||
|
def __get__(self, obj, type):
|
||
|
if obj is None:
|
||
|
return self
|
||
|
query = (
|
||
|
(self.otherClass.q.id ==
|
||
|
sqlbuilder.Field(self.intermediateTable, self.otherColumn))
|
||
|
& (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
|
||
|
== obj.id))
|
||
|
select = self.otherClass.select(query)
|
||
|
return _ManyToManySelectWrapper(obj, self, select)
|
||
|
|
||
|
def event_CreateTableSignal(self, soClass, connection, extra_sql,
|
||
|
post_funcs):
|
||
|
if self.createJoinTable:
|
||
|
post_funcs.append(self.event_CreateTableSignalPost)
|
||
|
|
||
|
def event_CreateTableSignalPost(self, soClass, connection):
|
||
|
if connection.tableExists(self.intermediateTable):
|
||
|
return
|
||
|
connection._SO_createJoinTable(self)
|
||
|
|
||
|
class ManyToMany(boundattributes.BoundFactory):
|
||
|
factory_class = SOManyToMany
|
||
|
__restrict_attributes__ = (
|
||
|
'join', 'intermediateTable',
|
||
|
'joinColumn', 'otherColumn', 'createJoinTable')
|
||
|
__unpackargs__ = ('join',)
|
||
|
|
||
|
# Default values:
|
||
|
intermediateTable = None
|
||
|
joinColumn = None
|
||
|
otherColumn = None
|
||
|
createJoinTable = True
|
||
|
|
||
|
class _ManyToManySelectWrapper(object):
|
||
|
|
||
|
def __init__(self, forObject, join, select):
|
||
|
self.forObject = forObject
|
||
|
self.join = join
|
||
|
self.select = select
|
||
|
|
||
|
def __getattr__(self, attr):
|
||
|
# @@: This passes through private variable access too... should it?
|
||
|
# Also magic methods, like __str__
|
||
|
return getattr(self.select, attr)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
|
||
|
|
||
|
def __str__(self):
|
||
|
return str(self.select)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter(self.select)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.select[key]
|
||
|
|
||
|
def add(self, obj):
|
||
|
obj._connection._SO_intermediateInsert(
|
||
|
self.join.intermediateTable,
|
||
|
self.join.joinColumn,
|
||
|
getID(self.forObject),
|
||
|
self.join.otherColumn,
|
||
|
getID(obj))
|
||
|
|
||
|
def remove(self, obj):
|
||
|
obj._connection._SO_intermediateDelete(
|
||
|
self.join.intermediateTable,
|
||
|
self.join.joinColumn,
|
||
|
getID(self.forObject),
|
||
|
self.join.otherColumn,
|
||
|
getID(obj))
|
||
|
|
||
|
def create(self, **kw):
|
||
|
obj = self.join.otherClass(**kw)
|
||
|
self.add(obj)
|
||
|
return obj
|
||
|
|
||
|
class SOOneToMany(object):
|
||
|
|
||
|
def __init__(self, soClass, name, join, joinColumn, **attrs):
|
||
|
self.soClass = soClass
|
||
|
self.name = name
|
||
|
self.joinColumn = joinColumn
|
||
|
for name, value in attrs.items():
|
||
|
setattr(self, name, value)
|
||
|
classregistry.registry(
|
||
|
soClass.sqlmeta.registry).addClassCallback(
|
||
|
join, self._setOtherClass)
|
||
|
|
||
|
def _setOtherClass(self, otherClass):
|
||
|
self.otherClass = otherClass
|
||
|
if not self.joinColumn:
|
||
|
self.joinColumn = styles.getStyle(
|
||
|
self.soClass).tableReference(self.soClass.sqlmeta.table)
|
||
|
self.clause = (
|
||
|
sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn)
|
||
|
== self.soClass.q.id)
|
||
|
|
||
|
def __get__(self, obj, type):
|
||
|
if obj is None:
|
||
|
return self
|
||
|
query = (
|
||
|
sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn)
|
||
|
== obj.id)
|
||
|
select = self.otherClass.select(query)
|
||
|
return _OneToManySelectWrapper(obj, self, select)
|
||
|
|
||
|
class OneToMany(boundattributes.BoundFactory):
|
||
|
factory_class = SOOneToMany
|
||
|
__restrict_attributes__ = (
|
||
|
'join', 'joinColumn')
|
||
|
__unpackargs__ = ('join',)
|
||
|
|
||
|
# Default values:
|
||
|
joinColumn = None
|
||
|
|
||
|
class _OneToManySelectWrapper(object):
|
||
|
|
||
|
def __init__(self, forObject, join, select):
|
||
|
self.forObject = forObject
|
||
|
self.join = join
|
||
|
self.select = select
|
||
|
|
||
|
def __getattr__(self, attr):
|
||
|
# @@: This passes through private variable access too... should it?
|
||
|
# Also magic methods, like __str__
|
||
|
return getattr(self.select, attr)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
|
||
|
|
||
|
def __str__(self):
|
||
|
return str(self.select)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter(self.select)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.select[key]
|
||
|
|
||
|
def create(self, **kw):
|
||
|
kw[self.join.joinColumn] = self.forObject.id
|
||
|
return self.join.otherClass(**kw)
|