| 1 | import os, re, sys |
|---|
| 2 | sys.path.append('contrib') |
|---|
| 3 | sys.path.append(os.getcwd()) |
|---|
| 4 | from django.core.management import setup_environ |
|---|
| 5 | import settings |
|---|
| 6 | setup_environ(settings) |
|---|
| 7 | |
|---|
| 8 | try: |
|---|
| 9 | app = sys.argv[1] |
|---|
| 10 | except IndexError: |
|---|
| 11 | sys.stderr.write("Usage: evolve.py <appname>") |
|---|
| 12 | sys.exit(1) |
|---|
| 13 | |
|---|
| 14 | assert settings.DATABASE_ENGINE == 'sqlite3' |
|---|
| 15 | database = settings.DATABASE_NAME |
|---|
| 16 | |
|---|
| 17 | class Table(object): |
|---|
| 18 | name_line = re.compile(r'CREATE TABLE "(.+)" \(') |
|---|
| 19 | field_line = re.compile(r'^"([^"]+)".*') |
|---|
| 20 | |
|---|
| 21 | def __init__(self, sql): |
|---|
| 22 | lines = sql.splitlines() |
|---|
| 23 | self.name = self.getTableName(lines[0]) |
|---|
| 24 | self.old_fields = {} |
|---|
| 25 | for line in lines[1:-1]: # remove create table and closing bracket |
|---|
| 26 | line = line.strip().strip(',') |
|---|
| 27 | self.old_fields[self.getFieldName(line)] = line |
|---|
| 28 | self.new_fields = [] |
|---|
| 29 | |
|---|
| 30 | def new(self, sql): |
|---|
| 31 | self.new_fields = [line.strip().strip(',') for line in sql.splitlines()[1:-1]] # remove create table and closing bracket |
|---|
| 32 | |
|---|
| 33 | |
|---|
| 34 | def getTableName(self, sql): |
|---|
| 35 | return self.name_line.search(sql).groups()[0] |
|---|
| 36 | def getFieldName(self, sql): |
|---|
| 37 | s = self.field_line.search(sql) |
|---|
| 38 | if not s: |
|---|
| 39 | return None |
|---|
| 40 | else: |
|---|
| 41 | return s.groups()[0] |
|---|
| 42 | def getFieldOption(self, field, option): |
|---|
| 43 | a, m = self.name.split('_', 1) |
|---|
| 44 | module = __import__('%s.models' % app, {}, {}, [m.capitalize()]) |
|---|
| 45 | model = getattr(module, m.capitalize()) |
|---|
| 46 | f = model._meta.get_field(field) |
|---|
| 47 | return getattr(f, option, None) |
|---|
| 48 | def getFieldAka(self, field): |
|---|
| 49 | return self.getFieldOption(field, 'aka') |
|---|
| 50 | def getFieldDefault(self, field): |
|---|
| 51 | return self.getFieldOption(field, 'default') |
|---|
| 52 | |
|---|
| 53 | def getSql(self): |
|---|
| 54 | name = self.name |
|---|
| 55 | |
|---|
| 56 | new_fields = [] |
|---|
| 57 | old_fields = [] |
|---|
| 58 | new_columns = [] |
|---|
| 59 | |
|---|
| 60 | for field_sql in self.new_fields: |
|---|
| 61 | field_name = self.getFieldName(field_sql) |
|---|
| 62 | if not field_name: continue |
|---|
| 63 | new_fields.append(field_sql) |
|---|
| 64 | if field_name not in self.old_fields: |
|---|
| 65 | aka = self.getFieldAka(field_name) |
|---|
| 66 | if aka: |
|---|
| 67 | # field renamed |
|---|
| 68 | old_fields.append('"%s"' % aka) |
|---|
| 69 | else: |
|---|
| 70 | # field added |
|---|
| 71 | old_fields.append('"%s"' % field_name) |
|---|
| 72 | default = self.getFieldDefault(field_name) |
|---|
| 73 | if default is not None: |
|---|
| 74 | new_columns.append('ALTER TABLE "%s__temp__" ADD COLUMN %s DEFAULT "%s";' % (self.name, field_sql, default)) |
|---|
| 75 | |
|---|
| 76 | else: |
|---|
| 77 | new_columns.append('ALTER TABLE "%s__temp__" ADD COLUMN %s;' % (self.name, field_sql)) |
|---|
| 78 | else: |
|---|
| 79 | # field not changed (or slightly tweaked) |
|---|
| 80 | old_fields.append('"%s"' % field_name) |
|---|
| 81 | |
|---|
| 82 | return """ |
|---|
| 83 | ALTER TABLE "%(name)s" RENAME TO "%(name)s__temp__"; |
|---|
| 84 | |
|---|
| 85 | %(new_columns)s |
|---|
| 86 | |
|---|
| 87 | CREATE TABLE "%(name)s" ( |
|---|
| 88 | %(new_fields)s |
|---|
| 89 | ); |
|---|
| 90 | |
|---|
| 91 | INSERT INTO "%(name)s" SELECT %(old_fields)s FROM "%(name)s__temp__"; |
|---|
| 92 | DROP TABLE "%(name)s__temp__"; |
|---|
| 93 | """ % dict(name=name, new_columns='\n'.join(new_columns), new_fields=',\n'.join(new_fields), old_fields=', '.join(old_fields)) |
|---|
| 94 | |
|---|
| 95 | if __name__=='__main__': |
|---|
| 96 | new_tables_sql = os.popen('manage.py sql %s' % app).read() |
|---|
| 97 | old_tables_sql = os.popen('sqlite3 %s ".schema %s_%%"' % (database, app)).read() |
|---|
| 98 | |
|---|
| 99 | # remove control characters |
|---|
| 100 | ctrl = re.compile(r'[\000-\011\013-\037]') |
|---|
| 101 | new_tables_sql = ctrl.sub('', new_tables_sql) |
|---|
| 102 | old_tables_sql = ctrl.sub('', old_tables_sql) |
|---|
| 103 | |
|---|
| 104 | new_tables_sql = new_tables_sql.split(';')[1:-2] # remove begin, commit and empty ; at the end |
|---|
| 105 | old_tables_sql = old_tables_sql.split(';')[:-1] # remove empty ; at the end |
|---|
| 106 | |
|---|
| 107 | tables = {} |
|---|
| 108 | |
|---|
| 109 | for sql in old_tables_sql: |
|---|
| 110 | if 'CREATE INDEX' in sql: continue |
|---|
| 111 | t = Table(sql.strip()) |
|---|
| 112 | tables[t.name] = t |
|---|
| 113 | |
|---|
| 114 | for sql in new_tables_sql: |
|---|
| 115 | if 'CREATE INDEX' in sql: continue |
|---|
| 116 | t = Table(sql.strip()) |
|---|
| 117 | tables[t.name].new(sql.strip()) |
|---|
| 118 | |
|---|
| 119 | for table in tables.values(): |
|---|
| 120 | print table.getSql() |
|---|
| 121 | |
|---|
| 122 | |
|---|
| 123 | |
|---|