Initial commit. Basic models mostly done.
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,583 @@
|
||||
"""
|
||||
SQLite backend for the sqlite3 module in the standard library.
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import functools
|
||||
import hashlib
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
import statistics
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from sqlite3 import dbapi2 as Database
|
||||
|
||||
import pytz
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import utils
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils import timezone
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.dateparse import parse_datetime, parse_time
|
||||
from django.utils.duration import duration_microseconds
|
||||
|
||||
from .client import DatabaseClient # isort:skip
|
||||
from .creation import DatabaseCreation # isort:skip
|
||||
from .features import DatabaseFeatures # isort:skip
|
||||
from .introspection import DatabaseIntrospection # isort:skip
|
||||
from .operations import DatabaseOperations # isort:skip
|
||||
from .schema import DatabaseSchemaEditor # isort:skip
|
||||
|
||||
|
||||
def decoder(conv_func):
|
||||
"""
|
||||
Convert bytestrings from Python's sqlite3 interface to a regular string.
|
||||
"""
|
||||
return lambda s: conv_func(s.decode())
|
||||
|
||||
|
||||
def none_guard(func):
|
||||
"""
|
||||
Decorator that returns None if any of the arguments to the decorated
|
||||
function are None. Many SQL functions return NULL if any of their arguments
|
||||
are NULL. This decorator simplifies the implementation of this for the
|
||||
custom functions registered below.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return None if None in args else func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def list_aggregate(function):
|
||||
"""
|
||||
Return an aggregate class that accumulates values in a list and applies
|
||||
the provided function to the data.
|
||||
"""
|
||||
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
|
||||
|
||||
|
||||
def check_sqlite_version():
|
||||
if Database.sqlite_version_info < (3, 8, 3):
|
||||
raise ImproperlyConfigured('SQLite 3.8.3 or later is required (found %s).' % Database.sqlite_version)
|
||||
|
||||
|
||||
check_sqlite_version()
|
||||
|
||||
Database.register_converter("bool", b'1'.__eq__)
|
||||
Database.register_converter("time", decoder(parse_time))
|
||||
Database.register_converter("datetime", decoder(parse_datetime))
|
||||
Database.register_converter("timestamp", decoder(parse_datetime))
|
||||
Database.register_converter("TIMESTAMP", decoder(parse_datetime))
|
||||
|
||||
Database.register_adapter(decimal.Decimal, str)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = 'sqlite'
|
||||
display_name = 'SQLite'
|
||||
# SQLite doesn't actually support most of these types, but it "does the right
|
||||
# thing" given more verbose field definitions, so leave them as is so that
|
||||
# schema inspection is more useful.
|
||||
data_types = {
|
||||
'AutoField': 'integer',
|
||||
'BigAutoField': 'integer',
|
||||
'BinaryField': 'BLOB',
|
||||
'BooleanField': 'bool',
|
||||
'CharField': 'varchar(%(max_length)s)',
|
||||
'DateField': 'date',
|
||||
'DateTimeField': 'datetime',
|
||||
'DecimalField': 'decimal',
|
||||
'DurationField': 'bigint',
|
||||
'FileField': 'varchar(%(max_length)s)',
|
||||
'FilePathField': 'varchar(%(max_length)s)',
|
||||
'FloatField': 'real',
|
||||
'IntegerField': 'integer',
|
||||
'BigIntegerField': 'bigint',
|
||||
'IPAddressField': 'char(15)',
|
||||
'GenericIPAddressField': 'char(39)',
|
||||
'NullBooleanField': 'bool',
|
||||
'OneToOneField': 'integer',
|
||||
'PositiveIntegerField': 'integer unsigned',
|
||||
'PositiveSmallIntegerField': 'smallint unsigned',
|
||||
'SlugField': 'varchar(%(max_length)s)',
|
||||
'SmallAutoField': 'integer',
|
||||
'SmallIntegerField': 'smallint',
|
||||
'TextField': 'text',
|
||||
'TimeField': 'time',
|
||||
'UUIDField': 'char(32)',
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
'PositiveIntegerField': '"%(column)s" >= 0',
|
||||
'PositiveSmallIntegerField': '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
'AutoField': 'AUTOINCREMENT',
|
||||
'BigAutoField': 'AUTOINCREMENT',
|
||||
'SmallAutoField': 'AUTOINCREMENT',
|
||||
}
|
||||
# SQLite requires LIKE statements to include an ESCAPE clause if the value
|
||||
# being escaped has a percent or underscore in it.
|
||||
# See https://www.sqlite.org/lang_expr.html for an explanation.
|
||||
operators = {
|
||||
'exact': '= %s',
|
||||
'iexact': "LIKE %s ESCAPE '\\'",
|
||||
'contains': "LIKE %s ESCAPE '\\'",
|
||||
'icontains': "LIKE %s ESCAPE '\\'",
|
||||
'regex': 'REGEXP %s',
|
||||
'iregex': "REGEXP '(?i)' || %s",
|
||||
'gt': '> %s',
|
||||
'gte': '>= %s',
|
||||
'lt': '< %s',
|
||||
'lte': '<= %s',
|
||||
'startswith': "LIKE %s ESCAPE '\\'",
|
||||
'endswith': "LIKE %s ESCAPE '\\'",
|
||||
'istartswith': "LIKE %s ESCAPE '\\'",
|
||||
'iendswith': "LIKE %s ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'",
|
||||
'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
|
||||
'startswith': r"LIKE {} || '%%' ESCAPE '\'",
|
||||
'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'",
|
||||
'endswith': r"LIKE '%%' || {} ESCAPE '\'",
|
||||
'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
if not settings_dict['NAME']:
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME value.")
|
||||
kwargs = {
|
||||
'database': settings_dict['NAME'],
|
||||
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
|
||||
**settings_dict['OPTIONS'],
|
||||
}
|
||||
# Always allow the underlying SQLite connection to be shareable
|
||||
# between multiple threads. The safe-guarding will be handled at a
|
||||
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
|
||||
# property. This is necessary as the shareability is disabled by
|
||||
# default in pysqlite and it cannot be changed once a connection is
|
||||
# opened.
|
||||
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
|
||||
warnings.warn(
|
||||
'The `check_same_thread` option was provided and set to '
|
||||
'True. It will be overridden with False. Use the '
|
||||
'`DatabaseWrapper.allow_thread_sharing` property instead '
|
||||
'for controlling thread shareability.',
|
||||
RuntimeWarning
|
||||
)
|
||||
kwargs.update({'check_same_thread': False, 'uri': True})
|
||||
return kwargs
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
conn = Database.connect(**conn_params)
|
||||
conn.create_function("django_date_extract", 2, _sqlite_datetime_extract)
|
||||
conn.create_function("django_date_trunc", 2, _sqlite_date_trunc)
|
||||
conn.create_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
|
||||
conn.create_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
|
||||
conn.create_function('django_datetime_extract', 4, _sqlite_datetime_extract)
|
||||
conn.create_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
|
||||
conn.create_function("django_time_extract", 2, _sqlite_time_extract)
|
||||
conn.create_function("django_time_trunc", 2, _sqlite_time_trunc)
|
||||
conn.create_function("django_time_diff", 2, _sqlite_time_diff)
|
||||
conn.create_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
|
||||
conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
|
||||
conn.create_function('regexp', 2, _sqlite_regexp)
|
||||
conn.create_function('ACOS', 1, none_guard(math.acos))
|
||||
conn.create_function('ASIN', 1, none_guard(math.asin))
|
||||
conn.create_function('ATAN', 1, none_guard(math.atan))
|
||||
conn.create_function('ATAN2', 2, none_guard(math.atan2))
|
||||
conn.create_function('CEILING', 1, none_guard(math.ceil))
|
||||
conn.create_function('COS', 1, none_guard(math.cos))
|
||||
conn.create_function('COT', 1, none_guard(lambda x: 1 / math.tan(x)))
|
||||
conn.create_function('DEGREES', 1, none_guard(math.degrees))
|
||||
conn.create_function('EXP', 1, none_guard(math.exp))
|
||||
conn.create_function('FLOOR', 1, none_guard(math.floor))
|
||||
conn.create_function('LN', 1, none_guard(math.log))
|
||||
conn.create_function('LOG', 2, none_guard(lambda x, y: math.log(y, x)))
|
||||
conn.create_function('LPAD', 3, _sqlite_lpad)
|
||||
conn.create_function('MD5', 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest()))
|
||||
conn.create_function('MOD', 2, none_guard(math.fmod))
|
||||
conn.create_function('PI', 0, lambda: math.pi)
|
||||
conn.create_function('POWER', 2, none_guard(operator.pow))
|
||||
conn.create_function('RADIANS', 1, none_guard(math.radians))
|
||||
conn.create_function('REPEAT', 2, none_guard(operator.mul))
|
||||
conn.create_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
|
||||
conn.create_function('RPAD', 3, _sqlite_rpad)
|
||||
conn.create_function('SHA1', 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest()))
|
||||
conn.create_function('SHA224', 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest()))
|
||||
conn.create_function('SHA256', 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest()))
|
||||
conn.create_function('SHA384', 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest()))
|
||||
conn.create_function('SHA512', 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest()))
|
||||
conn.create_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0)))
|
||||
conn.create_function('SIN', 1, none_guard(math.sin))
|
||||
conn.create_function('SQRT', 1, none_guard(math.sqrt))
|
||||
conn.create_function('TAN', 1, none_guard(math.tan))
|
||||
conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
|
||||
conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
|
||||
conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
|
||||
conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
|
||||
conn.execute('PRAGMA foreign_keys = ON')
|
||||
return conn
|
||||
|
||||
def init_connection_state(self):
|
||||
pass
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
return self.connection.cursor(factory=SQLiteCursorWrapper)
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
self.validate_thread_sharing()
|
||||
# If database is in memory, closing the connection destroys the
|
||||
# database. To prevent accidental data loss, ignore close requests on
|
||||
# an in-memory db.
|
||||
if not self.is_in_memory_db():
|
||||
BaseDatabaseWrapper.close(self)
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# When 'isolation_level' is not None, sqlite3 commits before each
|
||||
# savepoint; it's a bug. When it is None, savepoints don't make sense
|
||||
# because autocommit is enabled. The only exception is inside 'atomic'
|
||||
# blocks. To work around that bug, on SQLite, 'atomic' starts a
|
||||
# transaction explicitly rather than simply disable autocommit.
|
||||
return self.in_atomic_block
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
if autocommit:
|
||||
level = None
|
||||
else:
|
||||
# sqlite3's internal default is ''. It's different from None.
|
||||
# See Modules/_sqlite/connection.c.
|
||||
level = ''
|
||||
# 'isolation_level' is a misleading API.
|
||||
# SQLite always runs at the SERIALIZABLE isolation level.
|
||||
with self.wrap_database_errors:
|
||||
self.connection.isolation_level = level
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('PRAGMA foreign_keys = OFF')
|
||||
# Foreign key constraints cannot be turned off while in a multi-
|
||||
# statement transaction. Fetch the current state of the pragma
|
||||
# to determine if constraints are effectively disabled.
|
||||
enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0]
|
||||
return not bool(enabled)
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
self.cursor().execute('PRAGMA foreign_keys = ON')
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
if self.features.supports_pragma_foreign_key_check:
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
violations = self.cursor().execute('PRAGMA foreign_key_check').fetchall()
|
||||
else:
|
||||
violations = chain.from_iterable(
|
||||
cursor.execute('PRAGMA foreign_key_check(%s)' % table_name).fetchall()
|
||||
for table_name in table_names
|
||||
)
|
||||
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
|
||||
for table_name, rowid, referenced_table_name, foreign_key_index in violations:
|
||||
foreign_key = cursor.execute(
|
||||
'PRAGMA foreign_key_list(%s)' % table_name
|
||||
).fetchall()[foreign_key_index]
|
||||
column_name, referenced_column_name = foreign_key[3:5]
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
|
||||
primary_key_value, bad_value = cursor.execute(
|
||||
'SELECT %s, %s FROM %s WHERE rowid = %%s' % (
|
||||
primary_key_column_name, column_name, table_name
|
||||
),
|
||||
(rowid,),
|
||||
).fetchone()
|
||||
raise utils.IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s." % (
|
||||
table_name, primary_key_value, table_name, column_name,
|
||||
bad_value, referenced_table_name, referenced_column_name
|
||||
)
|
||||
)
|
||||
else:
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
table_names = self.introspection.table_names(cursor)
|
||||
for table_name in table_names:
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
|
||||
if not primary_key_column_name:
|
||||
continue
|
||||
key_columns = self.introspection.get_key_columns(cursor, table_name)
|
||||
for column_name, referenced_table_name, referenced_column_name in key_columns:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
|
||||
LEFT JOIN `%s` as REFERRED
|
||||
ON (REFERRING.`%s` = REFERRED.`%s`)
|
||||
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
|
||||
"""
|
||||
% (
|
||||
primary_key_column_name, column_name, table_name,
|
||||
referenced_table_name, column_name, referenced_column_name,
|
||||
column_name, referenced_column_name,
|
||||
)
|
||||
)
|
||||
for bad_row in cursor.fetchall():
|
||||
raise utils.IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s." % (
|
||||
table_name, bad_row[0], table_name, column_name,
|
||||
bad_row[1], referenced_table_name, referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
|
||||
def _start_transaction_under_autocommit(self):
|
||||
"""
|
||||
Start a transaction explicitly in autocommit mode.
|
||||
|
||||
Staying in autocommit mode works around a bug of sqlite3 that breaks
|
||||
savepoints when autocommit is disabled.
|
||||
"""
|
||||
self.cursor().execute("BEGIN")
|
||||
|
||||
def is_in_memory_db(self):
|
||||
return self.creation.is_in_memory_db(self.settings_dict['NAME'])
|
||||
|
||||
|
||||
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')
|
||||
|
||||
|
||||
class SQLiteCursorWrapper(Database.Cursor):
|
||||
"""
|
||||
Django uses "format" style placeholders, but pysqlite2 uses "qmark" style.
|
||||
This fixes it -- but note that if you want to use a literal "%s" in a query,
|
||||
you'll need to use "%%s".
|
||||
"""
|
||||
def execute(self, query, params=None):
|
||||
if params is None:
|
||||
return Database.Cursor.execute(self, query)
|
||||
query = self.convert_query(query)
|
||||
return Database.Cursor.execute(self, query, params)
|
||||
|
||||
def executemany(self, query, param_list):
|
||||
query = self.convert_query(query)
|
||||
return Database.Cursor.executemany(self, query, param_list)
|
||||
|
||||
def convert_query(self, query):
|
||||
return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
|
||||
|
||||
|
||||
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_timestamp(dt)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if conn_tzname:
|
||||
dt = dt.replace(tzinfo=pytz.timezone(conn_tzname))
|
||||
if tzname is not None and tzname != conn_tzname:
|
||||
sign_index = tzname.find('+') + tzname.find('-') + 1
|
||||
if sign_index > -1:
|
||||
sign = tzname[sign_index]
|
||||
tzname, offset = tzname.split(sign)
|
||||
if offset:
|
||||
hours, minutes = offset.split(':')
|
||||
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
|
||||
dt += offset_delta if sign == '+' else -offset_delta
|
||||
dt = timezone.localtime(dt, pytz.timezone(tzname))
|
||||
return dt
|
||||
|
||||
|
||||
def _sqlite_date_trunc(lookup_type, dt):
|
||||
dt = _sqlite_datetime_parse(dt)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'year':
|
||||
return "%i-01-01" % dt.year
|
||||
elif lookup_type == 'quarter':
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return '%i-%02i-01' % (dt.year, month_in_quarter)
|
||||
elif lookup_type == 'month':
|
||||
return "%i-%02i-01" % (dt.year, dt.month)
|
||||
elif lookup_type == 'week':
|
||||
dt = dt - datetime.timedelta(days=dt.weekday())
|
||||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == 'day':
|
||||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
|
||||
|
||||
def _sqlite_time_trunc(lookup_type, dt):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if lookup_type == 'hour':
|
||||
return "%02i:00:00" % dt.hour
|
||||
elif lookup_type == 'minute':
|
||||
return "%02i:%02i:00" % (dt.hour, dt.minute)
|
||||
elif lookup_type == 'second':
|
||||
return "%02i:%02i:%02i" % (dt.hour, dt.minute, dt.second)
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.date().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.time().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'week_day':
|
||||
return (dt.isoweekday() % 7) + 1
|
||||
elif lookup_type == 'week':
|
||||
return dt.isocalendar()[1]
|
||||
elif lookup_type == 'quarter':
|
||||
return math.ceil(dt.month / 3)
|
||||
elif lookup_type == 'iso_year':
|
||||
return dt.isocalendar()[0]
|
||||
else:
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'year':
|
||||
return "%i-01-01 00:00:00" % dt.year
|
||||
elif lookup_type == 'quarter':
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return '%i-%02i-01 00:00:00' % (dt.year, month_in_quarter)
|
||||
elif lookup_type == 'month':
|
||||
return "%i-%02i-01 00:00:00" % (dt.year, dt.month)
|
||||
elif lookup_type == 'week':
|
||||
dt = dt - datetime.timedelta(days=dt.weekday())
|
||||
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == 'day':
|
||||
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == 'hour':
|
||||
return "%i-%02i-%02i %02i:00:00" % (dt.year, dt.month, dt.day, dt.hour)
|
||||
elif lookup_type == 'minute':
|
||||
return "%i-%02i-%02i %02i:%02i:00" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)
|
||||
elif lookup_type == 'second':
|
||||
return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
|
||||
|
||||
|
||||
def _sqlite_time_extract(lookup_type, dt):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_format_dtdelta(conn, lhs, rhs):
|
||||
"""
|
||||
LHS and RHS can be either:
|
||||
- An integer number of microseconds
|
||||
- A string representing a datetime
|
||||
"""
|
||||
try:
|
||||
real_lhs = datetime.timedelta(0, 0, lhs) if isinstance(lhs, int) else backend_utils.typecast_timestamp(lhs)
|
||||
real_rhs = datetime.timedelta(0, 0, rhs) if isinstance(rhs, int) else backend_utils.typecast_timestamp(rhs)
|
||||
if conn.strip() == '+':
|
||||
out = real_lhs + real_rhs
|
||||
else:
|
||||
out = real_lhs - real_rhs
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
# typecast_timestamp returns a date or a datetime without timezone.
|
||||
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
||||
return str(out)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_time_diff(lhs, rhs):
|
||||
left = backend_utils.typecast_time(lhs)
|
||||
right = backend_utils.typecast_time(rhs)
|
||||
return (
|
||||
(left.hour * 60 * 60 * 1000000) +
|
||||
(left.minute * 60 * 1000000) +
|
||||
(left.second * 1000000) +
|
||||
(left.microsecond) -
|
||||
(right.hour * 60 * 60 * 1000000) -
|
||||
(right.minute * 60 * 1000000) -
|
||||
(right.second * 1000000) -
|
||||
(right.microsecond)
|
||||
)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_timestamp_diff(lhs, rhs):
|
||||
left = backend_utils.typecast_timestamp(lhs)
|
||||
right = backend_utils.typecast_timestamp(rhs)
|
||||
return duration_microseconds(left - right)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_regexp(re_pattern, re_string):
|
||||
return bool(re.search(re_pattern, str(re_string)))
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_lpad(text, length, fill_text):
|
||||
if len(text) >= length:
|
||||
return text[:length]
|
||||
return (fill_text * length)[:length - len(text)] + text
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_rpad(text, length, fill_text):
|
||||
return (text + fill_text * length)[:length]
|
||||
@@ -0,0 +1,12 @@
|
||||
import subprocess
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = 'sqlite3'
|
||||
|
||||
def runshell(self):
|
||||
args = [self.executable_name,
|
||||
self.connection.settings_dict['NAME']]
|
||||
subprocess.run(args, check=True)
|
||||
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
|
||||
@staticmethod
|
||||
def is_in_memory_db(database_name):
|
||||
return database_name == ':memory:' or 'mode=memory' in database_name
|
||||
|
||||
def _get_test_db_name(self):
|
||||
test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:'
|
||||
if test_database_name == ':memory:':
|
||||
return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias
|
||||
return test_database_name
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
if not self.is_in_memory_db(test_database_name):
|
||||
# Erase the old test database
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
))
|
||||
if os.access(test_database_name, os.F_OK):
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == 'yes':
|
||||
try:
|
||||
os.remove(test_database_name)
|
||||
except Exception as e:
|
||||
self.log('Got an error deleting the old test database: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log('Tests cancelled.')
|
||||
sys.exit(1)
|
||||
return test_database_name
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
source_database_name = orig_settings_dict['NAME']
|
||||
if self.is_in_memory_db(source_database_name):
|
||||
return orig_settings_dict
|
||||
else:
|
||||
root, ext = os.path.splitext(orig_settings_dict['NAME'])
|
||||
return {**orig_settings_dict, 'NAME': '{}_{}.{}'.format(root, suffix, ext)}
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict['NAME']
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
|
||||
# Forking automatically makes a copy of an in-memory database.
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
# Erase the old test database
|
||||
if os.access(target_database_name, os.F_OK):
|
||||
if keepdb:
|
||||
return
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, target_database_name),
|
||||
))
|
||||
try:
|
||||
os.remove(target_database_name)
|
||||
except Exception as e:
|
||||
self.log('Got an error deleting the old test database: %s' % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
shutil.copy(source_database_name, target_database_name)
|
||||
except Exception as e:
|
||||
self.log('Got an error cloning the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
if test_database_name and not self.is_in_memory_db(test_database_name):
|
||||
# Remove the SQLite database file
|
||||
os.remove(test_database_name)
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple that uniquely identifies a test database.
|
||||
|
||||
This takes into account the special cases of ":memory:" and "" for
|
||||
SQLite since the databases will be distinct despite having the same
|
||||
TEST NAME. See https://www.sqlite.org/inmemorydb.html
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
sig = [self.connection.settings_dict['NAME']]
|
||||
if self.is_in_memory_db(test_database_name):
|
||||
sig.append(self.connection.alias)
|
||||
return tuple(sig)
|
||||
@@ -0,0 +1,44 @@
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
# SQLite can read from a cursor since SQLite 3.6.5, subject to the caveat
|
||||
# that statements within a connection aren't isolated from each other. See
|
||||
# https://sqlite.org/isolation.html.
|
||||
can_use_chunked_reads = True
|
||||
test_db_allows_multiple_connections = False
|
||||
supports_unspecified_pk = True
|
||||
supports_timezones = False
|
||||
max_query_params = 999
|
||||
supports_mixed_date_datetime_comparisons = False
|
||||
can_introspect_autofield = True
|
||||
can_introspect_decimal_field = False
|
||||
can_introspect_duration_field = False
|
||||
can_introspect_positive_integer_field = True
|
||||
can_introspect_small_integer_field = True
|
||||
introspected_big_auto_field_type = 'AutoField'
|
||||
introspected_small_auto_field_type = 'AutoField'
|
||||
supports_transactions = True
|
||||
atomic_transactions = False
|
||||
can_rollback_ddl = True
|
||||
supports_atomic_references_rename = Database.sqlite_version_info >= (3, 26, 0)
|
||||
can_create_inline_fk = False
|
||||
supports_paramstyle_pyformat = False
|
||||
supports_sequence_reset = False
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
ignores_table_name_case = True
|
||||
supports_cast_with_precision = False
|
||||
time_cast_precision = 3
|
||||
can_release_savepoints = True
|
||||
# Is "ALTER TABLE ... RENAME COLUMN" supported?
|
||||
can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0)
|
||||
supports_parentheses_in_compound = False
|
||||
# Deferred constraint checks can be emulated on SQLite < 3.20 but not in a
|
||||
# reasonably performant way.
|
||||
supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0)
|
||||
can_defer_constraint_checks = supports_pragma_foreign_key_check
|
||||
supports_functions_in_partial_indexes = Database.sqlite_version_info >= (3, 15, 0)
|
||||
supports_over_clause = Database.sqlite_version_info >= (3, 25, 0)
|
||||
@@ -0,0 +1,417 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.db.backends.base.introspection import (
|
||||
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
||||
)
|
||||
from django.db.models.indexes import Index
|
||||
|
||||
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk',))
|
||||
|
||||
field_size_re = re.compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
|
||||
|
||||
|
||||
def get_field_size(name):
|
||||
""" Extract the size number from a "varchar(11)" type name """
|
||||
m = field_size_re.search(name)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
|
||||
# This light wrapper "fakes" a dictionary interface, because some SQLite data
|
||||
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
|
||||
# as a simple dictionary lookup.
|
||||
class FlexibleFieldLookupDict:
|
||||
# Maps SQL types to Django Field types. Some of the SQL types have multiple
|
||||
# entries here because SQLite allows for anything and doesn't normalize the
|
||||
# field type; it uses whatever was given.
|
||||
base_data_types_reverse = {
|
||||
'bool': 'BooleanField',
|
||||
'boolean': 'BooleanField',
|
||||
'smallint': 'SmallIntegerField',
|
||||
'smallint unsigned': 'PositiveSmallIntegerField',
|
||||
'smallinteger': 'SmallIntegerField',
|
||||
'int': 'IntegerField',
|
||||
'integer': 'IntegerField',
|
||||
'bigint': 'BigIntegerField',
|
||||
'integer unsigned': 'PositiveIntegerField',
|
||||
'decimal': 'DecimalField',
|
||||
'real': 'FloatField',
|
||||
'text': 'TextField',
|
||||
'char': 'CharField',
|
||||
'varchar': 'CharField',
|
||||
'blob': 'BinaryField',
|
||||
'date': 'DateField',
|
||||
'datetime': 'DateTimeField',
|
||||
'time': 'TimeField',
|
||||
}
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = key.lower().split('(', 1)[0].strip()
|
||||
return self.base_data_types_reverse[key]
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = FlexibleFieldLookupDict()
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}:
|
||||
# No support for BigAutoField or SmallAutoField as SQLite treats
|
||||
# all integer primary keys as signed 64-bit integers.
|
||||
return 'AutoField'
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
# Skip the sqlite_sequence system table used for autoincrement key
|
||||
# generation.
|
||||
cursor.execute("""
|
||||
SELECT name, type FROM sqlite_master
|
||||
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
|
||||
ORDER BY name""")
|
||||
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name))
|
||||
return [
|
||||
FieldInfo(
|
||||
name, data_type, None, get_field_size(data_type), None, None,
|
||||
not notnull, default, pk == 1,
|
||||
)
|
||||
for cid, name, data_type, notnull, default, pk in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||
return [{'table': table_name, 'column': pk_col}]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all relationships to the given table.
|
||||
"""
|
||||
# Dictionary of relations to return
|
||||
relations = {}
|
||||
|
||||
# Schema for this table
|
||||
cursor.execute(
|
||||
"SELECT sql, type FROM sqlite_master "
|
||||
"WHERE tbl_name = %s AND type IN ('table', 'view')",
|
||||
[table_name]
|
||||
)
|
||||
create_sql, table_type = cursor.fetchone()
|
||||
if table_type == 'view':
|
||||
# It might be a view, then no results will be returned
|
||||
return relations
|
||||
results = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
|
||||
|
||||
# Walk through and look for references to other tables. SQLite doesn't
|
||||
# really have enforced references, but since it echoes out the SQL used
|
||||
# to create the table we can look for REFERENCES statements used there.
|
||||
for field_desc in results.split(','):
|
||||
field_desc = field_desc.strip()
|
||||
if field_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
m = re.search(r'references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
|
||||
if not m:
|
||||
continue
|
||||
table, column = [s.strip('"') for s in m.groups()]
|
||||
|
||||
if field_desc.startswith("FOREIGN KEY"):
|
||||
# Find name of the target FK field
|
||||
m = re.match(r'FOREIGN KEY\s*\(([^\)]*)\).*', field_desc, re.I)
|
||||
field_name = m.groups()[0].strip('"')
|
||||
else:
|
||||
field_name = field_desc.split()[0].strip('"')
|
||||
|
||||
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
|
||||
result = cursor.fetchall()[0]
|
||||
other_table_results = result[0].strip()
|
||||
li, ri = other_table_results.index('('), other_table_results.rindex(')')
|
||||
other_table_results = other_table_results[li + 1:ri]
|
||||
|
||||
for other_desc in other_table_results.split(','):
|
||||
other_desc = other_desc.strip()
|
||||
if other_desc.startswith('UNIQUE'):
|
||||
continue
|
||||
|
||||
other_name = other_desc.split(' ', 1)[0].strip('"')
|
||||
if other_name == column:
|
||||
relations[field_name] = (other_name, table)
|
||||
break
|
||||
|
||||
return relations
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
"""
|
||||
Return a list of (column_name, referenced_table_name, referenced_column_name)
|
||||
for all key columns in given table.
|
||||
"""
|
||||
key_columns = []
|
||||
|
||||
# Schema for this table
|
||||
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
|
||||
results = cursor.fetchone()[0].strip()
|
||||
results = results[results.index('(') + 1:results.rindex(')')]
|
||||
|
||||
# Walk through and look for references to other tables. SQLite doesn't
|
||||
# really have enforced references, but since it echoes out the SQL used
|
||||
# to create the table we can look for REFERENCES statements used there.
|
||||
for field_index, field_desc in enumerate(results.split(',')):
|
||||
field_desc = field_desc.strip()
|
||||
if field_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
m = re.search(r'"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I)
|
||||
if not m:
|
||||
continue
|
||||
|
||||
# This will append (column_name, referenced_table_name, referenced_column_name) to key_columns
|
||||
key_columns.append(tuple(s.strip('"') for s in m.groups()))
|
||||
|
||||
return key_columns
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
"""Return the column name of the primary key for the given table."""
|
||||
# Don't use PRAGMA because that causes issues with some transactions
|
||||
cursor.execute(
|
||||
"SELECT sql, type FROM sqlite_master "
|
||||
"WHERE tbl_name = %s AND type IN ('table', 'view')",
|
||||
[table_name]
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise ValueError("Table %s does not exist" % table_name)
|
||||
create_sql, table_type = row
|
||||
if table_type == 'view':
|
||||
# Views don't have a primary key.
|
||||
return None
|
||||
fields_sql = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
|
||||
for field_desc in fields_sql.split(','):
|
||||
field_desc = field_desc.strip()
|
||||
m = re.match(r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc)
|
||||
if m:
|
||||
return m.group(1) if m.group(1) else m.group(2)
|
||||
return None
|
||||
|
||||
def _get_foreign_key_constraints(self, cursor, table_name):
|
||||
constraints = {}
|
||||
cursor.execute('PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name))
|
||||
for row in cursor.fetchall():
|
||||
# Remaining on_update/on_delete/match values are of no interest.
|
||||
id_, _, table, from_, to = row[:5]
|
||||
constraints['fk_%d' % id_] = {
|
||||
'columns': [from_],
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'foreign_key': (table, to),
|
||||
'check': False,
|
||||
'index': False,
|
||||
}
|
||||
return constraints
|
||||
|
||||
def _parse_column_or_constraint_definition(self, tokens, columns):
|
||||
token = None
|
||||
is_constraint_definition = None
|
||||
field_name = None
|
||||
constraint_name = None
|
||||
unique = False
|
||||
unique_columns = []
|
||||
check = False
|
||||
check_columns = []
|
||||
braces_deep = 0
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, '('):
|
||||
braces_deep += 1
|
||||
elif token.match(sqlparse.tokens.Punctuation, ')'):
|
||||
braces_deep -= 1
|
||||
if braces_deep < 0:
|
||||
# End of columns and constraints for table definition.
|
||||
break
|
||||
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
|
||||
# End of current column or constraint definition.
|
||||
break
|
||||
# Detect column or constraint definition by first token.
|
||||
if is_constraint_definition is None:
|
||||
is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
|
||||
if is_constraint_definition:
|
||||
continue
|
||||
if is_constraint_definition:
|
||||
# Detect constraint name by second token.
|
||||
if constraint_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
constraint_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
constraint_name = token.value[1:-1]
|
||||
# Start constraint columns parsing after UNIQUE keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
|
||||
unique = True
|
||||
unique_braces_deep = braces_deep
|
||||
elif unique:
|
||||
if unique_braces_deep == braces_deep:
|
||||
if unique_columns:
|
||||
# Stop constraint parsing.
|
||||
unique = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
unique_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
unique_columns.append(token.value[1:-1])
|
||||
else:
|
||||
# Detect field name by first token.
|
||||
if field_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
field_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
field_name = token.value[1:-1]
|
||||
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
|
||||
unique_columns = [field_name]
|
||||
# Start constraint columns parsing after CHECK keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
|
||||
check = True
|
||||
check_braces_deep = braces_deep
|
||||
elif check:
|
||||
if check_braces_deep == braces_deep:
|
||||
if check_columns:
|
||||
# Stop constraint parsing.
|
||||
check = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
if token.value in columns:
|
||||
check_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
if token.value[1:-1] in columns:
|
||||
check_columns.append(token.value[1:-1])
|
||||
unique_constraint = {
|
||||
'unique': True,
|
||||
'columns': unique_columns,
|
||||
'primary_key': False,
|
||||
'foreign_key': None,
|
||||
'check': False,
|
||||
'index': False,
|
||||
} if unique_columns else None
|
||||
check_constraint = {
|
||||
'check': True,
|
||||
'columns': check_columns,
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'foreign_key': None,
|
||||
'index': False,
|
||||
} if check_columns else None
|
||||
return constraint_name, unique_constraint, check_constraint, token
|
||||
|
||||
def _parse_table_constraints(self, sql, columns):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
unnamed_constrains_index = 0
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
# Go to columns and constraint definition
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, '('):
|
||||
break
|
||||
# Parse columns and constraint definition
|
||||
while True:
|
||||
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
|
||||
if unique:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = unique
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
|
||||
if check:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = check
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
|
||||
if end_token.match(sqlparse.tokens.Punctuation, ')'):
|
||||
break
|
||||
return constraints
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Find inline check constraints.
|
||||
try:
|
||||
table_schema = cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % (
|
||||
self.connection.ops.quote_name(table_name),
|
||||
)
|
||||
).fetchone()[0]
|
||||
except TypeError:
|
||||
# table_name is a view.
|
||||
pass
|
||||
else:
|
||||
columns = {info.name for info in self.get_table_description(cursor, table_name)}
|
||||
constraints.update(self._parse_table_constraints(table_schema, columns))
|
||||
|
||||
# Get the index info
|
||||
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
|
||||
for row in cursor.fetchall():
|
||||
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
|
||||
# columns. Discard last 2 columns if there.
|
||||
number, index, unique = row[:3]
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
# There's at most one row.
|
||||
sql, = cursor.fetchone() or (None,)
|
||||
# Inline constraints are already detected in
|
||||
# _parse_table_constraints(). The reasons to avoid fetching inline
|
||||
# constraints from `PRAGMA index_list` are:
|
||||
# - Inline constraints can have a different name and information
|
||||
# than what `PRAGMA index_list` gives.
|
||||
# - Not all inline constraints may appear in `PRAGMA index_list`.
|
||||
if not sql:
|
||||
# An inline constraint
|
||||
continue
|
||||
# Get the index info for that index
|
||||
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
|
||||
for index_rank, column_rank, column in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": [],
|
||||
"primary_key": False,
|
||||
"unique": bool(unique),
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
}
|
||||
constraints[index]['columns'].append(column)
|
||||
# Add type and column orders for indexes
|
||||
if constraints[index]['index'] and not constraints[index]['unique']:
|
||||
# SQLite doesn't support any index type other than b-tree
|
||||
constraints[index]['type'] = Index.suffix
|
||||
order_info = sql.split('(')[-1].split(')')[0].split(',')
|
||||
orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info]
|
||||
constraints[index]['orders'] = orders
|
||||
# Get the PK
|
||||
pk_column = self.get_primary_key_column(cursor, table_name)
|
||||
if pk_column:
|
||||
# SQLite doesn't actually give a name to the PK constraint,
|
||||
# so we invent one. This is fine, as the SQLite backend never
|
||||
# deletes PK constraints by name, as you can't delete constraints
|
||||
# in SQLite; we remake the table with a new PK instead.
|
||||
constraints["__primary__"] = {
|
||||
"columns": [pk_column],
|
||||
"primary_key": True,
|
||||
"unique": False, # It's not actually a unique constraint.
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
constraints.update(self._get_foreign_key_constraints(cursor, table_name))
|
||||
return constraints
|
||||
@@ -0,0 +1,335 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import utils
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models import aggregates, fields
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.duration import duration_microseconds
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = 'text'
|
||||
cast_data_types = {
|
||||
'DateField': 'TEXT',
|
||||
'DateTimeField': 'TEXT',
|
||||
}
|
||||
explain_prefix = 'EXPLAIN QUERY PLAN'
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
|
||||
999 variables per query.
|
||||
|
||||
If there's only a single field to insert, the limit is 500
|
||||
(SQLITE_MAX_COMPOUND_SELECT).
|
||||
"""
|
||||
if len(fields) == 1:
|
||||
return 500
|
||||
elif len(fields) > 1:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
else:
|
||||
return len(objs)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
|
||||
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
|
||||
if isinstance(expression, bad_aggregates):
|
||||
for expr in expression.get_source_expressions():
|
||||
try:
|
||||
output_field = expr.output_field
|
||||
except FieldError:
|
||||
# Not every subexpression has an output_field which is fine
|
||||
# to ignore.
|
||||
pass
|
||||
else:
|
||||
if isinstance(output_field, bad_fields):
|
||||
raise utils.NotSupportedError(
|
||||
'You cannot use Sum, Avg, StdDev, and Variance '
|
||||
'aggregations on date/time fields in sqlite3 '
|
||||
'since date/time is saved as text.'
|
||||
)
|
||||
if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
|
||||
raise utils.NotSupportedError(
|
||||
"SQLite doesn't support DISTINCT on aggregate functions "
|
||||
"accepting multiple arguments."
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
"""
|
||||
Support EXTRACT with a user-defined function django_date_extract()
|
||||
that's registered in connect(). Use single quotes because this is a
|
||||
string and could otherwise cause a collision with a field name.
|
||||
"""
|
||||
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def date_interval_sql(self, timedelta):
|
||||
return str(duration_microseconds(timedelta))
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name):
|
||||
return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name):
|
||||
return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if settings.USE_TZ:
|
||||
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
|
||||
return 'NULL', 'NULL'
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
return 'django_datetime_cast_date(%s, %s, %s)' % (
|
||||
field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
return 'django_datetime_cast_time(%s, %s, %s)' % (
|
||||
field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_extract('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def _quote_params_for_last_executed_query(self, params):
|
||||
"""
|
||||
Only for last_executed_query! Don't use this to execute SQL queries!
|
||||
"""
|
||||
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
|
||||
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
|
||||
# number of return values, default = 2000). Since Python's sqlite3
|
||||
# module doesn't expose the get_limit() C API, assume the default
|
||||
# limits are in effect and split the work in batches if needed.
|
||||
BATCH_SIZE = 999
|
||||
if len(params) > BATCH_SIZE:
|
||||
results = ()
|
||||
for index in range(0, len(params), BATCH_SIZE):
|
||||
chunk = params[index:index + BATCH_SIZE]
|
||||
results += self._quote_params_for_last_executed_query(chunk)
|
||||
return results
|
||||
|
||||
sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
|
||||
# Bypass Django's wrappers and use the underlying sqlite3 connection
|
||||
# to avoid logging this query - it would trigger infinite recursion.
|
||||
cursor = self.connection.connection.cursor()
|
||||
# Native sqlite3 cursors cannot be used as context managers.
|
||||
try:
|
||||
return cursor.execute(sql, params).fetchone()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
|
||||
# pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
|
||||
# Unfortunately there is no way to reach self->statement from Python,
|
||||
# so we quote and substitute parameters manually.
|
||||
if params:
|
||||
if isinstance(params, (list, tuple)):
|
||||
params = self._quote_params_for_last_executed_query(params)
|
||||
else:
|
||||
values = tuple(params.values())
|
||||
values = self._quote_params_for_last_executed_query(values)
|
||||
params = dict(zip(params, values))
|
||||
return sql % params
|
||||
# For consistency with SQLiteCursorWrapper.execute(), just return sql
|
||||
# when there are no parameters. See #13648 and #17158.
|
||||
else:
|
||||
return sql
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
def __references_graph(self, table_name):
|
||||
query = """
|
||||
WITH tables AS (
|
||||
SELECT %s name
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||
) SELECT name FROM tables;
|
||||
"""
|
||||
params = (
|
||||
table_name,
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
return [row[0] for row in results.fetchall()]
|
||||
|
||||
@cached_property
|
||||
def _references_graph(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__references_graph)
|
||||
|
||||
def sql_flush(self, style, tables, sequences, allow_cascade=False):
|
||||
if tables and allow_cascade:
|
||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||
# referencing the tables to be flushed.
|
||||
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
|
||||
# Note: No requirement for reset of auto-incremented indices (cf. other
|
||||
# sql_flush() implementations). Just return SQL at this point
|
||||
return ['%s %s %s;' % (
|
||||
style.SQL_KEYWORD('DELETE'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
style.SQL_FIELD(self.quote_name(table))
|
||||
) for table in tables]
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("SQLite backend does not support timezone-aware times.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == 'DateTimeField':
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == 'DateField':
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == 'TimeField':
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == 'DecimalField':
|
||||
converters.append(self.get_decimalfield_converter(expression))
|
||||
elif internal_type == 'UUIDField':
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
elif internal_type in ('NullBooleanField', 'BooleanField'):
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
return converters
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.datetime):
|
||||
value = parse_datetime(value)
|
||||
if settings.USE_TZ and not timezone.is_aware(value):
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.date):
|
||||
value = parse_date(value)
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.time):
|
||||
value = parse_time(value)
|
||||
return value
|
||||
|
||||
def get_decimalfield_converter(self, expression):
|
||||
# SQLite stores only 15 significant digits. Digits coming from
|
||||
# float inaccuracy must be removed.
|
||||
create_decimal = decimal.Context(prec=15).create_decimal_from_float
|
||||
if isinstance(expression, Col):
|
||||
quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
|
||||
else:
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value)
|
||||
return converter
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
return bool(value) if value in (1, 0) else value
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
return " UNION ALL ".join(
|
||||
"SELECT %s" % ", ".join(row)
|
||||
for row in placeholder_rows
|
||||
)
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||
# function that's registered in connect().
|
||||
if connector == '^':
|
||||
return 'POWER(%s)' % ','.join(sub_expressions)
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
if connector not in ['+', '-']:
|
||||
raise utils.DatabaseError('Invalid connector for timedelta: %s.' % connector)
|
||||
fn_params = ["'%s'" % connector] + sub_expressions
|
||||
if len(fn_params) > 3:
|
||||
raise ValueError('Too many params for timedelta operations.')
|
||||
return "django_format_dtdelta(%s)" % ', '.join(fn_params)
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
# SQLite doesn't enforce any integer constraints
|
||||
return (None, None)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
if internal_type == 'TimeField':
|
||||
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
|
||||
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
|
||||
@@ -0,0 +1,412 @@
|
||||
import copy
|
||||
from decimal import Decimal
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.backends.utils import strip_quotes
|
||||
from django.db.models import UniqueConstraint
|
||||
from django.db.transaction import atomic
|
||||
from django.db.utils import NotSupportedError
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
|
||||
sql_delete_table = "DROP TABLE %(table)s"
|
||||
sql_create_fk = None
|
||||
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
|
||||
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
|
||||
sql_delete_unique = "DROP INDEX %(name)s"
|
||||
|
||||
def __enter__(self):
|
||||
# Some SQLite schema alterations need foreign key constraints to be
|
||||
# disabled. Enforce it here for the duration of the schema edition.
|
||||
if not self.connection.disable_constraint_checking():
|
||||
raise NotSupportedError(
|
||||
'SQLite schema editor cannot be used while foreign key '
|
||||
'constraint checks are enabled. Make sure to disable them '
|
||||
'before entering a transaction.atomic() context because '
|
||||
'SQLite does not support disabling them in the middle of '
|
||||
'a multi-statement transaction.'
|
||||
)
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.check_constraints()
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
self.connection.enable_constraint_checking()
|
||||
|
||||
def quote_value(self, value):
|
||||
# The backend "mostly works" without this function and there are use
|
||||
# cases for compiling Python without the sqlite3 libraries (e.g.
|
||||
# security hardening).
|
||||
try:
|
||||
import sqlite3
|
||||
value = sqlite3.adapt(value)
|
||||
except ImportError:
|
||||
pass
|
||||
except sqlite3.ProgrammingError:
|
||||
pass
|
||||
# Manual emulation of SQLite parameter quoting
|
||||
if isinstance(value, bool):
|
||||
return str(int(value))
|
||||
elif isinstance(value, (Decimal, float, int)):
|
||||
return str(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("\'", "\'\'")
|
||||
elif value is None:
|
||||
return "NULL"
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# Bytes are only allowed for BLOB fields, encoded as string
|
||||
# literals containing hexadecimal data and preceded by a single "X"
|
||||
# character.
|
||||
return "X'%s'" % value.hex()
|
||||
else:
|
||||
raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))
|
||||
|
||||
def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False):
|
||||
"""
|
||||
Return whether or not the provided table name is referenced by another
|
||||
one. If `column_name` is specified, only references pointing to that
|
||||
column are considered. If `ignore_self` is True, self-referential
|
||||
constraints are ignored.
|
||||
"""
|
||||
with self.connection.cursor() as cursor:
|
||||
for other_table in self.connection.introspection.get_table_list(cursor):
|
||||
if ignore_self and other_table.name == table_name:
|
||||
continue
|
||||
constraints = self.connection.introspection._get_foreign_key_constraints(cursor, other_table.name)
|
||||
for constraint in constraints.values():
|
||||
constraint_table, constraint_column = constraint['foreign_key']
|
||||
if (constraint_table == table_name and
|
||||
(column_name is None or constraint_column == column_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
|
||||
if (not self.connection.features.supports_atomic_references_rename and
|
||||
disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError((
|
||||
'Renaming the %r table while in a transaction is not '
|
||||
'supported on SQLite < 3.26 because it would break referential '
|
||||
'integrity. Try adding `atomic = False` to the Migration class.'
|
||||
) % old_db_table)
|
||||
self.connection.enable_constraint_checking()
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
self.connection.disable_constraint_checking()
|
||||
else:
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
old_field_name = old_field.name
|
||||
table_name = model._meta.db_table
|
||||
_, old_column_name = old_field.get_attname_column()
|
||||
if (new_field.name != old_field_name and
|
||||
not self.connection.features.supports_atomic_references_rename and
|
||||
self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError((
|
||||
'Renaming the %r.%r column while in a transaction is not '
|
||||
'supported on SQLite < 3.26 because it would break referential '
|
||||
'integrity. Try adding `atomic = False` to the Migration class.'
|
||||
) % (model._meta.db_table, old_field_name))
|
||||
with atomic(self.connection.alias):
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
# Follow SQLite's documented procedure for performing changes
|
||||
# that don't affect the on-disk content.
|
||||
# https://sqlite.org/lang_altertable.html#otheralter
|
||||
with self.connection.cursor() as cursor:
|
||||
schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0]
|
||||
cursor.execute('PRAGMA writable_schema = 1')
|
||||
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
|
||||
new_column_name = new_field.get_attname_column()[1]
|
||||
search = references_template % old_column_name
|
||||
replacement = references_template % new_column_name
|
||||
cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement))
|
||||
cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1))
|
||||
cursor.execute('PRAGMA writable_schema = 0')
|
||||
# The integrity check will raise an exception and rollback
|
||||
# the transaction if the sqlite_master updates corrupt the
|
||||
# database.
|
||||
cursor.execute('PRAGMA integrity_check')
|
||||
# Perform a VACUUM to refresh the database representation from
|
||||
# the sqlite_master table.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute('VACUUM')
|
||||
else:
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
|
||||
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
|
||||
"""
|
||||
Shortcut to transform a model from old_model into new_model
|
||||
|
||||
This follows the correct procedure to perform non-rename or column
|
||||
addition operations based on SQLite's documentation
|
||||
|
||||
https://www.sqlite.org/lang_altertable.html#caution
|
||||
|
||||
The essential steps are:
|
||||
1. Create a table with the updated definition called "new__app_model"
|
||||
2. Copy the data from the existing "app_model" table to the new table
|
||||
3. Drop the "app_model" table
|
||||
4. Rename the "new__app_model" table to "app_model"
|
||||
5. Restore any index of the previous "app_model" table.
|
||||
"""
|
||||
# Self-referential fields must be recreated rather than copied from
|
||||
# the old model to ensure their remote_field.field_name doesn't refer
|
||||
# to an altered field.
|
||||
def is_self_referential(f):
|
||||
return f.is_relation and f.remote_field.model is model
|
||||
# Work out the new fields dict / mapping
|
||||
body = {
|
||||
f.name: f.clone() if is_self_referential(f) else f
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
# Since mapping might mix column names and default values,
|
||||
# its values must be already quoted.
|
||||
mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
|
||||
# This maps field names (not columns) for things like unique_together
|
||||
rename_mapping = {}
|
||||
# If any of the new or altered fields is introducing a new PK,
|
||||
# remove the old one
|
||||
restore_pk_field = None
|
||||
if getattr(create_field, 'primary_key', False) or (
|
||||
alter_field and getattr(alter_field[1], 'primary_key', False)):
|
||||
for name, field in list(body.items()):
|
||||
if field.primary_key:
|
||||
field.primary_key = False
|
||||
restore_pk_field = field
|
||||
if field.auto_created:
|
||||
del body[name]
|
||||
del mapping[field.column]
|
||||
# Add in any created fields
|
||||
if create_field:
|
||||
body[create_field.name] = create_field
|
||||
# Choose a default and insert it into the copy map
|
||||
if not create_field.many_to_many and create_field.concrete:
|
||||
mapping[create_field.column] = self.quote_value(
|
||||
self.effective_default(create_field)
|
||||
)
|
||||
# Add in any altered fields
|
||||
if alter_field:
|
||||
old_field, new_field = alter_field
|
||||
body.pop(old_field.name, None)
|
||||
mapping.pop(old_field.column, None)
|
||||
body[new_field.name] = new_field
|
||||
if old_field.null and not new_field.null:
|
||||
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
||||
'col': self.quote_name(old_field.column),
|
||||
'default': self.quote_value(self.effective_default(new_field))
|
||||
}
|
||||
mapping[new_field.column] = case_sql
|
||||
else:
|
||||
mapping[new_field.column] = self.quote_name(old_field.column)
|
||||
rename_mapping[old_field.name] = new_field.name
|
||||
# Remove any deleted fields
|
||||
if delete_field:
|
||||
del body[delete_field.name]
|
||||
del mapping[delete_field.column]
|
||||
# Remove any implicit M2M tables
|
||||
if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
|
||||
return self.delete_model(delete_field.remote_field.through)
|
||||
# Work inside a new app registry
|
||||
apps = Apps()
|
||||
|
||||
# Work out the new value of unique_together, taking renames into
|
||||
# account
|
||||
unique_together = [
|
||||
[rename_mapping.get(n, n) for n in unique]
|
||||
for unique in model._meta.unique_together
|
||||
]
|
||||
|
||||
# Work out the new value for index_together, taking renames into
|
||||
# account
|
||||
index_together = [
|
||||
[rename_mapping.get(n, n) for n in index]
|
||||
for index in model._meta.index_together
|
||||
]
|
||||
|
||||
indexes = model._meta.indexes
|
||||
if delete_field:
|
||||
indexes = [
|
||||
index for index in indexes
|
||||
if delete_field.name not in index.fields
|
||||
]
|
||||
|
||||
constraints = list(model._meta.constraints)
|
||||
|
||||
# Provide isolated instances of the fields to the new model body so
|
||||
# that the existing model's internals aren't interfered with when
|
||||
# the dummy model is constructed.
|
||||
body_copy = copy.deepcopy(body)
|
||||
|
||||
# Construct a new model with the new fields to allow self referential
|
||||
# primary key to resolve to. This model won't ever be materialized as a
|
||||
# table and solely exists for foreign key reference resolution purposes.
|
||||
# This wouldn't be required if the schema editor was operating on model
|
||||
# states instead of rendered models.
|
||||
meta_contents = {
|
||||
'app_label': model._meta.app_label,
|
||||
'db_table': model._meta.db_table,
|
||||
'unique_together': unique_together,
|
||||
'index_together': index_together,
|
||||
'indexes': indexes,
|
||||
'constraints': constraints,
|
||||
'apps': apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy['Meta'] = meta
|
||||
body_copy['__module__'] = model.__module__
|
||||
type(model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Construct a model with a renamed table name.
|
||||
body_copy = copy.deepcopy(body)
|
||||
meta_contents = {
|
||||
'app_label': model._meta.app_label,
|
||||
'db_table': 'new__%s' % strip_quotes(model._meta.db_table),
|
||||
'unique_together': unique_together,
|
||||
'index_together': index_together,
|
||||
'indexes': indexes,
|
||||
'constraints': constraints,
|
||||
'apps': apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy['Meta'] = meta
|
||||
body_copy['__module__'] = model.__module__
|
||||
new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Create a new table with the updated schema.
|
||||
self.create_model(new_model)
|
||||
|
||||
# Copy data from the old table into the new table
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
|
||||
self.quote_name(new_model._meta.db_table),
|
||||
', '.join(self.quote_name(x) for x in mapping),
|
||||
', '.join(mapping.values()),
|
||||
self.quote_name(model._meta.db_table),
|
||||
))
|
||||
|
||||
# Delete the old table to make way for the new
|
||||
self.delete_model(model, handle_autom2m=False)
|
||||
|
||||
# Rename the new table to take way for the old
|
||||
self.alter_db_table(
|
||||
new_model, new_model._meta.db_table, model._meta.db_table,
|
||||
disable_constraints=False,
|
||||
)
|
||||
|
||||
# Run deferred SQL on correct table
|
||||
for sql in self.deferred_sql:
|
||||
self.execute(sql)
|
||||
self.deferred_sql = []
|
||||
# Fix any PK-removed field
|
||||
if restore_pk_field:
|
||||
restore_pk_field.primary_key = True
|
||||
|
||||
def delete_model(self, model, handle_autom2m=True):
|
||||
if handle_autom2m:
|
||||
super().delete_model(model)
|
||||
else:
|
||||
# Delete the table (and only that)
|
||||
self.execute(self.sql_delete_table % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
})
|
||||
# Remove all deferred statements referencing the deleted table.
|
||||
for sql in list(self.deferred_sql):
|
||||
if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
|
||||
self.deferred_sql.remove(sql)
|
||||
|
||||
def add_field(self, model, field):
|
||||
"""
|
||||
Create a field on a model. Usually involves adding a column, but may
|
||||
involve adding a table instead (for M2M fields).
|
||||
"""
|
||||
# Special-case implicit M2M tables
|
||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||
return self.create_model(field.remote_field.through)
|
||||
self._remake_table(model, create_field=field)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
"""
|
||||
Remove a field from a model. Usually involves deleting a column,
|
||||
but for M2Ms may involve deleting a table.
|
||||
"""
|
||||
# M2M fields are a special case
|
||||
if field.many_to_many:
|
||||
# For implicit M2M tables, delete the auto-created table
|
||||
if field.remote_field.through._meta.auto_created:
|
||||
self.delete_model(field.remote_field.through)
|
||||
# For explicit "through" M2M fields, do nothing
|
||||
# For everything else, remake.
|
||||
else:
|
||||
# It might not actually have a column behind it
|
||||
if field.db_parameters(connection=self.connection)['type'] is None:
|
||||
return
|
||||
self._remake_table(model, delete_field=field)
|
||||
|
||||
def _alter_field(self, model, old_field, new_field, old_type, new_type,
|
||||
old_db_params, new_db_params, strict=False):
|
||||
"""Perform a "physical" (non-ManyToMany) field update."""
|
||||
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
|
||||
# changed and there aren't any constraints.
|
||||
if (self.connection.features.can_alter_table_rename_column and
|
||||
old_field.column != new_field.column and
|
||||
self.column_sql(model, old_field) == self.column_sql(model, new_field) and
|
||||
not (old_field.remote_field and old_field.db_constraint or
|
||||
new_field.remote_field and new_field.db_constraint)):
|
||||
return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
|
||||
# Alter by remaking table
|
||||
self._remake_table(model, alter_field=(old_field, new_field))
|
||||
# Rebuild tables with FKs pointing to this field if the PK type changed.
|
||||
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||
for rel in new_field.model._meta.related_objects:
|
||||
if not rel.many_to_many:
|
||||
self._remake_table(rel.related_model)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"""Alter M2Ms to repoint their to= endpoints."""
|
||||
if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table:
|
||||
# The field name didn't change, but some options did; we have to propagate this altering.
|
||||
self._remake_table(
|
||||
old_field.remote_field.through,
|
||||
alter_field=(
|
||||
# We need the field that points to the target model, so we can tell alter_field to change it -
|
||||
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
|
||||
old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
|
||||
new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# Make a new through table
|
||||
self.create_model(new_field.remote_field.through)
|
||||
# Copy the data across
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
|
||||
self.quote_name(new_field.remote_field.through._meta.db_table),
|
||||
', '.join([
|
||||
"id",
|
||||
new_field.m2m_column_name(),
|
||||
new_field.m2m_reverse_name(),
|
||||
]),
|
||||
', '.join([
|
||||
"id",
|
||||
old_field.m2m_column_name(),
|
||||
old_field.m2m_reverse_name(),
|
||||
]),
|
||||
self.quote_name(old_field.remote_field.through._meta.db_table),
|
||||
))
|
||||
# Delete the old through table
|
||||
self.delete_model(old_field.remote_field.through)
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and constraint.condition:
|
||||
super().add_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and constraint.condition:
|
||||
super().remove_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
Reference in New Issue
Block a user