import dbconnection
import joins
import main
import sqlbuilder

__all__ = ['SelectResults']

class SelectResults(object):
    IterationClass = dbconnection.Iteration

    def __init__(self, sourceClass, clause, clauseTables=None,
                 **ops):
        self.sourceClass = sourceClass
        if clause is None or isinstance(clause, str) and clause == 'all':
            clause = sqlbuilder.SQLTrueClause
        if not isinstance(clause, sqlbuilder.SQLExpression):
            clause = sqlbuilder.SQLConstant(clause)
        self.clause = clause
        self.ops = ops
        if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
            ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
        orderBy = ops['orderBy']
        if isinstance(orderBy, (tuple, list)):
            orderBy = map(self._mungeOrderBy, orderBy)
        else:
            orderBy = self._mungeOrderBy(orderBy)
        ops['dbOrderBy'] = orderBy
        if 'connection' in ops and ops['connection'] is None:
            del ops['connection']
        if ops.get('limit', None):
            assert not ops.get('start', None) and not ops.get('end', None), \
               "'limit' cannot be used with 'start' or 'end'"
            ops["start"] = 0
            ops["end"] = ops.pop("limit")

        tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
        if clauseTables:
            for table in clauseTables:
                tablesSet.add(table)
        self.clauseTables = clauseTables
        # Explicitly post-adding-in sqlmeta.table, sqlbuilder.Select will handle sqlrepr'ing and dupes
        self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]

    def queryForSelect(self):
        columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
        query = sqlbuilder.Select(columns,
                                  where=self.clause,
                                  join=self.ops.get('join', sqlbuilder.NoDefault),
                                  distinct=self.ops.get('distinct',False),
                                  lazyColumns=self.ops.get('lazyColumns', False),
                                  start=self.ops.get('start', 0),
                                  end=self.ops.get('end', None),
                                  orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
                                  reversed=self.ops.get('reversed', False),
                                  staticTables=self.tables,
                                  forUpdate=self.ops.get('forUpdate', False))
        return query

    def __repr__(self):
        return "<%s at %x>" % (self.__class__.__name__, id(self))

    def _getConnection(self):
        return self.ops.get('connection') or self.sourceClass._connection

    def __str__(self):
        conn = self._getConnection()
        return conn.queryForSelect(self)

    def _mungeOrderBy(self, orderBy):
        if isinstance(orderBy, str) and orderBy.startswith('-'):
            orderBy = orderBy[1:]
            desc = True
        else:
            desc = False
        if isinstance(orderBy, basestring):
            if orderBy in self.sourceClass.sqlmeta.columns:
                val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
                if desc:
                    return sqlbuilder.DESC(val)
                else:
                    return val
            else:
                orderBy = sqlbuilder.SQLConstant(orderBy)
                if desc:
                    return sqlbuilder.DESC(orderBy)
                else:
                    return orderBy
        else:
            return orderBy

    def clone(self, **newOps):
        ops = self.ops.copy()
        ops.update(newOps)
        return self.__class__(self.sourceClass, self.clause,
                              self.clauseTables, **ops)

    def orderBy(self, orderBy):
        return self.clone(orderBy=orderBy)

    def connection(self, conn):
        return self.clone(connection=conn)

    def limit(self, limit):
        return self[:limit]

    def lazyColumns(self, value):
        return self.clone(lazyColumns=value)

    def reversed(self):
        return self.clone(reversed=not self.ops.get('reversed', False))

    def distinct(self):
        return self.clone(distinct=True)

    def newClause(self, new_clause):
        return self.__class__(self.sourceClass, new_clause,
                              self.clauseTables, **self.ops)

    def filter(self, filter_clause):
        if filter_clause is None:
            # None doesn't filter anything, it's just a no-op:
            return self
        clause = self.clause
        if isinstance(clause, basestring):
            clause = sqlbuilder.SQLConstant('(%s)' % self.clause)
        return self.newClause(sqlbuilder.AND(clause, filter_clause))

    def __getitem__(self, value):
        if isinstance(value, slice):
            assert not value.step, "Slices do not support steps"
            if not value.start and not value.stop:
                # No need to copy, I'm immutable
                return self

            # Negative indexes aren't handled (and everything we
            # don't handle ourselves we just create a list to
            # handle)
            if (value.start and value.start < 0) \
               or (value.stop and value.stop < 0):
                if value.start:
                    if value.stop:
                        return list(self)[value.start:value.stop]
                    return list(self)[value.start:]
                return list(self)[:value.stop]


            if value.start:
                assert value.start >= 0
                start = self.ops.get('start', 0) + value.start
                if value.stop is not None:
                    assert value.stop >= 0
                    if value.stop < value.start:
                        # an empty result:
                        end = start
                    else:
                        end = value.stop + self.ops.get('start', 0)
                        if self.ops.get('end', None) is not None and \
                                self.ops['end'] < end:
                            # truncated by previous slice:
                            end = self.ops['end']
                else:
                    end = self.ops.get('end', None)
            else:
                start = self.ops.get('start', 0)
                end = value.stop + start
                if self.ops.get('end', None) is not None \
                   and self.ops['end'] < end:
                    end = self.ops['end']
            return self.clone(start=start, end=end)
        else:
            if value < 0:
                return list(iter(self))[value]
            else:
                start = self.ops.get('start', 0) + value
                return list(self.clone(start=start, end=start+1))[0]

    def __iter__(self):
        # @@: This could be optimized, using a simpler algorithm
        # since we don't have to worry about garbage collection,
        # etc., like we do with .lazyIter()
        return iter(list(self.lazyIter()))

    def lazyIter(self):
        """
        Returns an iterator that will lazily pull rows out of the
        database and return SQLObject instances
        """
        conn = self._getConnection()
        return conn.iterSelect(self)

    def accumulate(self, *expressions):
        """ Use accumulate expression(s) to select result
            using another SQL select through current
            connection.
            Return the accumulate result
        """
        conn = self._getConnection()
        exprs = []
        for expr in expressions:
            if not isinstance(expr, sqlbuilder.SQLExpression):
                expr = sqlbuilder.SQLConstant(expr)
            exprs.append(expr)
        return conn.accumulateSelect(self, *exprs)

    def count(self):
        """ Counting elements of current select results """
        assert not self.ops.get('start') and not self.ops.get('end'), \
            "start/end/limit have no meaning with 'count'"
        assert not (self.ops.get('distinct') and (self.ops.get('start')
                                                  or self.ops.get('end'))), \
            "distinct-counting of sliced objects is not supported"
        if self.ops.get('distinct'):
            # Column must be specified, so we are using unique ID column.
            # COUNT(DISTINCT column) is supported by MySQL and PostgreSQL,
            # but not by SQLite. Perhaps more portable would be subquery:
            #  SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table)
            count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
        else:
            count = self.accumulate('COUNT(*)')
        if self.ops.get('start'):
            count -= self.ops['start']
        if self.ops.get('end'):
            count = min(self.ops['end'] - self.ops.get('start', 0), count)
        return count

    def accumulateMany(self, *attributes):
        """ Making the expressions for count/sum/min/max/avg
            of a given select result attributes.
            `attributes` must be a list/tuple of pairs (func_name, attribute);
            `attribute` can be a column name (like 'a_column')
            or a dot-q attribute (like Table.q.aColumn)
        """
        expressions = []
        conn = self._getConnection()
        if self.ops.get('distinct'):
            distinct = 'DISTINCT '
        else:
            distinct = ''
        for func_name, attribute in attributes:
            if not isinstance(attribute, str):
                attribute = conn.sqlrepr(attribute)
            expression = '%s(%s%s)' % (func_name, distinct, attribute)
            expressions.append(expression)
        return self.accumulate(*expressions)

    def accumulateOne(self, func_name, attribute):
        """ Making the sum/min/max/avg of a given select result attribute.
            `attribute` can be a column name (like 'a_column')
            or a dot-q attribute (like Table.q.aColumn)
        """
        return self.accumulateMany((func_name, attribute))

    def sum(self, attribute):
        return self.accumulateOne("SUM", attribute)

    def min(self, attribute):
        return self.accumulateOne("MIN", attribute)

    def avg(self, attribute):
        return self.accumulateOne("AVG", attribute)

    def max(self, attribute):
        return self.accumulateOne("MAX", attribute)

    def getOne(self, default=sqlbuilder.NoDefault):
        """
        If a query is expected to only return a single value,
        using ``.getOne()`` will return just that value.

        If not results are found, ``SQLObjectNotFound`` will be
        raised, unless you pass in a default value (like
        ``.getOne(None)``).

        If more than one result is returned,
        ``SQLObjectIntegrityError`` will be raised.
        """
        results = list(self)
        if not results:
            if default is sqlbuilder.NoDefault:
                raise main.SQLObjectNotFound(
                    "No results matched the query for %s"
                    % self.sourceClass.__name__)
            return default
        if len(results) > 1:
            raise main.SQLObjectIntegrityError(
                "More than one result returned from query: %s"
                % results)
        return results[0]

    def throughTo(self):
        class _throughTo_getter(object):
            def __init__(self, inst):
                self.sresult = inst
            def __getattr__(self, attr):
                return self.sresult._throughTo(attr)
        return _throughTo_getter(self)
    throughTo = property(throughTo)

    def _throughTo(self, attr):
        otherClass = None
        orderBy = sqlbuilder.NoDefault

        ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
        if ref and ref.foreignKey:
            otherClass, clause = self._throughToFK(ref)
        else:
            join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
            if join:
                join = join[0]
                orderBy = join.orderBy
                if hasattr(join, 'otherColumn'):
                    otherClass, clause = self._throughToRelatedJoin(join)
                else:
                    otherClass, clause = self._throughToMultipleJoin(join)

        if not otherClass:
            raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))

        return otherClass.select(clause,
                                orderBy=orderBy,
                                connection=self._getConnection())

    def _throughToFK(self, col):
        otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
        colName = col.name
        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
        return otherClass, otherClass.q.id==getattr(query.q, colName)

    def _throughToMultipleJoin(self, join):
        otherClass = join.otherClass
        colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
        joinColumn = getattr(otherClass.q, colName)
        return otherClass, joinColumn==query.q.id

    def _throughToRelatedJoin(self, join):
        otherClass = join.otherClass
        intTable = sqlbuilder.Table(join.intermediateTable)
        colName = join.joinColumn
        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
        clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
                     getattr(intTable, colName) == query.q.id)
        return otherClass, clause