from paste.deploy.converters import asbool from paste.wsgilib import catch_errors from paste.util import import_string import sqlobject import threading def make_middleware(app, global_conf, database=None, use_transaction=False, hub=None): """ WSGI middleware that sets the connection for the request (using the database URI or connection object) and the given hub (or ``sqlobject.sqlhub`` if not given). If ``use_transaction`` is true, then the request will be run in a transaction. Applications can use the keys (which are all no-argument functions): ``sqlobject.get_connection()``: Returns the connection object ``sqlobject.abort()``: Aborts the transaction. Does not raise an error, but at the *end* of the request there will be a rollback. ``sqlobject.begin()``: Starts a transaction. First commits (or rolls back if aborted) if this is run in a transaction. ``sqlobject.in_transaction()``: Returns true or false, depending if we are currently in a transaction. """ use_transaction = asbool(use_transaction) if database is None: database = global_conf.get('database') if not database: raise ValueError( "You must provide a 'database' configuration value") if isinstance(hub, basestring): hub = import_string.eval_import(hub) if not hub: hub = sqlobject.sqlhub if isinstance(database, basestring): database = sqlobject.connectionForURI(database) return SQLObjectMiddleware(app, database, use_transaction, hub) class SQLObjectMiddleware(object): def __init__(self, app, conn, use_transaction, hub): self.app = app self.conn = conn self.use_transaction = use_transaction self.hub = hub def __call__(self, environ, start_response): conn = [self.conn] if self.use_transaction: conn[0] = conn[0].transaction() any_errors = [] use_transaction = [self.use_transaction] self.hub.threadConnection = conn[0] def abort(): assert use_transaction[0], ( "You cannot abort, because a transaction is not being used") any_errors.append(None) def begin(): if use_transaction[0]: if any_errors: conn[0].rollback() else: conn[0].commit() any_errors[:] = [] use_transaction[0] = True conn[0] = self.conn.transaction() self.hub.threadConnection = conn[0] def error(exc_info=None): any_errors.append(None) ok() def ok(): if use_transaction[0]: if any_errors: conn[0].rollback() else: conn[0].commit(close=True) self.hub.threadConnection = None def in_transaction(): return use_transaction[0] def get_connection(): return conn[0] environ['sqlobject.get_connection'] = get_connection environ['sqlobject.abort'] = abort environ['sqlobject.begin'] = begin environ['sqlobject.in_transaction'] = in_transaction return catch_errors(self.app, environ, start_response, error_callback=error, ok_callback=ok)