mirror of
https://github.com/djohnlewis/stackdump
synced 2025-01-22 22:51:36 +00:00
98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
from paste.deploy.converters import asbool
|
|
from paste.wsgilib import catch_errors
|
|
from paste.util import import_string
|
|
import sqlobject
|
|
import threading
|
|
|
|
def make_middleware(app, global_conf, database=None, use_transaction=False,
|
|
hub=None):
|
|
"""
|
|
WSGI middleware that sets the connection for the request (using
|
|
the database URI or connection object) and the given hub (or
|
|
``sqlobject.sqlhub`` if not given).
|
|
|
|
If ``use_transaction`` is true, then the request will be run in a
|
|
transaction.
|
|
|
|
Applications can use the keys (which are all no-argument functions):
|
|
|
|
``sqlobject.get_connection()``:
|
|
Returns the connection object
|
|
|
|
``sqlobject.abort()``:
|
|
Aborts the transaction. Does not raise an error, but at the *end*
|
|
of the request there will be a rollback.
|
|
|
|
``sqlobject.begin()``:
|
|
Starts a transaction. First commits (or rolls back if aborted) if
|
|
this is run in a transaction.
|
|
|
|
``sqlobject.in_transaction()``:
|
|
Returns true or false, depending if we are currently in a
|
|
transaction.
|
|
"""
|
|
use_transaction = asbool(use_transaction)
|
|
if database is None:
|
|
database = global_conf.get('database')
|
|
if not database:
|
|
raise ValueError(
|
|
"You must provide a 'database' configuration value")
|
|
if isinstance(hub, basestring):
|
|
hub = import_string.eval_import(hub)
|
|
if not hub:
|
|
hub = sqlobject.sqlhub
|
|
if isinstance(database, basestring):
|
|
database = sqlobject.connectionForURI(database)
|
|
return SQLObjectMiddleware(app, database, use_transaction, hub)
|
|
|
|
class SQLObjectMiddleware(object):
|
|
|
|
def __init__(self, app, conn, use_transaction, hub):
|
|
self.app = app
|
|
self.conn = conn
|
|
self.use_transaction = use_transaction
|
|
self.hub = hub
|
|
|
|
def __call__(self, environ, start_response):
|
|
conn = [self.conn]
|
|
if self.use_transaction:
|
|
conn[0] = conn[0].transaction()
|
|
any_errors = []
|
|
use_transaction = [self.use_transaction]
|
|
self.hub.threadConnection = conn[0]
|
|
def abort():
|
|
assert use_transaction[0], (
|
|
"You cannot abort, because a transaction is not being used")
|
|
any_errors.append(None)
|
|
def begin():
|
|
if use_transaction[0]:
|
|
if any_errors:
|
|
conn[0].rollback()
|
|
else:
|
|
conn[0].commit()
|
|
any_errors[:] = []
|
|
use_transaction[0] = True
|
|
conn[0] = self.conn.transaction()
|
|
self.hub.threadConnection = conn[0]
|
|
def error(exc_info=None):
|
|
any_errors.append(None)
|
|
ok()
|
|
def ok():
|
|
if use_transaction[0]:
|
|
if any_errors:
|
|
conn[0].rollback()
|
|
else:
|
|
conn[0].commit(close=True)
|
|
self.hub.threadConnection = None
|
|
def in_transaction():
|
|
return use_transaction[0]
|
|
def get_connection():
|
|
return conn[0]
|
|
environ['sqlobject.get_connection'] = get_connection
|
|
environ['sqlobject.abort'] = abort
|
|
environ['sqlobject.begin'] = begin
|
|
environ['sqlobject.in_transaction'] = in_transaction
|
|
return catch_errors(self.app, environ, start_response,
|
|
error_callback=error, ok_callback=ok)
|
|
|