mirror of
https://github.com/djohnlewis/stackdump
synced 2025-01-23 15:11:36 +00:00
1272 lines
48 KiB
Python
Executable File
1272 lines
48 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import fnmatch
|
|
import optparse
|
|
import os
|
|
import re
|
|
import sys
|
|
import textwrap
|
|
import time
|
|
import warnings
|
|
|
|
try:
|
|
from paste.deploy import appconfig
|
|
except ImportError:
|
|
appconfig = None
|
|
|
|
import sqlobject
|
|
from sqlobject import col
|
|
from sqlobject.classregistry import findClass
|
|
from sqlobject.declarative import DeclarativeMeta
|
|
from sqlobject.util import moduleloader
|
|
|
|
# It's not very unsafe to use tempnam like we are doing:
|
|
warnings.filterwarnings(
|
|
'ignore', 'tempnam is a potential security risk.*',
|
|
RuntimeWarning, '.*command', 28)
|
|
|
|
def nowarning_tempnam(*args, **kw):
|
|
return os.tempnam(*args, **kw)
|
|
|
|
class SQLObjectVersionTable(sqlobject.SQLObject):
|
|
"""
|
|
This table is used to store information about the database and
|
|
its version (used with record and update commands).
|
|
"""
|
|
class sqlmeta:
|
|
table = 'sqlobject_db_version'
|
|
version = col.StringCol()
|
|
updated = col.DateTimeCol(default=col.DateTimeCol.now)
|
|
|
|
def db_differences(soClass, conn):
|
|
"""
|
|
Returns the differences between a class and the table in a
|
|
connection. Returns [] if no differences are found. This
|
|
function does the best it can; it can miss many differences.
|
|
"""
|
|
# @@: Repeats a lot from CommandStatus.command, but it's hard
|
|
# to actually factor out the display logic. Or I'm too lazy
|
|
# to do so.
|
|
diffs = []
|
|
if not conn.tableExists(soClass.sqlmeta.table):
|
|
if soClass.sqlmeta.columns:
|
|
diffs.append('Does not exist in database')
|
|
else:
|
|
try:
|
|
columns = conn.columnsFromSchema(soClass.sqlmeta.table,
|
|
soClass)
|
|
except AttributeError:
|
|
# Database does not support reading columns
|
|
pass
|
|
else:
|
|
existing = {}
|
|
for col in columns:
|
|
col = col.withClass(soClass)
|
|
existing[col.dbName] = col
|
|
missing = {}
|
|
for col in soClass.sqlmeta.columnList:
|
|
if col.dbName in existing:
|
|
del existing[col.dbName]
|
|
else:
|
|
missing[col.dbName] = col
|
|
for col in existing.values():
|
|
diffs.append('Database has extra column: %s'
|
|
% col.dbName)
|
|
for col in missing.values():
|
|
diffs.append('Database missing column: %s' % col.dbName)
|
|
return diffs
|
|
|
|
class CommandRunner(object):
|
|
|
|
def __init__(self):
|
|
self.commands = {}
|
|
self.command_aliases = {}
|
|
|
|
def run(self, argv):
|
|
invoked_as = argv[0]
|
|
args = argv[1:]
|
|
for i in range(len(args)):
|
|
if not args[i].startswith('-'):
|
|
# this must be a command
|
|
command = args[i].lower()
|
|
del args[i]
|
|
break
|
|
else:
|
|
# no command found
|
|
self.invalid('No COMMAND given (try "%s help")'
|
|
% os.path.basename(invoked_as))
|
|
real_command = self.command_aliases.get(command, command)
|
|
if real_command not in self.commands.keys():
|
|
self.invalid('COMMAND %s unknown' % command)
|
|
runner = self.commands[real_command](
|
|
invoked_as, command, args, self)
|
|
runner.run()
|
|
|
|
def register(self, command):
|
|
name = command.name
|
|
self.commands[name] = command
|
|
for alias in command.aliases:
|
|
self.command_aliases[alias] = name
|
|
|
|
def invalid(self, msg, code=2):
|
|
print msg
|
|
sys.exit(code)
|
|
|
|
the_runner = CommandRunner()
|
|
register = the_runner.register
|
|
|
|
def standard_parser(connection=True, simulate=True,
|
|
interactive=False, find_modules=True):
|
|
parser = optparse.OptionParser()
|
|
parser.add_option('-v', '--verbose',
|
|
help='Be verbose (multiple times for more verbosity)',
|
|
action='count',
|
|
dest='verbose',
|
|
default=0)
|
|
if simulate:
|
|
parser.add_option('-n', '--simulate',
|
|
help="Don't actually do anything (implies -v)",
|
|
action='store_true',
|
|
dest='simulate')
|
|
if connection:
|
|
parser.add_option('-c', '--connection',
|
|
help="The database connection URI",
|
|
metavar='URI',
|
|
dest='connection_uri')
|
|
parser.add_option('-f', '--config-file',
|
|
help="The Paste config file that contains the database URI (in the database key)",
|
|
metavar="FILE",
|
|
dest="config_file")
|
|
if find_modules:
|
|
parser.add_option('-m', '--module',
|
|
help="Module in which to find SQLObject classes",
|
|
action='append',
|
|
metavar='MODULE',
|
|
dest='modules',
|
|
default=[])
|
|
parser.add_option('-p', '--package',
|
|
help="Package to search for SQLObject classes",
|
|
action="append",
|
|
metavar="PACKAGE",
|
|
dest="packages",
|
|
default=[])
|
|
parser.add_option('--class',
|
|
help="Select only named classes (wildcards allowed)",
|
|
action="append",
|
|
metavar="NAME",
|
|
dest="class_matchers",
|
|
default=[])
|
|
if interactive:
|
|
parser.add_option('-i', '--interactive',
|
|
help="Ask before doing anything (use twice to be more careful)",
|
|
action="count",
|
|
dest="interactive",
|
|
default=0)
|
|
parser.add_option('--egg',
|
|
help="Select modules from the given Egg, using sqlobject.txt",
|
|
action="append",
|
|
metavar="EGG_SPEC",
|
|
dest="eggs",
|
|
default=[])
|
|
return parser
|
|
|
|
class Command(object):
|
|
|
|
__metaclass__ = DeclarativeMeta
|
|
|
|
min_args = 0
|
|
min_args_error = 'You must provide at least %(min_args)s arguments'
|
|
max_args = 0
|
|
max_args_error = 'You must provide no more than %(max_args)s arguments'
|
|
aliases = ()
|
|
required_args = []
|
|
description = None
|
|
|
|
help = ''
|
|
|
|
def orderClassesByDependencyLevel(self, classes):
|
|
"""
|
|
Return classes ordered by their depth in the class dependency
|
|
tree (this is *not* the inheritance tree), from the
|
|
top level (independant) classes to the deepest level.
|
|
The dependency tree is defined by the foreign key relations.
|
|
"""
|
|
# @@: written as a self-contained function for now, to prevent
|
|
# having to modify any core SQLObject component and namespace
|
|
# contamination.
|
|
# yemartin - 2006-08-08
|
|
|
|
class SQLObjectCircularReferenceError(Exception): pass
|
|
|
|
def findReverseDependencies(cls):
|
|
"""
|
|
Return a list of classes that cls depends on. Note that
|
|
"depends on" here mean "has a foreign key pointing to".
|
|
"""
|
|
depended = []
|
|
for col in cls.sqlmeta.columnList:
|
|
if col.foreignKey:
|
|
other = findClass(col.foreignKey,
|
|
col.soClass.sqlmeta.registry)
|
|
if (other is not cls) and (other not in depended):
|
|
depended.append(other)
|
|
return depended
|
|
|
|
# Cache to save already calculated dependency levels.
|
|
dependency_levels = {}
|
|
def calculateDependencyLevel(cls, dependency_stack=[]):
|
|
"""
|
|
Recursively calculate the dependency level of cls, while
|
|
using the dependency_stack to detect any circular reference.
|
|
"""
|
|
# Return value from the cache if already calculated
|
|
if cls in dependency_levels:
|
|
return dependency_levels[cls]
|
|
# Check for circular references
|
|
if cls in dependency_stack:
|
|
dependency_stack.append(cls)
|
|
raise SQLObjectCircularReferenceError, (
|
|
"Found a circular reference: %s " %
|
|
(' --> '.join([x.__name__
|
|
for x in dependency_stack])))
|
|
dependency_stack.append(cls)
|
|
# Recursively inspect dependent classes.
|
|
depended = findReverseDependencies(cls)
|
|
if depended:
|
|
level = max([calculateDependencyLevel(x, dependency_stack)
|
|
for x in depended]) + 1
|
|
else:
|
|
level = 0
|
|
dependency_levels[cls] = level
|
|
return level
|
|
|
|
# Now simply calculate and sort by dependency levels:
|
|
try:
|
|
sorter = []
|
|
for cls in classes:
|
|
level = calculateDependencyLevel(cls)
|
|
sorter.append((level, cls))
|
|
sorter.sort()
|
|
ordered_classes = [cls for level, cls in sorter]
|
|
except SQLObjectCircularReferenceError, msg:
|
|
# Failsafe: return the classes as-is if a circular reference
|
|
# prevented the dependency levels to be calculated.
|
|
print ("Warning: a circular reference was detected in the "
|
|
"model. Unable to sort the classes by dependency: they "
|
|
"will be treated in alphabetic order. This may or may "
|
|
"not work depending on your database backend. "
|
|
"The error was:\n%s" % msg)
|
|
return classes
|
|
return ordered_classes
|
|
|
|
def __classinit__(cls, new_args):
|
|
if cls.__bases__ == (object,):
|
|
# This abstract base class
|
|
return
|
|
register(cls)
|
|
|
|
def __init__(self, invoked_as, command_name, args, runner):
|
|
self.invoked_as = invoked_as
|
|
self.command_name = command_name
|
|
self.raw_args = args
|
|
self.runner = runner
|
|
|
|
def run(self):
|
|
self.parser.usage = "%%prog [options]\n%s" % self.summary
|
|
if self.help:
|
|
help = textwrap.fill(
|
|
self.help, int(os.environ.get('COLUMNS', 80))-4)
|
|
self.parser.usage += '\n' + help
|
|
self.parser.prog = '%s %s' % (
|
|
os.path.basename(self.invoked_as),
|
|
self.command_name)
|
|
if self.description:
|
|
self.parser.description = description
|
|
self.options, self.args = self.parser.parse_args(self.raw_args)
|
|
if (getattr(self.options, 'simulate', False)
|
|
and not self.options.verbose):
|
|
self.options.verbose = 1
|
|
if self.min_args is not None and len(self.args) < self.min_args:
|
|
self.runner.invalid(
|
|
self.min_args_error % {'min_args': self.min_args,
|
|
'actual_args': len(self.args)})
|
|
if self.max_args is not None and len(self.args) > self.max_args:
|
|
self.runner.invalid(
|
|
self.max_args_error % {'max_args': self.max_args,
|
|
'actual_args': len(self.args)})
|
|
for var_name, option_name in self.required_args:
|
|
if not getattr(self.options, var_name, None):
|
|
self.runner.invalid(
|
|
'You must provide the option %s' % option_name)
|
|
conf = self.config()
|
|
if conf and conf.get('sys_path'):
|
|
update_sys_path(conf['sys_path'], self.options.verbose)
|
|
if conf and conf.get('database'):
|
|
conn = sqlobject.connectionForURI(conf['database'])
|
|
sqlobject.sqlhub.processConnection = conn
|
|
for egg_spec in getattr(self.options, 'eggs', []):
|
|
self.load_options_from_egg(egg_spec)
|
|
self.command()
|
|
|
|
def classes(self, require_connection=True,
|
|
require_some=False):
|
|
all = []
|
|
conf = self.config()
|
|
for module_name in self.options.modules:
|
|
all.extend(self.classes_from_module(
|
|
moduleloader.load_module(module_name)))
|
|
for package_name in self.options.packages:
|
|
all.extend(self.classes_from_package(package_name))
|
|
for egg_spec in self.options.eggs:
|
|
all.extend(self.classes_from_egg(egg_spec))
|
|
if self.options.class_matchers:
|
|
filtered = []
|
|
for soClass in all:
|
|
name = soClass.__name__
|
|
for matcher in self.options.class_matchers:
|
|
if fnmatch.fnmatch(name, matcher):
|
|
filtered.append(soClass)
|
|
break
|
|
all = filtered
|
|
conn = self.connection()
|
|
if conn:
|
|
for soClass in all:
|
|
soClass._connection = conn
|
|
else:
|
|
missing = []
|
|
for soClass in all:
|
|
try:
|
|
if not soClass._connection:
|
|
missing.append(soClass)
|
|
except AttributeError:
|
|
missing.append(soClass)
|
|
if missing and require_connection:
|
|
self.runner.invalid(
|
|
'These classes do not have connections set:\n * %s\n'
|
|
'You must indicate --connection=URI'
|
|
% '\n * '.join([soClass.__name__
|
|
for soClass in missing]))
|
|
if require_some and not all:
|
|
print 'No classes found!'
|
|
if self.options.modules:
|
|
print 'Looked in modules: %s' % ', '.join(self.options.modules)
|
|
else:
|
|
print 'No modules specified'
|
|
if self.options.packages:
|
|
print 'Looked in packages: %s' % ', '.join(self.options.packages)
|
|
else:
|
|
print 'No packages specified'
|
|
if self.options.class_matchers:
|
|
print 'Matching class pattern: %s' % self.options.class_matches
|
|
if self.options.eggs:
|
|
print 'Looked in eggs: %s' % ', '.join(self.options.eggs)
|
|
else:
|
|
print 'No eggs specified'
|
|
sys.exit(1)
|
|
return self.orderClassesByDependencyLevel(all)
|
|
|
|
def classes_from_module(self, module):
|
|
all = []
|
|
if hasattr(module, 'soClasses'):
|
|
for name_or_class in module.soClasses:
|
|
if isinstance(name_or_class, str):
|
|
name_or_class = getattr(module, name_or_class)
|
|
all.append(name_or_class)
|
|
else:
|
|
for name in dir(module):
|
|
value = getattr(module, name)
|
|
if (isinstance(value, type)
|
|
and issubclass(value, sqlobject.SQLObject)
|
|
and value.__module__ == module.__name__):
|
|
all.append(value)
|
|
return all
|
|
|
|
def connection(self):
|
|
config = self.config()
|
|
if config is not None:
|
|
assert config.get('database'), (
|
|
"No database variable found in config file %s"
|
|
% self.options.config_file)
|
|
return sqlobject.connectionForURI(config['database'])
|
|
elif getattr(self.options, 'connection_uri', None):
|
|
return sqlobject.connectionForURI(self.options.connection_uri)
|
|
else:
|
|
return None
|
|
|
|
def config(self):
|
|
if not getattr(self.options, 'config_file', None):
|
|
return None
|
|
config_file = self.options.config_file
|
|
if appconfig:
|
|
if (not config_file.startswith('egg:')
|
|
and not config_file.startswith('config:')):
|
|
config_file = 'config:' + config_file
|
|
return appconfig(config_file,
|
|
relative_to=os.getcwd())
|
|
else:
|
|
return self.ini_config(config_file)
|
|
|
|
def ini_config(self, conf_fn):
|
|
conf_section = 'main'
|
|
if '#' in conf_fn:
|
|
conf_fn, conf_section = conf_fn.split('#', 1)
|
|
|
|
from ConfigParser import ConfigParser
|
|
p = ConfigParser()
|
|
# Case-sensitive:
|
|
p.optionxform = str
|
|
if not os.path.exists(conf_fn):
|
|
# Stupid RawConfigParser doesn't give an error for
|
|
# non-existant files:
|
|
raise OSError(
|
|
"Config file %s does not exist" % self.options.config_file)
|
|
p.read([conf_fn])
|
|
p._defaults.setdefault(
|
|
'here', os.path.dirname(os.path.abspath(conf_fn)))
|
|
|
|
possible_sections = []
|
|
for section in p.sections():
|
|
name = section.strip().lower()
|
|
if (conf_section == name or
|
|
(conf_section == name.split(':')[-1]
|
|
and name.split(':')[0] in ('app', 'application'))):
|
|
possible_sections.append(section)
|
|
|
|
if not possible_sections:
|
|
raise OSError(
|
|
"Config file %s does not have a section [%s] or [*:%s]"
|
|
% (conf_fn, conf_section, conf_section))
|
|
if len(possible_sections) > 1:
|
|
raise OSError(
|
|
"Config file %s has multiple sections matching %s: %s"
|
|
% (conf_fn, conf_section, ', '.join(possible_sections)))
|
|
|
|
config = {}
|
|
for op in p.options(possible_sections[0]):
|
|
config[op] = p.get(possible_sections[0], op)
|
|
return config
|
|
|
|
def classes_from_package(self, package_name):
|
|
all = []
|
|
package = moduleloader.load_module(package_name)
|
|
package_dir = os.path.dirname(package.__file__)
|
|
|
|
def find_classes_in_file(arg, dir_name, filenames):
|
|
if dir_name.startswith('.svn'):
|
|
return
|
|
filenames = filter(lambda fname: fname.endswith('.py') and fname != '__init__.py',
|
|
filenames)
|
|
for fname in filenames:
|
|
module_name = os.path.join(dir_name, fname)
|
|
module_name = module_name[module_name.find(package_name):]
|
|
module_name = module_name.replace(os.path.sep,'.')[:-3]
|
|
try:
|
|
module = moduleloader.load_module(module_name)
|
|
except ImportError, err:
|
|
if self.options.verbose:
|
|
print 'Could not import module "%s". Error was : "%s"' % (module_name, err)
|
|
continue
|
|
except Exception, exc:
|
|
if self.options.verbose:
|
|
print 'Unknown exception while processing module "%s" : "%s"' % (module_name, exc)
|
|
continue
|
|
classes = self.classes_from_module(module)
|
|
all.extend(classes)
|
|
|
|
os.path.walk(package_dir, find_classes_in_file, None)
|
|
return all
|
|
|
|
def classes_from_egg(self, egg_spec):
|
|
modules = []
|
|
dist, conf = self.config_from_egg(egg_spec, warn_no_sqlobject=True)
|
|
for mod in conf.get('db_module', '').split(','):
|
|
mod = mod.strip()
|
|
if not mod:
|
|
continue
|
|
if self.options.verbose:
|
|
print 'Looking in module %s' % mod
|
|
modules.extend(self.classes_from_module(
|
|
moduleloader.load_module(mod)))
|
|
return modules
|
|
|
|
def load_options_from_egg(self, egg_spec):
|
|
dist, conf = self.config_from_egg(egg_spec)
|
|
if (hasattr(self.options, 'output_dir')
|
|
and not self.options.output_dir
|
|
and conf.get('history_dir')):
|
|
dir = conf['history_dir']
|
|
dir = dir.replace('$base', dist.location)
|
|
self.options.output_dir = dir
|
|
|
|
def config_from_egg(self, egg_spec, warn_no_sqlobject=True):
|
|
import pkg_resources
|
|
dist = pkg_resources.get_distribution(egg_spec)
|
|
if not dist.has_metadata('sqlobject.txt'):
|
|
if warn_no_sqlobject:
|
|
print 'No sqlobject.txt in %s egg info' % egg_spec
|
|
return None, {}
|
|
result = {}
|
|
for line in dist.get_metadata_lines('sqlobject.txt'):
|
|
line = line.strip()
|
|
if not line or line.startswith('#'):
|
|
continue
|
|
name, value = line.split('=', 1)
|
|
name = name.strip().lower()
|
|
if name in result:
|
|
print 'Warning: %s appears more than once in sqlobject.txt' % name
|
|
result[name.strip().lower()] = value.strip()
|
|
return dist, result
|
|
|
|
def command(self):
|
|
raise NotImplementedError
|
|
|
|
def _get_prog_name(self):
|
|
return os.path.basename(self.invoked_as)
|
|
prog_name = property(_get_prog_name)
|
|
|
|
def ask(self, prompt, safe=False, default=True):
|
|
if self.options.interactive >= 2:
|
|
default = safe
|
|
if default:
|
|
prompt += ' [Y/n]? '
|
|
else:
|
|
prompt += ' [y/N]? '
|
|
while 1:
|
|
response = raw_input(prompt).strip()
|
|
if not response.strip():
|
|
return default
|
|
if response and response[0].lower() in ('y', 'n'):
|
|
return response[0].lower() == 'y'
|
|
print 'Y or N please'
|
|
|
|
def shorten_filename(self, fn):
|
|
"""
|
|
Shortens a filename to make it relative to the current
|
|
directory (if it can). For display purposes.
|
|
"""
|
|
if fn.startswith(os.getcwd() + '/'):
|
|
fn = fn[len(os.getcwd())+1:]
|
|
return fn
|
|
|
|
def open_editor(self, pretext, breaker=None, extension='.txt'):
|
|
"""
|
|
Open an editor with the given text. Return the new text,
|
|
or None if no edits were made. If given, everything after
|
|
`breaker` will be ignored.
|
|
"""
|
|
fn = nowarning_tempnam() + extension
|
|
f = open(fn, 'w')
|
|
f.write(pretext)
|
|
f.close()
|
|
print '$EDITOR %s' % fn
|
|
os.system('$EDITOR %s' % fn)
|
|
f = open(fn, 'r')
|
|
content = f.read()
|
|
f.close()
|
|
if breaker:
|
|
content = content.split(breaker)[0]
|
|
pretext = pretext.split(breaker)[0]
|
|
if content == pretext or not content.strip():
|
|
return None
|
|
return content
|
|
|
|
class CommandSQL(Command):
|
|
|
|
name = 'sql'
|
|
summary = 'Show SQL CREATE statements'
|
|
|
|
parser = standard_parser(simulate=False)
|
|
|
|
def command(self):
|
|
classes = self.classes()
|
|
allConstraints = []
|
|
for cls in classes:
|
|
if self.options.verbose >= 1:
|
|
print '-- %s from %s' % (
|
|
cls.__name__, cls.__module__)
|
|
createSql, constraints = cls.createTableSQL()
|
|
print createSql.strip() + ';\n'
|
|
allConstraints.append(constraints)
|
|
for constraints in allConstraints:
|
|
if constraints:
|
|
for constraint in constraints:
|
|
if constraint:
|
|
print constraint.strip() + ';\n'
|
|
|
|
|
|
class CommandList(Command):
|
|
|
|
name = 'list'
|
|
summary = 'Show all SQLObject classes found'
|
|
|
|
parser = standard_parser(simulate=False, connection=False)
|
|
|
|
def command(self):
|
|
if self.options.verbose >= 1:
|
|
print 'Classes found:'
|
|
classes = self.classes(require_connection=False)
|
|
for soClass in classes:
|
|
print '%s.%s' % (soClass.__module__, soClass.__name__)
|
|
if self.options.verbose >= 1:
|
|
print ' Table: %s' % soClass.sqlmeta.table
|
|
|
|
class CommandCreate(Command):
|
|
|
|
name = 'create'
|
|
summary = 'Create tables'
|
|
|
|
parser = standard_parser(interactive=True)
|
|
parser.add_option('--create-db',
|
|
action='store_true',
|
|
dest='create_db',
|
|
help="Create the database")
|
|
|
|
def command(self):
|
|
v = self.options.verbose
|
|
created = 0
|
|
existing = 0
|
|
dbs_created = []
|
|
constraints = {}
|
|
for soClass in self.classes(require_some=True):
|
|
if (self.options.create_db
|
|
and soClass._connection not in dbs_created):
|
|
if not self.options.simulate:
|
|
try:
|
|
soClass._connection.createEmptyDatabase()
|
|
except soClass._connection.module.ProgrammingError, e:
|
|
if str(e).find('already exists') != -1:
|
|
print 'Database already exists'
|
|
else:
|
|
raise
|
|
else:
|
|
print '(simulating; cannot create database)'
|
|
dbs_created.append(soClass._connection)
|
|
if soClass._connection not in constraints.keys():
|
|
constraints[soClass._connection] = []
|
|
exists = soClass._connection.tableExists(soClass.sqlmeta.table)
|
|
if v >= 1:
|
|
if exists:
|
|
existing += 1
|
|
print '%s already exists.' % soClass.__name__
|
|
else:
|
|
print 'Creating %s' % soClass.__name__
|
|
if v >= 2:
|
|
sql, extra = soClass.createTableSQL()
|
|
print sql
|
|
if (not self.options.simulate
|
|
and not exists):
|
|
if self.options.interactive:
|
|
if self.ask('Create %s' % soClass.__name__):
|
|
created += 1
|
|
tableConstraints = soClass.createTable(applyConstraints=False)
|
|
if tableConstraints:
|
|
constraints[soClass._connection].append(tableConstraints)
|
|
else:
|
|
print 'Cancelled'
|
|
else:
|
|
created += 1
|
|
tableConstraints = soClass.createTable(applyConstraints=False)
|
|
if tableConstraints:
|
|
constraints[soClass._connection].append(tableConstraints)
|
|
for connection in constraints.keys():
|
|
if v >= 2:
|
|
print 'Creating constraints'
|
|
for constraintList in constraints[connection]:
|
|
for constraint in constraintList:
|
|
if constraint:
|
|
connection.query(constraint)
|
|
if v >= 1:
|
|
print '%i tables created (%i already exist)' % (
|
|
created, existing)
|
|
|
|
|
|
class CommandDrop(Command):
|
|
|
|
name = 'drop'
|
|
summary = 'Drop tables'
|
|
|
|
parser = standard_parser(interactive=True)
|
|
|
|
def command(self):
|
|
v = self.options.verbose
|
|
dropped = 0
|
|
not_existing = 0
|
|
for soClass in reversed(self.classes()):
|
|
exists = soClass._connection.tableExists(soClass.sqlmeta.table)
|
|
if v >= 1:
|
|
if exists:
|
|
print 'Dropping %s' % soClass.__name__
|
|
else:
|
|
not_existing += 1
|
|
print '%s does not exist.' % soClass.__name__
|
|
if (not self.options.simulate
|
|
and exists):
|
|
if self.options.interactive:
|
|
if self.ask('Drop %s' % soClass.__name__):
|
|
dropped += 1
|
|
soClass.dropTable()
|
|
else:
|
|
print 'Cancelled'
|
|
else:
|
|
dropped += 1
|
|
soClass.dropTable()
|
|
if v >= 1:
|
|
print '%i tables dropped (%i didn\'t exist)' % (
|
|
dropped, not_existing)
|
|
|
|
class CommandStatus(Command):
|
|
|
|
name = 'status'
|
|
summary = 'Show status of classes vs. database'
|
|
help = ('This command checks the SQLObject definition and checks if '
|
|
'the tables in the database match. It can always test for '
|
|
'missing tables, and on some databases can test for the '
|
|
'existance of other tables. Column types are not currently '
|
|
'checked.')
|
|
|
|
parser = standard_parser(simulate=False)
|
|
|
|
def print_class(self, soClass):
|
|
if self.printed:
|
|
return
|
|
self.printed = True
|
|
print 'Checking %s...' % soClass.__name__
|
|
|
|
def command(self):
|
|
good = 0
|
|
bad = 0
|
|
missing_tables = 0
|
|
columnsFromSchema_warning = False
|
|
for soClass in self.classes(require_some=True):
|
|
conn = soClass._connection
|
|
self.printed = False
|
|
if self.options.verbose:
|
|
self.print_class(soClass)
|
|
if not conn.tableExists(soClass.sqlmeta.table):
|
|
self.print_class(soClass)
|
|
print ' Does not exist in database'
|
|
missing_tables += 1
|
|
continue
|
|
try:
|
|
columns = conn.columnsFromSchema(soClass.sqlmeta.table,
|
|
soClass)
|
|
except AttributeError:
|
|
if not columnsFromSchema_warning:
|
|
print 'Database does not support reading columns'
|
|
columnsFromSchema_warning = True
|
|
good += 1
|
|
continue
|
|
except AssertionError, e:
|
|
print 'Cannot read db table %s: %s' % (
|
|
soClass.sqlmeta.table, e)
|
|
continue
|
|
existing = {}
|
|
for col in columns:
|
|
col = col.withClass(soClass)
|
|
existing[col.dbName] = col
|
|
missing = {}
|
|
for col in soClass.sqlmeta.columnList:
|
|
if col.dbName in existing:
|
|
del existing[col.dbName]
|
|
else:
|
|
missing[col.dbName] = col
|
|
if existing:
|
|
self.print_class(soClass)
|
|
for col in existing.values():
|
|
print ' Database has extra column: %s' % col.dbName
|
|
if missing:
|
|
self.print_class(soClass)
|
|
for col in missing.values():
|
|
print ' Database missing column: %s' % col.dbName
|
|
if existing or missing:
|
|
bad += 1
|
|
else:
|
|
good += 1
|
|
if self.options.verbose:
|
|
print '%i in sync; %i out of sync; %i not in database' % (
|
|
good, bad, missing_tables)
|
|
|
|
class CommandHelp(Command):
|
|
|
|
name = 'help'
|
|
summary = 'Show help'
|
|
|
|
parser = optparse.OptionParser()
|
|
|
|
max_args = 1
|
|
|
|
def command(self):
|
|
if self.args:
|
|
the_runner.run([self.invoked_as, self.args[0], '-h'])
|
|
else:
|
|
print 'Available commands:'
|
|
print ' (use "%s help COMMAND" or "%s COMMAND -h" ' % (
|
|
self.prog_name, self.prog_name)
|
|
print ' for more information)'
|
|
items = the_runner.commands.items()
|
|
items.sort()
|
|
max_len = max([len(cn) for cn, c in items])
|
|
for command_name, command in items:
|
|
print '%s:%s %s' % (command_name,
|
|
' '*(max_len-len(command_name)),
|
|
command.summary)
|
|
if command.aliases:
|
|
print '%s (Aliases: %s)' % (
|
|
' '*max_len, ', '.join(command.aliases))
|
|
|
|
class CommandExecute(Command):
|
|
|
|
name = 'execute'
|
|
summary = 'Execute SQL statements'
|
|
help = ('Runs SQL statements directly in the database, with no '
|
|
'intervention. Useful when used with a configuration file. '
|
|
'Each argument is executed as an individual statement.')
|
|
|
|
parser = standard_parser(find_modules=False)
|
|
parser.add_option('--stdin',
|
|
help="Read SQL from stdin (normally takes SQL from the command line)",
|
|
dest="use_stdin",
|
|
action="store_true")
|
|
|
|
max_args = None
|
|
|
|
def command(self):
|
|
args = self.args
|
|
if self.options.use_stdin:
|
|
if self.options.verbose:
|
|
print "Reading additional SQL from stdin (Ctrl-D or Ctrl-Z to finish)..."
|
|
args.append(sys.stdin.read())
|
|
self.conn = self.connection().getConnection()
|
|
self.cursor = self.conn.cursor()
|
|
for sql in args:
|
|
self.execute_sql(sql)
|
|
|
|
def execute_sql(self, sql):
|
|
if self.options.verbose:
|
|
print sql
|
|
try:
|
|
self.cursor.execute(sql)
|
|
except Exception, e:
|
|
if not self.options.verbose:
|
|
print sql
|
|
print "****Error:"
|
|
print ' ', e
|
|
return
|
|
desc = self.cursor.description
|
|
rows = self.cursor.fetchall()
|
|
if self.options.verbose:
|
|
if not self.cursor.rowcount:
|
|
print "No rows accessed"
|
|
else:
|
|
print "%i rows accessed" % self.cursor.rowcount
|
|
if desc:
|
|
for name, type_code, display_size, internal_size, precision, scale, null_ok in desc:
|
|
sys.stdout.write("%s\t" % name)
|
|
sys.stdout.write("\n")
|
|
for row in rows:
|
|
for col in row:
|
|
sys.stdout.write("%r\t" % col)
|
|
sys.stdout.write("\n")
|
|
print
|
|
|
|
class CommandRecord(Command):
|
|
|
|
name = 'record'
|
|
summary = 'Record historical information about the database status'
|
|
help = ('Record state of table definitions. The state of each '
|
|
'table is written out to a separate file in a directory, '
|
|
'and that directory forms a "version". A table is also '
|
|
'added to your database (%s) that reflects the version the '
|
|
'database is currently at. Use the upgrade command to '
|
|
'sync databases with code.'
|
|
% SQLObjectVersionTable.sqlmeta.table)
|
|
|
|
parser = standard_parser()
|
|
parser.add_option('--output-dir',
|
|
help="Base directory for recorded definitions",
|
|
dest="output_dir",
|
|
metavar="DIR",
|
|
default=None)
|
|
parser.add_option('--no-db-record',
|
|
help="Don't record version to database",
|
|
dest="db_record",
|
|
action="store_false",
|
|
default=True)
|
|
parser.add_option('--force-create',
|
|
help="Create a new version even if appears to be "
|
|
"identical to the last version",
|
|
action="store_true",
|
|
dest="force_create")
|
|
parser.add_option('--name',
|
|
help="The name to append to the version. The "
|
|
"version should sort after previous versions (so "
|
|
"any versions from the same day should come "
|
|
"alphabetically before this version).",
|
|
dest="version_name",
|
|
metavar="NAME")
|
|
parser.add_option('--force-db-version',
|
|
help="Update the database version, and include no "
|
|
"database information. This is for databases that "
|
|
"were developed without any interaction with "
|
|
"this tool, to create a 'beginning' revision.",
|
|
metavar="VERSION_NAME",
|
|
dest="force_db_version")
|
|
parser.add_option('--edit',
|
|
help="Open an editor for the upgrader in the last "
|
|
"version (using $EDITOR).",
|
|
action="store_true",
|
|
dest="open_editor")
|
|
|
|
version_regex = re.compile(r'^\d\d\d\d-\d\d-\d\d')
|
|
|
|
def command(self):
|
|
if self.options.force_db_version:
|
|
self.command_force_db_version()
|
|
return
|
|
|
|
v = self.options.verbose
|
|
sim = self.options.simulate
|
|
classes = self.classes()
|
|
if not classes:
|
|
print "No classes found!"
|
|
return
|
|
|
|
output_dir = self.find_output_dir()
|
|
version = os.path.basename(output_dir)
|
|
print "Creating version %s" % version
|
|
conns = []
|
|
files = {}
|
|
for cls in self.classes():
|
|
dbName = cls._connection.dbName
|
|
if cls._connection not in conns:
|
|
conns.append(cls._connection)
|
|
fn = os.path.join(cls.__name__
|
|
+ '_' + dbName + '.sql')
|
|
if sim:
|
|
continue
|
|
create, constraints = cls.createTableSQL()
|
|
if constraints:
|
|
constraints = '\n-- Constraints:\n%s\n' % (
|
|
'\n'.join(constraints))
|
|
else:
|
|
constraints = ''
|
|
files[fn] = ''.join([
|
|
'-- Exported definition from %s\n'
|
|
% time.strftime('%Y-%m-%dT%H:%M:%S'),
|
|
'-- Class %s.%s\n'
|
|
% (cls.__module__, cls.__name__),
|
|
'-- Database: %s\n'
|
|
% dbName,
|
|
create.strip(),
|
|
'\n',
|
|
constraints])
|
|
last_version_dir = self.find_last_version()
|
|
if last_version_dir and not self.options.force_create:
|
|
if v > 1:
|
|
print "Checking %s to see if it is current" % last_version_dir
|
|
files_copy = files.copy()
|
|
for fn in os.listdir(last_version_dir):
|
|
if not fn.endswith('.sql'):
|
|
continue
|
|
if not fn in files_copy:
|
|
if v > 1:
|
|
print "Missing file %s" % fn
|
|
break
|
|
f = open(os.path.join(last_version_dir, fn), 'r')
|
|
content = f.read()
|
|
f.close()
|
|
if (self.strip_comments(files_copy[fn])
|
|
!= self.strip_comments(content)):
|
|
if v > 1:
|
|
print "Content does not match: %s" % fn
|
|
break
|
|
del files_copy[fn]
|
|
else:
|
|
# No differences so far
|
|
if not files_copy:
|
|
# Used up all files
|
|
print ("Current status matches version %s"
|
|
% os.path.basename(last_version_dir))
|
|
return
|
|
if v > 1:
|
|
print "Extra files: %s" % ', '.join(files_copy.keys())
|
|
if v:
|
|
print ("Current state does not match %s"
|
|
% os.path.basename(last_version_dir))
|
|
if v > 1 and not last_version_dir:
|
|
print "No last version to check"
|
|
if not sim:
|
|
os.mkdir(output_dir)
|
|
if v:
|
|
print 'Making directory %s' % self.shorten_filename(output_dir)
|
|
files = files.items()
|
|
files.sort()
|
|
for fn, content in files:
|
|
if v:
|
|
print ' Writing %s' % self.shorten_filename(fn)
|
|
if not sim:
|
|
f = open(os.path.join(output_dir, fn), 'w')
|
|
f.write(content)
|
|
f.close()
|
|
if self.options.db_record:
|
|
all_diffs = []
|
|
for cls in self.classes():
|
|
for conn in conns:
|
|
diffs = db_differences(cls, conn)
|
|
for diff in diffs:
|
|
if len(conns) > 1:
|
|
diff = ' (%s).%s: %s' % (
|
|
conn.uri(), cls.sqlmeta.table, diff)
|
|
else:
|
|
diff = ' %s: %s' % (cls.sqlmeta.table, diff)
|
|
all_diffs.append(diff)
|
|
if all_diffs:
|
|
print 'Database does not match schema:'
|
|
print '\n'.join(all_diffs)
|
|
for conn in conns:
|
|
self.update_db(version, conn)
|
|
else:
|
|
all_diffs = []
|
|
if self.options.open_editor:
|
|
if not last_version_dir:
|
|
print ("Cannot edit upgrader because there is no "
|
|
"previous version")
|
|
else:
|
|
breaker = ('-'*20 + ' lines below this will be ignored '
|
|
+ '-'*20)
|
|
pre_text = breaker + '\n' + '\n'.join(all_diffs)
|
|
text = self.open_editor('\n\n' + pre_text, breaker=breaker,
|
|
extension='.sql')
|
|
if text is not None:
|
|
fn = os.path.join(last_version_dir,
|
|
'upgrade_%s_%s.sql' %
|
|
(dbName, version))
|
|
f = open(fn, 'w')
|
|
f.write(text)
|
|
f.close()
|
|
print 'Wrote to %s' % fn
|
|
|
|
def update_db(self, version, conn):
|
|
v = self.options.verbose
|
|
if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
|
|
if v:
|
|
print ('Creating table %s'
|
|
% SQLObjectVersionTable.sqlmeta.table)
|
|
sql = SQLObjectVersionTable.createTableSQL(connection=conn)
|
|
if v > 1:
|
|
print sql
|
|
if not self.options.simulate:
|
|
SQLObjectVersionTable.createTable(connection=conn)
|
|
if not self.options.simulate:
|
|
SQLObjectVersionTable.clearTable(connection=conn)
|
|
SQLObjectVersionTable(
|
|
version=version,
|
|
connection=conn)
|
|
|
|
def strip_comments(self, sql):
|
|
lines = [l for l in sql.splitlines()
|
|
if not l.strip().startswith('--')]
|
|
return '\n'.join(lines)
|
|
|
|
def base_dir(self):
|
|
base = self.options.output_dir
|
|
if base is None:
|
|
base = CONFIG.get('sqlobject_history_dir', '.')
|
|
if not os.path.exists(base):
|
|
print 'Creating history directory %s' % self.shorten_filename(base)
|
|
if not self.options.simulate:
|
|
os.makedirs(base)
|
|
return base
|
|
|
|
def find_output_dir(self):
|
|
today = time.strftime('%Y-%m-%d', time.localtime())
|
|
if self.options.version_name:
|
|
dir = os.path.join(self.base_dir(), today + '-' +
|
|
self.options.version_name)
|
|
if os.path.exists(dir):
|
|
print ("Error, directory already exists: %s"
|
|
% dir)
|
|
sys.exit(1)
|
|
return dir
|
|
extra = ''
|
|
while 1:
|
|
dir = os.path.join(self.base_dir(), today + extra)
|
|
if not os.path.exists(dir):
|
|
return dir
|
|
if not extra:
|
|
extra = 'a'
|
|
else:
|
|
extra = chr(ord(extra)+1)
|
|
|
|
def find_last_version(self):
|
|
names = []
|
|
for fn in os.listdir(self.base_dir()):
|
|
if not self.version_regex.search(fn):
|
|
continue
|
|
names.append(fn)
|
|
if not names:
|
|
return None
|
|
names.sort()
|
|
return os.path.join(self.base_dir(), names[-1])
|
|
|
|
def command_force_db_version(self):
|
|
v = self.options.verbose
|
|
sim = self.options.simulate
|
|
version = self.options.force_db_version
|
|
if not self.version_regex.search(version):
|
|
print "Versions must be in the format YYYY-MM-DD..."
|
|
print "You version %s does not fit this" % version
|
|
return
|
|
version_dir = os.path.join(self.base_dir(), version)
|
|
if not os.path.exists(version_dir):
|
|
if v:
|
|
print 'Creating %s' % self.shorten_filename(version_dir)
|
|
if not sim:
|
|
os.mkdir(version_dir)
|
|
elif v:
|
|
print ('Directory %s exists'
|
|
% self.shorten_filename(version_dir))
|
|
if self.options.db_record:
|
|
self.update_db(version, self.connection())
|
|
|
|
class CommandUpgrade(CommandRecord):
|
|
|
|
name = 'upgrade'
|
|
summary = 'Update the database to a new version (as created by record)'
|
|
help = ('This command runs scripts (that you write by hand) to '
|
|
'upgrade a database. The database\'s current version is in '
|
|
'the sqlobject_version table (use record --force-db-version '
|
|
'if a database does not have a sqlobject_version table), '
|
|
'and upgrade scripts are in the version directory you are '
|
|
'upgrading FROM, named upgrade_DBNAME_VERSION.sql, like '
|
|
'"upgrade_mysql_2004-12-01b.sql".')
|
|
|
|
parser = standard_parser(find_modules=False)
|
|
parser.add_option('--upgrade-to',
|
|
help="Upgrade to the given version (default: newest version)",
|
|
dest="upgrade_to",
|
|
metavar="VERSION")
|
|
parser.add_option('--output-dir',
|
|
help="Base directory for recorded definitions",
|
|
dest="output_dir",
|
|
metavar="DIR",
|
|
default=None)
|
|
|
|
upgrade_regex = re.compile(r'^upgrade_([a-z]*)_([^.]*)\.sql$', re.I)
|
|
|
|
def command(self):
|
|
v = self.options.verbose
|
|
sim = self.options.simulate
|
|
if self.options.upgrade_to:
|
|
version_to = self.options.upgrade_to
|
|
else:
|
|
fname = self.find_last_version()
|
|
if fname is None:
|
|
print "No version exists, use 'record' command to create one"
|
|
return
|
|
version_to = os.path.basename(fname)
|
|
current = self.current_version()
|
|
if v:
|
|
print 'Current version: %s' % current
|
|
version_list = self.make_plan(current, version_to)
|
|
if not version_list:
|
|
print 'Database up to date'
|
|
return
|
|
if v:
|
|
print 'Plan:'
|
|
for next_version, upgrader in version_list:
|
|
print ' Use %s to upgrade to %s' % (
|
|
self.shorten_filename(upgrader), next_version)
|
|
conn = self.connection()
|
|
for next_version, upgrader in version_list:
|
|
f = open(upgrader)
|
|
sql = f.read()
|
|
f.close()
|
|
if v:
|
|
print "Running:"
|
|
print sql
|
|
print '-'*60
|
|
if not sim:
|
|
try:
|
|
conn.query(sql)
|
|
except:
|
|
print "Error in script: %s" % upgrader
|
|
raise
|
|
self.update_db(next_version, conn)
|
|
print 'Done.'
|
|
|
|
|
|
def current_version(self):
|
|
conn = self.connection()
|
|
if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
|
|
print 'No sqlobject_version table!'
|
|
sys.exit(1)
|
|
versions = list(SQLObjectVersionTable.select(connection=conn))
|
|
if not versions:
|
|
print 'No rows in sqlobject_version!'
|
|
sys.exit(1)
|
|
if len(versions) > 1:
|
|
print 'Ambiguous sqlobject_version_table'
|
|
sys.exit(1)
|
|
return versions[0].version
|
|
|
|
def make_plan(self, current, dest):
|
|
if current == dest:
|
|
return []
|
|
dbname = self.connection().dbName
|
|
next_version, upgrader = self.best_upgrade(current, dest, dbname)
|
|
if not upgrader:
|
|
print 'No way to upgrade from %s to %s' % (current, dest)
|
|
print ('(you need a %s/upgrade_%s_%s.sql script)'
|
|
% (current, dbname, dest))
|
|
sys.exit(1)
|
|
plan = [(next_version, upgrader)]
|
|
if next_version == dest:
|
|
return plan
|
|
else:
|
|
return plan + self.make_plan(next_version, dest)
|
|
|
|
def best_upgrade(self, current, dest, target_dbname):
|
|
current_dir = os.path.join(self.base_dir(), current)
|
|
if self.options.verbose > 1:
|
|
print ('Looking in %s for upgraders'
|
|
% self.shorten_filename(current_dir))
|
|
upgraders = []
|
|
for fn in os.listdir(current_dir):
|
|
match = self.upgrade_regex.search(fn)
|
|
if not match:
|
|
if self.options.verbose > 1:
|
|
print 'Not an upgrade script: %s' % fn
|
|
continue
|
|
dbname = match.group(1)
|
|
version = match.group(2)
|
|
if dbname != target_dbname:
|
|
if self.options.verbose > 1:
|
|
print 'Not for this database: %s (want %s)' % (
|
|
dbname, target_dbname)
|
|
continue
|
|
if version > dest:
|
|
if self.options.verbose > 1:
|
|
print 'Version too new: %s (only want %s)' % (
|
|
version, dest)
|
|
upgraders.append((version, os.path.join(current_dir, fn)))
|
|
if not upgraders:
|
|
if self.options.verbose > 1:
|
|
print 'No upgraders found in %s' % current_dir
|
|
return None, None
|
|
upgraders.sort()
|
|
return upgraders[-1]
|
|
|
|
def update_sys_path(paths, verbose):
|
|
if isinstance(paths, basestring):
|
|
paths = [paths]
|
|
for path in paths:
|
|
path = os.path.abspath(path)
|
|
if path not in sys.path:
|
|
if verbose > 1:
|
|
print 'Adding %s to path' % path
|
|
sys.path.insert(0, path)
|
|
|
|
if __name__ == '__main__':
|
|
the_runner.run(sys.argv)
|