root/django/trunk/contrib/evolve.py

Revision 92, 4.2 kB (checked in by steadicat, 15 months ago)

Evolve parser now ignores indexes and unique constraints.

  • Property svn:keywords set to Id
Line 
1import os, re, sys
2sys.path.append('contrib')
3sys.path.append(os.getcwd())
4from django.core.management import setup_environ
5import settings
6setup_environ(settings)
7
8try:
9    app = sys.argv[1]
10except IndexError:
11    sys.stderr.write("Usage: evolve.py <appname>")
12    sys.exit(1)
13
14assert settings.DATABASE_ENGINE == 'sqlite3'
15database = settings.DATABASE_NAME
16
17class Table(object):
18    name_line = re.compile(r'CREATE TABLE "(.+)" \(')
19    field_line = re.compile(r'^"([^"]+)".*')
20
21    def __init__(self, sql):
22        lines = sql.splitlines()
23        self.name = self.getTableName(lines[0])
24        self.old_fields = {}
25        for line in lines[1:-1]: # remove create table and closing bracket
26            line = line.strip().strip(',')
27            self.old_fields[self.getFieldName(line)] = line
28        self.new_fields = []
29
30    def new(self, sql):
31        self.new_fields = [line.strip().strip(',') for line in sql.splitlines()[1:-1]] # remove create table and closing bracket
32
33
34    def getTableName(self, sql):
35        return self.name_line.search(sql).groups()[0]
36    def getFieldName(self, sql):
37        s = self.field_line.search(sql)
38        if not s:
39            return None
40        else:
41            return s.groups()[0]
42    def getFieldOption(self, field, option):
43        a, m = self.name.split('_', 1)
44        module = __import__('%s.models' % app, {}, {}, [m.capitalize()])
45        model = getattr(module, m.capitalize())
46        f = model._meta.get_field(field)
47        return getattr(f, option, None)
48    def getFieldAka(self, field):
49        return self.getFieldOption(field, 'aka')
50    def getFieldDefault(self, field):
51        return self.getFieldOption(field, 'default')
52
53    def getSql(self):
54        name = self.name
55
56        new_fields = []
57        old_fields = []
58        new_columns = []
59
60        for field_sql in self.new_fields:
61            field_name = self.getFieldName(field_sql)
62            if not field_name: continue
63            new_fields.append(field_sql)
64            if field_name not in self.old_fields:
65                aka = self.getFieldAka(field_name)
66                if aka:
67                    # field renamed
68                    old_fields.append('"%s"' % aka)
69                else:
70                    # field added
71                    old_fields.append('"%s"' % field_name)
72                    default = self.getFieldDefault(field_name)
73                    if default is not None:
74                        new_columns.append('ALTER TABLE "%s__temp__" ADD COLUMN %s DEFAULT "%s";' % (self.name, field_sql, default))
75
76                    else:
77                        new_columns.append('ALTER TABLE "%s__temp__" ADD COLUMN %s;' % (self.name, field_sql))
78            else:
79                # field not changed (or slightly tweaked)
80                old_fields.append('"%s"' % field_name)
81
82        return """
83        ALTER TABLE "%(name)s" RENAME TO "%(name)s__temp__";
84
85        %(new_columns)s
86
87        CREATE TABLE "%(name)s" (
88           %(new_fields)s
89        );
90
91        INSERT INTO "%(name)s" SELECT %(old_fields)s FROM "%(name)s__temp__";
92        DROP TABLE "%(name)s__temp__";
93        """ % dict(name=name, new_columns='\n'.join(new_columns), new_fields=',\n'.join(new_fields), old_fields=', '.join(old_fields))
94
95if __name__=='__main__':
96    new_tables_sql = os.popen('manage.py sql %s' % app).read()
97    old_tables_sql = os.popen('sqlite3 %s ".schema %s_%%"' % (database, app)).read()
98
99    # remove control characters
100    ctrl = re.compile(r'[\000-\011\013-\037]')
101    new_tables_sql = ctrl.sub('', new_tables_sql)
102    old_tables_sql = ctrl.sub('', old_tables_sql)
103
104    new_tables_sql = new_tables_sql.split(';')[1:-2] # remove begin, commit and empty ; at the end
105    old_tables_sql = old_tables_sql.split(';')[:-1] # remove empty ; at the end
106
107    tables = {}
108
109    for sql in old_tables_sql:
110        if 'CREATE INDEX' in sql: continue
111        t = Table(sql.strip())
112        tables[t.name] = t
113
114    for sql in new_tables_sql:
115        if 'CREATE INDEX' in sql: continue
116        t = Table(sql.strip())
117        tables[t.name].new(sql.strip())
118
119    for table in tables.values():
120        print table.getSql()
121
122
123
Note: See TracBrowser for help on using the browser.