mirror of
https://github.com/djohnlewis/stackdump
synced 2025-12-07 08:23:25 +00:00
Initial commit. Still building up the env and some parsing code.
This commit is contained in:
346
python/packages/sqlobject/sresults.py
Normal file
346
python/packages/sqlobject/sresults.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import dbconnection
|
||||
import joins
|
||||
import main
|
||||
import sqlbuilder
|
||||
|
||||
__all__ = ['SelectResults']
|
||||
|
||||
class SelectResults(object):
|
||||
IterationClass = dbconnection.Iteration
|
||||
|
||||
def __init__(self, sourceClass, clause, clauseTables=None,
|
||||
**ops):
|
||||
self.sourceClass = sourceClass
|
||||
if clause is None or isinstance(clause, str) and clause == 'all':
|
||||
clause = sqlbuilder.SQLTrueClause
|
||||
if not isinstance(clause, sqlbuilder.SQLExpression):
|
||||
clause = sqlbuilder.SQLConstant(clause)
|
||||
self.clause = clause
|
||||
self.ops = ops
|
||||
if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
|
||||
ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
|
||||
orderBy = ops['orderBy']
|
||||
if isinstance(orderBy, (tuple, list)):
|
||||
orderBy = map(self._mungeOrderBy, orderBy)
|
||||
else:
|
||||
orderBy = self._mungeOrderBy(orderBy)
|
||||
ops['dbOrderBy'] = orderBy
|
||||
if 'connection' in ops and ops['connection'] is None:
|
||||
del ops['connection']
|
||||
if ops.get('limit', None):
|
||||
assert not ops.get('start', None) and not ops.get('end', None), \
|
||||
"'limit' cannot be used with 'start' or 'end'"
|
||||
ops["start"] = 0
|
||||
ops["end"] = ops.pop("limit")
|
||||
|
||||
tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
|
||||
if clauseTables:
|
||||
for table in clauseTables:
|
||||
tablesSet.add(table)
|
||||
self.clauseTables = clauseTables
|
||||
# Explicitly post-adding-in sqlmeta.table, sqlbuilder.Select will handle sqlrepr'ing and dupes
|
||||
self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
|
||||
|
||||
def queryForSelect(self):
|
||||
columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
|
||||
query = sqlbuilder.Select(columns,
|
||||
where=self.clause,
|
||||
join=self.ops.get('join', sqlbuilder.NoDefault),
|
||||
distinct=self.ops.get('distinct',False),
|
||||
lazyColumns=self.ops.get('lazyColumns', False),
|
||||
start=self.ops.get('start', 0),
|
||||
end=self.ops.get('end', None),
|
||||
orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
|
||||
reversed=self.ops.get('reversed', False),
|
||||
staticTables=self.tables,
|
||||
forUpdate=self.ops.get('forUpdate', False))
|
||||
return query
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s at %x>" % (self.__class__.__name__, id(self))
|
||||
|
||||
def _getConnection(self):
|
||||
return self.ops.get('connection') or self.sourceClass._connection
|
||||
|
||||
def __str__(self):
|
||||
conn = self._getConnection()
|
||||
return conn.queryForSelect(self)
|
||||
|
||||
def _mungeOrderBy(self, orderBy):
|
||||
if isinstance(orderBy, str) and orderBy.startswith('-'):
|
||||
orderBy = orderBy[1:]
|
||||
desc = True
|
||||
else:
|
||||
desc = False
|
||||
if isinstance(orderBy, basestring):
|
||||
if orderBy in self.sourceClass.sqlmeta.columns:
|
||||
val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
|
||||
if desc:
|
||||
return sqlbuilder.DESC(val)
|
||||
else:
|
||||
return val
|
||||
else:
|
||||
orderBy = sqlbuilder.SQLConstant(orderBy)
|
||||
if desc:
|
||||
return sqlbuilder.DESC(orderBy)
|
||||
else:
|
||||
return orderBy
|
||||
else:
|
||||
return orderBy
|
||||
|
||||
def clone(self, **newOps):
|
||||
ops = self.ops.copy()
|
||||
ops.update(newOps)
|
||||
return self.__class__(self.sourceClass, self.clause,
|
||||
self.clauseTables, **ops)
|
||||
|
||||
def orderBy(self, orderBy):
|
||||
return self.clone(orderBy=orderBy)
|
||||
|
||||
def connection(self, conn):
|
||||
return self.clone(connection=conn)
|
||||
|
||||
def limit(self, limit):
|
||||
return self[:limit]
|
||||
|
||||
def lazyColumns(self, value):
|
||||
return self.clone(lazyColumns=value)
|
||||
|
||||
def reversed(self):
|
||||
return self.clone(reversed=not self.ops.get('reversed', False))
|
||||
|
||||
def distinct(self):
|
||||
return self.clone(distinct=True)
|
||||
|
||||
def newClause(self, new_clause):
|
||||
return self.__class__(self.sourceClass, new_clause,
|
||||
self.clauseTables, **self.ops)
|
||||
|
||||
def filter(self, filter_clause):
|
||||
if filter_clause is None:
|
||||
# None doesn't filter anything, it's just a no-op:
|
||||
return self
|
||||
clause = self.clause
|
||||
if isinstance(clause, basestring):
|
||||
clause = sqlbuilder.SQLConstant('(%s)' % self.clause)
|
||||
return self.newClause(sqlbuilder.AND(clause, filter_clause))
|
||||
|
||||
def __getitem__(self, value):
|
||||
if isinstance(value, slice):
|
||||
assert not value.step, "Slices do not support steps"
|
||||
if not value.start and not value.stop:
|
||||
# No need to copy, I'm immutable
|
||||
return self
|
||||
|
||||
# Negative indexes aren't handled (and everything we
|
||||
# don't handle ourselves we just create a list to
|
||||
# handle)
|
||||
if (value.start and value.start < 0) \
|
||||
or (value.stop and value.stop < 0):
|
||||
if value.start:
|
||||
if value.stop:
|
||||
return list(self)[value.start:value.stop]
|
||||
return list(self)[value.start:]
|
||||
return list(self)[:value.stop]
|
||||
|
||||
|
||||
if value.start:
|
||||
assert value.start >= 0
|
||||
start = self.ops.get('start', 0) + value.start
|
||||
if value.stop is not None:
|
||||
assert value.stop >= 0
|
||||
if value.stop < value.start:
|
||||
# an empty result:
|
||||
end = start
|
||||
else:
|
||||
end = value.stop + self.ops.get('start', 0)
|
||||
if self.ops.get('end', None) is not None and \
|
||||
self.ops['end'] < end:
|
||||
# truncated by previous slice:
|
||||
end = self.ops['end']
|
||||
else:
|
||||
end = self.ops.get('end', None)
|
||||
else:
|
||||
start = self.ops.get('start', 0)
|
||||
end = value.stop + start
|
||||
if self.ops.get('end', None) is not None \
|
||||
and self.ops['end'] < end:
|
||||
end = self.ops['end']
|
||||
return self.clone(start=start, end=end)
|
||||
else:
|
||||
if value < 0:
|
||||
return list(iter(self))[value]
|
||||
else:
|
||||
start = self.ops.get('start', 0) + value
|
||||
return list(self.clone(start=start, end=start+1))[0]
|
||||
|
||||
def __iter__(self):
|
||||
# @@: This could be optimized, using a simpler algorithm
|
||||
# since we don't have to worry about garbage collection,
|
||||
# etc., like we do with .lazyIter()
|
||||
return iter(list(self.lazyIter()))
|
||||
|
||||
def lazyIter(self):
|
||||
"""
|
||||
Returns an iterator that will lazily pull rows out of the
|
||||
database and return SQLObject instances
|
||||
"""
|
||||
conn = self._getConnection()
|
||||
return conn.iterSelect(self)
|
||||
|
||||
def accumulate(self, *expressions):
|
||||
""" Use accumulate expression(s) to select result
|
||||
using another SQL select through current
|
||||
connection.
|
||||
Return the accumulate result
|
||||
"""
|
||||
conn = self._getConnection()
|
||||
exprs = []
|
||||
for expr in expressions:
|
||||
if not isinstance(expr, sqlbuilder.SQLExpression):
|
||||
expr = sqlbuilder.SQLConstant(expr)
|
||||
exprs.append(expr)
|
||||
return conn.accumulateSelect(self, *exprs)
|
||||
|
||||
def count(self):
|
||||
""" Counting elements of current select results """
|
||||
assert not self.ops.get('start') and not self.ops.get('end'), \
|
||||
"start/end/limit have no meaning with 'count'"
|
||||
assert not (self.ops.get('distinct') and (self.ops.get('start')
|
||||
or self.ops.get('end'))), \
|
||||
"distinct-counting of sliced objects is not supported"
|
||||
if self.ops.get('distinct'):
|
||||
# Column must be specified, so we are using unique ID column.
|
||||
# COUNT(DISTINCT column) is supported by MySQL and PostgreSQL,
|
||||
# but not by SQLite. Perhaps more portable would be subquery:
|
||||
# SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table)
|
||||
count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
|
||||
else:
|
||||
count = self.accumulate('COUNT(*)')
|
||||
if self.ops.get('start'):
|
||||
count -= self.ops['start']
|
||||
if self.ops.get('end'):
|
||||
count = min(self.ops['end'] - self.ops.get('start', 0), count)
|
||||
return count
|
||||
|
||||
def accumulateMany(self, *attributes):
|
||||
""" Making the expressions for count/sum/min/max/avg
|
||||
of a given select result attributes.
|
||||
`attributes` must be a list/tuple of pairs (func_name, attribute);
|
||||
`attribute` can be a column name (like 'a_column')
|
||||
or a dot-q attribute (like Table.q.aColumn)
|
||||
"""
|
||||
expressions = []
|
||||
conn = self._getConnection()
|
||||
if self.ops.get('distinct'):
|
||||
distinct = 'DISTINCT '
|
||||
else:
|
||||
distinct = ''
|
||||
for func_name, attribute in attributes:
|
||||
if not isinstance(attribute, str):
|
||||
attribute = conn.sqlrepr(attribute)
|
||||
expression = '%s(%s%s)' % (func_name, distinct, attribute)
|
||||
expressions.append(expression)
|
||||
return self.accumulate(*expressions)
|
||||
|
||||
def accumulateOne(self, func_name, attribute):
|
||||
""" Making the sum/min/max/avg of a given select result attribute.
|
||||
`attribute` can be a column name (like 'a_column')
|
||||
or a dot-q attribute (like Table.q.aColumn)
|
||||
"""
|
||||
return self.accumulateMany((func_name, attribute))
|
||||
|
||||
def sum(self, attribute):
|
||||
return self.accumulateOne("SUM", attribute)
|
||||
|
||||
def min(self, attribute):
|
||||
return self.accumulateOne("MIN", attribute)
|
||||
|
||||
def avg(self, attribute):
|
||||
return self.accumulateOne("AVG", attribute)
|
||||
|
||||
def max(self, attribute):
|
||||
return self.accumulateOne("MAX", attribute)
|
||||
|
||||
def getOne(self, default=sqlbuilder.NoDefault):
|
||||
"""
|
||||
If a query is expected to only return a single value,
|
||||
using ``.getOne()`` will return just that value.
|
||||
|
||||
If not results are found, ``SQLObjectNotFound`` will be
|
||||
raised, unless you pass in a default value (like
|
||||
``.getOne(None)``).
|
||||
|
||||
If more than one result is returned,
|
||||
``SQLObjectIntegrityError`` will be raised.
|
||||
"""
|
||||
results = list(self)
|
||||
if not results:
|
||||
if default is sqlbuilder.NoDefault:
|
||||
raise main.SQLObjectNotFound(
|
||||
"No results matched the query for %s"
|
||||
% self.sourceClass.__name__)
|
||||
return default
|
||||
if len(results) > 1:
|
||||
raise main.SQLObjectIntegrityError(
|
||||
"More than one result returned from query: %s"
|
||||
% results)
|
||||
return results[0]
|
||||
|
||||
def throughTo(self):
|
||||
class _throughTo_getter(object):
|
||||
def __init__(self, inst):
|
||||
self.sresult = inst
|
||||
def __getattr__(self, attr):
|
||||
return self.sresult._throughTo(attr)
|
||||
return _throughTo_getter(self)
|
||||
throughTo = property(throughTo)
|
||||
|
||||
def _throughTo(self, attr):
|
||||
otherClass = None
|
||||
orderBy = sqlbuilder.NoDefault
|
||||
|
||||
ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
|
||||
if ref and ref.foreignKey:
|
||||
otherClass, clause = self._throughToFK(ref)
|
||||
else:
|
||||
join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
|
||||
if join:
|
||||
join = join[0]
|
||||
orderBy = join.orderBy
|
||||
if hasattr(join, 'otherColumn'):
|
||||
otherClass, clause = self._throughToRelatedJoin(join)
|
||||
else:
|
||||
otherClass, clause = self._throughToMultipleJoin(join)
|
||||
|
||||
if not otherClass:
|
||||
raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))
|
||||
|
||||
return otherClass.select(clause,
|
||||
orderBy=orderBy,
|
||||
connection=self._getConnection())
|
||||
|
||||
def _throughToFK(self, col):
|
||||
otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
|
||||
colName = col.name
|
||||
query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
|
||||
query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
|
||||
return otherClass, otherClass.q.id==getattr(query.q, colName)
|
||||
|
||||
def _throughToMultipleJoin(self, join):
|
||||
otherClass = join.otherClass
|
||||
colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
|
||||
query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
|
||||
query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
|
||||
joinColumn = getattr(otherClass.q, colName)
|
||||
return otherClass, joinColumn==query.q.id
|
||||
|
||||
def _throughToRelatedJoin(self, join):
|
||||
otherClass = join.otherClass
|
||||
intTable = sqlbuilder.Table(join.intermediateTable)
|
||||
colName = join.joinColumn
|
||||
query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
|
||||
query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
|
||||
clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
|
||||
getattr(intTable, colName) == query.q.id)
|
||||
return otherClass, clause
|
||||
Reference in New Issue
Block a user