Initial commit. Basic models mostly done.
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .migration import Migration, swappable_dependency # NOQA
|
||||
from .operations import * # NOQA
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,54 @@
|
||||
from django.db.utils import DatabaseError
|
||||
|
||||
|
||||
class AmbiguityError(Exception):
|
||||
"""More than one migration matches a name prefix."""
|
||||
pass
|
||||
|
||||
|
||||
class BadMigrationError(Exception):
|
||||
"""There's a bad migration (unreadable/bad format/etc.)."""
|
||||
pass
|
||||
|
||||
|
||||
class CircularDependencyError(Exception):
|
||||
"""There's an impossible-to-resolve circular dependency."""
|
||||
pass
|
||||
|
||||
|
||||
class InconsistentMigrationHistory(Exception):
|
||||
"""An applied migration has some of its dependencies not applied."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidBasesError(ValueError):
|
||||
"""A model's base classes can't be resolved."""
|
||||
pass
|
||||
|
||||
|
||||
class IrreversibleError(RuntimeError):
|
||||
"""An irreversible migration is about to be reversed."""
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(LookupError):
|
||||
"""An attempt on a node is made that is not available in the graph."""
|
||||
|
||||
def __init__(self, message, node, origin=None):
|
||||
self.message = message
|
||||
self.origin = origin
|
||||
self.node = node
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "NodeNotFoundError(%r)" % (self.node,)
|
||||
|
||||
|
||||
class MigrationSchemaMissing(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMigrationPlan(ValueError):
|
||||
pass
|
||||
@@ -0,0 +1,376 @@
|
||||
from django.apps.registry import apps as global_apps
|
||||
from django.db import migrations, router
|
||||
|
||||
from .exceptions import InvalidMigrationPlan
|
||||
from .loader import MigrationLoader
|
||||
from .recorder import MigrationRecorder
|
||||
from .state import ProjectState
|
||||
|
||||
|
||||
class MigrationExecutor:
|
||||
"""
|
||||
End-to-end migration execution - load migrations and run them up or down
|
||||
to a specified set of targets.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, progress_callback=None):
|
||||
self.connection = connection
|
||||
self.loader = MigrationLoader(self.connection)
|
||||
self.recorder = MigrationRecorder(self.connection)
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
def migration_plan(self, targets, clean_start=False):
|
||||
"""
|
||||
Given a set of targets, return a list of (Migration instance, backwards?).
|
||||
"""
|
||||
plan = []
|
||||
if clean_start:
|
||||
applied = {}
|
||||
else:
|
||||
applied = dict(self.loader.applied_migrations)
|
||||
for target in targets:
|
||||
# If the target is (app_label, None), that means unmigrate everything
|
||||
if target[1] is None:
|
||||
for root in self.loader.graph.root_nodes():
|
||||
if root[0] == target[0]:
|
||||
for migration in self.loader.graph.backwards_plan(root):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
# If the migration is already applied, do backwards mode,
|
||||
# otherwise do forwards mode.
|
||||
elif target in applied:
|
||||
# Don't migrate backwards all the way to the target node (that
|
||||
# may roll back dependencies in other apps that don't need to
|
||||
# be rolled back); instead roll back through target's immediate
|
||||
# child(ren) in the same app, and no further.
|
||||
next_in_app = sorted(
|
||||
n for n in
|
||||
self.loader.graph.node_map[target].children
|
||||
if n[0] == target[0]
|
||||
)
|
||||
for node in next_in_app:
|
||||
for migration in self.loader.graph.backwards_plan(node):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
else:
|
||||
for migration in self.loader.graph.forwards_plan(target):
|
||||
if migration not in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], False))
|
||||
applied[migration] = self.loader.graph.nodes[migration]
|
||||
return plan
|
||||
|
||||
def _create_project_state(self, with_applied_migrations=False):
|
||||
"""
|
||||
Create a project state including all the applications without
|
||||
migrations and applied migrations if with_applied_migrations=True.
|
||||
"""
|
||||
state = ProjectState(real_apps=list(self.loader.unmigrated_apps))
|
||||
if with_applied_migrations:
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
for migration, _ in full_plan:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
return state
|
||||
|
||||
def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
|
||||
"""
|
||||
Migrate the database up to the given targets.
|
||||
|
||||
Django first needs to create all project states before a migration is
|
||||
(un)applied and in a second step run all the database operations.
|
||||
"""
|
||||
# The django_migrations table must be present to record applied
|
||||
# migrations.
|
||||
self.recorder.ensure_schema()
|
||||
|
||||
if plan is None:
|
||||
plan = self.migration_plan(targets)
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
|
||||
|
||||
all_forwards = all(not backwards for mig, backwards in plan)
|
||||
all_backwards = all(backwards for mig, backwards in plan)
|
||||
|
||||
if not plan:
|
||||
if state is None:
|
||||
# The resulting state should include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
elif all_forwards == all_backwards:
|
||||
# This should only happen if there's a mixed plan
|
||||
raise InvalidMigrationPlan(
|
||||
"Migration plans with both forwards and backwards migrations "
|
||||
"are not supported. Please split your migration process into "
|
||||
"separate plans of only forwards OR backwards migrations.",
|
||||
plan
|
||||
)
|
||||
elif all_forwards:
|
||||
if state is None:
|
||||
# The resulting state should still include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
state = self._migrate_all_forwards(state, plan, full_plan, fake=fake, fake_initial=fake_initial)
|
||||
else:
|
||||
# No need to check for `elif all_backwards` here, as that condition
|
||||
# would always evaluate to true.
|
||||
state = self._migrate_all_backwards(plan, full_plan, fake=fake)
|
||||
|
||||
self.check_replacements()
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, False) and
|
||||
apply them in the order they occur in the full_plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from these sets so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if 'apps' not in state.__dict__:
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
state.apps # Render all -- performance critical
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
state = self.apply_migration(state, migration, fake=fake, fake_initial=fake_initial)
|
||||
migrations_to_run.remove(migration)
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_backwards(self, plan, full_plan, fake):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, True) and
|
||||
unapply them in reverse order they occur in the full_plan.
|
||||
|
||||
Since unapplying a migration requires the project state prior to that
|
||||
migration, Django will compute the migration states before each of them
|
||||
in a first run over the plan and then unapply them in a second run over
|
||||
the plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
# Holds all migration states prior to the migrations being unapplied
|
||||
states = {}
|
||||
state = self._create_project_state()
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from this set so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if 'apps' not in state.__dict__:
|
||||
state.apps # Render all -- performance critical
|
||||
# The state before this migration
|
||||
states[migration] = state
|
||||
# The old state keeps as-is, we continue with the new state
|
||||
state = migration.mutate_state(state, preserve=True)
|
||||
migrations_to_run.remove(migration)
|
||||
elif migration in applied_migrations:
|
||||
# Only mutate the state if the migration is actually applied
|
||||
# to make sure the resulting state doesn't include changes
|
||||
# from unrelated migrations.
|
||||
migration.mutate_state(state, preserve=False)
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
|
||||
for migration, _ in plan:
|
||||
self.unapply_migration(states[migration], migration, fake=fake)
|
||||
applied_migrations.remove(migration)
|
||||
|
||||
# Generate the post migration state by starting from the state before
|
||||
# the last migration is unapplied and mutating it to include all the
|
||||
# remaining applied migrations.
|
||||
last_unapplied_migration = plan[-1][0]
|
||||
state = states[last_unapplied_migration]
|
||||
for index, (migration, _) in enumerate(full_plan):
|
||||
if migration == last_unapplied_migration:
|
||||
for migration, _ in full_plan[index:]:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
break
|
||||
|
||||
return state
|
||||
|
||||
def collect_sql(self, plan):
|
||||
"""
|
||||
Take a migration plan and return a list of collected SQL statements
|
||||
that represent the best-efforts version of that plan.
|
||||
"""
|
||||
statements = []
|
||||
state = None
|
||||
for migration, backwards in plan:
|
||||
with self.connection.schema_editor(collect_sql=True, atomic=migration.atomic) as schema_editor:
|
||||
if state is None:
|
||||
state = self.loader.project_state((migration.app_label, migration.name), at_end=False)
|
||||
if not backwards:
|
||||
state = migration.apply(state, schema_editor, collect_sql=True)
|
||||
else:
|
||||
state = migration.unapply(state, schema_editor, collect_sql=True)
|
||||
statements.extend(schema_editor.collected_sql)
|
||||
return statements
|
||||
|
||||
def apply_migration(self, state, migration, fake=False, fake_initial=False):
|
||||
"""Run a migration forwards."""
|
||||
migration_recorded = False
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_start", migration, fake)
|
||||
if not fake:
|
||||
if fake_initial:
|
||||
# Test to see if this is an already-applied initial migration
|
||||
applied, state = self.detect_soft_applied(state, migration)
|
||||
if applied:
|
||||
fake = True
|
||||
if not fake:
|
||||
# Alright, do it normally
|
||||
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
|
||||
state = migration.apply(state, schema_editor)
|
||||
self.record_migration(migration)
|
||||
migration_recorded = True
|
||||
if not migration_recorded:
|
||||
self.record_migration(migration)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def record_migration(self, migration):
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_applied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
|
||||
def unapply_migration(self, state, migration, fake=False):
|
||||
"""Run a migration backwards."""
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_start", migration, fake)
|
||||
if not fake:
|
||||
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
|
||||
state = migration.unapply(state, schema_editor)
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_unapplied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def check_replacements(self):
|
||||
"""
|
||||
Mark replacement migrations applied if their replaced set all are.
|
||||
|
||||
Do this unconditionally on every migrate, rather than just when
|
||||
migrations are applied or unapplied, to correctly handle the case
|
||||
when a new squash migration is pushed to a deployment that already had
|
||||
all its replaced migrations applied. In this case no new migration will
|
||||
be applied, but the applied state of the squashed migration must be
|
||||
maintained.
|
||||
"""
|
||||
applied = self.recorder.applied_migrations()
|
||||
for key, migration in self.loader.replacements.items():
|
||||
all_applied = all(m in applied for m in migration.replaces)
|
||||
if all_applied and key not in applied:
|
||||
self.recorder.record_applied(*key)
|
||||
|
||||
def detect_soft_applied(self, project_state, migration):
|
||||
"""
|
||||
Test whether a migration has been implicitly applied - that the
|
||||
tables or columns it would create exist. This is intended only for use
|
||||
on initial migrations (as it only looks for CreateModel and AddField).
|
||||
"""
|
||||
def should_skip_detecting_model(migration, model):
|
||||
"""
|
||||
No need to detect tables for proxy models, unmanaged models, or
|
||||
models that can't be migrated on the current database.
|
||||
"""
|
||||
return (
|
||||
model._meta.proxy or not model._meta.managed or not
|
||||
router.allow_migrate(
|
||||
self.connection.alias, migration.app_label,
|
||||
model_name=model._meta.model_name,
|
||||
)
|
||||
)
|
||||
|
||||
if migration.initial is None:
|
||||
# Bail if the migration isn't the first one in its app
|
||||
if any(app == migration.app_label for app, name in migration.dependencies):
|
||||
return False, project_state
|
||||
elif migration.initial is False:
|
||||
# Bail if it's NOT an initial migration
|
||||
return False, project_state
|
||||
|
||||
if project_state is None:
|
||||
after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
|
||||
else:
|
||||
after_state = migration.mutate_state(project_state)
|
||||
apps = after_state.apps
|
||||
found_create_model_migration = False
|
||||
found_add_field_migration = False
|
||||
with self.connection.cursor() as cursor:
|
||||
existing_table_names = self.connection.introspection.table_names(cursor)
|
||||
# Make sure all create model and add field operations are done
|
||||
for operation in migration.operations:
|
||||
if isinstance(operation, migrations.CreateModel):
|
||||
model = apps.get_model(migration.app_label, operation.name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
if model._meta.db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
found_create_model_migration = True
|
||||
elif isinstance(operation, migrations.AddField):
|
||||
model = apps.get_model(migration.app_label, operation.model_name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
|
||||
table = model._meta.db_table
|
||||
field = model._meta.get_field(operation.name)
|
||||
|
||||
# Handle implicit many-to-many tables created by AddField.
|
||||
if field.many_to_many:
|
||||
if field.remote_field.through._meta.db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
else:
|
||||
found_add_field_migration = True
|
||||
continue
|
||||
|
||||
column_names = [
|
||||
column.name for column in
|
||||
self.connection.introspection.get_table_description(self.connection.cursor(), table)
|
||||
]
|
||||
if field.column not in column_names:
|
||||
return False, project_state
|
||||
found_add_field_migration = True
|
||||
# If we get this far and we found at least one CreateModel or AddField migration,
|
||||
# the migration is considered implicitly applied.
|
||||
return (found_create_model_migration or found_add_field_migration), after_state
|
||||
319
venv/lib/python3.8/site-packages/django/db/migrations/graph.py
Normal file
319
venv/lib/python3.8/site-packages/django/db/migrations/graph.py
Normal file
@@ -0,0 +1,319 @@
|
||||
from functools import total_ordering
|
||||
|
||||
from django.db.migrations.state import ProjectState
|
||||
|
||||
from .exceptions import CircularDependencyError, NodeNotFoundError
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Node:
|
||||
"""
|
||||
A single node in the migration graph. Contains direct links to adjacent
|
||||
nodes in either direction.
|
||||
"""
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.children = set()
|
||||
self.parents = set()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.key < other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.key[item]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.key)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: (%r, %r)>' % (self.__class__.__name__, self.key[0], self.key[1])
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.add(child)
|
||||
|
||||
def add_parent(self, parent):
|
||||
self.parents.add(parent)
|
||||
|
||||
|
||||
class DummyNode(Node):
|
||||
"""
|
||||
A node that doesn't correspond to a migration file on disk.
|
||||
(A squashed migration that was removed, for example.)
|
||||
|
||||
After the migration graph is processed, all dummy nodes should be removed.
|
||||
If there are any left, a nonexistent dependency error is raised.
|
||||
"""
|
||||
def __init__(self, key, origin, error_message):
|
||||
super().__init__(key)
|
||||
self.origin = origin
|
||||
self.error_message = error_message
|
||||
|
||||
def raise_error(self):
|
||||
raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
|
||||
|
||||
|
||||
class MigrationGraph:
|
||||
"""
|
||||
Represent the digraph of all migrations in a project.
|
||||
|
||||
Each migration is a node, and each dependency is an edge. There are
|
||||
no implicit dependencies between numbered migrations - the numbering is
|
||||
merely a convention to aid file listing. Every new numbered migration
|
||||
has a declared dependency to the previous number, meaning that VCS
|
||||
branch merges can be detected and resolved.
|
||||
|
||||
Migrations files can be marked as replacing another set of migrations -
|
||||
this is to support the "squash" feature. The graph handler isn't responsible
|
||||
for these; instead, the code to load them in here should examine the
|
||||
migration files and if the replaced migrations are all either unapplied
|
||||
or not present, it should ignore the replaced ones, load in just the
|
||||
replacing migration, and repoint any dependencies that pointed to the
|
||||
replaced migrations to point to the replacing one.
|
||||
|
||||
A node should be a tuple: (app_path, migration_name). The tree special-cases
|
||||
things within an app - namely, root nodes and leaf nodes ignore dependencies
|
||||
to other apps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.node_map = {}
|
||||
self.nodes = {}
|
||||
|
||||
def add_node(self, key, migration):
|
||||
assert key not in self.node_map
|
||||
node = Node(key)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = migration
|
||||
|
||||
def add_dummy_node(self, key, origin, error_message):
|
||||
node = DummyNode(key, origin, error_message)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = None
|
||||
|
||||
def add_dependency(self, migration, child, parent, skip_validation=False):
|
||||
"""
|
||||
This may create dummy nodes if they don't yet exist. If
|
||||
`skip_validation=True`, validate_consistency() should be called
|
||||
afterwards.
|
||||
"""
|
||||
if child not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" child node %r" % (migration, child)
|
||||
)
|
||||
self.add_dummy_node(child, migration, error_message)
|
||||
if parent not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" parent node %r" % (migration, parent)
|
||||
)
|
||||
self.add_dummy_node(parent, migration, error_message)
|
||||
self.node_map[child].add_parent(self.node_map[parent])
|
||||
self.node_map[parent].add_child(self.node_map[child])
|
||||
if not skip_validation:
|
||||
self.validate_consistency()
|
||||
|
||||
def remove_replaced_nodes(self, replacement, replaced):
|
||||
"""
|
||||
Remove each of the `replaced` nodes (when they exist). Any
|
||||
dependencies that were referencing them are changed to reference the
|
||||
`replacement` node instead.
|
||||
"""
|
||||
# Cast list of replaced keys to set to speed up lookup later.
|
||||
replaced = set(replaced)
|
||||
try:
|
||||
replacement_node = self.node_map[replacement]
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to find replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed." % (replacement,),
|
||||
replacement
|
||||
) from err
|
||||
for replaced_key in replaced:
|
||||
self.nodes.pop(replaced_key, None)
|
||||
replaced_node = self.node_map.pop(replaced_key, None)
|
||||
if replaced_node:
|
||||
for child in replaced_node.children:
|
||||
child.parents.remove(replaced_node)
|
||||
# We don't want to create dependencies between the replaced
|
||||
# node and the replacement node as this would lead to
|
||||
# self-referencing on the replacement node at a later iteration.
|
||||
if child.key not in replaced:
|
||||
replacement_node.add_child(child)
|
||||
child.add_parent(replacement_node)
|
||||
for parent in replaced_node.parents:
|
||||
parent.children.remove(replaced_node)
|
||||
# Again, to avoid self-referencing.
|
||||
if parent.key not in replaced:
|
||||
replacement_node.add_parent(parent)
|
||||
parent.add_child(replacement_node)
|
||||
|
||||
def remove_replacement_node(self, replacement, replaced):
|
||||
"""
|
||||
The inverse operation to `remove_replaced_nodes`. Almost. Remove the
|
||||
replacement node `replacement` and remap its child nodes to `replaced`
|
||||
- the list of nodes it would have replaced. Don't remap its parent
|
||||
nodes as they are expected to be correct already.
|
||||
"""
|
||||
self.nodes.pop(replacement, None)
|
||||
try:
|
||||
replacement_node = self.node_map.pop(replacement)
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to remove replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed already." % (replacement,),
|
||||
replacement
|
||||
) from err
|
||||
replaced_nodes = set()
|
||||
replaced_nodes_parents = set()
|
||||
for key in replaced:
|
||||
replaced_node = self.node_map.get(key)
|
||||
if replaced_node:
|
||||
replaced_nodes.add(replaced_node)
|
||||
replaced_nodes_parents |= replaced_node.parents
|
||||
# We're only interested in the latest replaced node, so filter out
|
||||
# replaced nodes that are parents of other replaced nodes.
|
||||
replaced_nodes -= replaced_nodes_parents
|
||||
for child in replacement_node.children:
|
||||
child.parents.remove(replacement_node)
|
||||
for replaced_node in replaced_nodes:
|
||||
replaced_node.add_child(child)
|
||||
child.add_parent(replaced_node)
|
||||
for parent in replacement_node.parents:
|
||||
parent.children.remove(replacement_node)
|
||||
# NOTE: There is no need to remap parent dependencies as we can
|
||||
# assume the replaced nodes already have the correct ancestry.
|
||||
|
||||
def validate_consistency(self):
|
||||
"""Ensure there are no dummy nodes remaining in the graph."""
|
||||
[n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
|
||||
|
||||
def forwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which previous nodes (dependencies) must
|
||||
be applied, ending with the node itself. This is the list you would
|
||||
follow if applying the migrations to a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target])
|
||||
|
||||
def backwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which dependent nodes (dependencies)
|
||||
must be unapplied, ending with the node itself. This is the list you
|
||||
would follow if removing the migrations from a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target], forwards=False)
|
||||
|
||||
def iterative_dfs(self, start, forwards=True):
|
||||
"""Iterative depth-first search for finding dependencies."""
|
||||
visited = []
|
||||
visited_set = set()
|
||||
stack = [(start, False)]
|
||||
while stack:
|
||||
node, processed = stack.pop()
|
||||
if node in visited_set:
|
||||
pass
|
||||
elif processed:
|
||||
visited_set.add(node)
|
||||
visited.append(node.key)
|
||||
else:
|
||||
stack.append((node, True))
|
||||
stack += [(n, False) for n in sorted(node.parents if forwards else node.children)]
|
||||
return visited
|
||||
|
||||
def root_nodes(self, app=None):
|
||||
"""
|
||||
Return all root nodes - that is, nodes with no dependencies inside
|
||||
their app. These are the starting point for an app.
|
||||
"""
|
||||
roots = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].parents) and (not app or app == node[0]):
|
||||
roots.add(node)
|
||||
return sorted(roots)
|
||||
|
||||
def leaf_nodes(self, app=None):
|
||||
"""
|
||||
Return all leaf nodes - that is, nodes with no dependents in their app.
|
||||
These are the "most current" version of an app's schema.
|
||||
Having more than one per app is technically an error, but one that
|
||||
gets handled further up, in the interactive command - it's usually the
|
||||
result of a VCS merge and needs some user input.
|
||||
"""
|
||||
leaves = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].children) and (not app or app == node[0]):
|
||||
leaves.add(node)
|
||||
return sorted(leaves)
|
||||
|
||||
def ensure_not_cyclic(self):
|
||||
# Algo from GvR:
|
||||
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
|
||||
todo = set(self.nodes)
|
||||
while todo:
|
||||
node = todo.pop()
|
||||
stack = [node]
|
||||
while stack:
|
||||
top = stack[-1]
|
||||
for child in self.node_map[top].children:
|
||||
# Use child.key instead of child to speed up the frequent
|
||||
# hashing.
|
||||
node = child.key
|
||||
if node in stack:
|
||||
cycle = stack[stack.index(node):]
|
||||
raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle))
|
||||
if node in todo:
|
||||
stack.append(node)
|
||||
todo.remove(node)
|
||||
break
|
||||
else:
|
||||
node = stack.pop()
|
||||
|
||||
def __str__(self):
|
||||
return 'Graph: %s nodes, %s edges' % self._nodes_and_edges()
|
||||
|
||||
def __repr__(self):
|
||||
nodes, edges = self._nodes_and_edges()
|
||||
return '<%s: nodes=%s, edges=%s>' % (self.__class__.__name__, nodes, edges)
|
||||
|
||||
def _nodes_and_edges(self):
|
||||
return len(self.nodes), sum(len(node.parents) for node in self.node_map.values())
|
||||
|
||||
def _generate_plan(self, nodes, at_end):
|
||||
plan = []
|
||||
for node in nodes:
|
||||
for migration in self.forwards_plan(node):
|
||||
if migration not in plan and (at_end or migration not in nodes):
|
||||
plan.append(migration)
|
||||
return plan
|
||||
|
||||
def make_state(self, nodes=None, at_end=True, real_apps=None):
|
||||
"""
|
||||
Given a migration node or nodes, return a complete ProjectState for it.
|
||||
If at_end is False, return the state before the migration has run.
|
||||
If nodes is not provided, return the overall most current project state.
|
||||
"""
|
||||
if nodes is None:
|
||||
nodes = list(self.leaf_nodes())
|
||||
if not nodes:
|
||||
return ProjectState()
|
||||
if not isinstance(nodes[0], tuple):
|
||||
nodes = [nodes]
|
||||
plan = self._generate_plan(nodes, at_end)
|
||||
project_state = ProjectState(real_apps=real_apps)
|
||||
for node in plan:
|
||||
project_state = self.nodes[node].mutate_state(project_state, preserve=False)
|
||||
return project_state
|
||||
|
||||
def __contains__(self, node):
|
||||
return node in self.nodes
|
||||
324
venv/lib/python3.8/site-packages/django/db/migrations/loader.py
Normal file
324
venv/lib/python3.8/site-packages/django/db/migrations/loader.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import pkgutil
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.db.migrations.graph import MigrationGraph
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
|
||||
from .exceptions import (
|
||||
AmbiguityError, BadMigrationError, InconsistentMigrationHistory,
|
||||
NodeNotFoundError,
|
||||
)
|
||||
|
||||
MIGRATIONS_MODULE_NAME = 'migrations'
|
||||
|
||||
|
||||
class MigrationLoader:
|
||||
"""
|
||||
Load migration files from disk and their status from the database.
|
||||
|
||||
Migration files are expected to live in the "migrations" directory of
|
||||
an app. Their names are entirely unimportant from a code perspective,
|
||||
but will probably follow the 1234_name.py convention.
|
||||
|
||||
On initialization, this class will scan those directories, and open and
|
||||
read the Python files, looking for a class called Migration, which should
|
||||
inherit from django.db.migrations.Migration. See
|
||||
django.db.migrations.migration for what that looks like.
|
||||
|
||||
Some migrations will be marked as "replacing" another set of migrations.
|
||||
These are loaded into a separate set of migrations away from the main ones.
|
||||
If all the migrations they replace are either unapplied or missing from
|
||||
disk, then they are injected into the main set, replacing the named migrations.
|
||||
Any dependency pointers to the replaced migrations are re-pointed to the
|
||||
new migration.
|
||||
|
||||
This does mean that this class MUST also talk to the database as well as
|
||||
to disk, but this is probably fine. We're already not just operating
|
||||
in memory.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, load=True, ignore_no_migrations=False):
|
||||
self.connection = connection
|
||||
self.disk_migrations = None
|
||||
self.applied_migrations = None
|
||||
self.ignore_no_migrations = ignore_no_migrations
|
||||
if load:
|
||||
self.build_graph()
|
||||
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label):
|
||||
"""
|
||||
Return the path to the migrations module for the specified app_label
|
||||
and a boolean indicating if the module is specified in
|
||||
settings.MIGRATION_MODULE.
|
||||
"""
|
||||
if app_label in settings.MIGRATION_MODULES:
|
||||
return settings.MIGRATION_MODULES[app_label], True
|
||||
else:
|
||||
app_package_name = apps.get_app_config(app_label).name
|
||||
return '%s.%s' % (app_package_name, MIGRATIONS_MODULE_NAME), False
|
||||
|
||||
def load_disk(self):
|
||||
"""Load the migrations from all INSTALLED_APPS from disk."""
|
||||
self.disk_migrations = {}
|
||||
self.unmigrated_apps = set()
|
||||
self.migrated_apps = set()
|
||||
for app_config in apps.get_app_configs():
|
||||
# Get the migrations module directory
|
||||
module_name, explicit = self.migrations_module(app_config.label)
|
||||
if module_name is None:
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
was_loaded = module_name in sys.modules
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ImportError as e:
|
||||
# I hate doing this, but I don't want to squash other import errors.
|
||||
# Might be better to try a directory check directly.
|
||||
if ((explicit and self.ignore_no_migrations) or (
|
||||
not explicit and "No module named" in str(e) and MIGRATIONS_MODULE_NAME in str(e))):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
# Empty directories are namespaces.
|
||||
# getattr() needed on PY36 and older (replace w/attribute access).
|
||||
if getattr(module, '__file__', None) is None:
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Module is not a package (e.g. migrations.py).
|
||||
if not hasattr(module, '__path__'):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Force a reload if it's already loaded (tests need this)
|
||||
if was_loaded:
|
||||
reload(module)
|
||||
self.migrated_apps.add(app_config.label)
|
||||
migration_names = {
|
||||
name for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
|
||||
if not is_pkg and name[0] not in '_~'
|
||||
}
|
||||
# Load migrations
|
||||
for migration_name in migration_names:
|
||||
migration_path = '%s.%s' % (module_name, migration_name)
|
||||
try:
|
||||
migration_module = import_module(migration_path)
|
||||
except ImportError as e:
|
||||
if 'bad magic number' in str(e):
|
||||
raise ImportError(
|
||||
"Couldn't import %r as it appears to be a stale "
|
||||
".pyc file." % migration_path
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
if not hasattr(migration_module, "Migration"):
|
||||
raise BadMigrationError(
|
||||
"Migration %s in app %s has no Migration class" % (migration_name, app_config.label)
|
||||
)
|
||||
self.disk_migrations[app_config.label, migration_name] = migration_module.Migration(
|
||||
migration_name,
|
||||
app_config.label,
|
||||
)
|
||||
|
||||
def get_migration(self, app_label, name_prefix):
|
||||
"""Return the named migration or raise NodeNotFoundError."""
|
||||
return self.graph.nodes[app_label, name_prefix]
|
||||
|
||||
def get_migration_by_prefix(self, app_label, name_prefix):
|
||||
"""
|
||||
Return the migration(s) which match the given app label and name_prefix.
|
||||
"""
|
||||
# Do the search
|
||||
results = []
|
||||
for migration_app_label, migration_name in self.disk_migrations:
|
||||
if migration_app_label == app_label and migration_name.startswith(name_prefix):
|
||||
results.append((migration_app_label, migration_name))
|
||||
if len(results) > 1:
|
||||
raise AmbiguityError(
|
||||
"There is more than one migration for '%s' with the prefix '%s'" % (app_label, name_prefix)
|
||||
)
|
||||
elif not results:
|
||||
raise KeyError("There no migrations for '%s' with the prefix '%s'" % (app_label, name_prefix))
|
||||
else:
|
||||
return self.disk_migrations[results[0]]
|
||||
|
||||
def check_key(self, key, current_app):
|
||||
if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
|
||||
return key
|
||||
# Special-case __first__, which means "the first migration" for
|
||||
# migrated apps, and is ignored for unmigrated apps. It allows
|
||||
# makemigrations to declare dependencies on apps before they even have
|
||||
# migrations.
|
||||
if key[0] == current_app:
|
||||
# Ignore __first__ references to the same app (#22325)
|
||||
return
|
||||
if key[0] in self.unmigrated_apps:
|
||||
# This app isn't migrated, but something depends on it.
|
||||
# The models will get auto-added into the state, though
|
||||
# so we're fine.
|
||||
return
|
||||
if key[0] in self.migrated_apps:
|
||||
try:
|
||||
if key[1] == "__first__":
|
||||
return self.graph.root_nodes(key[0])[0]
|
||||
else: # "__latest__"
|
||||
return self.graph.leaf_nodes(key[0])[0]
|
||||
except IndexError:
|
||||
if self.ignore_no_migrations:
|
||||
return None
|
||||
else:
|
||||
raise ValueError("Dependency on app with no migrations: %s" % key[0])
|
||||
raise ValueError("Dependency on unknown app: %s" % key[0])
|
||||
|
||||
def add_internal_dependencies(self, key, migration):
|
||||
"""
|
||||
Internal dependencies need to be added first to ensure `__first__`
|
||||
dependencies find the correct root node.
|
||||
"""
|
||||
for parent in migration.dependencies:
|
||||
# Ignore __first__ references to the same app.
|
||||
if parent[0] == key[0] and parent[1] != '__first__':
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
|
||||
def add_external_dependencies(self, key, migration):
|
||||
for parent in migration.dependencies:
|
||||
# Skip internal dependencies
|
||||
if key[0] == parent[0]:
|
||||
continue
|
||||
parent = self.check_key(parent, key[0])
|
||||
if parent is not None:
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
for child in migration.run_before:
|
||||
child = self.check_key(child, key[0])
|
||||
if child is not None:
|
||||
self.graph.add_dependency(migration, child, key, skip_validation=True)
|
||||
|
||||
def build_graph(self):
|
||||
"""
|
||||
Build a migration dependency graph using both the disk and database.
|
||||
You'll need to rebuild the graph if you apply migrations. This isn't
|
||||
usually a problem as generally migration stuff runs in a one-shot process.
|
||||
"""
|
||||
# Load disk data
|
||||
self.load_disk()
|
||||
# Load database data
|
||||
if self.connection is None:
|
||||
self.applied_migrations = {}
|
||||
else:
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# To start, populate the migration graph with nodes for ALL migrations
|
||||
# and their dependencies. Also make note of replacing migrations at this step.
|
||||
self.graph = MigrationGraph()
|
||||
self.replacements = {}
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.graph.add_node(key, migration)
|
||||
# Replacing migrations.
|
||||
if migration.replaces:
|
||||
self.replacements[key] = migration
|
||||
for key, migration in self.disk_migrations.items():
|
||||
# Internal (same app) dependencies.
|
||||
self.add_internal_dependencies(key, migration)
|
||||
# Add external dependencies now that the internal ones have been resolved.
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.add_external_dependencies(key, migration)
|
||||
# Carry out replacements where possible.
|
||||
for key, migration in self.replacements.items():
|
||||
# Get applied status of each of this migration's replacement targets.
|
||||
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
|
||||
# Ensure the replacing migration is only marked as applied if all of
|
||||
# its replacement targets are.
|
||||
if all(applied_statuses):
|
||||
self.applied_migrations[key] = migration
|
||||
else:
|
||||
self.applied_migrations.pop(key, None)
|
||||
# A replacing migration can be used if either all or none of its
|
||||
# replacement targets have been applied.
|
||||
if all(applied_statuses) or (not any(applied_statuses)):
|
||||
self.graph.remove_replaced_nodes(key, migration.replaces)
|
||||
else:
|
||||
# This replacing migration cannot be used because it is partially applied.
|
||||
# Remove it from the graph and remap dependencies to it (#25945).
|
||||
self.graph.remove_replacement_node(key, migration.replaces)
|
||||
# Ensure the graph is consistent.
|
||||
try:
|
||||
self.graph.validate_consistency()
|
||||
except NodeNotFoundError as exc:
|
||||
# Check if the missing node could have been replaced by any squash
|
||||
# migration but wasn't because the squash migration was partially
|
||||
# applied before. In that case raise a more understandable exception
|
||||
# (#23556).
|
||||
# Get reverse replacements.
|
||||
reverse_replacements = {}
|
||||
for key, migration in self.replacements.items():
|
||||
for replaced in migration.replaces:
|
||||
reverse_replacements.setdefault(replaced, set()).add(key)
|
||||
# Try to reraise exception with more detail.
|
||||
if exc.node in reverse_replacements:
|
||||
candidates = reverse_replacements.get(exc.node, set())
|
||||
is_replaced = any(candidate in self.graph.nodes for candidate in candidates)
|
||||
if not is_replaced:
|
||||
tries = ', '.join('%s.%s' % c for c in candidates)
|
||||
raise NodeNotFoundError(
|
||||
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
|
||||
"Django tried to replace migration {1}.{2} with any of [{3}] "
|
||||
"but wasn't able to because some of the replaced migrations "
|
||||
"are already applied.".format(
|
||||
exc.origin, exc.node[0], exc.node[1], tries
|
||||
),
|
||||
exc.node
|
||||
) from exc
|
||||
raise exc
|
||||
self.graph.ensure_not_cyclic()
|
||||
|
||||
def check_consistent_history(self, connection):
|
||||
"""
|
||||
Raise InconsistentMigrationHistory if any applied migrations have
|
||||
unapplied dependencies.
|
||||
"""
|
||||
recorder = MigrationRecorder(connection)
|
||||
applied = recorder.applied_migrations()
|
||||
for migration in applied:
|
||||
# If the migration is unknown, skip it.
|
||||
if migration not in self.graph.nodes:
|
||||
continue
|
||||
for parent in self.graph.node_map[migration].parents:
|
||||
if parent not in applied:
|
||||
# Skip unapplied squashed migrations that have all of their
|
||||
# `replaces` applied.
|
||||
if parent in self.replacements:
|
||||
if all(m in applied for m in self.replacements[parent].replaces):
|
||||
continue
|
||||
raise InconsistentMigrationHistory(
|
||||
"Migration {}.{} is applied before its dependency "
|
||||
"{}.{} on database '{}'.".format(
|
||||
migration[0], migration[1], parent[0], parent[1],
|
||||
connection.alias,
|
||||
)
|
||||
)
|
||||
|
||||
def detect_conflicts(self):
|
||||
"""
|
||||
Look through the loaded graph and detect any conflicts - apps
|
||||
with more than one leaf migration. Return a dict of the app labels
|
||||
that conflict with the migration names that conflict.
|
||||
"""
|
||||
seen_apps = {}
|
||||
conflicting_apps = set()
|
||||
for app_label, migration_name in self.graph.leaf_nodes():
|
||||
if app_label in seen_apps:
|
||||
conflicting_apps.add(app_label)
|
||||
seen_apps.setdefault(app_label, set()).add(migration_name)
|
||||
return {app_label: seen_apps[app_label] for app_label in conflicting_apps}
|
||||
|
||||
def project_state(self, nodes=None, at_end=True):
|
||||
"""
|
||||
Return a ProjectState object representing the most recent state
|
||||
that the loaded migrations represent.
|
||||
|
||||
See graph.make_state() for the meaning of "nodes" and "at_end".
|
||||
"""
|
||||
return self.graph.make_state(nodes=nodes, at_end=at_end, real_apps=list(self.unmigrated_apps))
|
||||
@@ -0,0 +1,193 @@
|
||||
from django.db.transaction import atomic
|
||||
|
||||
from .exceptions import IrreversibleError
|
||||
|
||||
|
||||
class Migration:
|
||||
"""
|
||||
The base class for all migrations.
|
||||
|
||||
Migration files will import this from django.db.migrations.Migration
|
||||
and subclass it as a class called Migration. It will have one or more
|
||||
of the following attributes:
|
||||
|
||||
- operations: A list of Operation instances, probably from django.db.migrations.operations
|
||||
- dependencies: A list of tuples of (app_path, migration_name)
|
||||
- run_before: A list of tuples of (app_path, migration_name)
|
||||
- replaces: A list of migration_names
|
||||
|
||||
Note that all migrations come out of migrations and into the Loader or
|
||||
Graph as instances, having been initialized with their app label and name.
|
||||
"""
|
||||
|
||||
# Operations to apply during this migration, in order.
|
||||
operations = []
|
||||
|
||||
# Other migrations that should be run before this migration.
|
||||
# Should be a list of (app, migration_name).
|
||||
dependencies = []
|
||||
|
||||
# Other migrations that should be run after this one (i.e. have
|
||||
# this migration added to their dependencies). Useful to make third-party
|
||||
# apps' migrations run after your AUTH_USER replacement, for example.
|
||||
run_before = []
|
||||
|
||||
# Migration names in this app that this migration replaces. If this is
|
||||
# non-empty, this migration will only be applied if all these migrations
|
||||
# are not applied.
|
||||
replaces = []
|
||||
|
||||
# Is this an initial migration? Initial migrations are skipped on
|
||||
# --fake-initial if the table or fields already exist. If None, check if
|
||||
# the migration has any dependencies to determine if there are dependencies
|
||||
# to tell if db introspection needs to be done. If True, always perform
|
||||
# introspection. If False, never perform introspection.
|
||||
initial = None
|
||||
|
||||
# Whether to wrap the whole migration in a transaction. Only has an effect
|
||||
# on database backends which support transactional DDL.
|
||||
atomic = True
|
||||
|
||||
def __init__(self, name, app_label):
|
||||
self.name = name
|
||||
self.app_label = app_label
|
||||
# Copy dependencies & other attrs as we might mutate them at runtime
|
||||
self.operations = list(self.__class__.operations)
|
||||
self.dependencies = list(self.__class__.dependencies)
|
||||
self.run_before = list(self.__class__.run_before)
|
||||
self.replaces = list(self.__class__.replaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Migration) and
|
||||
self.name == other.name and
|
||||
self.app_label == other.app_label
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<Migration %s.%s>" % (self.app_label, self.name)
|
||||
|
||||
def __str__(self):
|
||||
return "%s.%s" % (self.app_label, self.name)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("%s.%s" % (self.app_label, self.name))
|
||||
|
||||
def mutate_state(self, project_state, preserve=True):
|
||||
"""
|
||||
Take a ProjectState and return a new one with the migration's
|
||||
operations applied to it. Preserve the original object state by
|
||||
default and return a mutated state from a copy.
|
||||
"""
|
||||
new_state = project_state
|
||||
if preserve:
|
||||
new_state = project_state.clone()
|
||||
|
||||
for operation in self.operations:
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
return new_state
|
||||
|
||||
def apply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a forwards order.
|
||||
|
||||
Return the resulting project state for efficient reuse by following
|
||||
Migrations.
|
||||
"""
|
||||
for operation in self.operations:
|
||||
# If this operation cannot be represented as SQL, place a comment
|
||||
# there instead
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
|
||||
)
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
continue
|
||||
# Save the state before the operation has run
|
||||
old_state = project_state.clone()
|
||||
operation.state_forwards(self.app_label, project_state)
|
||||
# Run the operation
|
||||
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
|
||||
return project_state
|
||||
|
||||
def unapply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a reverse order.
|
||||
|
||||
The backwards migration process consists of two phases:
|
||||
|
||||
1. The intermediate states from right before the first until right
|
||||
after the last operation inside this migration are preserved.
|
||||
2. The operations are applied in reverse order using the states
|
||||
recorded in step 1.
|
||||
"""
|
||||
# Construct all the intermediate states we need for a reverse migration
|
||||
to_run = []
|
||||
new_state = project_state
|
||||
# Phase 1
|
||||
for operation in self.operations:
|
||||
# If it's irreversible, error out
|
||||
if not operation.reversible:
|
||||
raise IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
|
||||
# Preserve new state from previous run to not tamper the same state
|
||||
# over all operations
|
||||
new_state = new_state.clone()
|
||||
old_state = new_state.clone()
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
to_run.insert(0, (operation, old_state, new_state))
|
||||
|
||||
# Phase 2
|
||||
for operation, to_state, from_state in to_run:
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
|
||||
)
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
continue
|
||||
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
|
||||
return project_state
|
||||
|
||||
|
||||
class SwappableTuple(tuple):
|
||||
"""
|
||||
Subclass of tuple so Django can tell this was originally a swappable
|
||||
dependency when it reads the migration file.
|
||||
"""
|
||||
|
||||
def __new__(cls, value, setting):
|
||||
self = tuple.__new__(cls, value)
|
||||
self.setting = setting
|
||||
return self
|
||||
|
||||
|
||||
def swappable_dependency(value):
|
||||
"""Turn a setting value into a dependency."""
|
||||
return SwappableTuple((value.split(".", 1)[0], "__first__"), value)
|
||||
@@ -0,0 +1,17 @@
|
||||
from .fields import AddField, AlterField, RemoveField, RenameField
|
||||
from .models import (
|
||||
AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
|
||||
AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
|
||||
AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
|
||||
RemoveIndex, RenameModel,
|
||||
)
|
||||
from .special import RunPython, RunSQL, SeparateDatabaseAndState
|
||||
|
||||
__all__ = [
|
||||
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
|
||||
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
|
||||
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
|
||||
'AddConstraint', 'RemoveConstraint',
|
||||
'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
|
||||
'AlterOrderWithRespectTo', 'AlterModelManagers',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,141 @@
|
||||
from django.db import router
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
|
||||
class Operation:
|
||||
"""
|
||||
Base class for migration operations.
|
||||
|
||||
It's responsible for both mutating the in-memory model state
|
||||
(see db/migrations/state.py) to represent what it performs, as well
|
||||
as actually performing it against a live database.
|
||||
|
||||
Note that some operations won't modify memory state at all (e.g. data
|
||||
copying operations), and some will need their modifications to be
|
||||
optionally specified by the user (e.g. custom Python code snippets)
|
||||
|
||||
Due to the way this class deals with deconstruction, it should be
|
||||
considered immutable.
|
||||
"""
|
||||
|
||||
# If this migration can be run in reverse.
|
||||
# Some operations are impossible to reverse, like deleting data.
|
||||
reversible = True
|
||||
|
||||
# Can this migration be represented as SQL? (things like RunPython cannot)
|
||||
reduces_to_sql = True
|
||||
|
||||
# Should this operation be forced as atomic even on backends with no
|
||||
# DDL transaction support (i.e., does it have no DDL, like RunPython)
|
||||
atomic = False
|
||||
|
||||
# Should this operation be considered safe to elide and optimize across?
|
||||
elidable = False
|
||||
|
||||
serialization_expand_args = []
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We capture the arguments to make returning them trivial
|
||||
self = object.__new__(cls)
|
||||
self._constructor_args = (args, kwargs)
|
||||
return self
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 3-tuple of class import path (or just name if it lives
|
||||
under django.db.migrations), positional arguments, and keyword
|
||||
arguments.
|
||||
"""
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
self._constructor_args[0],
|
||||
self._constructor_args[1],
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
"""
|
||||
Take the state from the previous migration, and mutate it
|
||||
so that it matches what this migration would perform.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of Operation must provide a state_forwards() method')
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the normal
|
||||
(forwards) direction.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of Operation must provide a database_forwards() method')
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the reverse
|
||||
direction - e.g. if this were CreateModel, it would in fact
|
||||
drop the model's table.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of Operation must provide a database_backwards() method')
|
||||
|
||||
def describe(self):
|
||||
"""
|
||||
Output a brief summary of what the action does.
|
||||
"""
|
||||
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
model name (as a string), with an optional app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True;
|
||||
returning a false positive will merely make the optimizer a little
|
||||
less efficient, while returning a false negative may result in an
|
||||
unusable optimized migration.
|
||||
"""
|
||||
return True
|
||||
|
||||
def references_field(self, model_name, name, app_label=None):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
field name, with an optional app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True.
|
||||
"""
|
||||
return self.references_model(model_name, app_label)
|
||||
|
||||
def allow_migrate_model(self, connection_alias, model):
|
||||
"""
|
||||
Return whether or not a model may be migrated.
|
||||
|
||||
This is a thin wrapper around router.allow_migrate_model() that
|
||||
preemptively rejects any proxy, swapped out, or unmanaged model.
|
||||
"""
|
||||
if not model._meta.can_migrate(connection_alias):
|
||||
return False
|
||||
|
||||
return router.allow_migrate_model(connection_alias, model)
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
"""
|
||||
Return either a list of operations the actual operation should be
|
||||
replaced with or a boolean that indicates whether or not the specified
|
||||
operation can be optimized across.
|
||||
"""
|
||||
if self.elidable:
|
||||
return [operation]
|
||||
elif operation.elidable:
|
||||
return [self]
|
||||
return False
|
||||
|
||||
def _get_model_tuple(self, remote_model, app_label, model_name):
|
||||
if remote_model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
return app_label, model_name.lower()
|
||||
elif '.' in remote_model:
|
||||
return tuple(remote_model.lower().split('.'))
|
||||
else:
|
||||
return app_label, remote_model.lower()
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s%s>" % (
|
||||
self.__class__.__name__,
|
||||
", ".join(map(repr, self._constructor_args[0])),
|
||||
",".join(" %s=%r" % x for x in self._constructor_args[1].items()),
|
||||
)
|
||||
@@ -0,0 +1,402 @@
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db.models.fields import NOT_PROVIDED
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Operation
|
||||
from .utils import (
|
||||
ModelTuple, field_references_model, is_referenced_by_foreign_key,
|
||||
)
|
||||
|
||||
|
||||
class FieldOperation(Operation):
|
||||
def __init__(self, model_name, name, field=None):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
self.field = field
|
||||
|
||||
@cached_property
|
||||
def model_name_lower(self):
|
||||
return self.model_name.lower()
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def is_same_model_operation(self, operation):
|
||||
return self.model_name_lower == operation.model_name_lower
|
||||
|
||||
def is_same_field_operation(self, operation):
|
||||
return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
name_lower = name.lower()
|
||||
if name_lower == self.model_name_lower:
|
||||
return True
|
||||
if self.field:
|
||||
return field_references_model(self.field, ModelTuple(app_label, name_lower))
|
||||
return False
|
||||
|
||||
def references_field(self, model_name, name, app_label=None):
|
||||
model_name_lower = model_name.lower()
|
||||
# Check if this operation locally references the field.
|
||||
if model_name_lower == self.model_name_lower:
|
||||
if name == self.name:
|
||||
return True
|
||||
elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
|
||||
return True
|
||||
# Check if this operation remotely references the field.
|
||||
if self.field:
|
||||
model_tuple = ModelTuple(app_label, model_name_lower)
|
||||
remote_field = self.field.remote_field
|
||||
if remote_field:
|
||||
if (ModelTuple.from_model(remote_field.model) == model_tuple and
|
||||
(not hasattr(self.field, 'to_fields') or
|
||||
name in self.field.to_fields or None in self.field.to_fields)):
|
||||
return True
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if (through and ModelTuple.from_model(through) == model_tuple and
|
||||
(getattr(remote_field, 'through_fields', None) is None or
|
||||
name in remote_field.through_fields)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
return (
|
||||
super().reduce(operation, app_label=app_label) or
|
||||
not operation.references_field(self.model_name, self.name, app_label)
|
||||
)
|
||||
|
||||
|
||||
class AddField(FieldOperation):
|
||||
"""Add a field to a model."""
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
'field': self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs['preserve_default'] = self.preserve_default
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# If preserve default is off, don't use the default for future state
|
||||
if not self.preserve_default:
|
||||
field = self.field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = self.field
|
||||
state.models[app_label, self.model_name_lower].fields.append((self.name, field))
|
||||
# Delay rendering of relationships if it's not a relational field
|
||||
delay = not field.is_relation
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
field.default = self.field.default
|
||||
schema_editor.add_field(
|
||||
from_model,
|
||||
field,
|
||||
)
|
||||
if not self.preserve_default:
|
||||
field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
|
||||
|
||||
def describe(self):
|
||||
return "Add field %s to %s" % (self.name, self.model_name)
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
|
||||
if isinstance(operation, AlterField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.name,
|
||||
field=operation.field,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RemoveField):
|
||||
return []
|
||||
elif isinstance(operation, RenameField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label=app_label)
|
||||
|
||||
|
||||
class RemoveField(FieldOperation):
|
||||
"""Remove a field from a model."""
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
new_fields = []
|
||||
old_field = None
|
||||
for name, instance in state.models[app_label, self.model_name_lower].fields:
|
||||
if name != self.name:
|
||||
new_fields.append((name, instance))
|
||||
else:
|
||||
old_field = instance
|
||||
state.models[app_label, self.model_name_lower].fields = new_fields
|
||||
# Delay rendering of relationships if it's not a relational field
|
||||
delay = not old_field.is_relation
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
|
||||
|
||||
def describe(self):
|
||||
return "Remove field %s from %s" % (self.name, self.model_name)
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
from .models import DeleteModel
|
||||
if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
|
||||
return [operation]
|
||||
return super().reduce(operation, app_label=app_label)
|
||||
|
||||
|
||||
class AlterField(FieldOperation):
|
||||
"""
|
||||
Alter a field's database column (e.g. null, max_length) to the provided
|
||||
new field.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
'field': self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs['preserve_default'] = self.preserve_default
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
if not self.preserve_default:
|
||||
field = self.field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = self.field
|
||||
state.models[app_label, self.model_name_lower].fields = [
|
||||
(n, field if n == self.name else f)
|
||||
for n, f in
|
||||
state.models[app_label, self.model_name_lower].fields
|
||||
]
|
||||
# TODO: investigate if old relational fields must be reloaded or if it's
|
||||
# sufficient if the new field is (#27737).
|
||||
# Delay rendering of relationships if it's not a relational field and
|
||||
# not referenced by a foreign key.
|
||||
delay = (
|
||||
not field.is_relation and
|
||||
not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name)
|
||||
)
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
from_field = from_model._meta.get_field(self.name)
|
||||
to_field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
to_field.default = self.field.default
|
||||
schema_editor.alter_field(from_model, from_field, to_field)
|
||||
if not self.preserve_default:
|
||||
to_field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Alter field %s on %s" % (self.name, self.model_name)
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
|
||||
return [operation]
|
||||
elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
|
||||
return [
|
||||
operation,
|
||||
AlterField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label=app_label)
|
||||
|
||||
|
||||
class RenameField(FieldOperation):
|
||||
"""Rename a field on the model. Might affect db_column too."""
|
||||
|
||||
def __init__(self, model_name, old_name, new_name):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
super().__init__(model_name, old_name)
|
||||
|
||||
@cached_property
|
||||
def old_name_lower(self):
|
||||
return self.old_name.lower()
|
||||
|
||||
@cached_property
|
||||
def new_name_lower(self):
|
||||
return self.new_name.lower()
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'old_name': self.old_name,
|
||||
'new_name': self.new_name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
# Rename the field
|
||||
fields = model_state.fields
|
||||
found = False
|
||||
delay = True
|
||||
for index, (name, field) in enumerate(fields):
|
||||
if not found and name == self.old_name:
|
||||
fields[index] = (self.new_name, field)
|
||||
found = True
|
||||
# Fix from_fields to refer to the new field.
|
||||
from_fields = getattr(field, 'from_fields', None)
|
||||
if from_fields:
|
||||
field.from_fields = tuple([
|
||||
self.new_name if from_field_name == self.old_name else from_field_name
|
||||
for from_field_name in from_fields
|
||||
])
|
||||
# Delay rendering of relationships if it's not a relational
|
||||
# field and not referenced by a foreign key.
|
||||
delay = delay and (
|
||||
not field.is_relation and
|
||||
not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name)
|
||||
)
|
||||
if not found:
|
||||
raise FieldDoesNotExist(
|
||||
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
|
||||
)
|
||||
# Fix index/unique_together to refer to the new field
|
||||
options = model_state.options
|
||||
for option in ('index_together', 'unique_together'):
|
||||
if option in options:
|
||||
options[option] = [
|
||||
[self.new_name if n == self.old_name else n for n in together]
|
||||
for together in options[option]
|
||||
]
|
||||
# Fix to_fields to refer to the new field.
|
||||
model_tuple = app_label, self.model_name_lower
|
||||
for (model_app_label, model_name), model_state in state.models.items():
|
||||
for index, (name, field) in enumerate(model_state.fields):
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
remote_model_tuple = self._get_model_tuple(
|
||||
remote_field.model, model_app_label, model_name
|
||||
)
|
||||
if remote_model_tuple == model_tuple:
|
||||
if getattr(remote_field, 'field_name', None) == self.old_name:
|
||||
remote_field.field_name = self.new_name
|
||||
to_fields = getattr(field, 'to_fields', None)
|
||||
if to_fields:
|
||||
field.to_fields = tuple([
|
||||
self.new_name if to_field_name == self.old_name else to_field_name
|
||||
for to_field_name in to_fields
|
||||
])
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.old_name),
|
||||
to_model._meta.get_field(self.new_name),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.new_name),
|
||||
to_model._meta.get_field(self.old_name),
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
|
||||
|
||||
def references_field(self, model_name, name, app_label=None):
|
||||
return self.references_model(model_name) and (
|
||||
name.lower() == self.old_name_lower or
|
||||
name.lower() == self.new_name_lower
|
||||
)
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
if (isinstance(operation, RenameField) and
|
||||
self.is_same_model_operation(operation) and
|
||||
self.new_name_lower == operation.old_name_lower):
|
||||
return [
|
||||
RenameField(
|
||||
self.model_name,
|
||||
self.old_name,
|
||||
operation.new_name,
|
||||
),
|
||||
]
|
||||
# Skip `FieldOperation.reduce` as we want to run `references_field`
|
||||
# against self.new_name.
|
||||
return (
|
||||
super(FieldOperation, self).reduce(operation, app_label=app_label) or
|
||||
not operation.references_field(self.model_name, self.new_name, app_label)
|
||||
)
|
||||
@@ -0,0 +1,873 @@
|
||||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.state import ModelState
|
||||
from django.db.models.options import normalize_together
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .fields import (
|
||||
AddField, AlterField, FieldOperation, RemoveField, RenameField,
|
||||
)
|
||||
from .utils import ModelTuple, field_references_model
|
||||
|
||||
|
||||
def _check_for_duplicates(arg_name, objs):
|
||||
used_vals = set()
|
||||
for val in objs:
|
||||
if val in used_vals:
|
||||
raise ValueError(
|
||||
"Found duplicate value %s in CreateModel %s argument." % (val, arg_name)
|
||||
)
|
||||
used_vals.add(val)
|
||||
|
||||
|
||||
class ModelOperation(Operation):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
return name.lower() == self.name_lower
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
return (
|
||||
super().reduce(operation, app_label=app_label) or
|
||||
not operation.references_model(self.name, app_label)
|
||||
)
|
||||
|
||||
|
||||
class CreateModel(ModelOperation):
|
||||
"""Create a model's table."""
|
||||
|
||||
serialization_expand_args = ['fields', 'options', 'managers']
|
||||
|
||||
def __init__(self, name, fields, options=None, bases=None, managers=None):
|
||||
self.fields = fields
|
||||
self.options = options or {}
|
||||
self.bases = bases or (models.Model,)
|
||||
self.managers = managers or []
|
||||
super().__init__(name)
|
||||
# Sanity-check that there are no duplicated field names, bases, or
|
||||
# manager names
|
||||
_check_for_duplicates('fields', (name for name, _ in self.fields))
|
||||
_check_for_duplicates('bases', (
|
||||
base._meta.label_lower if hasattr(base, '_meta') else
|
||||
base.lower() if isinstance(base, str) else base
|
||||
for base in self.bases
|
||||
))
|
||||
_check_for_duplicates('managers', (name for name, _ in self.managers))
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'fields': self.fields,
|
||||
}
|
||||
if self.options:
|
||||
kwargs['options'] = self.options
|
||||
if self.bases and self.bases != (models.Model,):
|
||||
kwargs['bases'] = self.bases
|
||||
if self.managers and self.managers != [('objects', models.Manager())]:
|
||||
kwargs['managers'] = self.managers
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_model(ModelState(
|
||||
app_label,
|
||||
self.name,
|
||||
list(self.fields),
|
||||
dict(self.options),
|
||||
tuple(self.bases),
|
||||
list(self.managers),
|
||||
))
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.create_model(model)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.delete_model(model)
|
||||
|
||||
def describe(self):
|
||||
return "Create %smodel %s" % ("proxy " if self.options.get("proxy", False) else "", self.name)
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
name_lower = name.lower()
|
||||
if name_lower == self.name_lower:
|
||||
return True
|
||||
|
||||
# Check we didn't inherit from the model
|
||||
model_tuple = ModelTuple(app_label, name_lower)
|
||||
for base in self.bases:
|
||||
if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
|
||||
ModelTuple.from_model(base) == model_tuple):
|
||||
return True
|
||||
|
||||
# Check we have no FKs/M2Ms with it
|
||||
for _name, field in self.fields:
|
||||
if field_references_model(field, model_tuple):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
if (isinstance(operation, DeleteModel) and
|
||||
self.name_lower == operation.name_lower and
|
||||
not self.options.get("proxy", False)):
|
||||
return []
|
||||
elif isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
operation.new_name,
|
||||
fields=self.fields,
|
||||
options=self.options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields,
|
||||
options={**self.options, **operation.options},
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterTogetherOptionOperation) and self.name_lower == operation.name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields,
|
||||
options={**self.options, **{operation.option_name: operation.option_value}},
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterOrderWithRespectTo) and self.name_lower == operation.name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields,
|
||||
options={**self.options, 'order_with_respect_to': operation.order_with_respect_to},
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower:
|
||||
if isinstance(operation, AddField):
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields + [(operation.name, operation.field)],
|
||||
options=self.options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterField):
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=[
|
||||
(n, operation.field if n == operation.name else v)
|
||||
for n, v in self.fields
|
||||
],
|
||||
options=self.options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RemoveField):
|
||||
options = self.options.copy()
|
||||
for option_name in ('unique_together', 'index_together'):
|
||||
option = options.pop(option_name, None)
|
||||
if option:
|
||||
option = set(filter(bool, (
|
||||
tuple(f for f in fields if f != operation.name_lower) for fields in option
|
||||
)))
|
||||
if option:
|
||||
options[option_name] = option
|
||||
order_with_respect_to = options.get('order_with_respect_to')
|
||||
if order_with_respect_to == operation.name_lower:
|
||||
del options['order_with_respect_to']
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=[
|
||||
(n, v)
|
||||
for n, v in self.fields
|
||||
if n.lower() != operation.name_lower
|
||||
],
|
||||
options=options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RenameField):
|
||||
options = self.options.copy()
|
||||
for option_name in ('unique_together', 'index_together'):
|
||||
option = options.get(option_name)
|
||||
if option:
|
||||
options[option_name] = {
|
||||
tuple(operation.new_name if f == operation.old_name else f for f in fields)
|
||||
for fields in option
|
||||
}
|
||||
order_with_respect_to = options.get('order_with_respect_to')
|
||||
if order_with_respect_to == operation.old_name:
|
||||
options['order_with_respect_to'] = operation.new_name
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=[
|
||||
(operation.new_name if n == operation.old_name else n, v)
|
||||
for n, v in self.fields
|
||||
],
|
||||
options=options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label=app_label)
|
||||
|
||||
|
||||
class DeleteModel(ModelOperation):
|
||||
"""Drop a model's table."""
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_model(app_label, self.name_lower)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.delete_model(model)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.create_model(model)
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
# The deleted model could be referencing the specified model through
|
||||
# related fields.
|
||||
return True
|
||||
|
||||
def describe(self):
|
||||
return "Delete model %s" % self.name
|
||||
|
||||
|
||||
class RenameModel(ModelOperation):
|
||||
"""Rename a model."""
|
||||
|
||||
def __init__(self, old_name, new_name):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
super().__init__(old_name)
|
||||
|
||||
@cached_property
|
||||
def old_name_lower(self):
|
||||
return self.old_name.lower()
|
||||
|
||||
@cached_property
|
||||
def new_name_lower(self):
|
||||
return self.new_name.lower()
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'old_name': self.old_name,
|
||||
'new_name': self.new_name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# Add a new model.
|
||||
renamed_model = state.models[app_label, self.old_name_lower].clone()
|
||||
renamed_model.name = self.new_name
|
||||
state.models[app_label, self.new_name_lower] = renamed_model
|
||||
# Repoint all fields pointing to the old model to the new one.
|
||||
old_model_tuple = ModelTuple(app_label, self.old_name_lower)
|
||||
new_remote_model = '%s.%s' % (app_label, self.new_name)
|
||||
to_reload = []
|
||||
for (model_app_label, model_name), model_state in state.models.items():
|
||||
model_changed = False
|
||||
for index, (name, field) in enumerate(model_state.fields):
|
||||
changed_field = None
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
remote_model_tuple = ModelTuple.from_model(
|
||||
remote_field.model, model_app_label, model_name
|
||||
)
|
||||
if remote_model_tuple == old_model_tuple:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.model = new_remote_model
|
||||
through_model = getattr(remote_field, 'through', None)
|
||||
if through_model:
|
||||
through_model_tuple = ModelTuple.from_model(
|
||||
through_model, model_app_label, model_name
|
||||
)
|
||||
if through_model_tuple == old_model_tuple:
|
||||
if changed_field is None:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.through = new_remote_model
|
||||
if changed_field:
|
||||
model_state.fields[index] = name, changed_field
|
||||
model_changed = True
|
||||
if model_changed:
|
||||
to_reload.append((model_app_label, model_name))
|
||||
# Reload models related to old model before removing the old model.
|
||||
state.reload_models(to_reload, delay=True)
|
||||
# Remove the old model.
|
||||
state.remove_model(app_label, self.old_name_lower)
|
||||
state.reload_model(app_label, self.new_name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.new_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
|
||||
old_model = from_state.apps.get_model(app_label, self.old_name)
|
||||
# Move the main table
|
||||
schema_editor.alter_db_table(
|
||||
new_model,
|
||||
old_model._meta.db_table,
|
||||
new_model._meta.db_table,
|
||||
)
|
||||
# Alter the fields pointing to us
|
||||
for related_object in old_model._meta.related_objects:
|
||||
if related_object.related_model == old_model:
|
||||
model = new_model
|
||||
related_key = (app_label, self.new_name_lower)
|
||||
else:
|
||||
model = related_object.related_model
|
||||
related_key = (
|
||||
related_object.related_model._meta.app_label,
|
||||
related_object.related_model._meta.model_name,
|
||||
)
|
||||
to_field = to_state.apps.get_model(
|
||||
*related_key
|
||||
)._meta.get_field(related_object.field.name)
|
||||
schema_editor.alter_field(
|
||||
model,
|
||||
related_object.field,
|
||||
to_field,
|
||||
)
|
||||
# Rename M2M fields whose name is based on this model's name.
|
||||
fields = zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many)
|
||||
for (old_field, new_field) in fields:
|
||||
# Skip self-referential fields as these are renamed above.
|
||||
if new_field.model == new_field.related_model or not new_field.remote_field.through._meta.auto_created:
|
||||
continue
|
||||
# Rename the M2M table that's based on this model's name.
|
||||
old_m2m_model = old_field.remote_field.through
|
||||
new_m2m_model = new_field.remote_field.through
|
||||
schema_editor.alter_db_table(
|
||||
new_m2m_model,
|
||||
old_m2m_model._meta.db_table,
|
||||
new_m2m_model._meta.db_table,
|
||||
)
|
||||
# Rename the column in the M2M table that's based on this
|
||||
# model's name.
|
||||
schema_editor.alter_field(
|
||||
new_m2m_model,
|
||||
old_m2m_model._meta.get_field(old_model._meta.model_name),
|
||||
new_m2m_model._meta.get_field(new_model._meta.model_name),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
|
||||
self.new_name, self.old_name = self.old_name, self.new_name
|
||||
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
|
||||
self.new_name, self.old_name = self.old_name, self.new_name
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
return (
|
||||
name.lower() == self.old_name_lower or
|
||||
name.lower() == self.new_name_lower
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Rename model %s to %s" % (self.old_name, self.new_name)
|
||||
|
||||
def reduce(self, operation, app_label=None):
|
||||
if (isinstance(operation, RenameModel) and
|
||||
self.new_name_lower == operation.old_name_lower):
|
||||
return [
|
||||
RenameModel(
|
||||
self.old_name,
|
||||
operation.new_name,
|
||||
),
|
||||
]
|
||||
# Skip `ModelOperation.reduce` as we want to run `references_model`
|
||||
# against self.new_name.
|
||||
return (
|
||||
super(ModelOperation, self).reduce(operation, app_label=app_label) or
|
||||
not operation.references_model(self.new_name, app_label)
|
||||
)
|
||||
|
||||
|
||||
class ModelOptionOperation(ModelOperation):
|
||||
def reduce(self, operation, app_label=None):
|
||||
if isinstance(operation, (self.__class__, DeleteModel)) and self.name_lower == operation.name_lower:
|
||||
return [operation]
|
||||
return super().reduce(operation, app_label=app_label)
|
||||
|
||||
|
||||
class AlterModelTable(ModelOptionOperation):
|
||||
"""Rename a model's table."""
|
||||
|
||||
def __init__(self, name, table):
|
||||
self.table = table
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'table': self.table,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.models[app_label, self.name_lower].options["db_table"] = self.table
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
|
||||
old_model = from_state.apps.get_model(app_label, self.name)
|
||||
schema_editor.alter_db_table(
|
||||
new_model,
|
||||
old_model._meta.db_table,
|
||||
new_model._meta.db_table,
|
||||
)
|
||||
# Rename M2M fields whose name is based on this model's db_table
|
||||
for (old_field, new_field) in zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many):
|
||||
if new_field.remote_field.through._meta.auto_created:
|
||||
schema_editor.alter_db_table(
|
||||
new_field.remote_field.through,
|
||||
old_field.remote_field.through._meta.db_table,
|
||||
new_field.remote_field.through._meta.db_table,
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
return self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Rename table for %s to %s" % (
|
||||
self.name,
|
||||
self.table if self.table is not None else "(default)"
|
||||
)
|
||||
|
||||
|
||||
class AlterTogetherOptionOperation(ModelOptionOperation):
|
||||
option_name = None
|
||||
|
||||
def __init__(self, name, option_value):
|
||||
if option_value:
|
||||
option_value = set(normalize_together(option_value))
|
||||
setattr(self, self.option_name, option_value)
|
||||
super().__init__(name)
|
||||
|
||||
@cached_property
|
||||
def option_value(self):
|
||||
return getattr(self, self.option_name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
self.option_name: self.option_value,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.options[self.option_name] = self.option_value
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
|
||||
old_model = from_state.apps.get_model(app_label, self.name)
|
||||
alter_together = getattr(schema_editor, 'alter_%s' % self.option_name)
|
||||
alter_together(
|
||||
new_model,
|
||||
getattr(old_model._meta, self.option_name, set()),
|
||||
getattr(new_model._meta, self.option_name, set()),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
return self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def references_field(self, model_name, name, app_label=None):
|
||||
return (
|
||||
self.references_model(model_name, app_label) and
|
||||
(
|
||||
not self.option_value or
|
||||
any((name in fields) for fields in self.option_value)
|
||||
)
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Alter %s for %s (%s constraint(s))" % (self.option_name, self.name, len(self.option_value or ''))
|
||||
|
||||
|
||||
class AlterUniqueTogether(AlterTogetherOptionOperation):
|
||||
"""
|
||||
Change the value of unique_together to the target one.
|
||||
Input value of unique_together must be a set of tuples.
|
||||
"""
|
||||
option_name = 'unique_together'
|
||||
|
||||
def __init__(self, name, unique_together):
|
||||
super().__init__(name, unique_together)
|
||||
|
||||
|
||||
class AlterIndexTogether(AlterTogetherOptionOperation):
|
||||
"""
|
||||
Change the value of index_together to the target one.
|
||||
Input value of index_together must be a set of tuples.
|
||||
"""
|
||||
option_name = "index_together"
|
||||
|
||||
def __init__(self, name, index_together):
|
||||
super().__init__(name, index_together)
|
||||
|
||||
|
||||
class AlterOrderWithRespectTo(ModelOptionOperation):
|
||||
"""Represent a change with the order_with_respect_to option."""
|
||||
|
||||
option_name = 'order_with_respect_to'
|
||||
|
||||
def __init__(self, name, order_with_respect_to):
|
||||
self.order_with_respect_to = order_with_respect_to
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'order_with_respect_to': self.order_with_respect_to,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.options['order_with_respect_to'] = self.order_with_respect_to
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.name)
|
||||
# Remove a field if we need to
|
||||
if from_model._meta.order_with_respect_to and not to_model._meta.order_with_respect_to:
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field("_order"))
|
||||
# Add a field if we need to (altering the column is untouched as
|
||||
# it's likely a rename)
|
||||
elif to_model._meta.order_with_respect_to and not from_model._meta.order_with_respect_to:
|
||||
field = to_model._meta.get_field("_order")
|
||||
if not field.has_default():
|
||||
field.default = 0
|
||||
schema_editor.add_field(
|
||||
from_model,
|
||||
field,
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def references_field(self, model_name, name, app_label=None):
|
||||
return (
|
||||
self.references_model(model_name, app_label) and
|
||||
(
|
||||
self.order_with_respect_to is None or
|
||||
name == self.order_with_respect_to
|
||||
)
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Set order_with_respect_to on %s to %s" % (self.name, self.order_with_respect_to)
|
||||
|
||||
|
||||
class AlterModelOptions(ModelOptionOperation):
|
||||
"""
|
||||
Set new model options that don't directly affect the database schema
|
||||
(like verbose_name, permissions, ordering). Python code in migrations
|
||||
may still need them.
|
||||
"""
|
||||
|
||||
# Model options we want to compare and preserve in an AlterModelOptions op
|
||||
ALTER_OPTION_KEYS = [
|
||||
"base_manager_name",
|
||||
"default_manager_name",
|
||||
"default_related_name",
|
||||
"get_latest_by",
|
||||
"managed",
|
||||
"ordering",
|
||||
"permissions",
|
||||
"default_permissions",
|
||||
"select_on_save",
|
||||
"verbose_name",
|
||||
"verbose_name_plural",
|
||||
]
|
||||
|
||||
def __init__(self, name, options):
|
||||
self.options = options
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'options': self.options,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.options = {**model_state.options, **self.options}
|
||||
for key in self.ALTER_OPTION_KEYS:
|
||||
if key not in self.options:
|
||||
model_state.options.pop(key, False)
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def describe(self):
|
||||
return "Change Meta options on %s" % self.name
|
||||
|
||||
|
||||
class AlterModelManagers(ModelOptionOperation):
|
||||
"""Alter the model's managers."""
|
||||
|
||||
serialization_expand_args = ['managers']
|
||||
|
||||
def __init__(self, name, managers):
|
||||
self.managers = managers
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[self.name, self.managers],
|
||||
{}
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.managers = list(self.managers)
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def describe(self):
|
||||
return "Change managers on %s" % self.name
|
||||
|
||||
|
||||
class IndexOperation(Operation):
|
||||
option_name = 'indexes'
|
||||
|
||||
@cached_property
|
||||
def model_name_lower(self):
|
||||
return self.model_name.lower()
|
||||
|
||||
|
||||
class AddIndex(IndexOperation):
|
||||
"""Add an index on a model."""
|
||||
|
||||
def __init__(self, model_name, index):
|
||||
self.model_name = model_name
|
||||
if not index.name:
|
||||
raise ValueError(
|
||||
"Indexes passed to AddIndex operations require a name "
|
||||
"argument. %r doesn't have one." % index
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.index.clone()]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.add_index(model, self.index)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.remove_index(model, self.index)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'index': self.index,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return 'Create index %s on field(s) %s of model %s' % (
|
||||
self.index.name,
|
||||
', '.join(self.index.fields),
|
||||
self.model_name,
|
||||
)
|
||||
|
||||
|
||||
class RemoveIndex(IndexOperation):
|
||||
"""Remove an index from a model."""
|
||||
|
||||
def __init__(self, model_name, name):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
indexes = model_state.options[self.option_name]
|
||||
model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
from_model_state = from_state.models[app_label, self.model_name_lower]
|
||||
index = from_model_state.get_index_by_name(self.name)
|
||||
schema_editor.remove_index(model, index)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
to_model_state = to_state.models[app_label, self.model_name_lower]
|
||||
index = to_model_state.get_index_by_name(self.name)
|
||||
schema_editor.add_index(model, index)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return 'Remove index %s from %s' % (self.name, self.model_name)
|
||||
|
||||
|
||||
class AddConstraint(IndexOperation):
|
||||
option_name = 'constraints'
|
||||
|
||||
def __init__(self, model_name, constraint):
|
||||
self.model_name = model_name
|
||||
self.constraint = constraint
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.constraint]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.add_constraint(model, self.constraint)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.remove_constraint(model, self.constraint)
|
||||
|
||||
def deconstruct(self):
|
||||
return self.__class__.__name__, [], {
|
||||
'model_name': self.model_name,
|
||||
'constraint': self.constraint,
|
||||
}
|
||||
|
||||
def describe(self):
|
||||
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
|
||||
|
||||
|
||||
class RemoveConstraint(IndexOperation):
|
||||
option_name = 'constraints'
|
||||
|
||||
def __init__(self, model_name, name):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
constraints = model_state.options[self.option_name]
|
||||
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
from_model_state = from_state.models[app_label, self.model_name_lower]
|
||||
constraint = from_model_state.get_constraint_by_name(self.name)
|
||||
schema_editor.remove_constraint(model, constraint)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
to_model_state = to_state.models[app_label, self.model_name_lower]
|
||||
constraint = to_model_state.get_constraint_by_name(self.name)
|
||||
schema_editor.add_constraint(model, constraint)
|
||||
|
||||
def deconstruct(self):
|
||||
return self.__class__.__name__, [], {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
|
||||
def describe(self):
|
||||
return 'Remove constraint %s from model %s' % (self.name, self.model_name)
|
||||
@@ -0,0 +1,203 @@
|
||||
from django.db import router
|
||||
|
||||
from .base import Operation
|
||||
|
||||
|
||||
class SeparateDatabaseAndState(Operation):
|
||||
"""
|
||||
Take two lists of operations - ones that will be used for the database,
|
||||
and ones that will be used for the state change. This allows operations
|
||||
that don't support state change to have it applied, or have operations
|
||||
that affect the state or not the database, or so on.
|
||||
"""
|
||||
|
||||
serialization_expand_args = ['database_operations', 'state_operations']
|
||||
|
||||
def __init__(self, database_operations=None, state_operations=None):
|
||||
self.database_operations = database_operations or []
|
||||
self.state_operations = state_operations or []
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {}
|
||||
if self.database_operations:
|
||||
kwargs['database_operations'] = self.database_operations
|
||||
if self.state_operations:
|
||||
kwargs['state_operations'] = self.state_operations
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
for database_operation in self.database_operations:
|
||||
to_state = from_state.clone()
|
||||
database_operation.state_forwards(app_label, to_state)
|
||||
database_operation.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
from_state = to_state
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
to_states = {}
|
||||
for dbop in self.database_operations:
|
||||
to_states[dbop] = to_state
|
||||
to_state = to_state.clone()
|
||||
dbop.state_forwards(app_label, to_state)
|
||||
# to_state now has the states of all the database_operations applied
|
||||
# which is the from_state for the backwards migration of the last
|
||||
# operation.
|
||||
for database_operation in reversed(self.database_operations):
|
||||
from_state = to_state
|
||||
to_state = to_states[database_operation]
|
||||
database_operation.database_backwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Custom state/database change combination"
|
||||
|
||||
|
||||
class RunSQL(Operation):
|
||||
"""
|
||||
Run some raw SQL. A reverse SQL statement may be provided.
|
||||
|
||||
Also accept a list of operations that represent the state change effected
|
||||
by this SQL change, in case it's custom column/table creation/deletion.
|
||||
"""
|
||||
noop = ''
|
||||
|
||||
def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False):
|
||||
self.sql = sql
|
||||
self.reverse_sql = reverse_sql
|
||||
self.state_operations = state_operations or []
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'sql': self.sql,
|
||||
}
|
||||
if self.reverse_sql is not None:
|
||||
kwargs['reverse_sql'] = self.reverse_sql
|
||||
if self.state_operations:
|
||||
kwargs['state_operations'] = self.state_operations
|
||||
if self.hints:
|
||||
kwargs['hints'] = self.hints
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_sql is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
self._run_sql(schema_editor, self.sql)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_sql is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
self._run_sql(schema_editor, self.reverse_sql)
|
||||
|
||||
def describe(self):
|
||||
return "Raw SQL operation"
|
||||
|
||||
def _run_sql(self, schema_editor, sqls):
|
||||
if isinstance(sqls, (list, tuple)):
|
||||
for sql in sqls:
|
||||
params = None
|
||||
if isinstance(sql, (list, tuple)):
|
||||
elements = len(sql)
|
||||
if elements == 2:
|
||||
sql, params = sql
|
||||
else:
|
||||
raise ValueError("Expected a 2-tuple but got %d" % elements)
|
||||
schema_editor.execute(sql, params=params)
|
||||
elif sqls != RunSQL.noop:
|
||||
statements = schema_editor.connection.ops.prepare_sql_script(sqls)
|
||||
for statement in statements:
|
||||
schema_editor.execute(statement, params=None)
|
||||
|
||||
|
||||
class RunPython(Operation):
|
||||
"""
|
||||
Run Python code in a context suitable for doing versioned ORM operations.
|
||||
"""
|
||||
|
||||
reduces_to_sql = False
|
||||
|
||||
def __init__(self, code, reverse_code=None, atomic=None, hints=None, elidable=False):
|
||||
self.atomic = atomic
|
||||
# Forwards code
|
||||
if not callable(code):
|
||||
raise ValueError("RunPython must be supplied with a callable")
|
||||
self.code = code
|
||||
# Reverse code
|
||||
if reverse_code is None:
|
||||
self.reverse_code = None
|
||||
else:
|
||||
if not callable(reverse_code):
|
||||
raise ValueError("RunPython must be supplied with callable arguments")
|
||||
self.reverse_code = reverse_code
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'code': self.code,
|
||||
}
|
||||
if self.reverse_code is not None:
|
||||
kwargs['reverse_code'] = self.reverse_code
|
||||
if self.atomic is not None:
|
||||
kwargs['atomic'] = self.atomic
|
||||
if self.hints:
|
||||
kwargs['hints'] = self.hints
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_code is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# RunPython objects have no state effect. To add some, combine this
|
||||
# with SeparateDatabaseAndState.
|
||||
pass
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# RunPython has access to all models. Ensure that all models are
|
||||
# reloaded in case any are delayed.
|
||||
from_state.clear_delayed_apps_cache()
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
# We now execute the Python code in a context that contains a 'models'
|
||||
# object, representing the versioned models as an app registry.
|
||||
# We could try to override the global cache, but then people will still
|
||||
# use direct imports, so we go with a documentation approach instead.
|
||||
self.code(from_state.apps, schema_editor)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_code is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
self.reverse_code(from_state.apps, schema_editor)
|
||||
|
||||
def describe(self):
|
||||
return "Raw Python operation"
|
||||
|
||||
@staticmethod
|
||||
def noop(apps, schema_editor):
|
||||
return None
|
||||
@@ -0,0 +1,53 @@
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
|
||||
def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
|
||||
for state_app_label, state_model in state.models:
|
||||
for _, f in state.models[state_app_label, state_model].fields:
|
||||
if (f.related_model and
|
||||
'%s.%s' % (state_app_label, model_name_lower) == f.related_model.lower() and
|
||||
hasattr(f, 'to_fields')):
|
||||
if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ModelTuple(namedtuple('ModelTupleBase', ('app_label', 'model_name'))):
|
||||
@classmethod
|
||||
def from_model(cls, model, app_label=None, model_name=None):
|
||||
"""
|
||||
Take a model class or an 'app_label.ModelName' string and return a
|
||||
ModelTuple('app_label', 'modelname'). The optional app_label and
|
||||
model_name arguments are the defaults if "self" or "ModelName" are
|
||||
passed.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
return cls(app_label, model_name)
|
||||
if '.' in model:
|
||||
return cls(*model.lower().split('.', 1))
|
||||
return cls(app_label, model.lower())
|
||||
return cls(model._meta.app_label, model._meta.model_name)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, ModelTuple):
|
||||
# Consider ModelTuple equal if their model_name is equal and either
|
||||
# one of them is missing an app_label.
|
||||
return self.model_name == other.model_name and (
|
||||
self.app_label is None or other.app_label is None or self.app_label == other.app_label
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
def field_references_model(field, model_tuple):
|
||||
"""Return whether or not field references model_tuple."""
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
if ModelTuple.from_model(remote_field.model) == model_tuple:
|
||||
return True
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if through and ModelTuple.from_model(through) == model_tuple:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,70 @@
|
||||
class MigrationOptimizer:
|
||||
"""
|
||||
Power the optimization process, where you provide a list of Operations
|
||||
and you are returned a list of equal or shorter length - operations
|
||||
are merged into one if possible.
|
||||
|
||||
For example, a CreateModel and an AddField can be optimized into a
|
||||
new CreateModel, and CreateModel and DeleteModel can be optimized into
|
||||
nothing.
|
||||
"""
|
||||
|
||||
def optimize(self, operations, app_label=None):
|
||||
"""
|
||||
Main optimization entry point. Pass in a list of Operation instances,
|
||||
get out a new list of Operation instances.
|
||||
|
||||
Unfortunately, due to the scope of the optimization (two combinable
|
||||
operations might be separated by several hundred others), this can't be
|
||||
done as a peephole optimization with checks/output implemented on
|
||||
the Operations themselves; instead, the optimizer looks at each
|
||||
individual operation and scans forwards in the list to see if there
|
||||
are any matches, stopping at boundaries - operations which can't
|
||||
be optimized over (RunSQL, operations on the same field/model, etc.)
|
||||
|
||||
The inner loop is run until the starting list is the same as the result
|
||||
list, and then the result is returned. This means that operation
|
||||
optimization must be stable and always return an equal or shorter list.
|
||||
|
||||
The app_label argument is optional, but if you pass it you'll get more
|
||||
efficient optimization.
|
||||
"""
|
||||
# Internal tracking variable for test assertions about # of loops
|
||||
self._iterations = 0
|
||||
while True:
|
||||
result = self.optimize_inner(operations, app_label)
|
||||
self._iterations += 1
|
||||
if result == operations:
|
||||
return result
|
||||
operations = result
|
||||
|
||||
def optimize_inner(self, operations, app_label=None):
|
||||
"""Inner optimization loop."""
|
||||
new_operations = []
|
||||
for i, operation in enumerate(operations):
|
||||
right = True # Should we reduce on the right or on the left.
|
||||
# Compare it to each operation after it
|
||||
for j, other in enumerate(operations[i + 1:]):
|
||||
in_between = operations[i + 1:i + j + 1]
|
||||
result = operation.reduce(other, app_label)
|
||||
if isinstance(result, list):
|
||||
if right:
|
||||
new_operations.extend(in_between)
|
||||
new_operations.extend(result)
|
||||
elif all(op.reduce(other, app_label) is True for op in in_between):
|
||||
# Perform a left reduction if all of the in-between
|
||||
# operations can optimize through other.
|
||||
new_operations.extend(result)
|
||||
new_operations.extend(in_between)
|
||||
else:
|
||||
# Otherwise keep trying.
|
||||
new_operations.append(operation)
|
||||
break
|
||||
new_operations.extend(operations[i + j + 2:])
|
||||
return new_operations
|
||||
elif not result:
|
||||
# Can't perform a right reduction.
|
||||
right = False
|
||||
else:
|
||||
new_operations.append(operation)
|
||||
return new_operations
|
||||
@@ -0,0 +1,239 @@
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from django.apps import apps
|
||||
from django.db.models.fields import NOT_PROVIDED
|
||||
from django.utils import timezone
|
||||
|
||||
from .loader import MigrationLoader
|
||||
|
||||
|
||||
class MigrationQuestioner:
|
||||
"""
|
||||
Give the autodetector responses to questions it might have.
|
||||
This base class has a built-in noninteractive mode, but the
|
||||
interactive subclass is what the command-line arguments will use.
|
||||
"""
|
||||
|
||||
def __init__(self, defaults=None, specified_apps=None, dry_run=None):
|
||||
self.defaults = defaults or {}
|
||||
self.specified_apps = specified_apps or set()
|
||||
self.dry_run = dry_run
|
||||
|
||||
def ask_initial(self, app_label):
|
||||
"""Should we create an initial migration for the app?"""
|
||||
# If it was specified on the command line, definitely true
|
||||
if app_label in self.specified_apps:
|
||||
return True
|
||||
# Otherwise, we look to see if it has a migrations module
|
||||
# without any Python files in it, apart from __init__.py.
|
||||
# Apps from the new app template will have these; the Python
|
||||
# file check will ensure we skip South ones.
|
||||
try:
|
||||
app_config = apps.get_app_config(app_label)
|
||||
except LookupError: # It's a fake app.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
|
||||
if migrations_import_path is None:
|
||||
# It's an application with migrations disabled.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
try:
|
||||
migrations_module = importlib.import_module(migrations_import_path)
|
||||
except ImportError:
|
||||
return self.defaults.get("ask_initial", False)
|
||||
else:
|
||||
# getattr() needed on PY36 and older (replace with attribute access).
|
||||
if getattr(migrations_module, "__file__", None):
|
||||
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
|
||||
elif hasattr(migrations_module, "__path__"):
|
||||
if len(migrations_module.__path__) > 1:
|
||||
return False
|
||||
filenames = os.listdir(list(migrations_module.__path__)[0])
|
||||
return not any(x.endswith(".py") for x in filenames if x != "__init__.py")
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
return self.defaults.get("ask_rename", False)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
return self.defaults.get("ask_rename_model", False)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
"""Do you really want to merge these migrations?"""
|
||||
return self.defaults.get("ask_merge", False)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
|
||||
class InteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
|
||||
def _boolean_input(self, question, default=None):
|
||||
result = input("%s " % question)
|
||||
if not result and default is not None:
|
||||
return default
|
||||
while not result or result[0].lower() not in "yn":
|
||||
result = input("Please answer yes or no: ")
|
||||
return result[0].lower() == "y"
|
||||
|
||||
def _choice_input(self, question, choices):
|
||||
print(question)
|
||||
for i, choice in enumerate(choices):
|
||||
print(" %s) %s" % (i + 1, choice))
|
||||
result = input("Select an option: ")
|
||||
while True:
|
||||
try:
|
||||
value = int(result)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if 0 < value <= len(choices):
|
||||
return value
|
||||
result = input("Please select a valid option: ")
|
||||
|
||||
def _ask_default(self, default=''):
|
||||
"""
|
||||
Prompt for a default value.
|
||||
|
||||
The ``default`` argument allows providing a custom default value (as a
|
||||
string) which will be shown to the user and used as the return value
|
||||
if the user doesn't provide any other input.
|
||||
"""
|
||||
print("Please enter the default value now, as valid Python")
|
||||
if default:
|
||||
print(
|
||||
"You can accept the default '{}' by pressing 'Enter' or you "
|
||||
"can provide another value.".format(default)
|
||||
)
|
||||
print("The datetime and django.utils.timezone modules are available, so you can do e.g. timezone.now")
|
||||
print("Type 'exit' to exit this prompt")
|
||||
while True:
|
||||
if default:
|
||||
prompt = "[default: {}] >>> ".format(default)
|
||||
else:
|
||||
prompt = ">>> "
|
||||
code = input(prompt)
|
||||
if not code and default:
|
||||
code = default
|
||||
if not code:
|
||||
print("Please enter some code, or 'exit' (with no quotes) to exit.")
|
||||
elif code == "exit":
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
return eval(code, {}, {'datetime': datetime, 'timezone': timezone})
|
||||
except (SyntaxError, NameError) as e:
|
||||
print("Invalid input: %s" % e)
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
"You are trying to add a non-nullable field '%s' to %s without a default; "
|
||||
"we can't do that (the database needs something to populate existing rows).\n"
|
||||
"Please select a fix:" % (field_name, model_name),
|
||||
[
|
||||
("Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"),
|
||||
"Quit, and let me add a default in models.py",
|
||||
]
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
"You are trying to change the nullable field '%s' on %s to non-nullable "
|
||||
"without a default; we can't do that (the database needs something to "
|
||||
"populate existing rows).\n"
|
||||
"Please select a fix:" % (field_name, model_name),
|
||||
[
|
||||
("Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"),
|
||||
("Ignore for now, and let me handle existing rows with NULL myself "
|
||||
"(e.g. because you added a RunPython or RunSQL operation to handle "
|
||||
"NULL values in a previous data migration)"),
|
||||
"Quit, and let me add a default in models.py",
|
||||
]
|
||||
)
|
||||
if choice == 2:
|
||||
return NOT_PROVIDED
|
||||
elif choice == 3:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
msg = "Did you rename %s.%s to %s.%s (a %s)? [y/N]"
|
||||
return self._boolean_input(msg % (model_name, old_name, model_name, new_name,
|
||||
field_instance.__class__.__name__), False)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
msg = "Did you rename the %s.%s model to %s? [y/N]"
|
||||
return self._boolean_input(msg % (old_model_state.app_label, old_model_state.name,
|
||||
new_model_state.name), False)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
return self._boolean_input(
|
||||
"\nMerging will only work if the operations printed above do not conflict\n" +
|
||||
"with each other (working on different fields or models)\n" +
|
||||
"Do you want to merge these migration branches? [y/N]",
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
"You are trying to add the field '{}' with 'auto_now_add=True' "
|
||||
"to {} without a default; the database needs something to "
|
||||
"populate existing rows.\n".format(field_name, model_name),
|
||||
[
|
||||
"Provide a one-off default now (will be set on all "
|
||||
"existing rows)",
|
||||
"Quit, and let me add a default in models.py",
|
||||
]
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default(default='timezone.now')
|
||||
return None
|
||||
|
||||
|
||||
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
sys.exit(3)
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
# We can't ask the user, so set as not provided.
|
||||
return NOT_PROVIDED
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
sys.exit(3)
|
||||
@@ -0,0 +1,95 @@
|
||||
from django.apps.registry import Apps
|
||||
from django.db import models
|
||||
from django.db.utils import DatabaseError
|
||||
from django.utils.decorators import classproperty
|
||||
from django.utils.timezone import now
|
||||
|
||||
from .exceptions import MigrationSchemaMissing
|
||||
|
||||
|
||||
class MigrationRecorder:
|
||||
"""
|
||||
Deal with storing migration records in the database.
|
||||
|
||||
Because this table is actually itself used for dealing with model
|
||||
creation, it's the one thing we can't do normally via migrations.
|
||||
We manually handle table creation/schema updating (using schema backend)
|
||||
and then have a floating model to do queries with.
|
||||
|
||||
If a migration is unapplied its row is removed from the table. Having
|
||||
a row in the table always means a migration is applied.
|
||||
"""
|
||||
_migration_class = None
|
||||
|
||||
@classproperty
|
||||
def Migration(cls):
|
||||
"""
|
||||
Lazy load to avoid AppRegistryNotReady if installed apps import
|
||||
MigrationRecorder.
|
||||
"""
|
||||
if cls._migration_class is None:
|
||||
class Migration(models.Model):
|
||||
app = models.CharField(max_length=255)
|
||||
name = models.CharField(max_length=255)
|
||||
applied = models.DateTimeField(default=now)
|
||||
|
||||
class Meta:
|
||||
apps = Apps()
|
||||
app_label = 'migrations'
|
||||
db_table = 'django_migrations'
|
||||
|
||||
def __str__(self):
|
||||
return 'Migration %s for %s' % (self.name, self.app)
|
||||
|
||||
cls._migration_class = Migration
|
||||
return cls._migration_class
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
@property
|
||||
def migration_qs(self):
|
||||
return self.Migration.objects.using(self.connection.alias)
|
||||
|
||||
def has_table(self):
|
||||
"""Return True if the django_migrations table exists."""
|
||||
return self.Migration._meta.db_table in self.connection.introspection.table_names(self.connection.cursor())
|
||||
|
||||
def ensure_schema(self):
|
||||
"""Ensure the table exists and has the correct schema."""
|
||||
# If the table's there, that's fine - we've never changed its schema
|
||||
# in the codebase.
|
||||
if self.has_table():
|
||||
return
|
||||
# Make the table
|
||||
try:
|
||||
with self.connection.schema_editor() as editor:
|
||||
editor.create_model(self.Migration)
|
||||
except DatabaseError as exc:
|
||||
raise MigrationSchemaMissing("Unable to create the django_migrations table (%s)" % exc)
|
||||
|
||||
def applied_migrations(self):
|
||||
"""
|
||||
Return a dict mapping (app_name, migration_name) to Migration instances
|
||||
for all applied migrations.
|
||||
"""
|
||||
if self.has_table():
|
||||
return {(migration.app, migration.name): migration for migration in self.migration_qs}
|
||||
else:
|
||||
# If the django_migrations table doesn't exist, then no migrations
|
||||
# are applied.
|
||||
return {}
|
||||
|
||||
def record_applied(self, app, name):
|
||||
"""Record that a migration was applied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.create(app=app, name=name)
|
||||
|
||||
def record_unapplied(self, app, name):
|
||||
"""Record that a migration was unapplied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.filter(app=app, name=name).delete()
|
||||
|
||||
def flush(self):
|
||||
"""Delete all migration records. Useful for testing migrations."""
|
||||
self.migration_qs.all().delete()
|
||||
@@ -0,0 +1,340 @@
|
||||
import builtins
|
||||
import collections.abc
|
||||
import datetime
|
||||
import decimal
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import re
|
||||
import types
|
||||
import uuid
|
||||
|
||||
from django.conf import SettingsReference
|
||||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
|
||||
from django.utils.functional import LazyObject, Promise
|
||||
from django.utils.timezone import utc
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
|
||||
class BaseSerializer:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def serialize(self):
|
||||
raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.')
|
||||
|
||||
|
||||
class BaseSequenceSerializer(BaseSerializer):
|
||||
def _format(self):
|
||||
raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.')
|
||||
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
value = self._format()
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class BaseSimpleSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), set()
|
||||
|
||||
|
||||
class ChoicesSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return serializer_factory(self.value.value).serialize()
|
||||
|
||||
|
||||
class DateTimeSerializer(BaseSerializer):
|
||||
"""For datetime.*, except datetime.datetime."""
|
||||
def serialize(self):
|
||||
return repr(self.value), {'import datetime'}
|
||||
|
||||
|
||||
class DatetimeDatetimeSerializer(BaseSerializer):
|
||||
"""For datetime.datetime."""
|
||||
def serialize(self):
|
||||
if self.value.tzinfo is not None and self.value.tzinfo != utc:
|
||||
self.value = self.value.astimezone(utc)
|
||||
imports = ["import datetime"]
|
||||
if self.value.tzinfo is not None:
|
||||
imports.append("from django.utils.timezone import utc")
|
||||
return repr(self.value).replace('<UTC>', 'utc'), set(imports)
|
||||
|
||||
|
||||
class DecimalSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), {"from decimal import Decimal"}
|
||||
|
||||
|
||||
class DeconstructableSerializer(BaseSerializer):
|
||||
@staticmethod
|
||||
def serialize_deconstructed(path, args, kwargs):
|
||||
name, imports = DeconstructableSerializer._serialize_path(path)
|
||||
strings = []
|
||||
for arg in args:
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
strings.append(arg_string)
|
||||
imports.update(arg_imports)
|
||||
for kw, arg in sorted(kwargs.items()):
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
imports.update(arg_imports)
|
||||
strings.append("%s=%s" % (kw, arg_string))
|
||||
return "%s(%s)" % (name, ", ".join(strings)), imports
|
||||
|
||||
@staticmethod
|
||||
def _serialize_path(path):
|
||||
module, name = path.rsplit(".", 1)
|
||||
if module == "django.db.models":
|
||||
imports = {"from django.db import models"}
|
||||
name = "models.%s" % name
|
||||
else:
|
||||
imports = {"import %s" % module}
|
||||
name = path
|
||||
return name, imports
|
||||
|
||||
def serialize(self):
|
||||
return self.serialize_deconstructed(*self.value.deconstruct())
|
||||
|
||||
|
||||
class DictionarySerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for k, v in sorted(self.value.items()):
|
||||
k_string, k_imports = serializer_factory(k).serialize()
|
||||
v_string, v_imports = serializer_factory(v).serialize()
|
||||
imports.update(k_imports)
|
||||
imports.update(v_imports)
|
||||
strings.append((k_string, v_string))
|
||||
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
|
||||
|
||||
|
||||
class EnumSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
enum_class = self.value.__class__
|
||||
module = enum_class.__module__
|
||||
return (
|
||||
'%s.%s[%r]' % (module, enum_class.__qualname__, self.value.name),
|
||||
{'import %s' % module},
|
||||
)
|
||||
|
||||
|
||||
class FloatSerializer(BaseSimpleSerializer):
|
||||
def serialize(self):
|
||||
if math.isnan(self.value) or math.isinf(self.value):
|
||||
return 'float("{}")'.format(self.value), set()
|
||||
return super().serialize()
|
||||
|
||||
|
||||
class FrozensetSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
return "frozenset([%s])"
|
||||
|
||||
|
||||
class FunctionTypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type):
|
||||
klass = self.value.__self__
|
||||
module = klass.__module__
|
||||
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module}
|
||||
# Further error checking
|
||||
if self.value.__name__ == '<lambda>':
|
||||
raise ValueError("Cannot serialize function: lambda")
|
||||
if self.value.__module__ is None:
|
||||
raise ValueError("Cannot serialize function %r: No module" % self.value)
|
||||
|
||||
module_name = self.value.__module__
|
||||
|
||||
if '<' not in self.value.__qualname__: # Qualname can include <locals>
|
||||
return '%s.%s' % (module_name, self.value.__qualname__), {'import %s' % self.value.__module__}
|
||||
|
||||
raise ValueError(
|
||||
'Could not find function %s in %s.\n' % (self.value.__name__, module_name)
|
||||
)
|
||||
|
||||
|
||||
class FunctoolsPartialSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Serialize functools.partial() arguments
|
||||
func_string, func_imports = serializer_factory(self.value.func).serialize()
|
||||
args_string, args_imports = serializer_factory(self.value.args).serialize()
|
||||
keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize()
|
||||
# Add any imports needed by arguments
|
||||
imports = {'import functools', *func_imports, *args_imports, *keywords_imports}
|
||||
return (
|
||||
'functools.%s(%s, *%s, **%s)' % (
|
||||
self.value.__class__.__name__,
|
||||
func_string,
|
||||
args_string,
|
||||
keywords_string,
|
||||
),
|
||||
imports,
|
||||
)
|
||||
|
||||
|
||||
class IterableSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
# When len(strings)==0, the empty iterable should be serialized as
|
||||
# "()", not "(,)" because (,) is invalid Python syntax.
|
||||
value = "(%s)" if len(strings) != 1 else "(%s,)"
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class ModelFieldSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
attr_name, path, args, kwargs = self.value.deconstruct()
|
||||
return self.serialize_deconstructed(path, args, kwargs)
|
||||
|
||||
|
||||
class ModelManagerSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
|
||||
if as_manager:
|
||||
name, imports = self._serialize_path(qs_path)
|
||||
return "%s.as_manager()" % name, imports
|
||||
else:
|
||||
return self.serialize_deconstructed(manager_path, args, kwargs)
|
||||
|
||||
|
||||
class OperationSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
from django.db.migrations.writer import OperationWriter
|
||||
string, imports = OperationWriter(self.value, indentation=0).serialize()
|
||||
# Nested operation, trailing comma is handled in upper OperationWriter._write()
|
||||
return string.rstrip(','), imports
|
||||
|
||||
|
||||
class RegexSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize()
|
||||
# Turn off default implicit flags (e.g. re.U) because regexes with the
|
||||
# same implicit and explicit flags aren't equal.
|
||||
flags = self.value.flags ^ re.compile('').flags
|
||||
regex_flags, flag_imports = serializer_factory(flags).serialize()
|
||||
imports = {'import re', *pattern_imports, *flag_imports}
|
||||
args = [regex_pattern]
|
||||
if flags:
|
||||
args.append(regex_flags)
|
||||
return "re.compile(%s)" % ', '.join(args), imports
|
||||
|
||||
|
||||
class SequenceSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
return "[%s]"
|
||||
|
||||
|
||||
class SetSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
# Serialize as a set literal except when value is empty because {}
|
||||
# is an empty dict.
|
||||
return '{%s}' if self.value else 'set(%s)'
|
||||
|
||||
|
||||
class SettingsReferenceSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "settings.%s" % self.value.setting_name, {"from django.conf import settings"}
|
||||
|
||||
|
||||
class TupleSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
# When len(value)==0, the empty tuple should be serialized as "()",
|
||||
# not "(,)" because (,) is invalid Python syntax.
|
||||
return "(%s)" if len(self.value) != 1 else "(%s,)"
|
||||
|
||||
|
||||
class TypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
special_cases = [
|
||||
(models.Model, "models.Model", []),
|
||||
(type(None), 'type(None)', []),
|
||||
]
|
||||
for case, string, imports in special_cases:
|
||||
if case is self.value:
|
||||
return string, set(imports)
|
||||
if hasattr(self.value, "__module__"):
|
||||
module = self.value.__module__
|
||||
if module == builtins.__name__:
|
||||
return self.value.__name__, set()
|
||||
else:
|
||||
return "%s.%s" % (module, self.value.__qualname__), {"import %s" % module}
|
||||
|
||||
|
||||
class UUIDSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "uuid.%s" % repr(self.value), {"import uuid"}
|
||||
|
||||
|
||||
class Serializer:
|
||||
_registry = {
|
||||
# Some of these are order-dependent.
|
||||
frozenset: FrozensetSerializer,
|
||||
list: SequenceSerializer,
|
||||
set: SetSerializer,
|
||||
tuple: TupleSerializer,
|
||||
dict: DictionarySerializer,
|
||||
models.Choices: ChoicesSerializer,
|
||||
enum.Enum: EnumSerializer,
|
||||
datetime.datetime: DatetimeDatetimeSerializer,
|
||||
(datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
|
||||
SettingsReference: SettingsReferenceSerializer,
|
||||
float: FloatSerializer,
|
||||
(bool, int, type(None), bytes, str, range): BaseSimpleSerializer,
|
||||
decimal.Decimal: DecimalSerializer,
|
||||
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
|
||||
(types.FunctionType, types.BuiltinFunctionType, types.MethodType): FunctionTypeSerializer,
|
||||
collections.abc.Iterable: IterableSerializer,
|
||||
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
|
||||
uuid.UUID: UUIDSerializer,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register(cls, type_, serializer):
|
||||
if not issubclass(serializer, BaseSerializer):
|
||||
raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
|
||||
cls._registry[type_] = serializer
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, type_):
|
||||
cls._registry.pop(type_)
|
||||
|
||||
|
||||
def serializer_factory(value):
|
||||
if isinstance(value, Promise):
|
||||
value = str(value)
|
||||
elif isinstance(value, LazyObject):
|
||||
# The unwrapped value is returned as the first item of the arguments
|
||||
# tuple.
|
||||
value = value.__reduce__()[1][0]
|
||||
|
||||
if isinstance(value, models.Field):
|
||||
return ModelFieldSerializer(value)
|
||||
if isinstance(value, models.manager.BaseManager):
|
||||
return ModelManagerSerializer(value)
|
||||
if isinstance(value, Operation):
|
||||
return OperationSerializer(value)
|
||||
if isinstance(value, type):
|
||||
return TypeSerializer(value)
|
||||
# Anything that knows how to deconstruct itself.
|
||||
if hasattr(value, 'deconstruct'):
|
||||
return DeconstructableSerializer(value)
|
||||
for type_, serializer_cls in Serializer._registry.items():
|
||||
if isinstance(value, type_):
|
||||
return serializer_cls(value)
|
||||
raise ValueError(
|
||||
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
|
||||
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
|
||||
"topics/migrations/#migration-serializing" % (value, get_docs_version())
|
||||
)
|
||||
611
venv/lib/python3.8/site-packages/django/db/migrations/state.py
Normal file
611
venv/lib/python3.8/site-packages/django/db/migrations/state.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.apps.registry import Apps, apps as global_apps
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.db.models.fields.proxy import OrderWrt
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
from django.db.models.options import DEFAULT_NAMES, normalize_together
|
||||
from django.db.models.utils import make_model_tuple
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
from .exceptions import InvalidBasesError
|
||||
|
||||
|
||||
def _get_app_label_and_model_name(model, app_label=''):
|
||||
if isinstance(model, str):
|
||||
split = model.split('.', 1)
|
||||
return tuple(split) if len(split) == 2 else (app_label, split[0])
|
||||
else:
|
||||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
|
||||
def _get_related_models(m):
|
||||
"""Return all models that have a direct relationship to the given model."""
|
||||
related_models = [
|
||||
subclass for subclass in m.__subclasses__()
|
||||
if issubclass(subclass, models.Model)
|
||||
]
|
||||
related_fields_models = set()
|
||||
for f in m._meta.get_fields(include_parents=True, include_hidden=True):
|
||||
if f.is_relation and f.related_model is not None and not isinstance(f.related_model, str):
|
||||
related_fields_models.add(f.model)
|
||||
related_models.append(f.related_model)
|
||||
# Reverse accessors of foreign keys to proxy models are attached to their
|
||||
# concrete proxied model.
|
||||
opts = m._meta
|
||||
if opts.proxy and m in related_fields_models:
|
||||
related_models.append(opts.concrete_model)
|
||||
return related_models
|
||||
|
||||
|
||||
def get_related_models_tuples(model):
|
||||
"""
|
||||
Return a list of typical (app_label, model_name) tuples for all related
|
||||
models for the given model.
|
||||
"""
|
||||
return {
|
||||
(rel_mod._meta.app_label, rel_mod._meta.model_name)
|
||||
for rel_mod in _get_related_models(model)
|
||||
}
|
||||
|
||||
|
||||
def get_related_models_recursive(model):
|
||||
"""
|
||||
Return all models that have a direct or indirect relationship
|
||||
to the given model.
|
||||
|
||||
Relationships are either defined by explicit relational fields, like
|
||||
ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another
|
||||
model (a superclass is related to its subclasses, but not vice versa). Note,
|
||||
however, that a model inheriting from a concrete model is also related to
|
||||
its superclass through the implicit *_ptr OneToOneField on the subclass.
|
||||
"""
|
||||
seen = set()
|
||||
queue = _get_related_models(model)
|
||||
for rel_mod in queue:
|
||||
rel_app_label, rel_model_name = rel_mod._meta.app_label, rel_mod._meta.model_name
|
||||
if (rel_app_label, rel_model_name) in seen:
|
||||
continue
|
||||
seen.add((rel_app_label, rel_model_name))
|
||||
queue.extend(_get_related_models(rel_mod))
|
||||
return seen - {(model._meta.app_label, model._meta.model_name)}
|
||||
|
||||
|
||||
class ProjectState:
|
||||
"""
|
||||
Represent the entire project's overall state. This is the item that is
|
||||
passed around - do it here rather than at the app level so that cross-app
|
||||
FKs/etc. resolve properly.
|
||||
"""
|
||||
|
||||
def __init__(self, models=None, real_apps=None):
|
||||
self.models = models or {}
|
||||
# Apps to include from main registry, usually unmigrated ones
|
||||
self.real_apps = real_apps or []
|
||||
self.is_delayed = False
|
||||
|
||||
def add_model(self, model_state):
|
||||
app_label, model_name = model_state.app_label, model_state.name_lower
|
||||
self.models[(app_label, model_name)] = model_state
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
self.reload_model(app_label, model_name)
|
||||
|
||||
def remove_model(self, app_label, model_name):
|
||||
del self.models[app_label, model_name]
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
self.apps.unregister_model(app_label, model_name)
|
||||
# Need to do this explicitly since unregister_model() doesn't clear
|
||||
# the cache automatically (#24513)
|
||||
self.apps.clear_cache()
|
||||
|
||||
def _find_reload_model(self, app_label, model_name, delay=False):
|
||||
if delay:
|
||||
self.is_delayed = True
|
||||
|
||||
related_models = set()
|
||||
|
||||
try:
|
||||
old_model = self.apps.get_model(app_label, model_name)
|
||||
except LookupError:
|
||||
pass
|
||||
else:
|
||||
# Get all relations to and from the old model before reloading,
|
||||
# as _meta.apps may change
|
||||
if delay:
|
||||
related_models = get_related_models_tuples(old_model)
|
||||
else:
|
||||
related_models = get_related_models_recursive(old_model)
|
||||
|
||||
# Get all outgoing references from the model to be rendered
|
||||
model_state = self.models[(app_label, model_name)]
|
||||
# Directly related models are the models pointed to by ForeignKeys,
|
||||
# OneToOneFields, and ManyToManyFields.
|
||||
direct_related_models = set()
|
||||
for name, field in model_state.fields:
|
||||
if field.is_relation:
|
||||
if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
continue
|
||||
rel_app_label, rel_model_name = _get_app_label_and_model_name(field.related_model, app_label)
|
||||
direct_related_models.add((rel_app_label, rel_model_name.lower()))
|
||||
|
||||
# For all direct related models recursively get all related models.
|
||||
related_models.update(direct_related_models)
|
||||
for rel_app_label, rel_model_name in direct_related_models:
|
||||
try:
|
||||
rel_model = self.apps.get_model(rel_app_label, rel_model_name)
|
||||
except LookupError:
|
||||
pass
|
||||
else:
|
||||
if delay:
|
||||
related_models.update(get_related_models_tuples(rel_model))
|
||||
else:
|
||||
related_models.update(get_related_models_recursive(rel_model))
|
||||
|
||||
# Include the model itself
|
||||
related_models.add((app_label, model_name))
|
||||
|
||||
return related_models
|
||||
|
||||
def reload_model(self, app_label, model_name, delay=False):
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
related_models = self._find_reload_model(app_label, model_name, delay)
|
||||
self._reload(related_models)
|
||||
|
||||
def reload_models(self, models, delay=True):
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
related_models = set()
|
||||
for app_label, model_name in models:
|
||||
related_models.update(self._find_reload_model(app_label, model_name, delay))
|
||||
self._reload(related_models)
|
||||
|
||||
def _reload(self, related_models):
|
||||
# Unregister all related models
|
||||
with self.apps.bulk_update():
|
||||
for rel_app_label, rel_model_name in related_models:
|
||||
self.apps.unregister_model(rel_app_label, rel_model_name)
|
||||
|
||||
states_to_be_rendered = []
|
||||
# Gather all models states of those models that will be rerendered.
|
||||
# This includes:
|
||||
# 1. All related models of unmigrated apps
|
||||
for model_state in self.apps.real_models:
|
||||
if (model_state.app_label, model_state.name_lower) in related_models:
|
||||
states_to_be_rendered.append(model_state)
|
||||
|
||||
# 2. All related models of migrated apps
|
||||
for rel_app_label, rel_model_name in related_models:
|
||||
try:
|
||||
model_state = self.models[rel_app_label, rel_model_name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
states_to_be_rendered.append(model_state)
|
||||
|
||||
# Render all models
|
||||
self.apps.render_multiple(states_to_be_rendered)
|
||||
|
||||
def clone(self):
|
||||
"""Return an exact copy of this ProjectState."""
|
||||
new_state = ProjectState(
|
||||
models={k: v.clone() for k, v in self.models.items()},
|
||||
real_apps=self.real_apps,
|
||||
)
|
||||
if 'apps' in self.__dict__:
|
||||
new_state.apps = self.apps.clone()
|
||||
new_state.is_delayed = self.is_delayed
|
||||
return new_state
|
||||
|
||||
def clear_delayed_apps_cache(self):
|
||||
if self.is_delayed and 'apps' in self.__dict__:
|
||||
del self.__dict__['apps']
|
||||
|
||||
@cached_property
|
||||
def apps(self):
|
||||
return StateApps(self.real_apps, self.models)
|
||||
|
||||
@property
|
||||
def concrete_apps(self):
|
||||
self.apps = StateApps(self.real_apps, self.models, ignore_swappable=True)
|
||||
return self.apps
|
||||
|
||||
@classmethod
|
||||
def from_apps(cls, apps):
|
||||
"""Take an Apps and return a ProjectState matching it."""
|
||||
app_models = {}
|
||||
for model in apps.get_models(include_swapped=True):
|
||||
model_state = ModelState.from_model(model)
|
||||
app_models[(model_state.app_label, model_state.name_lower)] = model_state
|
||||
return cls(app_models)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.models == other.models and set(self.real_apps) == set(other.real_apps)
|
||||
|
||||
|
||||
class AppConfigStub(AppConfig):
|
||||
"""Stub of an AppConfig. Only provides a label and a dict of models."""
|
||||
# Not used, but required by AppConfig.__init__
|
||||
path = ''
|
||||
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
# App-label and app-name are not the same thing, so technically passing
|
||||
# in the label here is wrong. In practice, migrations don't care about
|
||||
# the app name, but we need something unique, and the label works fine.
|
||||
super().__init__(label, None)
|
||||
|
||||
def import_models(self):
|
||||
self.models = self.apps.all_models[self.label]
|
||||
|
||||
|
||||
class StateApps(Apps):
|
||||
"""
|
||||
Subclass of the global Apps registry class to better handle dynamic model
|
||||
additions and removals.
|
||||
"""
|
||||
def __init__(self, real_apps, models, ignore_swappable=False):
|
||||
# Any apps in self.real_apps should have all their models included
|
||||
# in the render. We don't use the original model instances as there
|
||||
# are some variables that refer to the Apps object.
|
||||
# FKs/M2Ms from real apps are also not included as they just
|
||||
# mess things up with partial states (due to lack of dependencies)
|
||||
self.real_models = []
|
||||
for app_label in real_apps:
|
||||
app = global_apps.get_app_config(app_label)
|
||||
for model in app.get_models():
|
||||
self.real_models.append(ModelState.from_model(model, exclude_rels=True))
|
||||
# Populate the app registry with a stub for each application.
|
||||
app_labels = {model_state.app_label for model_state in models.values()}
|
||||
app_configs = [AppConfigStub(label) for label in sorted([*real_apps, *app_labels])]
|
||||
super().__init__(app_configs)
|
||||
|
||||
# These locks get in the way of copying as implemented in clone(),
|
||||
# which is called whenever Django duplicates a StateApps before
|
||||
# updating it.
|
||||
self._lock = None
|
||||
self.ready_event = None
|
||||
|
||||
self.render_multiple([*models.values(), *self.real_models])
|
||||
|
||||
# There shouldn't be any operations pending at this point.
|
||||
from django.core.checks.model_checks import _check_lazy_references
|
||||
ignore = {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
|
||||
errors = _check_lazy_references(self, ignore=ignore)
|
||||
if errors:
|
||||
raise ValueError("\n".join(error.msg for error in errors))
|
||||
|
||||
@contextmanager
|
||||
def bulk_update(self):
|
||||
# Avoid clearing each model's cache for each change. Instead, clear
|
||||
# all caches when we're finished updating the model instances.
|
||||
ready = self.ready
|
||||
self.ready = False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.ready = ready
|
||||
self.clear_cache()
|
||||
|
||||
def render_multiple(self, model_states):
|
||||
# We keep trying to render the models in a loop, ignoring invalid
|
||||
# base errors, until the size of the unrendered models doesn't
|
||||
# decrease by at least one, meaning there's a base dependency loop/
|
||||
# missing base.
|
||||
if not model_states:
|
||||
return
|
||||
# Prevent that all model caches are expired for each render.
|
||||
with self.bulk_update():
|
||||
unrendered_models = model_states
|
||||
while unrendered_models:
|
||||
new_unrendered_models = []
|
||||
for model in unrendered_models:
|
||||
try:
|
||||
model.render(self)
|
||||
except InvalidBasesError:
|
||||
new_unrendered_models.append(model)
|
||||
if len(new_unrendered_models) == len(unrendered_models):
|
||||
raise InvalidBasesError(
|
||||
"Cannot resolve bases for %r\nThis can happen if you are inheriting models from an "
|
||||
"app with migrations (e.g. contrib.auth)\n in an app with no migrations; see "
|
||||
"https://docs.djangoproject.com/en/%s/topics/migrations/#dependencies "
|
||||
"for more" % (new_unrendered_models, get_docs_version())
|
||||
)
|
||||
unrendered_models = new_unrendered_models
|
||||
|
||||
def clone(self):
|
||||
"""Return a clone of this registry."""
|
||||
clone = StateApps([], {})
|
||||
clone.all_models = copy.deepcopy(self.all_models)
|
||||
clone.app_configs = copy.deepcopy(self.app_configs)
|
||||
# Set the pointer to the correct app registry.
|
||||
for app_config in clone.app_configs.values():
|
||||
app_config.apps = clone
|
||||
# No need to actually clone them, they'll never change
|
||||
clone.real_models = self.real_models
|
||||
return clone
|
||||
|
||||
def register_model(self, app_label, model):
|
||||
self.all_models[app_label][model._meta.model_name] = model
|
||||
if app_label not in self.app_configs:
|
||||
self.app_configs[app_label] = AppConfigStub(app_label)
|
||||
self.app_configs[app_label].apps = self
|
||||
self.app_configs[app_label].models = {}
|
||||
self.app_configs[app_label].models[model._meta.model_name] = model
|
||||
self.do_pending_operations(model)
|
||||
self.clear_cache()
|
||||
|
||||
def unregister_model(self, app_label, model_name):
|
||||
try:
|
||||
del self.all_models[app_label][model_name]
|
||||
del self.app_configs[app_label].models[model_name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class ModelState:
|
||||
"""
|
||||
Represent a Django Model. Don't use the actual Model class as it's not
|
||||
designed to have its options changed - instead, mutate this one and then
|
||||
render it into a Model as required.
|
||||
|
||||
Note that while you are allowed to mutate .fields, you are not allowed
|
||||
to mutate the Field instances inside there themselves - you must instead
|
||||
assign new ones, as these are not detached during a clone.
|
||||
"""
|
||||
|
||||
def __init__(self, app_label, name, fields, options=None, bases=None, managers=None):
|
||||
self.app_label = app_label
|
||||
self.name = name
|
||||
self.fields = fields
|
||||
self.options = options or {}
|
||||
self.options.setdefault('indexes', [])
|
||||
self.options.setdefault('constraints', [])
|
||||
self.bases = bases or (models.Model,)
|
||||
self.managers = managers or []
|
||||
# Sanity-check that fields is NOT a dict. It must be ordered.
|
||||
if isinstance(self.fields, dict):
|
||||
raise ValueError("ModelState.fields cannot be a dict - it must be a list of 2-tuples.")
|
||||
for name, field in fields:
|
||||
# Sanity-check that fields are NOT already bound to a model.
|
||||
if hasattr(field, 'model'):
|
||||
raise ValueError(
|
||||
'ModelState.fields cannot be bound to a model - "%s" is.' % name
|
||||
)
|
||||
# Sanity-check that relation fields are NOT referring to a model class.
|
||||
if field.is_relation and hasattr(field.related_model, '_meta'):
|
||||
raise ValueError(
|
||||
'ModelState.fields cannot refer to a model class - "%s.to" does. '
|
||||
'Use a string reference instead.' % name
|
||||
)
|
||||
if field.many_to_many and hasattr(field.remote_field.through, '_meta'):
|
||||
raise ValueError(
|
||||
'ModelState.fields cannot refer to a model class - "%s.through" does. '
|
||||
'Use a string reference instead.' % name
|
||||
)
|
||||
# Sanity-check that indexes have their name set.
|
||||
for index in self.options['indexes']:
|
||||
if not index.name:
|
||||
raise ValueError(
|
||||
"Indexes passed to ModelState require a name attribute. "
|
||||
"%r doesn't have one." % index
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model, exclude_rels=False):
|
||||
"""Given a model, return a ModelState representing it."""
|
||||
# Deconstruct the fields
|
||||
fields = []
|
||||
for field in model._meta.local_fields:
|
||||
if getattr(field, "remote_field", None) and exclude_rels:
|
||||
continue
|
||||
if isinstance(field, OrderWrt):
|
||||
continue
|
||||
name = field.name
|
||||
try:
|
||||
fields.append((name, field.clone()))
|
||||
except TypeError as e:
|
||||
raise TypeError("Couldn't reconstruct field %s on %s: %s" % (
|
||||
name,
|
||||
model._meta.label,
|
||||
e,
|
||||
))
|
||||
if not exclude_rels:
|
||||
for field in model._meta.local_many_to_many:
|
||||
name = field.name
|
||||
try:
|
||||
fields.append((name, field.clone()))
|
||||
except TypeError as e:
|
||||
raise TypeError("Couldn't reconstruct m2m field %s on %s: %s" % (
|
||||
name,
|
||||
model._meta.object_name,
|
||||
e,
|
||||
))
|
||||
# Extract the options
|
||||
options = {}
|
||||
for name in DEFAULT_NAMES:
|
||||
# Ignore some special options
|
||||
if name in ["apps", "app_label"]:
|
||||
continue
|
||||
elif name in model._meta.original_attrs:
|
||||
if name == "unique_together":
|
||||
ut = model._meta.original_attrs["unique_together"]
|
||||
options[name] = set(normalize_together(ut))
|
||||
elif name == "index_together":
|
||||
it = model._meta.original_attrs["index_together"]
|
||||
options[name] = set(normalize_together(it))
|
||||
elif name == "indexes":
|
||||
indexes = [idx.clone() for idx in model._meta.indexes]
|
||||
for index in indexes:
|
||||
if not index.name:
|
||||
index.set_name_with_model(model)
|
||||
options['indexes'] = indexes
|
||||
elif name == 'constraints':
|
||||
options['constraints'] = [con.clone() for con in model._meta.constraints]
|
||||
else:
|
||||
options[name] = model._meta.original_attrs[name]
|
||||
# If we're ignoring relationships, remove all field-listing model
|
||||
# options (that option basically just means "make a stub model")
|
||||
if exclude_rels:
|
||||
for key in ["unique_together", "index_together", "order_with_respect_to"]:
|
||||
if key in options:
|
||||
del options[key]
|
||||
# Private fields are ignored, so remove options that refer to them.
|
||||
elif options.get('order_with_respect_to') in {field.name for field in model._meta.private_fields}:
|
||||
del options['order_with_respect_to']
|
||||
|
||||
def flatten_bases(model):
|
||||
bases = []
|
||||
for base in model.__bases__:
|
||||
if hasattr(base, "_meta") and base._meta.abstract:
|
||||
bases.extend(flatten_bases(base))
|
||||
else:
|
||||
bases.append(base)
|
||||
return bases
|
||||
|
||||
# We can't rely on __mro__ directly because we only want to flatten
|
||||
# abstract models and not the whole tree. However by recursing on
|
||||
# __bases__ we may end up with duplicates and ordering issues, we
|
||||
# therefore discard any duplicates and reorder the bases according
|
||||
# to their index in the MRO.
|
||||
flattened_bases = sorted(set(flatten_bases(model)), key=lambda x: model.__mro__.index(x))
|
||||
|
||||
# Make our record
|
||||
bases = tuple(
|
||||
(
|
||||
base._meta.label_lower
|
||||
if hasattr(base, "_meta") else
|
||||
base
|
||||
)
|
||||
for base in flattened_bases
|
||||
)
|
||||
# Ensure at least one base inherits from models.Model
|
||||
if not any((isinstance(base, str) or issubclass(base, models.Model)) for base in bases):
|
||||
bases = (models.Model,)
|
||||
|
||||
managers = []
|
||||
manager_names = set()
|
||||
default_manager_shim = None
|
||||
for manager in model._meta.managers:
|
||||
if manager.name in manager_names:
|
||||
# Skip overridden managers.
|
||||
continue
|
||||
elif manager.use_in_migrations:
|
||||
# Copy managers usable in migrations.
|
||||
new_manager = copy.copy(manager)
|
||||
new_manager._set_creation_counter()
|
||||
elif manager is model._base_manager or manager is model._default_manager:
|
||||
# Shim custom managers used as default and base managers.
|
||||
new_manager = models.Manager()
|
||||
new_manager.model = manager.model
|
||||
new_manager.name = manager.name
|
||||
if manager is model._default_manager:
|
||||
default_manager_shim = new_manager
|
||||
else:
|
||||
continue
|
||||
manager_names.add(manager.name)
|
||||
managers.append((manager.name, new_manager))
|
||||
|
||||
# Ignore a shimmed default manager called objects if it's the only one.
|
||||
if managers == [('objects', default_manager_shim)]:
|
||||
managers = []
|
||||
|
||||
# Construct the new ModelState
|
||||
return cls(
|
||||
model._meta.app_label,
|
||||
model._meta.object_name,
|
||||
fields,
|
||||
options,
|
||||
bases,
|
||||
managers,
|
||||
)
|
||||
|
||||
def construct_managers(self):
|
||||
"""Deep-clone the managers using deconstruction."""
|
||||
# Sort all managers by their creation counter
|
||||
sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter)
|
||||
for mgr_name, manager in sorted_managers:
|
||||
as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct()
|
||||
if as_manager:
|
||||
qs_class = import_string(qs_path)
|
||||
yield mgr_name, qs_class.as_manager()
|
||||
else:
|
||||
manager_class = import_string(manager_path)
|
||||
yield mgr_name, manager_class(*args, **kwargs)
|
||||
|
||||
def clone(self):
|
||||
"""Return an exact copy of this ModelState."""
|
||||
return self.__class__(
|
||||
app_label=self.app_label,
|
||||
name=self.name,
|
||||
fields=list(self.fields),
|
||||
# Since options are shallow-copied here, operations such as
|
||||
# AddIndex must replace their option (e.g 'indexes') rather
|
||||
# than mutating it.
|
||||
options=dict(self.options),
|
||||
bases=self.bases,
|
||||
managers=list(self.managers),
|
||||
)
|
||||
|
||||
def render(self, apps):
|
||||
"""Create a Model object from our current state into the given apps."""
|
||||
# First, make a Meta object
|
||||
meta_contents = {'app_label': self.app_label, 'apps': apps, **self.options}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
# Then, work out our bases
|
||||
try:
|
||||
bases = tuple(
|
||||
(apps.get_model(base) if isinstance(base, str) else base)
|
||||
for base in self.bases
|
||||
)
|
||||
except LookupError:
|
||||
raise InvalidBasesError("Cannot resolve one or more bases from %r" % (self.bases,))
|
||||
# Turn fields into a dict for the body, add other bits
|
||||
body = {name: field.clone() for name, field in self.fields}
|
||||
body['Meta'] = meta
|
||||
body['__module__'] = "__fake__"
|
||||
|
||||
# Restore managers
|
||||
body.update(self.construct_managers())
|
||||
# Then, make a Model object (apps.register_model is called in __new__)
|
||||
return type(self.name, bases, body)
|
||||
|
||||
def get_field_by_name(self, name):
|
||||
for fname, field in self.fields:
|
||||
if fname == name:
|
||||
return field
|
||||
raise ValueError("No field called %s on model %s" % (name, self.name))
|
||||
|
||||
def get_index_by_name(self, name):
|
||||
for index in self.options['indexes']:
|
||||
if index.name == name:
|
||||
return index
|
||||
raise ValueError("No index named %s on model %s" % (name, self.name))
|
||||
|
||||
def get_constraint_by_name(self, name):
|
||||
for constraint in self.options['constraints']:
|
||||
if constraint.name == name:
|
||||
return constraint
|
||||
raise ValueError('No constraint named %s on model %s' % (name, self.name))
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
(self.app_label == other.app_label) and
|
||||
(self.name == other.name) and
|
||||
(len(self.fields) == len(other.fields)) and
|
||||
all((k1 == k2 and (f1.deconstruct()[1:] == f2.deconstruct()[1:]))
|
||||
for (k1, f1), (k2, f2) in zip(self.fields, other.fields)) and
|
||||
(self.options == other.options) and
|
||||
(self.bases == other.bases) and
|
||||
(self.managers == other.managers)
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
import datetime
|
||||
import re
|
||||
|
||||
COMPILED_REGEX_TYPE = type(re.compile(''))
|
||||
|
||||
|
||||
class RegexObject:
|
||||
def __init__(self, obj):
|
||||
self.pattern = obj.pattern
|
||||
self.flags = obj.flags
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.pattern == other.pattern and self.flags == other.flags
|
||||
|
||||
|
||||
def get_migration_name_timestamp():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||||
300
venv/lib/python3.8/site-packages/django/db/migrations/writer.py
Normal file
300
venv/lib/python3.8/site-packages/django/db/migrations/writer.py
Normal file
@@ -0,0 +1,300 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from importlib import import_module
|
||||
|
||||
from django import get_version
|
||||
from django.apps import apps
|
||||
# SettingsReference imported for backwards compatibility in Django 2.2.
|
||||
from django.conf import SettingsReference # NOQA
|
||||
from django.db import migrations
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
from django.db.migrations.serializer import Serializer, serializer_factory
|
||||
from django.utils.inspect import get_func_args
|
||||
from django.utils.module_loading import module_dir
|
||||
from django.utils.timezone import now
|
||||
|
||||
|
||||
class OperationWriter:
|
||||
def __init__(self, operation, indentation=2):
|
||||
self.operation = operation
|
||||
self.buff = []
|
||||
self.indentation = indentation
|
||||
|
||||
def serialize(self):
|
||||
|
||||
def _write(_arg_name, _arg_value):
|
||||
if (_arg_name in self.operation.serialization_expand_args and
|
||||
isinstance(_arg_value, (list, tuple, dict))):
|
||||
if isinstance(_arg_value, dict):
|
||||
self.feed('%s={' % _arg_name)
|
||||
self.indent()
|
||||
for key, value in _arg_value.items():
|
||||
key_string, key_imports = MigrationWriter.serialize(key)
|
||||
arg_string, arg_imports = MigrationWriter.serialize(value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed('%s: %s' % (key_string, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed('%s,' % args[-1])
|
||||
else:
|
||||
self.feed('%s: %s,' % (key_string, arg_string))
|
||||
imports.update(key_imports)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed('},')
|
||||
else:
|
||||
self.feed('%s=[' % _arg_name)
|
||||
self.indent()
|
||||
for item in _arg_value:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(item)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
for arg in args[:-1]:
|
||||
self.feed(arg)
|
||||
self.feed('%s,' % args[-1])
|
||||
else:
|
||||
self.feed('%s,' % arg_string)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed('],')
|
||||
else:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed('%s=%s' % (_arg_name, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed('%s,' % args[-1])
|
||||
else:
|
||||
self.feed('%s=%s,' % (_arg_name, arg_string))
|
||||
imports.update(arg_imports)
|
||||
|
||||
imports = set()
|
||||
name, args, kwargs = self.operation.deconstruct()
|
||||
operation_args = get_func_args(self.operation.__init__)
|
||||
|
||||
# See if this operation is in django.db.migrations. If it is,
|
||||
# We can just use the fact we already have that imported,
|
||||
# otherwise, we need to add an import for the operation class.
|
||||
if getattr(migrations, name, None) == self.operation.__class__:
|
||||
self.feed('migrations.%s(' % name)
|
||||
else:
|
||||
imports.add('import %s' % (self.operation.__class__.__module__))
|
||||
self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
|
||||
|
||||
self.indent()
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
arg_value = arg
|
||||
arg_name = operation_args[i]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
i = len(args)
|
||||
# Only iterate over remaining arguments
|
||||
for arg_name in operation_args[i:]:
|
||||
if arg_name in kwargs: # Don't sort to maintain signature order
|
||||
arg_value = kwargs[arg_name]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
self.unindent()
|
||||
self.feed('),')
|
||||
return self.render(), imports
|
||||
|
||||
def indent(self):
|
||||
self.indentation += 1
|
||||
|
||||
def unindent(self):
|
||||
self.indentation -= 1
|
||||
|
||||
def feed(self, line):
|
||||
self.buff.append(' ' * (self.indentation * 4) + line)
|
||||
|
||||
def render(self):
|
||||
return '\n'.join(self.buff)
|
||||
|
||||
|
||||
class MigrationWriter:
|
||||
"""
|
||||
Take a Migration instance and is able to produce the contents
|
||||
of the migration file from it.
|
||||
"""
|
||||
|
||||
def __init__(self, migration, include_header=True):
|
||||
self.migration = migration
|
||||
self.include_header = include_header
|
||||
self.needs_manual_porting = False
|
||||
|
||||
def as_string(self):
|
||||
"""Return a string of the file contents."""
|
||||
items = {
|
||||
"replaces_str": "",
|
||||
"initial_str": "",
|
||||
}
|
||||
|
||||
imports = set()
|
||||
|
||||
# Deconstruct operations
|
||||
operations = []
|
||||
for operation in self.migration.operations:
|
||||
operation_string, operation_imports = OperationWriter(operation).serialize()
|
||||
imports.update(operation_imports)
|
||||
operations.append(operation_string)
|
||||
items["operations"] = "\n".join(operations) + "\n" if operations else ""
|
||||
|
||||
# Format dependencies and write out swappable dependencies right
|
||||
dependencies = []
|
||||
for dependency in self.migration.dependencies:
|
||||
if dependency[0] == "__setting__":
|
||||
dependencies.append(" migrations.swappable_dependency(settings.%s)," % dependency[1])
|
||||
imports.add("from django.conf import settings")
|
||||
else:
|
||||
dependencies.append(" %s," % self.serialize(dependency)[0])
|
||||
items["dependencies"] = "\n".join(dependencies) + "\n" if dependencies else ""
|
||||
|
||||
# Format imports nicely, swapping imports of functions from migration files
|
||||
# for comments
|
||||
migration_imports = set()
|
||||
for line in list(imports):
|
||||
if re.match(r"^import (.*)\.\d+[^\s]*$", line):
|
||||
migration_imports.add(line.split("import")[1].strip())
|
||||
imports.remove(line)
|
||||
self.needs_manual_porting = True
|
||||
|
||||
# django.db.migrations is always used, but models import may not be.
|
||||
# If models import exists, merge it with migrations import.
|
||||
if "from django.db import models" in imports:
|
||||
imports.discard("from django.db import models")
|
||||
imports.add("from django.db import migrations, models")
|
||||
else:
|
||||
imports.add("from django.db import migrations")
|
||||
|
||||
# Sort imports by the package / module to be imported (the part after
|
||||
# "from" in "from ... import ..." or after "import" in "import ...").
|
||||
sorted_imports = sorted(imports, key=lambda i: i.split()[1])
|
||||
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
|
||||
if migration_imports:
|
||||
items["imports"] += (
|
||||
"\n\n# Functions from the following migrations need manual "
|
||||
"copying.\n# Move them and any dependencies into this file, "
|
||||
"then update the\n# RunPython operations to refer to the local "
|
||||
"versions:\n# %s"
|
||||
) % "\n# ".join(sorted(migration_imports))
|
||||
# If there's a replaces, make a string for it
|
||||
if self.migration.replaces:
|
||||
items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
|
||||
# Hinting that goes into comment
|
||||
if self.include_header:
|
||||
items['migration_header'] = MIGRATION_HEADER_TEMPLATE % {
|
||||
'version': get_version(),
|
||||
'timestamp': now().strftime("%Y-%m-%d %H:%M"),
|
||||
}
|
||||
else:
|
||||
items['migration_header'] = ""
|
||||
|
||||
if self.migration.initial:
|
||||
items['initial_str'] = "\n initial = True\n"
|
||||
|
||||
return MIGRATION_TEMPLATE % items
|
||||
|
||||
@property
|
||||
def basedir(self):
|
||||
migrations_package_name, _ = MigrationLoader.migrations_module(self.migration.app_label)
|
||||
|
||||
if migrations_package_name is None:
|
||||
raise ValueError(
|
||||
"Django can't create migrations for app '%s' because "
|
||||
"migrations have been disabled via the MIGRATION_MODULES "
|
||||
"setting." % self.migration.app_label
|
||||
)
|
||||
|
||||
# See if we can import the migrations module directly
|
||||
try:
|
||||
migrations_module = import_module(migrations_package_name)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
return module_dir(migrations_module)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Alright, see if it's a direct submodule of the app
|
||||
app_config = apps.get_app_config(self.migration.app_label)
|
||||
maybe_app_name, _, migrations_package_basename = migrations_package_name.rpartition(".")
|
||||
if app_config.name == maybe_app_name:
|
||||
return os.path.join(app_config.path, migrations_package_basename)
|
||||
|
||||
# In case of using MIGRATION_MODULES setting and the custom package
|
||||
# doesn't exist, create one, starting from an existing package
|
||||
existing_dirs, missing_dirs = migrations_package_name.split("."), []
|
||||
while existing_dirs:
|
||||
missing_dirs.insert(0, existing_dirs.pop(-1))
|
||||
try:
|
||||
base_module = import_module(".".join(existing_dirs))
|
||||
except (ImportError, ValueError):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
base_dir = module_dir(base_module)
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not locate an appropriate location to create "
|
||||
"migrations package %s. Make sure the toplevel "
|
||||
"package exists and can be imported." %
|
||||
migrations_package_name)
|
||||
|
||||
final_dir = os.path.join(base_dir, *missing_dirs)
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
for missing_dir in missing_dirs:
|
||||
base_dir = os.path.join(base_dir, missing_dir)
|
||||
with open(os.path.join(base_dir, "__init__.py"), "w"):
|
||||
pass
|
||||
|
||||
return final_dir
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return "%s.py" % self.migration.name
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return os.path.join(self.basedir, self.filename)
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, value):
|
||||
return serializer_factory(value).serialize()
|
||||
|
||||
@classmethod
|
||||
def register_serializer(cls, type_, serializer):
|
||||
Serializer.register(type_, serializer)
|
||||
|
||||
@classmethod
|
||||
def unregister_serializer(cls, type_):
|
||||
Serializer.unregister(type_)
|
||||
|
||||
|
||||
MIGRATION_HEADER_TEMPLATE = """\
|
||||
# Generated by Django %(version)s on %(timestamp)s
|
||||
|
||||
"""
|
||||
|
||||
|
||||
MIGRATION_TEMPLATE = """\
|
||||
%(migration_header)s%(imports)s
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
%(replaces_str)s%(initial_str)s
|
||||
dependencies = [
|
||||
%(dependencies)s\
|
||||
]
|
||||
|
||||
operations = [
|
||||
%(operations)s\
|
||||
]
|
||||
"""
|
||||
Reference in New Issue
Block a user