Initial commit. Basic models mostly done.

This commit is contained in:
tcaxle
2020-04-11 13:03:48 +01:00
commit 840e3c86f9
5761 changed files with 650959 additions and 0 deletions

View File

@@ -0,0 +1,670 @@
import copy
import threading
import time
import warnings
from collections import deque
from contextlib import contextmanager
import _thread
import pytz
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS
from django.db.backends import utils
from django.db.backends.base.validation import BaseDatabaseValidation
from django.db.backends.signals import connection_created
from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError, DatabaseErrorWrapper
from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
NO_DB_ALIAS = '__no_db__'
class BaseDatabaseWrapper:
"""Represent a database connection."""
# Mapping of Field objects to their column types.
data_types = {}
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
data_types_suffix = {}
# Mapping of Field objects to their SQL for CHECK constraints.
data_type_check_constraints = {}
ops = None
vendor = 'unknown'
display_name = 'unknown'
SchemaEditorClass = None
# Classes instantiated in __init__().
client_class = None
creation_class = None
features_class = None
introspection_class = None
ops_class = None
validation_class = BaseDatabaseValidation
queries_limit = 9000
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
# Connection related attributes.
# The underlying database connection.
self.connection = None
# `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules.
self.settings_dict = settings_dict
self.alias = alias
# Query logging in debug mode or when explicitly enabled.
self.queries_log = deque(maxlen=self.queries_limit)
self.force_debug_cursor = False
# Transaction related attributes.
# Tracks if the connection is in autocommit mode. Per PEP 249, by
# default, it isn't.
self.autocommit = False
# Tracks if the connection is in a transaction managed by 'atomic'.
self.in_atomic_block = False
# Increment to generate unique savepoint ids.
self.savepoint_state = 0
# List of savepoints created by 'atomic'.
self.savepoint_ids = []
# Tracks if the outermost 'atomic' block should commit on exit,
# ie. if autocommit was active on entry.
self.commit_on_exit = True
# Tracks if the transaction should be rolled back to the next
# available savepoint because of an exception in an inner block.
self.needs_rollback = False
# Connection termination related attributes.
self.close_at = None
self.closed_in_transaction = False
self.errors_occurred = False
# Thread-safety related attributes.
self._thread_sharing_lock = threading.Lock()
self._thread_sharing_count = 0
self._thread_ident = _thread.get_ident()
# A list of no-argument functions to run when the transaction commits.
# Each entry is an (sids, func) tuple, where sids is a set of the
# active savepoint IDs when this function was registered.
self.run_on_commit = []
# Should we run the on-commit hooks the next time set_autocommit(True)
# is called?
self.run_commit_hooks_on_set_autocommit_on = False
# A stack of wrappers to be invoked around execute()/executemany()
# calls. Each entry is a function taking five arguments: execute, sql,
# params, many, and context. It's the function's responsibility to
# call execute(sql, params, many, context).
self.execute_wrappers = []
self.client = self.client_class(self)
self.creation = self.creation_class(self)
self.features = self.features_class(self)
self.introspection = self.introspection_class(self)
self.ops = self.ops_class(self)
self.validation = self.validation_class(self)
def ensure_timezone(self):
"""
Ensure the connection's timezone is set to `self.timezone_name` and
return whether it changed or not.
"""
return False
@cached_property
def timezone(self):
"""
Time zone for datetimes stored as naive values in the database.
Return a tzinfo object or None.
This is only needed when time zone support is enabled and the database
doesn't support time zones. (When the database supports time zones,
the adapter handles aware datetimes so Django doesn't need to.)
"""
if not settings.USE_TZ:
return None
elif self.features.supports_timezones:
return None
elif self.settings_dict['TIME_ZONE'] is None:
return timezone.utc
else:
return pytz.timezone(self.settings_dict['TIME_ZONE'])
@cached_property
def timezone_name(self):
"""
Name of the time zone of the database connection.
"""
if not settings.USE_TZ:
return settings.TIME_ZONE
elif self.settings_dict['TIME_ZONE'] is None:
return 'UTC'
else:
return self.settings_dict['TIME_ZONE']
@property
def queries_logged(self):
return self.force_debug_cursor or settings.DEBUG
@property
def queries(self):
if len(self.queries_log) == self.queries_log.maxlen:
warnings.warn(
"Limit for query logging exceeded, only the last {} queries "
"will be returned.".format(self.queries_log.maxlen))
return list(self.queries_log)
# ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
"""Return a dict of parameters suitable for get_new_connection."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_connection_params() method')
def get_new_connection(self, conn_params):
"""Open a connection to the database."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_new_connection() method')
def init_connection_state(self):
"""Initialize the database connection settings."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method')
def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method')
# ##### Backend-specific methods for creating connections #####
@async_unsafe
def connect(self):
"""Connect to the database. Assume that the connection is closed."""
# Check for invalid configurations.
self.check_settings()
# In case the previous connection was closed while in an atomic block
self.in_atomic_block = False
self.savepoint_ids = []
self.needs_rollback = False
# Reset parameters defining when to close the connection
max_age = self.settings_dict['CONN_MAX_AGE']
self.close_at = None if max_age is None else time.monotonic() + max_age
self.closed_in_transaction = False
self.errors_occurred = False
# Establish the connection
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict['AUTOCOMMIT'])
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
def check_settings(self):
if self.settings_dict['TIME_ZONE'] is not None:
if not settings.USE_TZ:
raise ImproperlyConfigured(
"Connection '%s' cannot set TIME_ZONE because USE_TZ is "
"False." % self.alias)
elif self.features.supports_timezones:
raise ImproperlyConfigured(
"Connection '%s' cannot set TIME_ZONE because its engine "
"handles time zones conversions natively." % self.alias)
@async_unsafe
def ensure_connection(self):
"""Guarantee that a connection to the database is established."""
if self.connection is None:
with self.wrap_database_errors:
self.connect()
# ##### Backend-specific wrappers for PEP-249 connection methods #####
def _prepare_cursor(self, cursor):
"""
Validate the connection is usable and perform database cursor wrapping.
"""
self.validate_thread_sharing()
if self.queries_logged:
wrapped_cursor = self.make_debug_cursor(cursor)
else:
wrapped_cursor = self.make_cursor(cursor)
return wrapped_cursor
def _cursor(self, name=None):
self.ensure_connection()
with self.wrap_database_errors:
return self._prepare_cursor(self.create_cursor(name))
def _commit(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.commit()
def _rollback(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.rollback()
def _close(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.close()
# ##### Generic wrappers for PEP-249 connection methods #####
@async_unsafe
def cursor(self):
"""Create a cursor, opening a connection if necessary."""
return self._cursor()
@async_unsafe
def commit(self):
"""Commit a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
self._commit()
# A successful commit means that the database connection works.
self.errors_occurred = False
self.run_commit_hooks_on_set_autocommit_on = True
@async_unsafe
def rollback(self):
"""Roll back a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
self._rollback()
# A successful rollback means that the database connection works.
self.errors_occurred = False
self.needs_rollback = False
self.run_on_commit = []
@async_unsafe
def close(self):
"""Close the connection to the database."""
self.validate_thread_sharing()
self.run_on_commit = []
# Don't call validate_no_atomic_block() to avoid making it difficult
# to get rid of a connection in an invalid state. The next connect()
# will reset the transaction state anyway.
if self.closed_in_transaction or self.connection is None:
return
try:
self._close()
finally:
if self.in_atomic_block:
self.closed_in_transaction = True
self.needs_rollback = True
else:
self.connection = None
# ##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
return self.features.uses_savepoints and not self.get_autocommit()
# ##### Generic savepoint management methods #####
@async_unsafe
def savepoint(self):
"""
Create a savepoint inside the current transaction. Return an
identifier for the savepoint that will be used for the subsequent
rollback or commit. Do nothing if savepoints are not supported.
"""
if not self._savepoint_allowed():
return
thread_ident = _thread.get_ident()
tid = str(thread_ident).replace('-', '')
self.savepoint_state += 1
sid = "s%s_x%d" % (tid, self.savepoint_state)
self.validate_thread_sharing()
self._savepoint(sid)
return sid
@async_unsafe
def savepoint_rollback(self, sid):
"""
Roll back to a savepoint. Do nothing if savepoints are not supported.
"""
if not self._savepoint_allowed():
return
self.validate_thread_sharing()
self._savepoint_rollback(sid)
# Remove any callbacks registered while this savepoint was active.
self.run_on_commit = [
(sids, func) for (sids, func) in self.run_on_commit if sid not in sids
]
@async_unsafe
def savepoint_commit(self, sid):
"""
Release a savepoint. Do nothing if savepoints are not supported.
"""
if not self._savepoint_allowed():
return
self.validate_thread_sharing()
self._savepoint_commit(sid)
@async_unsafe
def clean_savepoints(self):
"""
Reset the counter used to generate unique savepoint ids in this thread.
"""
self.savepoint_state = 0
# ##### Backend-specific transaction management methods #####
def _set_autocommit(self, autocommit):
"""
Backend-specific implementation to enable or disable autocommit.
"""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a _set_autocommit() method')
# ##### Generic transaction management methods #####
def get_autocommit(self):
"""Get the autocommit state."""
self.ensure_connection()
return self.autocommit
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
"""
Enable or disable autocommit.
The usual way to start a transaction is to turn autocommit off.
SQLite does not properly start a transaction when disabling
autocommit. To avoid this buggy behavior and to actually enter a new
transaction, an explicit BEGIN is required. Using
force_begin_transaction_with_broken_autocommit=True will issue an
explicit BEGIN with SQLite. This option will be ignored for other
backends.
"""
self.validate_no_atomic_block()
self.ensure_connection()
start_transaction_under_autocommit = (
force_begin_transaction_with_broken_autocommit and not autocommit and
hasattr(self, '_start_transaction_under_autocommit')
)
if start_transaction_under_autocommit:
self._start_transaction_under_autocommit()
else:
self._set_autocommit(autocommit)
self.autocommit = autocommit
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
self.run_and_clear_commit_hooks()
self.run_commit_hooks_on_set_autocommit_on = False
def get_rollback(self):
"""Get the "needs rollback" flag -- for *advanced use* only."""
if not self.in_atomic_block:
raise TransactionManagementError(
"The rollback flag doesn't work outside of an 'atomic' block.")
return self.needs_rollback
def set_rollback(self, rollback):
"""
Set or unset the "needs rollback" flag -- for *advanced use* only.
"""
if not self.in_atomic_block:
raise TransactionManagementError(
"The rollback flag doesn't work outside of an 'atomic' block.")
self.needs_rollback = rollback
def validate_no_atomic_block(self):
"""Raise an error if an atomic block is active."""
if self.in_atomic_block:
raise TransactionManagementError(
"This is forbidden when an 'atomic' block is active.")
def validate_no_broken_transaction(self):
if self.needs_rollback:
raise TransactionManagementError(
"An error occurred in the current transaction. You can't "
"execute queries until the end of the 'atomic' block.")
# ##### Foreign key constraints checks handling #####
@contextmanager
def constraint_checks_disabled(self):
"""
Disable foreign key constraint checking.
"""
disabled = self.disable_constraint_checking()
try:
yield
finally:
if disabled:
self.enable_constraint_checking()
def disable_constraint_checking(self):
"""
Backends can implement as needed to temporarily disable foreign key
constraint checking. Should return True if the constraints were
disabled and will need to be reenabled.
"""
return False
def enable_constraint_checking(self):
"""
Backends can implement as needed to re-enable foreign key constraint
checking.
"""
pass
def check_constraints(self, table_names=None):
"""
Backends can override this method if they can apply constraint
checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
IntegrityError if any invalid foreign key references are encountered.
"""
pass
# ##### Connection termination handling #####
def is_usable(self):
"""
Test if the database connection is usable.
This method may assume that self.connection is not None.
Actual implementations should take care not to raise exceptions
as that may prevent Django from recycling unusable connections.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an is_usable() method")
def close_if_unusable_or_obsolete(self):
"""
Close the current connection if unrecoverable errors have occurred
or if it outlived its maximum age.
"""
if self.connection is not None:
# If the application didn't restore the original autocommit setting,
# don't take chances, drop the connection.
if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']:
self.close()
return
# If an exception other than DataError or IntegrityError occurred
# since the last commit / rollback, check if the connection works.
if self.errors_occurred:
if self.is_usable():
self.errors_occurred = False
else:
self.close()
return
if self.close_at is not None and time.monotonic() >= self.close_at:
self.close()
return
# ##### Thread safety handling #####
@property
def allow_thread_sharing(self):
with self._thread_sharing_lock:
return self._thread_sharing_count > 0
def inc_thread_sharing(self):
with self._thread_sharing_lock:
self._thread_sharing_count += 1
def dec_thread_sharing(self):
with self._thread_sharing_lock:
if self._thread_sharing_count <= 0:
raise RuntimeError('Cannot decrement the thread sharing count below zero.')
self._thread_sharing_count -= 1
def validate_thread_sharing(self):
"""
Validate that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `inc_thread_sharing()`
method). Raise an exception if the validation fails.
"""
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
raise DatabaseError(
"DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, _thread.get_ident())
)
# ##### Miscellaneous #####
def prepare_database(self):
"""
Hook to do any database check or preparation, generally called before
migrating a project or an app.
"""
pass
@cached_property
def wrap_database_errors(self):
"""
Context manager and decorator that re-throws backend-specific database
exceptions using Django's common wrappers.
"""
return DatabaseErrorWrapper(self)
def chunked_cursor(self):
"""
Return a cursor that tries to avoid caching in the database (if
supported by the database), otherwise return a regular cursor.
"""
return self.cursor()
def make_debug_cursor(self, cursor):
"""Create a cursor that logs all queries in self.queries_log."""
return utils.CursorDebugWrapper(cursor, self)
def make_cursor(self, cursor):
"""Create a cursor without debug logging."""
return utils.CursorWrapper(cursor, self)
@contextmanager
def temporary_connection(self):
"""
Context manager that ensures that a connection is established, and
if it opened one, closes it to avoid leaving a dangling connection.
This is useful for operations outside of the request-response cycle.
Provide a cursor: with self.temporary_connection() as cursor: ...
"""
must_close = self.connection is None
try:
with self.cursor() as cursor:
yield cursor
finally:
if must_close:
self.close()
@property
def _nodb_connection(self):
"""
Return an alternative connection to be used when there is no need to
access the main database, specifically for test db creation/deletion.
This also prevents the production database from being exposed to
potential child threads while (or after) the test database is destroyed.
Refs #10868, #17786, #16969.
"""
return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
def schema_editor(self, *args, **kwargs):
"""
Return a new instance of this backend's SchemaEditor.
"""
if self.SchemaEditorClass is None:
raise NotImplementedError(
'The SchemaEditorClass attribute of this database wrapper is still None')
return self.SchemaEditorClass(self, *args, **kwargs)
def on_commit(self, func):
if self.in_atomic_block:
# Transaction in progress; save for execution on commit.
self.run_on_commit.append((set(self.savepoint_ids), func))
elif not self.get_autocommit():
raise TransactionManagementError('on_commit() cannot be used in manual transaction management')
else:
# No transaction in progress and in autocommit mode; execute
# immediately.
func()
def run_and_clear_commit_hooks(self):
self.validate_no_atomic_block()
current_run_on_commit = self.run_on_commit
self.run_on_commit = []
while current_run_on_commit:
sids, func = current_run_on_commit.pop(0)
func()
@contextmanager
def execute_wrapper(self, wrapper):
"""
Return a context manager under which the wrapper is applied to suitable
database query executions.
"""
self.execute_wrappers.append(wrapper)
try:
yield
finally:
self.execute_wrappers.pop()
def copy(self, alias=None):
"""
Return a copy of this connection.
For tests that require two connections to the same database.
"""
settings_dict = copy.deepcopy(self.settings_dict)
if alias is None:
alias = self.alias
return type(self)(settings_dict, alias)

View File

@@ -0,0 +1,12 @@
class BaseDatabaseClient:
"""Encapsulate backend-specific methods for opening a client shell."""
# This should be a string representing the name of the executable
# (e.g., "psql"). Subclasses must override this.
executable_name = None
def __init__(self, connection):
# connection is an instance of BaseDatabaseWrapper.
self.connection = connection
def runshell(self):
raise NotImplementedError('subclasses of BaseDatabaseClient must provide a runshell() method')

View File

@@ -0,0 +1,296 @@
import os
import sys
from io import StringIO
from django.apps import apps
from django.conf import settings
from django.core import serializers
from django.db import router
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
class BaseDatabaseCreation:
"""
Encapsulate backend-specific differences pertaining to creation and
destruction of the test database.
"""
def __init__(self, connection):
self.connection = connection
@property
def _nodb_connection(self):
"""
Used to be defined here, now moved to DatabaseWrapper.
"""
return self.connection._nodb_connection
def log(self, msg):
sys.stderr.write(msg + os.linesep)
def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False):
"""
Create a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
"""
# Don't import django.core.management if it isn't needed.
from django.core.management import call_command
test_database_name = self._get_test_db_name()
if verbosity >= 1:
action = 'Creating'
if keepdb:
action = "Using existing"
self.log('%s test database for alias %s...' % (
action,
self._get_database_display_str(verbosity, test_database_name),
))
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. This is to handle the case
# where the test DB doesn't exist, in which case we need to
# create it, then just not destroy it. If we instead skip
# this, we will get an exception.
self._create_test_db(verbosity, autoclobber, keepdb)
self.connection.close()
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
self.connection.settings_dict["NAME"] = test_database_name
# We report migrate messages at one level lower than that requested.
# This ensures we don't get flooded with messages during testing
# (unless you really ask to be flooded).
call_command(
'migrate',
verbosity=max(verbosity - 1, 0),
interactive=False,
database=self.connection.alias,
run_syncdb=True,
)
# We then serialize the current state of the database into a string
# and store it on the connection. This slightly horrific process is so people
# who are testing on databases without transactions or who are using
# a TransactionTestCase still get a clean database on every test run.
if serialize:
self.connection._test_serialized_contents = self.serialize_db_to_string()
call_command('createcachetable', database=self.connection.alias)
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name
def set_as_test_mirror(self, primary_settings_dict):
"""
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
self.connection.settings_dict['NAME'] = primary_settings_dict['NAME']
def serialize_db_to_string(self):
"""
Serialize all data in the database into a JSON string.
Designed only for test runner usage; will not handle large
amounts of data.
"""
# Build list of all apps to serialize
from django.db.migrations.loader import MigrationLoader
loader = MigrationLoader(self.connection)
app_list = []
for app_config in apps.get_app_configs():
if (
app_config.models_module is not None and
app_config.label in loader.migrated_apps and
app_config.name not in settings.TEST_NON_SERIALIZED_APPS
):
app_list.append((app_config, None))
# Make a function to iteratively return every object
def get_objects():
for model in serializers.sort_dependencies(app_list):
if (model._meta.can_migrate(self.connection) and
router.allow_migrate_model(self.connection.alias, model)):
queryset = model._default_manager.using(self.connection.alias).order_by(model._meta.pk.name)
yield from queryset.iterator()
# Serialize to a string
out = StringIO()
serializers.serialize("json", get_objects(), indent=None, stream=out)
return out.getvalue()
def deserialize_db_from_string(self, data):
"""
Reload the database with data from a string generated by
the serialize_db_to_string() method.
"""
data = StringIO(data)
for obj in serializers.deserialize("json", data, using=self.connection.alias):
obj.save()
def _get_database_display_str(self, verbosity, database_name):
"""
Return display string for a database for use in various actions.
"""
return "'%s'%s" % (
self.connection.alias,
(" ('%s')" % database_name) if verbosity >= 2 else '',
)
def _get_test_db_name(self):
"""
Internal implementation - return the name of the test DB that will be
created. Only useful when called from create_test_db() and
_create_test_db() and when no external munging is done with the 'NAME'
settings.
"""
if self.connection.settings_dict['TEST']['NAME']:
return self.connection.settings_dict['TEST']['NAME']
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
cursor.execute('CREATE DATABASE %(dbname)s %(suffix)s' % parameters)
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
"""
Internal implementation - create the test db tables.
"""
test_database_name = self._get_test_db_name()
test_db_params = {
'dbname': self.connection.ops.quote_name(test_database_name),
'suffix': self.sql_table_creation_suffix(),
}
# Create the test database and connect to it.
with self._nodb_connection.cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
# if we want to keep the db, then no need to do any of the below,
# just return and skip it all.
if keepdb:
return test_database_name
self.log('Got an error creating the test database: %s' % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
return test_database_name
def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
"""
Clone a test database.
"""
source_database_name = self.connection.settings_dict['NAME']
if verbosity >= 1:
action = 'Cloning test database'
if keepdb:
action = 'Using existing clone'
self.log('%s for alias %s...' % (
action,
self._get_database_display_str(verbosity, source_database_name),
))
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. See create_test_db for details.
self._clone_test_db(suffix, verbosity, keepdb)
def get_test_db_clone_settings(self, suffix):
"""
Return a modified connection settings dict for the n-th clone of a DB.
"""
# When this function is called, the test database has been created
# already and its name has been copied to settings_dict['NAME'] so
# we don't need to call _get_test_db_name.
orig_settings_dict = self.connection.settings_dict
return {**orig_settings_dict, 'NAME': '{}_{}'.format(orig_settings_dict['NAME'], suffix)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
"""
Internal implementation - duplicate the test db tables.
"""
raise NotImplementedError(
"The database backend doesn't support cloning databases. "
"Disable the option to run tests in parallel processes.")
def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, suffix=None):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists.
"""
self.connection.close()
if suffix is None:
test_database_name = self.connection.settings_dict['NAME']
else:
test_database_name = self.get_test_db_clone_settings(suffix)['NAME']
if verbosity >= 1:
action = 'Destroying'
if keepdb:
action = 'Preserving'
self.log('%s test database for alias %s...' % (
action,
self._get_database_display_str(verbosity, test_database_name),
))
# if we want to preserve the database
# skip the actual destroying piece.
if not keepdb:
self._destroy_test_db(test_database_name, verbosity)
# Restore the original database name
if old_database_name is not None:
settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
self.connection.settings_dict["NAME"] = old_database_name
def _destroy_test_db(self, test_database_name, verbosity):
"""
Internal implementation - remove the test db tables.
"""
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
with self.connection._nodb_connection.cursor() as cursor:
cursor.execute("DROP DATABASE %s"
% self.connection.ops.quote_name(test_database_name))
def sql_table_creation_suffix(self):
"""
SQL to append to the end of the test table creation statements.
"""
return ''
def test_db_signature(self):
"""
Return a tuple with elements of self.connection.settings_dict (a
DATABASES setting value) that uniquely identify a database
accordingly to the RDBMS particularities.
"""
settings_dict = self.connection.settings_dict
return (
settings_dict['HOST'],
settings_dict['PORT'],
settings_dict['ENGINE'],
self._get_test_db_name(),
)

View File

@@ -0,0 +1,320 @@
from django.db.utils import ProgrammingError
from django.utils.functional import cached_property
class BaseDatabaseFeatures:
gis_enabled = False
allows_group_by_pk = False
allows_group_by_selected_pks = False
empty_fetchmany_value = []
update_can_self_select = True
# Does the backend distinguish between '' and None?
interprets_empty_strings_as_nulls = False
# Does the backend allow inserting duplicate NULL rows in a nullable
# unique field? All core backends implement this correctly, but other
# databases such as SQL Server do not.
supports_nullable_unique_constraints = True
# Does the backend allow inserting duplicate rows when a unique_together
# constraint exists and some fields are nullable but not all of them?
supports_partially_nullable_unique_constraints = True
can_use_chunked_reads = True
can_return_columns_from_insert = False
can_return_multiple_columns_from_insert = False
can_return_rows_from_bulk_insert = False
has_bulk_insert = True
uses_savepoints = True
can_release_savepoints = False
# If True, don't use integer foreign keys referring to, e.g., positive
# integer primary keys.
related_fields_match_type = False
allow_sliced_subqueries_with_in = True
has_select_for_update = False
has_select_for_update_nowait = False
has_select_for_update_skip_locked = False
has_select_for_update_of = False
# Does the database's SELECT FOR UPDATE OF syntax require a column rather
# than a table?
select_for_update_of_column = False
# Does the default test database allow multiple connections?
# Usually an indication that the test database is in-memory
test_db_allows_multiple_connections = True
# Can an object be saved without an explicit primary key?
supports_unspecified_pk = False
# Can a fixture contain forward references? i.e., are
# FK constraints checked at the end of transaction, or
# at the end of each save operation?
supports_forward_references = True
# Does the backend truncate names properly when they are too long?
truncates_names = False
# Is there a REAL datatype in addition to floats/doubles?
has_real_datatype = False
supports_subqueries_in_group_by = True
# Is there a true datatype for uuid?
has_native_uuid_field = False
# Is there a true datatype for timedeltas?
has_native_duration_field = False
# Does the database driver supports same type temporal data subtraction
# by returning the type used to store duration field?
supports_temporal_subtraction = False
# Does the __regex lookup support backreferencing and grouping?
supports_regex_backreferencing = True
# Can date/datetime lookups be performed using a string?
supports_date_lookup_using_string = True
# Can datetimes with timezones be used?
supports_timezones = True
# Does the database have a copy of the zoneinfo database?
has_zoneinfo_database = True
# When performing a GROUP BY, is an ORDER BY NULL required
# to remove any ordering?
requires_explicit_null_ordering_when_grouping = False
# Does the backend order NULL values as largest or smallest?
nulls_order_largest = False
# The database's limit on the number of query parameters.
max_query_params = None
# Can an object have an autoincrement primary key of 0? MySQL says No.
allows_auto_pk_0 = True
# Do we need to NULL a ForeignKey out, or can the constraint check be
# deferred
can_defer_constraint_checks = False
# date_interval_sql can properly handle mixed Date/DateTime fields and timedeltas
supports_mixed_date_datetime_comparisons = True
# Does the backend support tablespaces? Default to False because it isn't
# in the SQL standard.
supports_tablespaces = False
# Does the backend reset sequences between tests?
supports_sequence_reset = True
# Can the backend introspect the default value of a column?
can_introspect_default = True
# Confirm support for introspected foreign keys
# Every database can do this reliably, except MySQL,
# which can't do it for MyISAM tables
can_introspect_foreign_keys = True
# Can the backend introspect an AutoField, instead of an IntegerField?
can_introspect_autofield = False
# Can the backend introspect a BigIntegerField, instead of an IntegerField?
can_introspect_big_integer_field = True
# Can the backend introspect an BinaryField, instead of an TextField?
can_introspect_binary_field = True
# Can the backend introspect an DecimalField, instead of an FloatField?
can_introspect_decimal_field = True
# Can the backend introspect a DurationField, instead of a BigIntegerField?
can_introspect_duration_field = True
# Can the backend introspect an IPAddressField, instead of an CharField?
can_introspect_ip_address_field = False
# Can the backend introspect a PositiveIntegerField, instead of an IntegerField?
can_introspect_positive_integer_field = False
# Can the backend introspect a SmallIntegerField, instead of an IntegerField?
can_introspect_small_integer_field = False
# Can the backend introspect a TimeField, instead of a DateTimeField?
can_introspect_time_field = True
# Some backends may not be able to differentiate BigAutoField or
# SmallAutoField from other fields such as AutoField.
introspected_big_auto_field_type = 'BigAutoField'
introspected_small_auto_field_type = 'SmallAutoField'
# Some backends may not be able to differentiate BooleanField from other
# fields such as IntegerField.
introspected_boolean_field_type = 'BooleanField'
# Can the backend introspect the column order (ASC/DESC) for indexes?
supports_index_column_ordering = True
# Does the backend support introspection of materialized views?
can_introspect_materialized_views = False
# Support for the DISTINCT ON clause
can_distinct_on_fields = False
# Does the backend prevent running SQL queries in broken transactions?
atomic_transactions = True
# Can we roll back DDL in a transaction?
can_rollback_ddl = False
# Does it support operations requiring references rename in a transaction?
supports_atomic_references_rename = True
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
supports_combined_alters = False
# Does it support foreign keys?
supports_foreign_keys = True
# Can it create foreign key constraints inline when adding columns?
can_create_inline_fk = True
# Does it support CHECK constraints?
supports_column_check_constraints = True
supports_table_check_constraints = True
# Does the backend support introspection of CHECK constraints?
can_introspect_check_constraints = True
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
# parameter passing? Note this can be provided by the backend even if not
# supported by the Python driver
supports_paramstyle_pyformat = True
# Does the backend require literal defaults, rather than parameterized ones?
requires_literal_defaults = False
# Does the backend require a connection reset after each material schema change?
connection_persists_old_columns = False
# What kind of error does the backend throw when accessing closed cursor?
closed_cursor_error_class = ProgrammingError
# Does 'a' LIKE 'A' match?
has_case_insensitive_like = True
# Suffix for backends that don't support "SELECT xxx;" queries.
bare_select_suffix = ''
# If NULL is implied on columns without needing to be explicitly specified
implied_column_null = False
# Does the backend support "select for update" queries with limit (and offset)?
supports_select_for_update_with_limit = True
# Does the backend ignore null expressions in GREATEST and LEAST queries unless
# every expression is null?
greatest_least_ignores_nulls = False
# Can the backend clone databases for parallel test execution?
# Defaults to False to allow third-party backends to opt-in.
can_clone_databases = False
# Does the backend consider table names with different casing to
# be equal?
ignores_table_name_case = False
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
for_update_after_from = False
# Combinatorial flags
supports_select_union = True
supports_select_intersection = True
supports_select_difference = True
supports_slicing_ordering_in_compound = False
supports_parentheses_in_compound = True
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
# expressions?
supports_aggregate_filter_clause = False
# Does the backend support indexing a TextField?
supports_index_on_text_field = True
# Does the backend support window expressions (expression OVER (...))?
supports_over_clause = False
supports_frame_range_fixed_distance = False
# Does the backend support CAST with precision?
supports_cast_with_precision = True
# How many second decimals does the database return when casting a value to
# a type with time?
time_cast_precision = 6
# SQL to create a procedure for use by the Django test suite. The
# functionality of the procedure isn't important.
create_test_procedure_without_params_sql = None
create_test_procedure_with_int_param_sql = None
# Does the backend support keyword parameters for cursor.callproc()?
supports_callproc_kwargs = False
# Convert CharField results from bytes to str in database functions.
db_functions_convert_bytes_to_str = False
# What formats does the backend EXPLAIN syntax support?
supported_explain_formats = set()
# Does DatabaseOperations.explain_query_prefix() raise ValueError if
# unknown kwargs are passed to QuerySet.explain()?
validates_explain_options = True
# Does the backend support the default parameter in lead() and lag()?
supports_default_in_lead_lag = True
# Does the backend support ignoring constraint or uniqueness errors during
# INSERT?
supports_ignore_conflicts = True
# Does this backend require casting the results of CASE expressions used
# in UPDATE statements to ensure the expression has the correct type?
requires_casted_case_in_updates = False
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
supports_partial_indexes = True
supports_functions_in_partial_indexes = True
# Does the database allow more than one constraint or index on the same
# field(s)?
allows_multiple_constraints_on_same_fields = True
# Does the backend support boolean expressions in the SELECT clause?
supports_boolean_expr_in_select_clause = True
def __init__(self, connection):
self.connection = connection
@cached_property
def supports_explaining_query_execution(self):
"""Does this backend support explaining query execution?"""
return self.connection.ops.explain_prefix is not None
@cached_property
def supports_transactions(self):
"""Confirm support for transactions."""
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.set_autocommit(False)
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
self.connection.rollback()
self.connection.set_autocommit(True)
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST')
return count == 0
def allows_group_by_selected_pks_on_model(self, model):
if not self.allows_group_by_selected_pks:
return False
return model._meta.managed

View File

@@ -0,0 +1,169 @@
from collections import namedtuple
# Structure returned by DatabaseIntrospection.get_table_list()
TableInfo = namedtuple('TableInfo', ['name', 'type'])
# Structure returned by the DB-API cursor.description interface (PEP 249)
FieldInfo = namedtuple('FieldInfo', 'name type_code display_size internal_size precision scale null_ok default')
class BaseDatabaseIntrospection:
"""Encapsulate backend-specific introspection utilities."""
data_types_reverse = {}
def __init__(self, connection):
self.connection = connection
def get_field_type(self, data_type, description):
"""
Hook for a database backend to use the cursor description to
match a Django field type to a database column.
For Oracle, the column data_type on its own is insufficient to
distinguish between a FloatField and IntegerField, for example.
"""
return self.data_types_reverse[data_type]
def identifier_converter(self, name):
"""
Apply a conversion to the identifier for the purposes of comparison.
The default identifier converter is for case sensitive comparison.
"""
return name
def table_names(self, cursor=None, include_views=False):
"""
Return a list of names of all tables that exist in the database.
Sort the returned table list by Python's default sorting. Do NOT use
the database's ORDER BY here to avoid subtle differences in sorting
order between databases.
"""
def get_names(cursor):
return sorted(ti.name for ti in self.get_table_list(cursor)
if include_views or ti.type == 't')
if cursor is None:
with self.connection.cursor() as cursor:
return get_names(cursor)
return get_names(cursor)
def get_table_list(self, cursor):
"""
Return an unsorted list of TableInfo named tuples of all tables and
views that exist in the database.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method')
def get_migratable_models(self):
from django.apps import apps
from django.db import router
return (
model
for app_config in apps.get_app_configs()
for model in router.get_migratable_models(app_config, self.connection.alias)
if model._meta.can_migrate(self.connection)
)
def django_table_names(self, only_existing=False, include_views=True):
"""
Return a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, include only the tables in the database.
"""
tables = set()
for model in self.get_migratable_models():
if not model._meta.managed:
continue
tables.add(model._meta.db_table)
tables.update(
f.m2m_db_table() for f in model._meta.local_many_to_many
if f.remote_field.through._meta.managed
)
tables = list(tables)
if only_existing:
existing_tables = set(self.table_names(include_views=include_views))
tables = [
t
for t in tables
if self.identifier_converter(t) in existing_tables
]
return tables
def installed_models(self, tables):
"""
Return a set of all models represented by the provided list of table
names.
"""
tables = set(map(self.identifier_converter, tables))
return {
m for m in self.get_migratable_models()
if self.identifier_converter(m._meta.db_table) in tables
}
def sequence_list(self):
"""
Return a list of information about all DB sequences for all models in
all apps.
"""
sequence_list = []
with self.connection.cursor() as cursor:
for model in self.get_migratable_models():
if not model._meta.managed:
continue
if model._meta.swapped:
continue
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
for f in model._meta.local_many_to_many:
# If this is an m2m using an intermediate table,
# we don't need to reset the sequence.
if f.remote_field.through._meta.auto_created:
sequence = self.get_sequences(cursor, f.m2m_db_table())
sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()):
"""
Return a list of introspected sequences for table_name. Each sequence
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
'name' key can be added if the backend supports named sequences.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
def get_key_columns(self, cursor, table_name):
"""
Backends can override this to return a list of:
(column_name, referenced_table_name, referenced_column_name)
for all key columns in given table.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_key_columns() method')
def get_primary_key_column(self, cursor, table_name):
"""
Return the name of the primary key column for the given table.
"""
for constraint in self.get_constraints(cursor, table_name).values():
if constraint['primary_key']:
return constraint['columns'][0]
return None
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index)
across one or more columns.
Return a dict mapping constraint names to their attributes,
where attributes is a dict with keys:
* columns: List of columns this covers
* primary_key: True if primary key, False otherwise
* unique: True if this is a unique constraint, False otherwise
* foreign_key: (table, column) of target, or None
* check: True if check constraint, False otherwise
* index: True if index, False otherwise.
* orders: The order (ASC/DESC) defined for the columns of indexes
* type: The type of the index (btree, hash, etc.)
Some backends may return special constraint names that don't exist
if they don't name constraints of a certain type (e.g. SQLite)
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_constraints() method')

View File

@@ -0,0 +1,681 @@
import datetime
import decimal
from importlib import import_module
import sqlparse
from django.conf import settings
from django.db import NotSupportedError, transaction
from django.db.backends import utils
from django.utils import timezone
from django.utils.encoding import force_str
class BaseDatabaseOperations:
"""
Encapsulate backend-specific differences, such as the way a backend
performs ordering or calculates the ID of a recently-inserted row.
"""
compiler_module = "django.db.models.sql.compiler"
# Integer field safe ranges by `internal_type` as documented
# in docs/ref/models/fields.txt.
integer_field_ranges = {
'SmallIntegerField': (-32768, 32767),
'IntegerField': (-2147483648, 2147483647),
'BigIntegerField': (-9223372036854775808, 9223372036854775807),
'PositiveSmallIntegerField': (0, 32767),
'PositiveIntegerField': (0, 2147483647),
'SmallAutoField': (-32768, 32767),
'AutoField': (-2147483648, 2147483647),
'BigAutoField': (-9223372036854775808, 9223372036854775807),
}
set_operators = {
'union': 'UNION',
'intersection': 'INTERSECT',
'difference': 'EXCEPT',
}
# Mapping of Field.get_internal_type() (typically the model field's class
# name) to the data type to use for the Cast() function, if different from
# DatabaseWrapper.data_types.
cast_data_types = {}
# CharField data type if the max_length argument isn't provided.
cast_char_field_without_max_length = None
# Start and end points for window expressions.
PRECEDING = 'PRECEDING'
FOLLOWING = 'FOLLOWING'
UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
CURRENT_ROW = 'CURRENT ROW'
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
explain_prefix = None
def __init__(self, connection):
self.connection = connection
self._cache = None
def autoinc_sql(self, table, column):
"""
Return any SQL needed to support auto-incrementing primary keys, or
None if no SQL is necessary.
This SQL is executed when a table is created.
"""
return None
def bulk_batch_size(self, fields, objs):
"""
Return the maximum allowed batch size for the backend. The fields
are the fields going to be inserted in the batch, the objs contains
all the objects to be inserted.
"""
return len(objs)
def cache_key_culling_sql(self):
"""
Return an SQL query that retrieves the first cache key greater than the
n smallest.
This is used by the 'db' cache backend to determine where to start
culling.
"""
return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s"
def unification_cast_sql(self, output_field):
"""
Given a field instance, return the SQL that casts the result of a union
to that type. The resulting string should contain a '%s' placeholder
for the expression being cast.
"""
return '%s'
def date_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
extracts a value from the given date field field_name.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
def date_interval_sql(self, timedelta):
"""
Implement the date interval functionality for expressions.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_interval_sql() method')
def date_trunc_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
truncates the given date field field_name to a date object with only
the given specificity.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
def datetime_cast_date_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to date value.
"""
raise NotImplementedError(
'subclasses of BaseDatabaseOperations may require a '
'datetime_cast_date_sql() method.'
)
def datetime_cast_time_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to time value.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method')
def datetime_extract_sql(self, lookup_type, field_name, tzname):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that extracts a value from the given
datetime field field_name.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method')
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that truncates the given datetime field
field_name to a datetime object with only the given specificity.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
def time_trunc_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
that truncates the given time field field_name to a time object with
only the given specificity.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')
def time_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
that extracts a value from the given time field field_name.
"""
return self.date_extract_sql(lookup_type, field_name)
def deferrable_sql(self):
"""
Return the SQL to make a constraint "initially deferred" during a
CREATE TABLE statement.
"""
return ''
def distinct_sql(self, fields, params):
"""
Return an SQL DISTINCT clause which removes duplicate rows from the
result set. If any fields are given, only check the given fields for
duplicates.
"""
if fields:
raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
else:
return ['DISTINCT'], []
def fetch_returned_insert_columns(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the newly created data.
"""
return cursor.fetchone()
def field_cast_sql(self, db_type, internal_type):
"""
Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
(e.g. 'GenericIPAddressField'), return the SQL to cast it before using
it in a WHERE statement. The resulting string should contain a '%s'
placeholder for the column being searched against.
"""
return '%s'
def force_no_ordering(self):
"""
Return a list used in the "ORDER BY" clause to force no ordering at
all. Return an empty list to include nothing in the ordering.
"""
return []
def for_update_sql(self, nowait=False, skip_locked=False, of=()):
"""
Return the FOR UPDATE SQL clause to lock rows for an update operation.
"""
return 'FOR UPDATE%s%s%s' % (
' OF %s' % ', '.join(of) if of else '',
' NOWAIT' if nowait else '',
' SKIP LOCKED' if skip_locked else '',
)
def _get_limit_offset_params(self, low_mark, high_mark):
offset = low_mark or 0
if high_mark is not None:
return (high_mark - offset), offset
elif offset:
return self.connection.ops.no_limit_value(), offset
return None, offset
def limit_offset_sql(self, low_mark, high_mark):
"""Return LIMIT/OFFSET SQL clause."""
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
return ' '.join(sql for sql in (
('LIMIT %d' % limit) if limit else None,
('OFFSET %d' % offset) if offset else None,
) if sql)
def last_executed_query(self, cursor, sql, params):
"""
Return a string of the query last executed by the given cursor, with
placeholders replaced with actual values.
`sql` is the raw query containing placeholders and `params` is the
sequence of parameters. These are used by default, but this method
exists for database backends to provide a better implementation
according to their own quoting schemes.
"""
# Convert params to contain string values.
def to_string(s):
return force_str(s, strings_only=True, errors='replace')
if isinstance(params, (list, tuple)):
u_params = tuple(to_string(val) for val in params)
elif params is None:
u_params = ()
else:
u_params = {to_string(k): to_string(v) for k, v in params.items()}
return "QUERY = %r - PARAMS = %r" % (sql, u_params)
def last_insert_id(self, cursor, table_name, pk_name):
"""
Given a cursor object that has just performed an INSERT statement into
a table that has an auto-incrementing ID, return the newly created ID.
`pk_name` is the name of the primary-key column.
"""
return cursor.lastrowid
def lookup_cast(self, lookup_type, internal_type=None):
"""
Return the string to use in a query when performing lookups
("contains", "like", etc.). It should contain a '%s' placeholder for
the column being searched against.
"""
return "%s"
def max_in_list_size(self):
"""
Return the maximum number of items that can be passed in a single 'IN'
list condition, or None if the backend does not impose a limit.
"""
return None
def max_name_length(self):
"""
Return the maximum length of table and column names, or None if there
is no limit.
"""
return None
def no_limit_value(self):
"""
Return the value to use for the LIMIT when we are wanting "LIMIT
infinity". Return None if the limit clause can be omitted in this case.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a no_limit_value() method')
def pk_default_value(self):
"""
Return the value to use during an INSERT statement to specify that
the field should use its default value.
"""
return 'DEFAULT'
def prepare_sql_script(self, sql):
"""
Take an SQL script that may contain multiple lines and return a list
of statements to feed to successive cursor.execute() calls.
Since few databases are able to process raw SQL scripts in a single
cursor.execute() call and PEP 249 doesn't talk about this use case,
the default implementation is conservative.
"""
return [
sqlparse.format(statement, strip_comments=True)
for statement in sqlparse.split(sql) if statement
]
def process_clob(self, value):
"""
Return the value of a CLOB column, for backends that return a locator
object that requires additional processing.
"""
return value
def return_insert_columns(self, fields):
"""
For backends that support returning columns as part of an insert query,
return the SQL and params to append to the INSERT query. The returned
fragment should contain a format string to hold the appropriate column.
"""
pass
def compiler(self, compiler_name):
"""
Return the SQLCompiler class corresponding to the given name,
in the namespace corresponding to the `compiler_module` attribute
on this backend.
"""
if self._cache is None:
self._cache = import_module(self.compiler_module)
return getattr(self._cache, compiler_name)
def quote_name(self, name):
"""
Return a quoted version of the given table, index, or column name. Do
not quote the given name if it's already been quoted.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a quote_name() method')
def random_function_sql(self):
"""Return an SQL expression that returns a random value."""
return 'RANDOM()'
def regex_lookup(self, lookup_type):
"""
Return the string to use in a query when performing regular expression
lookups (using "regex" or "iregex"). It should contain a '%s'
placeholder for the column being searched against.
If the feature is not supported (or part of it is not supported), raise
NotImplementedError.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a regex_lookup() method')
def savepoint_create_sql(self, sid):
"""
Return the SQL for starting a new savepoint. Only required if the
"uses_savepoints" feature is True. The "sid" parameter is a string
for the savepoint id.
"""
return "SAVEPOINT %s" % self.quote_name(sid)
def savepoint_commit_sql(self, sid):
"""
Return the SQL for committing the given savepoint.
"""
return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
def savepoint_rollback_sql(self, sid):
"""
Return the SQL for rolling back the given savepoint.
"""
return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
def set_time_zone_sql(self):
"""
Return the SQL that will set the connection's time zone.
Return '' if the backend doesn't support time zones.
"""
return ''
def sql_flush(self, style, tables, sequences, allow_cascade=False):
"""
Return a list of SQL statements required to remove all data from
the given database tables (without actually removing the tables
themselves) and the SQL statements required to reset the sequences
passed in `sequences`.
The `style` argument is a Style object as returned by either
color_style() or no_style() in django.core.management.color.
The `allow_cascade` argument determines whether truncation may cascade
to tables with foreign keys pointing the tables being truncated.
PostgreSQL requires a cascade even if these tables are empty.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations must provide a sql_flush() method')
def execute_sql_flush(self, using, sql_list):
"""Execute a list of SQL statements to flush the database."""
with transaction.atomic(using=using, savepoint=self.connection.features.can_rollback_ddl):
with self.connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
def sequence_reset_by_name_sql(self, style, sequences):
"""
Return a list of the SQL statements required to reset sequences
passed in `sequences`.
The `style` argument is a Style object as returned by either
color_style() or no_style() in django.core.management.color.
"""
return []
def sequence_reset_sql(self, style, model_list):
"""
Return a list of the SQL statements required to reset sequences for
the given models.
The `style` argument is a Style object as returned by either
color_style() or no_style() in django.core.management.color.
"""
return [] # No sequence reset required by default.
def start_transaction_sql(self):
"""Return the SQL statement required to start a transaction."""
return "BEGIN;"
def end_transaction_sql(self, success=True):
"""Return the SQL statement required to end a transaction."""
if not success:
return "ROLLBACK;"
return "COMMIT;"
def tablespace_sql(self, tablespace, inline=False):
"""
Return the SQL that will be used in a query to define the tablespace.
Return '' if the backend doesn't support tablespaces.
If `inline` is True, append the SQL to a row; otherwise append it to
the entire CREATE TABLE or CREATE INDEX statement.
"""
return ''
def prep_for_like_query(self, x):
"""Prepare a value for use in a LIKE query."""
return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
# Same as prep_for_like_query(), but called for "iexact" matches, which
# need not necessarily be implemented using "LIKE" in the backend.
prep_for_iexact_query = prep_for_like_query
def validate_autopk_value(self, value):
"""
Certain backends do not accept some values for "serial" fields
(for example zero in MySQL). Raise a ValueError if the value is
invalid, otherwise return the validated value.
"""
return value
def adapt_unknown_value(self, value):
"""
Transform a value to something compatible with the backend driver.
This method only depends on the type of the value. It's designed for
cases where the target type isn't known, such as .raw() SQL queries.
As a consequence it may not work perfectly in all circumstances.
"""
if isinstance(value, datetime.datetime): # must be before date
return self.adapt_datetimefield_value(value)
elif isinstance(value, datetime.date):
return self.adapt_datefield_value(value)
elif isinstance(value, datetime.time):
return self.adapt_timefield_value(value)
elif isinstance(value, decimal.Decimal):
return self.adapt_decimalfield_value(value)
else:
return value
def adapt_datefield_value(self, value):
"""
Transform a date value to an object compatible with what is expected
by the backend driver for date columns.
"""
if value is None:
return None
return str(value)
def adapt_datetimefield_value(self, value):
"""
Transform a datetime value to an object compatible with what is expected
by the backend driver for datetime columns.
"""
if value is None:
return None
return str(value)
def adapt_timefield_value(self, value):
"""
Transform a time value to an object compatible with what is expected
by the backend driver for time columns.
"""
if value is None:
return None
if timezone.is_aware(value):
raise ValueError("Django does not support timezone-aware times.")
return str(value)
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
"""
Transform a decimal.Decimal value to an object compatible with what is
expected by the backend driver for decimal (numeric) columns.
"""
return utils.format_number(value, max_digits, decimal_places)
def adapt_ipaddressfield_value(self, value):
"""
Transform a string representation of an IP address into the expected
type for the backend driver.
"""
return value or None
def year_lookup_bounds_for_date_field(self, value):
"""
Return a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateField value using a year
lookup.
`value` is an int, containing the looked-up year.
"""
first = datetime.date(value, 1, 1)
second = datetime.date(value, 12, 31)
first = self.adapt_datefield_value(first)
second = self.adapt_datefield_value(second)
return [first, second]
def year_lookup_bounds_for_datetime_field(self, value):
"""
Return a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateTimeField value using a year
lookup.
`value` is an int, containing the looked-up year.
"""
first = datetime.datetime(value, 1, 1)
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
if settings.USE_TZ:
tz = timezone.get_current_timezone()
first = timezone.make_aware(first, tz)
second = timezone.make_aware(second, tz)
first = self.adapt_datetimefield_value(first)
second = self.adapt_datetimefield_value(second)
return [first, second]
def get_db_converters(self, expression):
"""
Return a list of functions needed to convert field data.
Some field types on some backends do not provide data in the correct
format, this is the hook for converter functions.
"""
return []
def convert_durationfield_value(self, value, expression, connection):
if value is not None:
return datetime.timedelta(0, 0, value)
def check_expression_support(self, expression):
"""
Check that the backend supports the provided expression.
This is used on specific backends to rule out known expressions
that have problematic or nonexistent implementations. If the
expression has a known problem, the backend should raise
NotSupportedError.
"""
pass
def conditional_expression_supported_in_where_clause(self, expression):
"""
Return True, if the conditional expression is supported in the WHERE
clause.
"""
return True
def combine_expression(self, connector, sub_expressions):
"""
Combine a list of subexpressions into a single expression, using
the provided connecting operator. This is required because operators
can vary between backends (e.g., Oracle with %% and &) and between
subexpression types (e.g., date expressions).
"""
conn = ' %s ' % connector
return conn.join(sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
return self.combine_expression(connector, sub_expressions)
def binary_placeholder_sql(self, value):
"""
Some backends require special syntax to insert binary content (MySQL
for example uses '_binary %s').
"""
return '%s'
def modify_insert_params(self, placeholder, params):
"""
Allow modification of insert parameters. Needed for Oracle Spatial
backend due to #10888.
"""
return params
def integer_field_range(self, internal_type):
"""
Given an integer field internal type (e.g. 'PositiveIntegerField'),
return a tuple of the (min_value, max_value) form representing the
range of the column type bound to the field.
"""
return self.integer_field_ranges[internal_type]
def subtract_temporals(self, internal_type, lhs, rhs):
if self.connection.features.supports_temporal_subtraction:
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
return '(%s - %s)' % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
raise NotSupportedError("This backend does not support %s subtraction." % internal_type)
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
return '%d %s' % (abs(start), self.PRECEDING)
elif start == 0:
return self.CURRENT_ROW
elif start is None:
return self.UNBOUNDED_PRECEDING
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return self.CURRENT_ROW
elif end > 0:
return '%d %s' % (end, self.FOLLOWING)
elif end is None:
return self.UNBOUNDED_FOLLOWING
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
if not self.connection.features.supports_over_clause:
raise NotSupportedError('This backend does not support window expressions.')
return self.window_frame_start(start), self.window_frame_end(end)
def window_frame_range_start_end(self, start=None, end=None):
return self.window_frame_rows_start_end(start, end)
def explain_query_prefix(self, format=None, **options):
if not self.connection.features.supports_explaining_query_execution:
raise NotSupportedError('This backend does not support explaining query execution.')
if format:
supported_formats = self.connection.features.supported_explain_formats
normalized_format = format.upper()
if normalized_format not in supported_formats:
msg = '%s is not a recognized format.' % normalized_format
if supported_formats:
msg += ' Allowed formats: %s' % ', '.join(sorted(supported_formats))
raise ValueError(msg)
if options:
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
return self.explain_prefix
def insert_statement(self, ignore_conflicts=False):
return 'INSERT INTO'
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return ''

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
class BaseDatabaseValidation:
"""Encapsulate backend-specific validation."""
def __init__(self, connection):
self.connection = connection
def check(self, **kwargs):
return []
def check_field(self, field, **kwargs):
errors = []
# Backends may implement a check_field_type() method.
if (hasattr(self, 'check_field_type') and
# Ignore any related fields.
not getattr(field, 'remote_field', None)):
# Ignore fields with unsupported features.
db_supports_all_required_features = all(
getattr(self.connection.features, feature, False)
for feature in field.model._meta.required_db_features
)
if db_supports_all_required_features:
field_type = field.db_type(self.connection)
# Ignore non-concrete fields.
if field_type is not None:
errors.extend(self.check_field_type(field, field_type))
return errors

View File

@@ -0,0 +1,194 @@
"""
Helpers to manipulate deferred DDL statements that might need to be adjusted or
discarded within when executing a migration.
"""
class Reference:
"""Base class that defines the reference interface."""
def references_table(self, table):
"""
Return whether or not this instance references the specified table.
"""
return False
def references_column(self, table, column):
"""
Return whether or not this instance references the specified column.
"""
return False
def rename_table_references(self, old_table, new_table):
"""
Rename all references to the old_name to the new_table.
"""
pass
def rename_column_references(self, table, old_column, new_column):
"""
Rename all references to the old_column to the new_column.
"""
pass
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, str(self))
def __str__(self):
raise NotImplementedError('Subclasses must define how they should be converted to string.')
class Table(Reference):
"""Hold a reference to a table."""
def __init__(self, table, quote_name):
self.table = table
self.quote_name = quote_name
def references_table(self, table):
return self.table == table
def rename_table_references(self, old_table, new_table):
if self.table == old_table:
self.table = new_table
def __str__(self):
return self.quote_name(self.table)
class TableColumns(Table):
"""Base class for references to multiple columns of a table."""
def __init__(self, table, columns):
self.table = table
self.columns = columns
def references_column(self, table, column):
return self.table == table and column in self.columns
def rename_column_references(self, table, old_column, new_column):
if self.table == table:
for index, column in enumerate(self.columns):
if column == old_column:
self.columns[index] = new_column
class Columns(TableColumns):
"""Hold a reference to one or many columns."""
def __init__(self, table, columns, quote_name, col_suffixes=()):
self.quote_name = quote_name
self.col_suffixes = col_suffixes
super().__init__(table, columns)
def __str__(self):
def col_str(column, idx):
try:
return self.quote_name(column) + self.col_suffixes[idx]
except IndexError:
return self.quote_name(column)
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class IndexName(TableColumns):
"""Hold a reference to an index name."""
def __init__(self, table, columns, suffix, create_index_name):
self.suffix = suffix
self.create_index_name = create_index_name
super().__init__(table, columns)
def __str__(self):
return self.create_index_name(self.table, self.columns, self.suffix)
class IndexColumns(Columns):
def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
self.opclasses = opclasses
super().__init__(table, columns, quote_name, col_suffixes)
def __str__(self):
def col_str(column, idx):
# Index.__init__() guarantees that self.opclasses is the same
# length as self.columns.
col = '{} {}'.format(self.quote_name(column), self.opclasses[idx])
try:
col = '{} {}'.format(col, self.col_suffixes[idx])
except IndexError:
pass
return col
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class ForeignKeyName(TableColumns):
"""Hold a reference to a foreign key name."""
def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
self.to_reference = TableColumns(to_table, to_columns)
self.suffix_template = suffix_template
self.create_fk_name = create_fk_name
super().__init__(from_table, from_columns,)
def references_table(self, table):
return super().references_table(table) or self.to_reference.references_table(table)
def references_column(self, table, column):
return (
super().references_column(table, column) or
self.to_reference.references_column(table, column)
)
def rename_table_references(self, old_table, new_table):
super().rename_table_references(old_table, new_table)
self.to_reference.rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
super().rename_column_references(table, old_column, new_column)
self.to_reference.rename_column_references(table, old_column, new_column)
def __str__(self):
suffix = self.suffix_template % {
'to_table': self.to_reference.table,
'to_column': self.to_reference.columns[0],
}
return self.create_fk_name(self.table, self.columns, suffix)
class Statement(Reference):
"""
Statement template and formatting parameters container.
Allows keeping a reference to a statement without interpolating identifiers
that might have to be adjusted if they're referencing a table or column
that is removed
"""
def __init__(self, template, **parts):
self.template = template
self.parts = parts
def references_table(self, table):
return any(
hasattr(part, 'references_table') and part.references_table(table)
for part in self.parts.values()
)
def references_column(self, table, column):
return any(
hasattr(part, 'references_column') and part.references_column(table, column)
for part in self.parts.values()
)
def rename_table_references(self, old_table, new_table):
for part in self.parts.values():
if hasattr(part, 'rename_table_references'):
part.rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
for part in self.parts.values():
if hasattr(part, 'rename_column_references'):
part.rename_column_references(table, old_column, new_column)
def __str__(self):
return self.template % self.parts

View File

@@ -0,0 +1,73 @@
"""
Dummy database backend for Django.
Django uses this if the database ENGINE setting is empty (None or empty string).
Each of these API functions, except connection.close(), raise
ImproperlyConfigured.
"""
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.base.client import BaseDatabaseClient
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.dummy.features import DummyDatabaseFeatures
def complain(*args, **kwargs):
raise ImproperlyConfigured("settings.DATABASES is improperly configured. "
"Please supply the ENGINE value. Check "
"settings documentation for more details.")
def ignore(*args, **kwargs):
pass
class DatabaseOperations(BaseDatabaseOperations):
quote_name = complain
class DatabaseClient(BaseDatabaseClient):
runshell = complain
class DatabaseCreation(BaseDatabaseCreation):
create_test_db = ignore
destroy_test_db = ignore
class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
get_key_columns = complain
class DatabaseWrapper(BaseDatabaseWrapper):
operators = {}
# Override the base class implementations with null
# implementations. Anything that tries to actually
# do something raises complain; anything that tries
# to rollback or undo something raises ignore.
_cursor = complain
ensure_connection = complain
_commit = complain
_rollback = ignore
_close = ignore
_savepoint = ignore
_savepoint_commit = complain
_savepoint_rollback = ignore
_set_autocommit = complain
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DummyDatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
def is_usable(self):
return True

View File

@@ -0,0 +1,6 @@
from django.db.backends.base.features import BaseDatabaseFeatures
class DummyDatabaseFeatures(BaseDatabaseFeatures):
supports_transactions = False
uses_savepoints = False

View File

@@ -0,0 +1,363 @@
"""
MySQL database backend for Django.
Requires mysqlclient: https://pypi.org/project/mysqlclient/
"""
import re
from django.core.exceptions import ImproperlyConfigured
from django.db import utils
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
try:
import MySQLdb as Database
except ImportError as err:
raise ImproperlyConfigured(
'Error loading MySQLdb module.\n'
'Did you install mysqlclient?'
) from err
from MySQLdb.constants import CLIENT, FIELD_TYPE # isort:skip
from MySQLdb.converters import conversions # isort:skip
# Some of these import MySQLdb, so import them after checking if it's installed.
from .client import DatabaseClient # isort:skip
from .creation import DatabaseCreation # isort:skip
from .features import DatabaseFeatures # isort:skip
from .introspection import DatabaseIntrospection # isort:skip
from .operations import DatabaseOperations # isort:skip
from .schema import DatabaseSchemaEditor # isort:skip
from .validation import DatabaseValidation # isort:skip
version = Database.version_info
if version < (1, 3, 13):
raise ImproperlyConfigured('mysqlclient 1.3.13 or newer is required; you have %s.' % Database.__version__)
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
# terms of actual behavior as they are signed and include days -- and Django
# expects time.
django_conversions = {
**conversions,
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
}
# This should match the numerical portion of the version numbers (we can treat
# versions like 5.0.24 and 5.0.24a as the same).
server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
class CursorWrapper:
"""
A thin wrapper around MySQLdb's normal cursor class that catches particular
exception instances and reraises them with the correct types.
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
to the particular underlying representation returned by Connection.cursor().
"""
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
3819, # CHECK constraint is violated
4025, # CHECK constraint failed
)
def __init__(self, cursor):
self.cursor = cursor
def execute(self, query, args=None):
try:
# args is None means no string interpolation
return self.cursor.execute(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise utils.IntegrityError(*tuple(e.args))
raise
def executemany(self, query, args):
try:
return self.cursor.executemany(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise utils.IntegrityError(*tuple(e.args))
raise
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql'
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'integer AUTO_INCREMENT',
'BigAutoField': 'bigint AUTO_INCREMENT',
'BinaryField': 'longblob',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime(6)',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PositiveIntegerField': 'integer UNSIGNED',
'PositiveSmallIntegerField': 'smallint UNSIGNED',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallint AUTO_INCREMENT',
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time(6)',
'UUIDField': 'char(32)',
}
# For these data types:
# - MySQL < 8.0.13 and MariaDB < 10.2.1 don't accept default values and
# implicitly treat them as nullable
# - all versions of MySQL and MariaDB don't support full width database
# indexes
_limited_data_types = (
'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text',
'mediumtext', 'longtext', 'json',
)
operators = {
'exact': '= %s',
'iexact': 'LIKE %s',
'contains': 'LIKE BINARY %s',
'icontains': 'LIKE %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE BINARY %s',
'endswith': 'LIKE BINARY %s',
'istartswith': 'LIKE %s',
'iendswith': 'LIKE %s',
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
'icontains': "LIKE CONCAT('%%', {}, '%%')",
'startswith': "LIKE BINARY CONCAT({}, '%%')",
'istartswith': "LIKE CONCAT({}, '%%')",
'endswith': "LIKE BINARY CONCAT('%%', {})",
'iendswith': "LIKE CONCAT('%%', {})",
}
isolation_levels = {
'read uncommitted',
'read committed',
'repeatable read',
'serializable',
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def get_connection_params(self):
kwargs = {
'conv': django_conversions,
'charset': 'utf8',
}
settings_dict = self.settings_dict
if settings_dict['USER']:
kwargs['user'] = settings_dict['USER']
if settings_dict['NAME']:
kwargs['db'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['passwd'] = settings_dict['PASSWORD']
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
kwargs['host'] = settings_dict['HOST']
if settings_dict['PORT']:
kwargs['port'] = int(settings_dict['PORT'])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs['client_flag'] = CLIENT.FOUND_ROWS
# Validate the transaction isolation level, if specified.
options = settings_dict['OPTIONS'].copy()
isolation_level = options.pop('isolation_level', 'read committed')
if isolation_level:
isolation_level = isolation_level.lower()
if isolation_level not in self.isolation_levels:
raise ImproperlyConfigured(
"Invalid transaction isolation level '%s' specified.\n"
"Use one of %s, or None." % (
isolation_level,
', '.join("'%s'" % s for s in sorted(self.isolation_levels))
))
self.isolation_level = isolation_level
kwargs.update(options)
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
return Database.connect(**conn_params)
def init_connection_state(self):
assignments = []
if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
assignments.append('SET SQL_AUTO_IS_NULL = 0')
if self.isolation_level:
assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper())
if assignments:
with self.cursor() as cursor:
cursor.execute('; '.join(assignments))
@async_unsafe
def create_cursor(self, name=None):
cursor = self.connection.cursor()
return CursorWrapper(cursor)
def _rollback(self):
try:
BaseDatabaseWrapper._rollback(self)
except Database.NotSupportedError:
pass
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit(autocommit)
def disable_constraint_checking(self):
"""
Disable foreign key checks, primarily for use in adding rows with
forward references. Always return True to indicate constraint checks
need to be re-enabled.
"""
self.cursor().execute('SET foreign_key_checks=0')
return True
def enable_constraint_checking(self):
"""
Re-enable foreign key checks after they have been disabled.
"""
# Override needs_rollback in case constraint_checks_disabled is
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
self.cursor().execute('SET foreign_key_checks=1')
finally:
self.needs_rollback = needs_rollback
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" % (
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self):
try:
self.connection.ping()
except Database.Error:
return False
else:
return True
@cached_property
def display_name(self):
return 'MariaDB' if self.mysql_is_mariadb else 'MySQL'
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
return {
'PositiveIntegerField': '`%(column)s` >= 0',
'PositiveSmallIntegerField': '`%(column)s` >= 0',
}
return {}
@cached_property
def mysql_server_info(self):
with self.temporary_connection() as cursor:
cursor.execute('SELECT VERSION()')
return cursor.fetchone()[0]
@cached_property
def mysql_version(self):
match = server_version_re.match(self.mysql_server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % self.mysql_server_info)
return tuple(int(x) for x in match.groups())
@cached_property
def mysql_is_mariadb(self):
return 'mariadb' in self.mysql_server_info.lower()

View File

@@ -0,0 +1,48 @@
import subprocess
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql'
@classmethod
def settings_to_cmd_args(cls, settings_dict):
args = [cls.executable_name]
db = settings_dict['OPTIONS'].get('db', settings_dict['NAME'])
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
passwd = settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD'])
host = settings_dict['OPTIONS'].get('host', settings_dict['HOST'])
port = settings_dict['OPTIONS'].get('port', settings_dict['PORT'])
server_ca = settings_dict['OPTIONS'].get('ssl', {}).get('ca')
client_cert = settings_dict['OPTIONS'].get('ssl', {}).get('cert')
client_key = settings_dict['OPTIONS'].get('ssl', {}).get('key')
defaults_file = settings_dict['OPTIONS'].get('read_default_file')
# Seems to be no good way to set sql_mode with CLI.
if defaults_file:
args += ["--defaults-file=%s" % defaults_file]
if user:
args += ["--user=%s" % user]
if passwd:
args += ["--password=%s" % passwd]
if host:
if '/' in host:
args += ["--socket=%s" % host]
else:
args += ["--host=%s" % host]
if port:
args += ["--port=%s" % port]
if server_ca:
args += ["--ssl-ca=%s" % server_ca]
if client_cert:
args += ["--ssl-cert=%s" % client_cert]
if client_key:
args += ["--ssl-key=%s" % client_key]
if db:
args += [db]
return args
def runshell(self):
args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)
subprocess.run(args, check=True)

View File

@@ -0,0 +1,25 @@
from django.db.models.sql import compiler
class SQLCompiler(compiler.SQLCompiler):
def as_subquery_condition(self, alias, columns, compiler):
qn = compiler.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
return '(%s) IN (%s)' % (', '.join('%s.%s' % (qn(alias), qn2(column)) for column in columns), sql), params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass

View File

@@ -0,0 +1,66 @@
import subprocess
import sys
from django.db.backends.base.creation import BaseDatabaseCreation
from .client import DatabaseClient
class DatabaseCreation(BaseDatabaseCreation):
def sql_table_creation_suffix(self):
suffix = []
test_settings = self.connection.settings_dict['TEST']
if test_settings['CHARSET']:
suffix.append('CHARACTER SET %s' % test_settings['CHARSET'])
if test_settings['COLLATION']:
suffix.append('COLLATE %s' % test_settings['COLLATION'])
return ' '.join(suffix)
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if len(e.args) < 1 or e.args[0] != 1007:
# All errors except "database exists" (1007) cancel tests.
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
else:
raise e
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
'dbname': self.connection.ops.quote_name(target_database_name),
'suffix': self.sql_table_creation_suffix(),
}
with self._nodb_connection.cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
if keepdb:
# If the database should be kept, skip everything else.
return
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name):
dump_args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)[1:]
dump_cmd = ['mysqldump', *dump_args[:-1], '--routines', '--events', source_database_name]
load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)
load_cmd[-1] = target_database_name
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE) as dump_proc:
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()

View File

@@ -0,0 +1,138 @@
import operator
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
empty_fetchmany_value = ()
update_can_self_select = False
allows_group_by_pk = True
related_fields_match_type = True
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
allow_sliced_subqueries_with_in = False
has_select_for_update = True
supports_forward_references = False
supports_regex_backreferencing = False
supports_date_lookup_using_string = False
can_introspect_autofield = True
can_introspect_binary_field = False
can_introspect_duration_field = False
can_introspect_small_integer_field = True
can_introspect_positive_integer_field = True
introspected_boolean_field_type = 'IntegerField'
supports_index_column_ordering = False
supports_timezones = False
requires_explicit_null_ordering_when_grouping = True
allows_auto_pk_0 = False
can_release_savepoints = True
atomic_transactions = False
can_clone_databases = True
supports_temporal_subtraction = True
supports_select_intersection = False
supports_select_difference = False
supports_slicing_ordering_in_compound = True
supports_index_on_text_field = False
has_case_insensitive_like = False
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN
DECLARE V_I INTEGER;
SET V_I = 1;
END;
"""
create_test_procedure_with_int_param_sql = """
CREATE PROCEDURE test_procedure (P_I INTEGER)
BEGIN
DECLARE V_I INTEGER;
SET V_I = P_I;
END;
"""
db_functions_convert_bytes_to_str = True
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
supported_explain_formats = {'JSON', 'TEXT', 'TRADITIONAL'}
# Neither MySQL nor MariaDB support partial indexes.
supports_partial_indexes = False
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
with self.connection.cursor() as cursor:
cursor.execute("SELECT ENGINE FROM INFORMATION_SCHEMA.ENGINES WHERE SUPPORT = 'DEFAULT'")
result = cursor.fetchone()
return result[0]
@cached_property
def can_introspect_foreign_keys(self):
"Confirm support for introspected foreign keys"
return self._mysql_storage_engine != 'MyISAM'
@cached_property
def has_zoneinfo_database(self):
# Test if the time zone definitions are installed. CONVERT_TZ returns
# NULL if 'UTC' timezone isn't loaded into the mysql.time_zone.
with self.connection.cursor() as cursor:
cursor.execute("SELECT CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC')")
return cursor.fetchone()[0] is not None
@cached_property
def is_sql_auto_is_null_enabled(self):
with self.connection.cursor() as cursor:
cursor.execute('SELECT @@SQL_AUTO_IS_NULL')
result = cursor.fetchone()
return result and result[0] == 1
@cached_property
def supports_over_clause(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2)
return self.connection.mysql_version >= (8, 0, 2)
@cached_property
def supports_column_check_constraints(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2, 1)
return self.connection.mysql_version >= (8, 0, 16)
supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints'))
@cached_property
def can_introspect_check_constraints(self):
if self.connection.mysql_is_mariadb:
version = self.connection.mysql_version
return (version >= (10, 2, 22) and version < (10, 3)) or version >= (10, 3, 10)
return self.connection.mysql_version >= (8, 0, 16)
@cached_property
def has_select_for_update_skip_locked(self):
return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)
@cached_property
def has_select_for_update_nowait(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 3, 0)
return self.connection.mysql_version >= (8, 0, 1)
@cached_property
def needs_explain_extended(self):
# EXTENDED is deprecated (and not required) in MySQL 5.7.
return not self.connection.mysql_is_mariadb and self.connection.mysql_version < (5, 7)
@cached_property
def supports_transactions(self):
"""
All storage engines except MyISAM support transactions.
"""
return self._mysql_storage_engine != 'MyISAM'
@cached_property
def ignores_table_name_case(self):
with self.connection.cursor() as cursor:
cursor.execute('SELECT @@LOWER_CASE_TABLE_NAMES')
result = cursor.fetchone()
return result and result[0] != 0
@cached_property
def supports_default_in_lead_lag(self):
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
return not self.connection.mysql_is_mariadb

View File

@@ -0,0 +1,267 @@
from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models.indexes import Index
from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned'))
InfoLine = namedtuple('InfoLine', 'col_name data_type max_len num_prec num_scale extra column_default is_unsigned')
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.NEWDECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'BigIntegerField',
FIELD_TYPE.SHORT: 'SmallIntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIME: 'TimeField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if 'auto_increment' in description.extra:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
if description.is_unsigned:
if field_type == 'IntegerField':
return 'PositiveIntegerField'
elif field_type == 'SmallIntegerField':
return 'PositiveSmallIntegerField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("SHOW FULL TABLES")
return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1]))
for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface."
"""
# information_schema database gives more accurate results for some figures:
# - varchar length returned by cursor.description is an internal length,
# not visible length (#5725)
# - precision and scale (for decimal fields) (#5014)
# - auto_increment is not available in cursor.description
cursor.execute("""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
CASE
WHEN column_type LIKE '%% unsigned' THEN 1
ELSE 0
END AS is_unsigned
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()""", [table_name])
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
def to_int(i):
return int(i) if i is not None else i
fields = []
for line in cursor.description:
info = field_info[line[0]]
fields.append(FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.extra,
info.is_unsigned,
))
return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if 'auto_increment' in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{'table': table_name, 'column': field_info.name}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
constraints = self.get_key_columns(cursor, table_name)
relations = {}
for my_fieldname, other_table, other_field in constraints:
relations[my_fieldname] = (other_field, other_table)
return relations
def get_key_columns(self, cursor, table_name):
"""
Return a list of (column_name, referenced_table_name, referenced_column_name)
for all key columns in the given table.
"""
key_columns = []
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
key_columns.extend(cursor.fetchall())
return key_columns
def get_storage_engine(self, cursor, table_name):
"""
Retrieve the storage engine for a given table. Return the default
storage engine if the table doesn't exist.
"""
cursor.execute(
"SELECT engine "
"FROM information_schema.tables "
"WHERE table_name = %s", [table_name])
result = cursor.fetchone()
if not result:
return self.connection.features._mysql_storage_engine
return result[0]
def _parse_constraint_columns(self, check_clause, columns):
check_columns = OrderedSet()
statement = sqlparse.parse(check_clause)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if (
token.ttype == sqlparse.tokens.Name and
self.connection.ops.quote_name(token.value) == token.value and
token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1])
return check_columns
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Get the actual constraint names and columns
name_query = """
SELECT kc.`constraint_name`, kc.`column_name`,
kc.`referenced_table_name`, kc.`referenced_column_name`
FROM information_schema.key_column_usage AS kc
WHERE
kc.table_schema = DATABASE() AND
kc.table_name = %s
ORDER BY kc.`ordinal_position`
"""
cursor.execute(name_query, [table_name])
for constraint, column, ref_table, ref_column in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
'columns': OrderedSet(),
'primary_key': False,
'unique': False,
'index': False,
'check': False,
'foreign_key': (ref_table, ref_column) if ref_column else None,
}
constraints[constraint]['columns'].add(column)
# Now get the constraint types
type_query = """
SELECT c.constraint_name, c.constraint_type
FROM information_schema.table_constraints AS c
WHERE
c.table_schema = DATABASE() AND
c.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, kind in cursor.fetchall():
if kind.lower() == "primary key":
constraints[constraint]['primary_key'] = True
constraints[constraint]['unique'] = True
elif kind.lower() == "unique":
constraints[constraint]['unique'] = True
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
columns = {info.name for info in self.get_table_description(cursor, table_name)}
if self.connection.mysql_is_mariadb:
type_query = """
SELECT c.constraint_name, c.check_clause
FROM information_schema.check_constraints AS c
WHERE
c.constraint_schema = DATABASE() AND
c.table_name = %s
"""
else:
type_query = """
SELECT cc.constraint_name, cc.check_clause
FROM
information_schema.check_constraints AS cc,
information_schema.table_constraints AS tc
WHERE
cc.constraint_schema = DATABASE() AND
tc.table_schema = cc.constraint_schema AND
cc.constraint_name = tc.constraint_name AND
tc.constraint_type = 'CHECK' AND
tc.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
constraint_columns = self._parse_constraint_columns(check_clause, columns)
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index
constraints[constraint] = {
'columns': constraint_columns,
'primary_key': False,
'unique': False,
'index': False,
'check': True,
'foreign_key': None,
}
# Now add in the indexes
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
for table, non_unique, index, colseq, column, type_ in [x[:5] + (x[10],) for x in cursor.fetchall()]:
if index not in constraints:
constraints[index] = {
'columns': OrderedSet(),
'primary_key': False,
'unique': False,
'check': False,
'foreign_key': None,
}
constraints[index]['index'] = True
constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower()
constraints[index]['columns'].add(column)
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint['columns'] = list(constraint['columns'])
return constraints

View File

@@ -0,0 +1,317 @@
import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.utils import timezone
from django.utils.duration import duration_microseconds
from django.utils.encoding import force_str
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.mysql.compiler"
# MySQL stores positive fields as UNSIGNED ints.
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
'PositiveSmallIntegerField': (0, 65535),
'PositiveIntegerField': (0, 4294967295),
}
cast_data_types = {
'AutoField': 'signed integer',
'BigAutoField': 'signed integer',
'SmallAutoField': 'signed integer',
'CharField': 'char(%(max_length)s)',
'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)',
'TextField': 'char',
'IntegerField': 'signed integer',
'BigIntegerField': 'signed integer',
'SmallIntegerField': 'signed integer',
'PositiveIntegerField': 'unsigned integer',
'PositiveSmallIntegerField': 'unsigned integer',
}
cast_char_field_without_max_length = 'char'
explain_prefix = 'EXPLAIN'
def date_extract_sql(self, lookup_type, field_name):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == 'week_day':
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
# Note: WEEKDAY() returns 0-6, Monday=0.
return "DAYOFWEEK(%s)" % field_name
elif lookup_type == 'week':
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return "WEEK(%s, 3)" % field_name
elif lookup_type == 'iso_year':
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
else:
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name):
fields = {
'year': '%%Y-01-01',
'month': '%%Y-%%m-01',
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
elif lookup_type == 'quarter':
return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % (
field_name, field_name
)
elif lookup_type == 'week':
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (
field_name, field_name
)
else:
return "DATE(%s)" % (field_name)
def _prepare_tzname_delta(self, tzname):
if '+' in tzname:
return tzname[tzname.find('+'):]
elif '-' in tzname:
return tzname[tzname.find('-'):]
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ and self.connection.timezone_name != tzname:
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
field_name,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE(%s)" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "TIME(%s)" % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
if lookup_type == 'quarter':
return (
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
"INTERVAL QUARTER({field_name}) QUARTER - " +
"INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
).format(field_name=field_name)
if lookup_type == 'week':
return (
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
"INTERVAL WEEKDAY({field_name}) DAY), "
"'%%Y-%%m-%%d 00:00:00') AS DATETIME)"
).format(field_name=field_name)
try:
i = fields.index(lookup_type) + 1
except ValueError:
sql = field_name
else:
format_str = ''.join(format[:i] + format_def[i:])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql
def time_trunc_sql(self, lookup_type, field_name):
fields = {
'hour': '%%H:00:00',
'minute': '%%H:%%i:00',
'second': '%%H:%%i:%%s',
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str)
else:
return "TIME(%s)" % (field_name)
def date_interval_sql(self, timedelta):
return 'INTERVAL %s MICROSECOND' % duration_microseconds(timedelta)
def format_for_duration_arithmetic(self, sql):
return 'INTERVAL %s MICROSECOND' % sql
def force_no_ordering(self):
"""
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
columns. If no ordering would otherwise be applied, we don't want any
implicit sorting going on.
"""
return [(None, ("NULL", [], False))]
def last_executed_query(self, cursor, sql, params):
# With MySQLdb, cursor objects have an (undocumented) "_executed"
# attribute where the exact query sent to the database is saved.
# See MySQLdb/cursors.py in the source distribution.
# MySQLdb returns string, PyMySQL bytes.
return force_str(getattr(cursor, '_executed', None), errors='replace')
def no_limit_value(self):
# 2**64 - 1, as recommended by the MySQL documentation
return 18446744073709551615
def quote_name(self, name):
if name.startswith("`") and name.endswith("`"):
return name # Quoting once is enough.
return "`%s`" % name
def random_function_sql(self):
return 'RAND()'
def sql_flush(self, style, tables, sequences, allow_cascade=False):
# NB: The generated SQL below is specific to MySQL
# 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements
# to clear all tables of all data
if tables:
sql = ['SET FOREIGN_KEY_CHECKS = 0;']
for table in tables:
sql.append('%s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
style.SQL_FIELD(self.quote_name(table)),
))
sql.append('SET FOREIGN_KEY_CHECKS = 1;')
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
else:
return []
def validate_autopk_value(self, value):
# MySQLism: zero in AUTO_INCREMENT field does not work. Refs #17653.
if value == 0:
raise ValueError('The database backend does not accept 0 as a '
'value for AutoField.')
return value
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.")
return str(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")
return str(value)
def max_name_length(self):
return 64
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
if connector == '^':
return 'POW(%s)' % ','.join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
# return an unsigned integer.
elif connector in ('&', '|', '<<'):
return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions)
elif connector == '>>':
lhs, rhs = sub_expressions
return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
return super().combine_expression(connector, sub_expressions)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value)
elif internal_type == 'DateTimeField':
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
return converters
def convert_booleanfield_value(self, value, expression, connection):
if value in (0, 1):
value = bool(value)
return value
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def binary_placeholder_sql(self, value):
return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
if internal_type == 'TimeField':
if self.connection.mysql_is_mariadb:
# MariaDB includes the microsecond component in TIME_TO_SEC as
# a decimal. MySQL returns an integer without microseconds.
return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % {
'lhs': lhs_sql, 'rhs': rhs_sql
}, (*lhs_params, *rhs_params)
return (
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
) % {'lhs': lhs_sql, 'rhs': rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2
params = (*rhs_params, *lhs_params)
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
def explain_query_prefix(self, format=None, **options):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
if format and format.upper() == 'TEXT':
format = 'TRADITIONAL'
prefix = super().explain_query_prefix(format, **options)
if format:
prefix += ' FORMAT=%s' % format
if self.connection.features.needs_explain_extended and format is None:
# EXTENDED and FORMAT are mutually exclusive options.
prefix += ' EXTENDED'
return prefix
def regex_lookup(self, lookup_type):
# REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
# doesn't exist in MySQL 5.6 or in MariaDB.
if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb:
if lookup_type == 'regex':
return '%s REGEXP BINARY %s'
return '%s REGEXP %s'
match_option = 'c' if lookup_type == 'regex' else 'i'
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, ignore_conflicts=False):
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)

View File

@@ -0,0 +1,139 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import NOT_PROVIDED
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_rename_column = "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
sql_create_column_inline_fk = (
', ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) '
'REFERENCES %(to_table)s(%(to_column)s)'
)
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s'
@property
def sql_delete_check(self):
if self.connection.mysql_is_mariadb:
# The name of the column check constraint is the same as the field
# name on MariaDB. Adding IF EXISTS clause prevents migrations
# crash. Constraint is removed during a "MODIFY" column statement.
return 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s'
return 'ALTER TABLE %(table)s DROP CHECK %(name)s'
def quote_value(self, value):
self.connection.ensure_connection()
if isinstance(value, str):
value = value.replace('%', '%%')
# MySQLdb escapes to string, PyMySQL to bytes.
quoted = self.connection.connection.escape(value, self.connection.connection.encoders)
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted
def _is_limited_data_type(self, field):
db_type = field.db_type(self.connection)
return db_type is not None and db_type.lower() in self.connection._limited_data_types
def skip_default(self, field):
if not self._supports_limited_data_type_defaults:
return self._is_limited_data_type(field)
return False
@property
def _supports_limited_data_type_defaults(self):
# MariaDB >= 10.2.1 and MySQL >= 8.0.13 supports defaults for BLOB
# and TEXT.
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2, 1)
return self.connection.mysql_version >= (8, 0, 13)
def _column_default_sql(self, field):
if (
not self.connection.mysql_is_mariadb and
self._supports_limited_data_type_defaults and
self._is_limited_data_type(field)
):
# MySQL supports defaults for BLOB and TEXT columns only if the
# default value is written as an expression i.e. in parentheses.
return '(%s)'
return super()._column_default_sql(field)
def add_field(self, model, field):
super().add_field(model, field)
# Simulate the effect of a one-off default.
# field.default may be unhashable, so a set isn't used for "in" check.
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
effective_default = self.effective_default(field)
self.execute('UPDATE %(table)s SET %(column)s = %%s' % {
'table': self.quote_name(model._meta.db_table),
'column': self.quote_name(field.column),
}, [effective_default])
def _field_should_be_indexed(self, model, field):
create_index = super()._field_should_be_indexed(model, field)
storage = self.connection.introspection.get_storage_engine(
self.connection.cursor(), model._meta.db_table
)
# No need to create an index for ForeignKey fields except if
# db_constraint=False because the index from that constraint won't be
# created.
if (storage == "InnoDB" and
create_index and
field.get_internal_type() == 'ForeignKey' and
field.db_constraint):
return False
return not self._is_limited_data_type(field) and create_index
def _delete_composed_index(self, model, fields, *args):
"""
MySQL can remove an implicit FK index on a field when that field is
covered by another index like a unique_together. "covered" here means
that the more complex index starts like the simpler one.
http://bugs.mysql.com/bug.php?id=37910 / Django ticket #24757
We check here before removing the [unique|index]_together if we have to
recreate a FK index.
"""
first_field = model._meta.get_field(fields[0])
if first_field.get_internal_type() == 'ForeignKey':
constraint_names = self._constraint_names(model, [first_field.column], index=True)
if not constraint_names:
self.execute(self._create_index_sql(model, [first_field], suffix=""))
return super()._delete_composed_index(model, fields, *args)
def _set_field_new_type_null_status(self, field, new_type):
"""
Keep the null property of the old field. If it has changed, it will be
handled separately.
"""
if field.null:
new_type += " NULL"
else:
new_type += " NOT NULL"
return new_type
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _rename_field_sql(self, table, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._rename_field_sql(table, old_field, new_field, new_type)

View File

@@ -0,0 +1,60 @@
from django.core import checks
from django.db.backends.base.validation import BaseDatabaseValidation
from django.utils.version import get_docs_version
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
issues = super().check(**kwargs)
issues.extend(self._check_sql_mode(**kwargs))
return issues
def _check_sql_mode(self, **kwargs):
with self.connection.cursor() as cursor:
cursor.execute("SELECT @@sql_mode")
sql_mode = cursor.fetchone()
modes = set(sql_mode[0].split(',') if sql_mode else ())
if not (modes & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}):
return [checks.Warning(
"MySQL Strict Mode is not set for database connection '%s'" % self.connection.alias,
hint="MySQL's Strict Mode fixes many data integrity problems in MySQL, "
"such as data truncation upon insertion, by escalating warnings into "
"errors. It is strongly recommended you activate it. See: "
"https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode"
% (get_docs_version(),),
id='mysql.W002',
)]
return []
def check_field_type(self, field, field_type):
"""
MySQL has the following field length restriction:
No character (varchar) fields can have a length exceeding 255
characters if they have a unique index on them.
MySQL doesn't support a database index on some data types.
"""
errors = []
if (field_type.startswith('varchar') and field.unique and
(field.max_length is None or int(field.max_length) > 255)):
errors.append(
checks.Error(
'MySQL does not allow unique CharFields to have a max_length > 255.',
obj=field,
id='mysql.E001',
)
)
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
'%s does not support a database index on %s columns.'
% (self.connection.display_name, field_type),
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id='fields.W162',
)
)
return errors

View File

@@ -0,0 +1,550 @@
"""
Oracle database backend for Django.
Requires cx_Oracle: https://oracle.github.io/python-cx_Oracle/
"""
import datetime
import decimal
import os
import platform
from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import utils
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.encoding import force_bytes, force_str
from django.utils.functional import cached_property
def _setup_environment(environ):
# Cygwin requires some special voodoo to set the environment variables
# properly so that Oracle will see them.
if platform.system().upper().startswith('CYGWIN'):
try:
import ctypes
except ImportError as e:
raise ImproperlyConfigured("Error loading ctypes: %s; "
"the Oracle backend requires ctypes to "
"operate correctly under Cygwin." % e)
kernel32 = ctypes.CDLL('kernel32')
for name, value in environ:
kernel32.SetEnvironmentVariableA(name, value)
else:
os.environ.update(environ)
_setup_environment([
# Oracle takes client-side character set encoding from the environment.
('NLS_LANG', '.AL32UTF8'),
# This prevents unicode from getting mangled by getting encoded into the
# potentially non-unicode database character set.
('ORA_NCHAR_LITERAL_REPLACE', 'TRUE'),
])
try:
import cx_Oracle as Database
except ImportError as e:
raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
# Some of these import cx_Oracle, so import them after checking if it's installed.
from .client import DatabaseClient # NOQA isort:skip
from .creation import DatabaseCreation # NOQA isort:skip
from .features import DatabaseFeatures # NOQA isort:skip
from .introspection import DatabaseIntrospection # NOQA isort:skip
from .operations import DatabaseOperations # NOQA isort:skip
from .schema import DatabaseSchemaEditor # NOQA isort:skip
from .utils import Oracle_datetime # NOQA isort:skip
from .validation import DatabaseValidation # NOQA isort:skip
@contextmanager
def wrap_oracle_errors():
try:
yield
except Database.DatabaseError as e:
# cx_Oracle raises a cx_Oracle.DatabaseError exception with the
# following attributes and values:
# code = 2091
# message = 'ORA-02091: transaction rolled back
# 'ORA-02291: integrity constraint (TEST_DJANGOTEST.SYS
# _C00102056) violated - parent key not found'
# Convert that case to Django's IntegrityError exception.
x = e.args[0]
if hasattr(x, 'code') and hasattr(x, 'message') and x.code == 2091 and 'ORA-02291' in x.message:
raise utils.IntegrityError(*tuple(e.args))
raise
class _UninitializedOperatorsDescriptor:
def __get__(self, instance, cls=None):
# If connection.operators is looked up before a connection has been
# created, transparently initialize connection.operators to avert an
# AttributeError.
if instance is None:
raise AttributeError("operators not available as class attribute")
# Creating a cursor will initialize the operators.
instance.cursor().close()
return instance.__dict__['operators']
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'oracle'
display_name = 'Oracle'
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
data_types = {
'AutoField': 'NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'BigAutoField': 'NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'BinaryField': 'BLOB',
'BooleanField': 'NUMBER(1)',
'CharField': 'NVARCHAR2(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'INTERVAL DAY(9) TO SECOND(6)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'BigIntegerField': 'NUMBER(19)',
'IPAddressField': 'VARCHAR2(15)',
'GenericIPAddressField': 'VARCHAR2(39)',
'NullBooleanField': 'NUMBER(1)',
'OneToOneField': 'NUMBER(11)',
'PositiveIntegerField': 'NUMBER(11)',
'PositiveSmallIntegerField': 'NUMBER(11)',
'SlugField': 'NVARCHAR2(%(max_length)s)',
'SmallAutoField': 'NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'UUIDField': 'VARCHAR2(32)',
}
data_type_check_constraints = {
'BooleanField': '%(qn_column)s IN (0,1)',
'NullBooleanField': '%(qn_column)s IN (0,1)',
'PositiveIntegerField': '%(qn_column)s >= 0',
'PositiveSmallIntegerField': '%(qn_column)s >= 0',
}
# Oracle doesn't support a database index on these columns.
_limited_data_types = ('clob', 'nclob', 'blob')
operators = _UninitializedOperatorsDescriptor()
_standard_operators = {
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
}
_likec_operators = {
**_standard_operators,
'contains': "LIKEC %s ESCAPE '\\'",
'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
'startswith': "LIKEC %s ESCAPE '\\'",
'endswith': "LIKEC %s ESCAPE '\\'",
'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, %, _)
# should be escaped on the database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
_pattern_ops = {
'contains': "'%%' || {} || '%%'",
'icontains': "'%%' || UPPER({}) || '%%'",
'startswith': "{} || '%%'",
'istartswith': "UPPER({}) || '%%'",
'endswith': "'%%' || {}",
'iendswith': "'%%' || UPPER({})",
}
_standard_pattern_ops = {k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
for k, v in _pattern_ops.items()}
_likec_pattern_ops = {k: "LIKEC " + v + " ESCAPE '\\'"
for k, v in _pattern_ops.items()}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
self.features.can_return_columns_from_insert = use_returning_into
def _dsn(self):
settings_dict = self.settings_dict
if not settings_dict['HOST'].strip():
settings_dict['HOST'] = 'localhost'
if settings_dict['PORT']:
return Database.makedsn(settings_dict['HOST'], int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']
def _connect_string(self):
return '%s/"%s"@%s' % (self.settings_dict['USER'], self.settings_dict['PASSWORD'], self._dsn())
def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
del conn_params['use_returning_into']
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
return Database.connect(
user=self.settings_dict['USER'],
password=self.settings_dict['PASSWORD'],
dsn=self._dsn(),
**conn_params,
)
def init_connection_state(self):
cursor = self.create_cursor()
# Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
# these are set in single statement it isn't clear what is supposed
# to happen.
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
# Set Oracle date to ANSI date format. This only needs to execute
# once when we create a new connection. We also set the Territory
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
# TO_CHAR().
cursor.execute(
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" +
(" TIME_ZONE = 'UTC'" if settings.USE_TZ else '')
)
cursor.close()
if 'operators' not in self.__dict__:
# Ticket #14149: Check whether our LIKE implementation will
# work for this connection or we need to fall back on LIKEC.
# This check is performed only once per DatabaseWrapper
# instance per thread, since subsequent connections will use
# the same settings.
cursor = self.create_cursor()
try:
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'],
['X'])
except Database.DatabaseError:
self.operators = self._likec_operators
self.pattern_ops = self._likec_pattern_ops
else:
self.operators = self._standard_operators
self.pattern_ops = self._standard_pattern_ops
cursor.close()
self.connection.stmtcachesize = 20
# Ensure all changes are preserved even when AUTOCOMMIT is False.
if not self.get_autocommit():
self.commit()
@async_unsafe
def create_cursor(self, name=None):
return FormatStylePlaceholderCursor(self.connection)
def _commit(self):
if self.connection is not None:
with wrap_oracle_errors():
return self.connection.commit()
# Oracle doesn't support releasing savepoints. But we fake them when query
# logging is enabled to keep query counts consistent with other backends.
def _savepoint_commit(self, sid):
if self.queries_logged:
self.queries_log.append({
'sql': '-- RELEASE SAVEPOINT %s (faked)' % self.ops.quote_name(sid),
'time': '0.000',
})
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
def check_constraints(self, table_names=None):
"""
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
self.connection.ping()
except Database.Error:
return False
else:
return True
@cached_property
def oracle_version(self):
with self.temporary_connection():
return tuple(int(x) for x in self.connection.version.split('.'))
class OracleParam:
"""
Wrapper object for formatting parameters for Oracle. If the string
representation of the value is large enough (greater than 4000 characters)
the input size needs to be set as CLOB. Alternatively, if the parameter
has an `input_size` attribute, then the value of the `input_size` attribute
will be used instead. Otherwise, no input size will be set for the
parameter when executing the query.
"""
def __init__(self, param, cursor, strings_only=False):
# With raw SQL queries, datetimes can reach this function
# without being converted by DateTimeField.get_db_prep_value.
if settings.USE_TZ and (isinstance(param, datetime.datetime) and
not isinstance(param, Oracle_datetime)):
param = Oracle_datetime.from_datetime(param)
string_size = 0
# Oracle doesn't recognize True and False correctly.
if param is True:
param = 1
elif param is False:
param = 0
if hasattr(param, 'bind_parameter'):
self.force_bytes = param.bind_parameter(cursor)
elif isinstance(param, (Database.Binary, datetime.timedelta)):
self.force_bytes = param
else:
# To transmit to the database, we need Unicode if supported
# To get size right, we must consider bytes.
self.force_bytes = force_str(param, cursor.charset, strings_only)
if isinstance(self.force_bytes, str):
# We could optimize by only converting up to 4000 bytes here
string_size = len(force_bytes(param, cursor.charset, strings_only))
if hasattr(param, 'input_size'):
# If parameter has `input_size` attribute, use that.
self.input_size = param.input_size
elif string_size > 4000:
# Mark any string param greater than 4000 characters as a CLOB.
self.input_size = Database.CLOB
elif isinstance(param, datetime.datetime):
self.input_size = Database.TIMESTAMP
else:
self.input_size = None
class VariableWrapper:
"""
An adapter class for cursor variables that prevents the wrapped object
from being converted into a string when used to instantiate an OracleParam.
This can be used generally for any other object that should be passed into
Cursor.execute as-is.
"""
def __init__(self, var):
self.var = var
def bind_parameter(self, cursor):
return self.var
def __getattr__(self, key):
return getattr(self.var, key)
def __setattr__(self, key, value):
if key == 'var':
self.__dict__[key] = value
else:
setattr(self.var, key, value)
class FormatStylePlaceholderCursor:
"""
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
style. This fixes it -- but note that if you want to use a literal "%s" in
a query, you'll need to use "%%s".
"""
charset = 'utf-8'
def __init__(self, connection):
self.cursor = connection.cursor()
self.cursor.outputtypehandler = self._output_type_handler
@staticmethod
def _output_number_converter(value):
return decimal.Decimal(value) if '.' in value else int(value)
@staticmethod
def _get_decimal_converter(precision, scale):
if scale == 0:
return int
context = decimal.Context(prec=precision)
quantize_value = decimal.Decimal(1).scaleb(-scale)
return lambda v: decimal.Decimal(v).quantize(quantize_value, context=context)
@staticmethod
def _output_type_handler(cursor, name, defaultType, length, precision, scale):
"""
Called for each db column fetched from cursors. Return numbers as the
appropriate Python type.
"""
if defaultType == Database.NUMBER:
if scale == -127:
if precision == 0:
# NUMBER column: decimal-precision floating point.
# This will normally be an integer from a sequence,
# but it could be a decimal value.
outconverter = FormatStylePlaceholderCursor._output_number_converter
else:
# FLOAT column: binary-precision floating point.
# This comes from FloatField columns.
outconverter = float
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntegerField and DecimalField columns.
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(precision, scale)
else:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
# or Decimal based on whether it has a decimal point.
outconverter = FormatStylePlaceholderCursor._output_number_converter
return cursor.var(
Database.STRING,
size=255,
arraysize=cursor.arraysize,
outconverter=outconverter,
)
def _format_params(self, params):
try:
return {k: OracleParam(v, self, True) for k, v in params.items()}
except AttributeError:
return tuple(OracleParam(p, self, True) for p in params)
def _guess_input_sizes(self, params_list):
# Try dict handling; if that fails, treat as sequence
if hasattr(params_list[0], 'keys'):
sizes = {}
for params in params_list:
for k, value in params.items():
if value.input_size:
sizes[k] = value.input_size
if sizes:
self.setinputsizes(**sizes)
else:
# It's not a list of dicts; it's a list of sequences
sizes = [None] * len(params_list[0])
for params in params_list:
for i, value in enumerate(params):
if value.input_size:
sizes[i] = value.input_size
if sizes:
self.setinputsizes(*sizes)
def _param_generator(self, params):
# Try dict handling; if that fails, treat as sequence
if hasattr(params, 'items'):
return {k: v.force_bytes for k, v in params.items()}
else:
return [p.force_bytes for p in params]
def _fix_for_params(self, query, params, unify_by_values=False):
# cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
# it does want a trailing ';' but not a trailing '/'. However, these
# characters must be included in the original query in case the query
# is being passed to SQL*Plus.
if query.endswith(';') or query.endswith('/'):
query = query[:-1]
if params is None:
params = []
elif hasattr(params, 'keys'):
# Handle params as dict
args = {k: ":%s" % k for k in params}
query = query % args
elif unify_by_values and params:
# Handle params as a dict with unified query parameters by their
# values. It can be used only in single query execute() because
# executemany() shares the formatted query with each of the params
# list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75]
# params_dict = {0.75: ':arg0', 2: ':arg1', 'sth': ':arg2'}
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
params_dict = {
param: ':arg%d' % i
for i, param in enumerate(dict.fromkeys(params))
}
args = [params_dict[param] for param in params]
params = {value: key for key, value in params_dict.items()}
query = query % tuple(args)
else:
# Handle params as sequence
args = [(':arg%d' % i) for i in range(len(params))]
query = query % tuple(args)
return query, self._format_params(params)
def execute(self, query, params=None):
query, params = self._fix_for_params(query, params, unify_by_values=True)
self._guess_input_sizes([params])
with wrap_oracle_errors():
return self.cursor.execute(query, self._param_generator(params))
def executemany(self, query, params=None):
if not params:
# No params given, nothing to do
return None
# uniform treatment for sequences and iterables
params_iter = iter(params)
query, firstparams = self._fix_for_params(query, next(params_iter))
# we build a list of formatted params; as we're going to traverse it
# more than once, we can't make it lazy by using a generator
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
self._guess_input_sizes(formatted)
with wrap_oracle_errors():
return self.cursor.executemany(query, [self._param_generator(p) for p in formatted])
def close(self):
try:
self.cursor.close()
except Database.InterfaceError:
# already closed
pass
def var(self, *args):
return VariableWrapper(self.cursor.var(*args))
def arrayvar(self, *args):
return VariableWrapper(self.cursor.arrayvar(*args))
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)

View File

@@ -0,0 +1,17 @@
import shutil
import subprocess
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlplus'
wrapper_name = 'rlwrap'
def runshell(self):
conn_string = self.connection._connect_string()
args = [self.executable_name, "-L", conn_string]
wrapper_path = shutil.which(self.wrapper_name)
if wrapper_path:
args = [wrapper_path, *args]
subprocess.run(args, check=True)

View File

@@ -0,0 +1,400 @@
import sys
from django.conf import settings
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.utils import DatabaseError
from django.utils.crypto import get_random_string
from django.utils.functional import cached_property
TEST_DATABASE_PREFIX = 'test_'
class DatabaseCreation(BaseDatabaseCreation):
@cached_property
def _maindb_connection(self):
"""
This is analogous to other backends' `_nodb_connection` property,
which allows access to an "administrative" connection which can
be used to manage the test databases.
For Oracle, the only connection that can be used for that purpose
is the main (non-test) connection.
"""
settings_dict = settings.DATABASES[self.connection.alias]
user = settings_dict.get('SAVED_USER') or settings_dict['USER']
password = settings_dict.get('SAVED_PASSWORD') or settings_dict['PASSWORD']
settings_dict = {**settings_dict, 'USER': user, 'PASSWORD': password}
DatabaseWrapper = type(self.connection)
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params()
with self._maindb_connection.cursor() as cursor:
if self._test_database_create():
try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e:
if 'ORA-01543' not in str(e):
# All errors except "tablespace already exists" cancel tests
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test database, %s, already exists. "
"Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
if verbosity >= 1:
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
try:
self._execute_test_db_destruction(cursor, parameters, verbosity)
except DatabaseError as e:
if 'ORA-29857' in str(e):
self._handle_objects_preventing_db_destruction(cursor, parameters,
verbosity, autoclobber)
else:
# Ran into a database error that isn't about leftover objects in the tablespace
self.log('Got an error destroying the old test database: %s' % e)
sys.exit(2)
except Exception as e:
self.log('Got an error destroying the old test database: %s' % e)
sys.exit(2)
try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
if self._test_user_create():
if verbosity >= 1:
self.log('Creating test user...')
try:
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
if 'ORA-01920' not in str(e):
# All errors except "user already exists" cancel tests
self.log('Got an error creating the test user: %s' % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log('Destroying old test user...')
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
self.log('Creating test user...')
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
self.log('Got an error recreating the test user: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME']
def _switch_to_test_user(self, parameters):
"""
Switch to the user that's used for creating the test database.
Oracle doesn't have the concept of separate databases under the same
user, so a separate user is used; see _create_test_db(). The main user
is also needed for cleanup when testing is completed, so save its
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
"""
real_settings = settings.DATABASES[self.connection.alias]
real_settings['SAVED_USER'] = self.connection.settings_dict['SAVED_USER'] = \
self.connection.settings_dict['USER']
real_settings['SAVED_PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] = \
self.connection.settings_dict['PASSWORD']
real_test_settings = real_settings['TEST']
test_settings = self.connection.settings_dict['TEST']
real_test_settings['USER'] = real_settings['USER'] = test_settings['USER'] = \
self.connection.settings_dict['USER'] = parameters['user']
real_settings['PASSWORD'] = self.connection.settings_dict['PASSWORD'] = parameters['password']
def set_as_test_mirror(self, primary_settings_dict):
"""
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
self.connection.settings_dict['USER'] = primary_settings_dict['USER']
self.connection.settings_dict['PASSWORD'] = primary_settings_dict['PASSWORD']
def _handle_objects_preventing_db_destruction(self, cursor, parameters, verbosity, autoclobber):
# There are objects in the test tablespace which prevent dropping it
# The easy fix is to drop the test user -- but are we allowed to do so?
self.log(
'There are objects in the old test database which prevent its destruction.\n'
'If they belong to the test user, deleting the user will allow the test '
'database to be recreated.\n'
'Otherwise, you will need to find and remove each of these objects, '
'or use a different tablespace.\n'
)
if self._test_user_create():
if not autoclobber:
confirm = input("Type 'yes' to delete user %s: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log('Destroying old test user...')
self._destroy_test_user(cursor, parameters, verbosity)
except Exception as e:
self.log('Got an error destroying the test user: %s' % e)
sys.exit(2)
try:
if verbosity >= 1:
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
self._execute_test_db_destruction(cursor, parameters, verbosity)
except Exception as e:
self.log('Got an error destroying the test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled -- test database cannot be recreated.')
sys.exit(1)
else:
self.log("Django is configured to use pre-existing test user '%s',"
" and will not attempt to delete it." % parameters['user'])
self.log('Tests cancelled -- test database cannot be recreated.')
sys.exit(1)
def _destroy_test_db(self, test_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
"""
self.connection.settings_dict['USER'] = self.connection.settings_dict['SAVED_USER']
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close()
parameters = self._get_test_db_params()
with self._maindb_connection.cursor() as cursor:
if self._test_user_create():
if verbosity >= 1:
self.log('Destroying test user...')
self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create():
if verbosity >= 1:
self.log('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity)
self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
self.log('_create_test_db(): dbname = %s' % parameters['user'])
if self._test_database_oracle_managed_files():
statements = [
"""
CREATE TABLESPACE %(tblspace)s
DATAFILE SIZE %(size)s
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
""",
"""
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE SIZE %(size_tmp)s
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
""",
]
else:
statements = [
"""
CREATE TABLESPACE %(tblspace)s
DATAFILE '%(datafile)s' SIZE %(size)s REUSE
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
""",
"""
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE '%(datafile_tmp)s' SIZE %(size_tmp)s REUSE
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
""",
]
# Ignore "tablespace already exists" error when keepdb is on.
acceptable_ora_err = 'ORA-01543' if keepdb else None
self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
self.log('_create_test_user(): username = %s' % parameters['user'])
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY "%(password)s"
DEFAULT TABLESPACE %(tblspace)s
TEMPORARY TABLESPACE %(tblspace_temp)s
QUOTA UNLIMITED ON %(tblspace)s
""",
"""GRANT CREATE SESSION,
CREATE TABLE,
CREATE SEQUENCE,
CREATE PROCEDURE,
CREATE TRIGGER
TO %(user)s""",
]
# Ignore "user already exists" error when keepdb is on
acceptable_ora_err = 'ORA-01920' if keepdb else None
success = self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
# If the password was randomly generated, change the user accordingly.
if not success and self._test_settings_get('PASSWORD') is None:
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
self._execute_statements(cursor, [set_password], parameters, verbosity)
# Most test suites can be run without "create view" and
# "create materialized view" privileges. But some need it.
for object_type in ('VIEW', 'MATERIALIZED VIEW'):
extra = 'GRANT CREATE %(object_type)s TO %(user)s'
parameters['object_type'] = object_type
success = self._execute_allow_fail_statements(cursor, [extra], parameters, verbosity, 'ORA-01031')
if not success and verbosity >= 2:
self.log('Failed to grant CREATE %s permission to test user. This may be ok.' % object_type)
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
if verbosity >= 2:
self.log('_execute_test_db_destruction(): dbname=%s' % parameters['user'])
statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(self, cursor, parameters, verbosity):
if verbosity >= 2:
self.log('_destroy_test_user(): user=%s' % parameters['user'])
self.log('Be patient. This can take some time...')
statements = [
'DROP USER %(user)s CASCADE',
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False):
for template in statements:
stmt = template % parameters
if verbosity >= 2:
print(stmt)
try:
cursor.execute(stmt)
except Exception as err:
if (not allow_quiet_fail) or verbosity >= 2:
self.log('Failed (%s)' % (err))
raise
def _execute_allow_fail_statements(self, cursor, statements, parameters, verbosity, acceptable_ora_err):
"""
Execute statements which are allowed to fail silently if the Oracle
error code given by `acceptable_ora_err` is raised. Return True if the
statements execute without an exception, or False otherwise.
"""
try:
# Statement can fail when acceptable_ora_err is not None
allow_quiet_fail = acceptable_ora_err is not None and len(acceptable_ora_err) > 0
self._execute_statements(cursor, statements, parameters, verbosity, allow_quiet_fail=allow_quiet_fail)
return True
except DatabaseError as err:
description = str(err)
if acceptable_ora_err is None or acceptable_ora_err not in description:
raise
return False
def _get_test_db_params(self):
return {
'dbname': self._test_database_name(),
'user': self._test_database_user(),
'password': self._test_database_passwd(),
'tblspace': self._test_database_tblspace(),
'tblspace_temp': self._test_database_tblspace_tmp(),
'datafile': self._test_database_tblspace_datafile(),
'datafile_tmp': self._test_database_tblspace_tmp_datafile(),
'maxsize': self._test_database_tblspace_maxsize(),
'maxsize_tmp': self._test_database_tblspace_tmp_maxsize(),
'size': self._test_database_tblspace_size(),
'size_tmp': self._test_database_tblspace_tmp_size(),
'extsize': self._test_database_tblspace_extsize(),
'extsize_tmp': self._test_database_tblspace_tmp_extsize(),
}
def _test_settings_get(self, key, default=None, prefixed=None):
"""
Return a value from the test settings dict, or a given default, or a
prefixed entry from the main settings dict.
"""
settings_dict = self.connection.settings_dict
val = settings_dict['TEST'].get(key, default)
if val is None and prefixed:
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
return val
def _test_database_name(self):
return self._test_settings_get('NAME', prefixed='NAME')
def _test_database_create(self):
return self._test_settings_get('CREATE_DB', default=True)
def _test_user_create(self):
return self._test_settings_get('CREATE_USER', default=True)
def _test_database_user(self):
return self._test_settings_get('USER', prefixed='USER')
def _test_database_passwd(self):
password = self._test_settings_get('PASSWORD')
if password is None and self._test_user_create():
# Oracle passwords are limited to 30 chars and can't contain symbols.
password = get_random_string(length=30)
return password
def _test_database_tblspace(self):
return self._test_settings_get('TBLSPACE', prefixed='USER')
def _test_database_tblspace_tmp(self):
settings_dict = self.connection.settings_dict
return settings_dict['TEST'].get('TBLSPACE_TMP',
TEST_DATABASE_PREFIX + settings_dict['USER'] + '_temp')
def _test_database_tblspace_datafile(self):
tblspace = '%s.dbf' % self._test_database_tblspace()
return self._test_settings_get('DATAFILE', default=tblspace)
def _test_database_tblspace_tmp_datafile(self):
tblspace = '%s.dbf' % self._test_database_tblspace_tmp()
return self._test_settings_get('DATAFILE_TMP', default=tblspace)
def _test_database_tblspace_maxsize(self):
return self._test_settings_get('DATAFILE_MAXSIZE', default='500M')
def _test_database_tblspace_tmp_maxsize(self):
return self._test_settings_get('DATAFILE_TMP_MAXSIZE', default='500M')
def _test_database_tblspace_size(self):
return self._test_settings_get('DATAFILE_SIZE', default='50M')
def _test_database_tblspace_tmp_size(self):
return self._test_settings_get('DATAFILE_TMP_SIZE', default='50M')
def _test_database_tblspace_extsize(self):
return self._test_settings_get('DATAFILE_EXTSIZE', default='25M')
def _test_database_tblspace_tmp_extsize(self):
return self._test_settings_get('DATAFILE_TMP_EXTSIZE', default='25M')
def _test_database_oracle_managed_files(self):
return self._test_settings_get('ORACLE_MANAGED_FILES', default=False)
def _get_test_db_name(self):
"""
Return the 'production' DB name to get the test DB creation machinery
to work. This isn't a great deal in this case because DB names as
handled by Django don't have real counterparts in Oracle.
"""
return self.connection.settings_dict['NAME']
def test_db_signature(self):
settings_dict = self.connection.settings_dict
return (
settings_dict['HOST'],
settings_dict['PORT'],
settings_dict['ENGINE'],
settings_dict['NAME'],
self._test_database_user(),
)

View File

@@ -0,0 +1,61 @@
from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.utils import InterfaceError
class DatabaseFeatures(BaseDatabaseFeatures):
interprets_empty_strings_as_nulls = True
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_skip_locked = True
has_select_for_update_of = True
select_for_update_of_column = True
can_return_columns_from_insert = True
can_introspect_autofield = True
supports_subqueries_in_group_by = False
supports_transactions = True
supports_timezones = False
has_native_duration_field = True
can_defer_constraint_checks = True
supports_partially_nullable_unique_constraints = False
truncates_names = True
supports_tablespaces = True
supports_sequence_reset = False
can_introspect_materialized_views = True
can_introspect_time_field = False
atomic_transactions = False
supports_combined_alters = False
nulls_order_largest = True
requires_literal_defaults = True
closed_cursor_error_class = InterfaceError
bare_select_suffix = " FROM DUAL"
# select for update with limit can be achieved on Oracle, but not with the current backend.
supports_select_for_update_with_limit = False
supports_temporal_subtraction = True
# Oracle doesn't ignore quoted identifiers case but the current backend
# does by uppercasing all identifiers.
ignores_table_name_case = True
supports_index_on_text_field = False
has_case_insensitive_like = False
create_test_procedure_without_params_sql = """
CREATE PROCEDURE "TEST_PROCEDURE" AS
V_I INTEGER;
BEGIN
V_I := 1;
END;
"""
create_test_procedure_with_int_param_sql = """
CREATE PROCEDURE "TEST_PROCEDURE" (P_I INTEGER) AS
V_I INTEGER;
BEGIN
V_I := P_I;
END;
"""
supports_callproc_kwargs = True
supports_over_clause = True
supports_frame_range_fixed_distance = True
supports_ignore_conflicts = False
max_query_params = 2**16 - 1
supports_partial_indexes = False
supports_slicing_ordering_in_compound = True
allows_multiple_constraints_on_same_fields = False
supports_boolean_expr_in_select_clause = False

View File

@@ -0,0 +1,22 @@
from django.db.models import DecimalField, DurationField, Func
class IntervalToSeconds(Func):
function = ''
template = """
EXTRACT(day from %(expressions)s) * 86400 +
EXTRACT(hour from %(expressions)s) * 3600 +
EXTRACT(minute from %(expressions)s) * 60 +
EXTRACT(second from %(expressions)s)
"""
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(expression, output_field=output_field or DecimalField(), **extra)
class SecondsToInterval(Func):
function = 'NUMTODSINTERVAL'
template = "%(function)s(%(expressions)s, 'SECOND')"
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(expression, output_field=output_field or DurationField(), **extra)

View File

@@ -0,0 +1,294 @@
from collections import namedtuple
import cx_Oracle
from django.db import models
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield',))
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type objects to Django Field types.
data_types_reverse = {
cx_Oracle.BLOB: 'BinaryField',
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.FIXED_NCHAR: 'CharField',
cx_Oracle.INTERVAL: 'DurationField',
cx_Oracle.NATIVE_FLOAT: 'FloatField',
cx_Oracle.NCHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}
cache_bust_counter = 1
def get_field_type(self, data_type, description):
if data_type == cx_Oracle.NUMBER:
precision, scale = description[4:6]
if scale == 0:
if precision > 11:
return 'BigAutoField' if description.is_autofield else 'BigIntegerField'
elif 1 < precision < 6 and description.is_autofield:
return 'SmallAutoField'
elif precision == 1:
return 'BooleanField'
elif description.is_autofield:
return 'AutoField'
else:
return 'IntegerField'
elif scale == -127:
return 'FloatField'
return super().get_field_type(data_type, description)
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("""
SELECT table_name, 't'
FROM user_tables
WHERE
NOT EXISTS (
SELECT 1
FROM user_mviews
WHERE user_mviews.mview_name = user_tables.table_name
)
UNION ALL
SELECT view_name, 'v' FROM user_views
UNION ALL
SELECT mview_name, 'v' FROM user_mviews
""")
return [TableInfo(self.identifier_converter(row[0]), row[1]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
# user_tab_columns gives data default for columns
cursor.execute("""
SELECT
column_name,
data_default,
CASE
WHEN char_used IS NULL THEN data_length
ELSE char_length
END as internal_size,
CASE
WHEN identity_column = 'YES' THEN 1
ELSE 0
END as is_autofield
FROM user_tab_cols
WHERE table_name = UPPER(%s)""", [table_name])
field_map = {
column: (internal_size, default if default != 'NULL' else None, is_autofield)
for column, default, internal_size, is_autofield in cursor.fetchall()
}
self.cache_bust_counter += 1
cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
self.connection.ops.quote_name(table_name),
self.cache_bust_counter))
description = []
for desc in cursor.description:
name = desc[0]
internal_size, default, is_autofield = field_map[name]
name = name % {} # cx_Oracle, for some reason, doubles percent signs.
description.append(FieldInfo(
self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0,
desc[5] or 0, *desc[6:], default, is_autofield,
))
return description
def identifier_converter(self, name):
"""Identifier comparison is case insensitive under Oracle."""
return name.lower()
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute("""
SELECT
user_tab_identity_cols.sequence_name,
user_tab_identity_cols.column_name
FROM
user_tab_identity_cols,
user_constraints,
user_cons_columns cols
WHERE
user_constraints.constraint_name = cols.constraint_name
AND user_constraints.table_name = user_tab_identity_cols.table_name
AND cols.column_name = user_tab_identity_cols.column_name
AND user_constraints.constraint_type = 'P'
AND user_tab_identity_cols.table_name = UPPER(%s)
""", [table_name])
# Oracle allows only one identity column per table.
row = cursor.fetchone()
if row:
return [{
'name': self.identifier_converter(row[0]),
'table': self.identifier_converter(table_name),
'column': self.identifier_converter(row[1]),
}]
# To keep backward compatibility for AutoFields that aren't Oracle
# identity columns.
for f in table_fields:
if isinstance(f, models.AutoField):
return [{'table': table_name, 'column': f.column}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
table_name = table_name.upper()
cursor.execute("""
SELECT ca.column_name, cb.table_name, cb.column_name
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
WHERE user_constraints.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
ca.position = cb.position""", [table_name])
return {
self.identifier_converter(field_name): (
self.identifier_converter(rel_field_name),
self.identifier_converter(rel_table_name),
) for field_name, rel_table_name, rel_field_name in cursor.fetchall()
}
def get_key_columns(self, cursor, table_name):
cursor.execute("""
SELECT ccol.column_name, rcol.table_name AS referenced_table, rcol.column_name AS referenced_column
FROM user_constraints c
JOIN user_cons_columns ccol
ON ccol.constraint_name = c.constraint_name
JOIN user_cons_columns rcol
ON rcol.constraint_name = c.r_constraint_name
WHERE c.table_name = %s AND c.constraint_type = 'R'""", [table_name.upper()])
return [
tuple(self.identifier_converter(cell) for cell in row)
for row in cursor.fetchall()
]
def get_primary_key_column(self, cursor, table_name):
cursor.execute("""
SELECT
cols.column_name
FROM
user_constraints,
user_cons_columns cols
WHERE
user_constraints.constraint_name = cols.constraint_name AND
user_constraints.constraint_type = 'P' AND
user_constraints.table_name = UPPER(%s) AND
cols.position = 1
""", [table_name])
row = cursor.fetchone()
return self.identifier_converter(row[0]) if row else None
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Loop over the constraints, getting PKs, uniques, and checks
cursor.execute("""
SELECT
user_constraints.constraint_name,
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
CASE user_constraints.constraint_type
WHEN 'P' THEN 1
ELSE 0
END AS is_primary_key,
CASE
WHEN user_constraints.constraint_type IN ('P', 'U') THEN 1
ELSE 0
END AS is_unique,
CASE user_constraints.constraint_type
WHEN 'C' THEN 1
ELSE 0
END AS is_check_constraint
FROM
user_constraints
LEFT OUTER JOIN
user_cons_columns cols ON user_constraints.constraint_name = cols.constraint_name
WHERE
user_constraints.constraint_type = ANY('P', 'U', 'C')
AND user_constraints.table_name = UPPER(%s)
GROUP BY user_constraints.constraint_name, user_constraints.constraint_type
""", [table_name])
for constraint, columns, pk, unique, check in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
'columns': columns.split(','),
'primary_key': pk,
'unique': unique,
'foreign_key': None,
'check': check,
'index': unique, # All uniques come with an index
}
# Foreign key constraints
cursor.execute("""
SELECT
cons.constraint_name,
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
LOWER(rcols.table_name),
LOWER(rcols.column_name)
FROM
user_constraints cons
INNER JOIN
user_cons_columns rcols ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
LEFT OUTER JOIN
user_cons_columns cols ON cons.constraint_name = cols.constraint_name
WHERE
cons.constraint_type = 'R' AND
cons.table_name = UPPER(%s)
GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name
""", [table_name])
for constraint, columns, other_table, other_column in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
'primary_key': False,
'unique': False,
'foreign_key': (other_table, other_column),
'check': False,
'index': False,
'columns': columns.split(','),
}
# Now get indexes
cursor.execute("""
SELECT
ind.index_name,
LOWER(ind.index_type),
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.column_position),
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
FROM
user_ind_columns cols, user_indexes ind
WHERE
cols.table_name = UPPER(%s) AND
NOT EXISTS (
SELECT 1
FROM user_constraints cons
WHERE ind.index_name = cons.index_name
) AND cols.index_name = ind.index_name
GROUP BY ind.index_name, ind.index_type
""", [table_name])
for constraint, type_, columns, orders in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
'primary_key': False,
'unique': False,
'foreign_key': None,
'check': False,
'index': True,
'type': 'idx' if type_ == 'normal' else type_,
'columns': columns.split(','),
'orders': orders.split(','),
}
return constraints

View File

@@ -0,0 +1,631 @@
import datetime
import re
import uuid
from functools import lru_cache
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import strip_quotes, truncate_name
from django.db.models.expressions import Exists, ExpressionWrapper, RawSQL
from django.db.models.query_utils import Q
from django.db.utils import DatabaseError
from django.utils import timezone
from django.utils.encoding import force_bytes, force_str
from django.utils.functional import cached_property
from .base import Database
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations):
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
# SmallAutoField, to preserve backward compatibility.
integer_field_ranges = {
'SmallIntegerField': (-99999999999, 99999999999),
'IntegerField': (-99999999999, 99999999999),
'BigIntegerField': (-9999999999999999999, 9999999999999999999),
'PositiveSmallIntegerField': (0, 99999999999),
'PositiveIntegerField': (0, 99999999999),
'SmallAutoField': (-99999, 99999),
'AutoField': (-99999999999, 99999999999),
'BigAutoField': (-9999999999999999999, 9999999999999999999),
}
set_operators = {**BaseDatabaseOperations.set_operators, 'difference': 'MINUS'}
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
_sequence_reset_sql = """
DECLARE
table_value integer;
seq_value integer;
seq_name user_tab_identity_cols.sequence_name%%TYPE;
BEGIN
BEGIN
SELECT sequence_name INTO seq_name FROM user_tab_identity_cols
WHERE table_name = '%(table_name)s' AND
column_name = '%(column_name)s';
EXCEPTION WHEN NO_DATA_FOUND THEN
seq_name := '%(no_autofield_sequence_name)s';
END;
SELECT NVL(MAX(%(column)s), 0) INTO table_value FROM %(table)s;
SELECT NVL(last_number - cache_size, 0) INTO seq_value FROM user_sequences
WHERE sequence_name = seq_name;
WHILE table_value > seq_value LOOP
EXECUTE IMMEDIATE 'SELECT "'||seq_name||'".nextval FROM DUAL'
INTO seq_value;
END LOOP;
END;
/"""
# Oracle doesn't support string without precision; use the max string size.
cast_char_field_without_max_length = 'NVARCHAR2(2000)'
cast_data_types = {
'AutoField': 'NUMBER(11)',
'BigAutoField': 'NUMBER(19)',
'SmallAutoField': 'NUMBER(5)',
'TextField': cast_char_field_without_max_length,
}
def cache_key_culling_sql(self):
return 'SELECT cache_key FROM %s ORDER BY cache_key OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY'
def date_extract_sql(self, lookup_type, field_name):
if lookup_type == 'week_day':
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
return "TO_CHAR(%s, 'D')" % field_name
elif lookup_type == 'week':
# IW = ISO week number
return "TO_CHAR(%s, 'IW')" % field_name
elif lookup_type == 'quarter':
return "TO_CHAR(%s, 'Q')" % field_name
elif lookup_type == 'iso_year':
return "TO_CHAR(%s, 'IYYY')" % field_name
else:
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name):
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
elif lookup_type == 'quarter':
return "TRUNC(%s, 'Q')" % field_name
elif lookup_type == 'week':
return "TRUNC(%s, 'IW')" % field_name
else:
return "TRUNC(%s)" % field_name
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
# if the time zone name is passed in parameter. Use interpolation instead.
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
# This regexp matches all time zone names from the zoneinfo database.
_tzname_re = re.compile(r'^[\w/:+-]+$')
def _prepare_tzname_delta(self, tzname):
if '+' in tzname:
return tzname[tzname.find('+'):]
elif '-' in tzname:
return tzname[tzname.find('-'):]
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if not settings.USE_TZ:
return field_name
if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname)
# Convert from connection timezone to the local time, returning
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
# TIME ZONE details.
if self.connection.timezone_name != tzname:
return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % (
field_name,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return 'TRUNC(%s)' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
# Since `TimeField` values are stored as TIMESTAMP where only the date
# part is ignored, convert the field to the specified timezone.
return self._convert_field_to_tz(field_name, tzname)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
elif lookup_type == 'quarter':
sql = "TRUNC(%s, 'Q')" % field_name
elif lookup_type == 'week':
sql = "TRUNC(%s, 'IW')" % field_name
elif lookup_type == 'day':
sql = "TRUNC(%s)" % field_name
elif lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == 'minute':
sql = "TRUNC(%s, 'MI')" % field_name
else:
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql
def time_trunc_sql(self, lookup_type, field_name):
# The implementation is similar to `datetime_trunc_sql` as both
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored.
if lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == 'minute':
sql = "TRUNC(%s, 'MI')" % field_name
elif lookup_type == 'second':
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == 'TextField':
converters.append(self.convert_textfield_value)
elif internal_type == 'BinaryField':
converters.append(self.convert_binaryfield_value)
elif internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value)
elif internal_type == 'DateTimeField':
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
# Oracle stores empty strings as null. If the field accepts the empty
# string, undo this to adhere to the Django convention of using
# the empty string instead of null.
if expression.field.empty_strings_allowed:
converters.append(
self.convert_empty_bytes
if internal_type == 'BinaryField' else
self.convert_empty_string
)
return converters
def convert_textfield_value(self, value, expression, connection):
if isinstance(value, Database.LOB):
value = value.read()
return value
def convert_binaryfield_value(self, value, expression, connection):
if isinstance(value, Database.LOB):
value = force_bytes(value.read())
return value
def convert_booleanfield_value(self, value, expression, connection):
if value in (0, 1):
value = bool(value)
return value
# cx_Oracle always returns datetime.datetime objects for
# DATE and TIMESTAMP columns, but Django wants to see a
# python datetime.date, .time, or .datetime.
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_datefield_value(self, value, expression, connection):
if isinstance(value, Database.Timestamp):
value = value.date()
return value
def convert_timefield_value(self, value, expression, connection):
if isinstance(value, Database.Timestamp):
value = value.time()
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
@staticmethod
def convert_empty_string(value, expression, connection):
return '' if value is None else value
@staticmethod
def convert_empty_bytes(value, expression, connection):
return b'' if value is None else value
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_columns(self, cursor):
value = cursor._insert_id_var.getvalue()
if value is None or value == []:
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
raise DatabaseError(
'The database did not return a new row id. Probably "ORA-1403: '
'no data found" was raised internally but was hidden by the '
'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).'
)
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
return value if isinstance(value, list) else [value]
def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith('LOB'):
return "DBMS_LOB.SUBSTR(%s)"
else:
return "%s"
def no_limit_value(self):
return None
def limit_offset_sql(self, low_mark, high_mark):
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
return ' '.join(sql for sql in (
('OFFSET %d ROWS' % offset) if offset else None,
('FETCH FIRST %d ROWS ONLY' % fetch) if fetch else None,
) if sql)
def last_executed_query(self, cursor, sql, params):
# https://cx-oracle.readthedocs.io/en/latest/cursor.html#Cursor.statement
# The DB API definition does not define this attribute.
statement = cursor.statement
# Unlike Psycopg's `query` and MySQLdb`'s `_executed`, cx_Oracle's
# `statement` doesn't contain the query parameters. Substitute
# parameters manually.
if isinstance(params, (tuple, list)):
for i, param in enumerate(params):
statement = statement.replace(':arg%d' % i, force_str(param, errors='replace'))
elif isinstance(params, dict):
for key, param in params.items():
statement = statement.replace(':%s' % key, force_str(param, errors='replace'))
return statement
def last_insert_id(self, cursor, table_name, pk_name):
sq_name = self._get_sequence_name(cursor, strip_quotes(table_name), pk_name)
cursor.execute('"%s".currval' % sq_name)
return cursor.fetchone()[0]
def lookup_cast(self, lookup_type, internal_type=None):
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
return "UPPER(%s)"
return "%s"
def max_in_list_size(self):
return 1000
def max_name_length(self):
return 30
def pk_default_value(self):
return "NULL"
def prep_for_iexact_query(self, x):
return x
def process_clob(self, value):
if value is None:
return ''
return value.read()
def quote_name(self, name):
# SQL92 requires delimited (quoted) names to be case-sensitive. When
# not quoted, Oracle has case-insensitive behavior for identifiers, but
# always defaults to uppercase.
# We simplify things by making Oracle identifiers always uppercase.
if not name.startswith('"') and not name.endswith('"'):
name = '"%s"' % truncate_name(name.upper(), self.max_name_length())
# Oracle puts the query text into a (query % args) construct, so % signs
# in names need to be escaped. The '%%' will be collapsed back to '%' at
# that stage so we aren't really making the name longer here.
name = name.replace('%', '%%')
return name.upper()
def random_function_sql(self):
return "DBMS_RANDOM.RANDOM"
def regex_lookup(self, lookup_type):
if lookup_type == 'regex':
match_option = "'c'"
else:
match_option = "'i'"
return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
def return_insert_columns(self, fields):
if not fields:
return '', ()
sql = 'RETURNING %s.%s INTO %%s' % (
self.quote_name(fields[0].model._meta.db_table),
self.quote_name(fields[0].column),
)
return sql, (InsertVar(fields[0]),)
def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor:
if recursive:
cursor.execute("""
SELECT
user_tables.table_name, rcons.constraint_name
FROM
user_tables
JOIN
user_constraints cons
ON (user_tables.table_name = cons.table_name AND cons.constraint_type = ANY('P', 'U'))
LEFT JOIN
user_constraints rcons
ON (user_tables.table_name = rcons.table_name AND rcons.constraint_type = 'R')
START WITH user_tables.table_name = UPPER(%s)
CONNECT BY NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
GROUP BY
user_tables.table_name, rcons.constraint_name
HAVING user_tables.table_name != UPPER(%s)
ORDER BY MAX(level) DESC
""", (table_name, table_name))
else:
cursor.execute("""
SELECT
cons.table_name, cons.constraint_name
FROM
user_constraints cons
WHERE
cons.constraint_type = 'R'
AND cons.table_name = UPPER(%s)
""", (table_name,))
return cursor.fetchall()
@cached_property
def _foreign_key_constraints(self):
# 512 is large enough to fit the ~330 tables (as of this writing) in
# Django's test suite.
return lru_cache(maxsize=512)(self.__foreign_key_constraints)
def sql_flush(self, style, tables, sequences, allow_cascade=False):
if tables:
truncated_tables = {table.upper() for table in tables}
constraints = set()
# Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE
# foreign keys which Django doesn't define. Emulate the
# PostgreSQL behavior which truncates all dependent tables by
# manually retrieving all foreign key constraints and resolving
# dependencies.
for table in tables:
for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade):
if allow_cascade:
truncated_tables.add(foreign_table)
constraints.add((foreign_table, constraint))
sql = [
"%s %s %s %s %s %s %s %s;" % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD('DISABLE'),
style.SQL_KEYWORD('CONSTRAINT'),
style.SQL_FIELD(self.quote_name(constraint)),
style.SQL_KEYWORD('KEEP'),
style.SQL_KEYWORD('INDEX'),
) for table, constraint in constraints
] + [
"%s %s %s;" % (
style.SQL_KEYWORD('TRUNCATE'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
) for table in truncated_tables
] + [
"%s %s %s %s %s %s;" % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD('ENABLE'),
style.SQL_KEYWORD('CONSTRAINT'),
style.SQL_FIELD(self.quote_name(constraint)),
) for table, constraint in constraints
]
# Since we've just deleted all the rows, running our sequence
# ALTER code will reset the sequence to 0.
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
else:
return []
def sequence_reset_by_name_sql(self, style, sequences):
sql = []
for sequence_info in sequences:
no_autofield_sequence_name = self._get_no_autofield_sequence_name(sequence_info['table'])
table = self.quote_name(sequence_info['table'])
column = self.quote_name(sequence_info['column'] or 'id')
query = self._sequence_reset_sql % {
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': strip_quotes(column),
}
sql.append(query)
return sql
def sequence_reset_sql(self, style, model_list):
from django.db import models
output = []
query = self._sequence_reset_sql
for model in model_list:
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table)
table = self.quote_name(model._meta.db_table)
column = self.quote_name(f.column)
output.append(query % {
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': strip_quotes(column),
})
# Only one AutoField is allowed per model, so don't
# continue to loop
break
for f in model._meta.many_to_many:
if not f.remote_field.through:
no_autofield_sequence_name = self._get_no_autofield_sequence_name(f.m2m_db_table())
table = self.quote_name(f.m2m_db_table())
column = self.quote_name('id')
output.append(query % {
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': 'ID',
})
return output
def start_transaction_sql(self):
return ''
def tablespace_sql(self, tablespace, inline=False):
if inline:
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
else:
return "TABLESPACE %s" % self.quote_name(tablespace)
def adapt_datefield_value(self, value):
"""
Transform a date value to an object compatible with what is expected
by the backend driver for date columns.
The default implementation transforms the date to text, but that is not
necessary for Oracle.
"""
return value
def adapt_datetimefield_value(self, value):
"""
Transform a datetime value to an object compatible with what is expected
by the backend driver for datetime columns.
If naive datetime is passed assumes that is in UTC. Normally Django
models.DateTimeField makes sure that if USE_TZ is True passed datetime
is timezone aware.
"""
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# cx_Oracle doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError("Oracle backend does not support timezone-aware datetimes when USE_TZ is False.")
return Oracle_datetime.from_datetime(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
if isinstance(value, str):
return datetime.datetime.strptime(value, '%H:%M:%S')
# Oracle doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("Oracle backend does not support timezone-aware times.")
return Oracle_datetime(1900, 1, 1, value.hour, value.minute,
value.second, value.microsecond)
def combine_expression(self, connector, sub_expressions):
lhs, rhs = sub_expressions
if connector == '%%':
return 'MOD(%s)' % ','.join(sub_expressions)
elif connector == '&':
return 'BITAND(%s)' % ','.join(sub_expressions)
elif connector == '|':
return 'BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s' % {'lhs': lhs, 'rhs': rhs}
elif connector == '<<':
return '(%(lhs)s * POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
elif connector == '>>':
return 'FLOOR(%(lhs)s / POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
elif connector == '^':
return 'POWER(%s)' % ','.join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def _get_no_autofield_sequence_name(self, table):
"""
Manually created sequence name to keep backward compatibility for
AutoFields that aren't Oracle identity columns.
"""
name_length = self.max_name_length() - 3
return '%s_SQ' % truncate_name(strip_quotes(table), name_length).upper()
def _get_sequence_name(self, cursor, table, pk_name):
cursor.execute("""
SELECT sequence_name
FROM user_tab_identity_cols
WHERE table_name = UPPER(%s)
AND column_name = UPPER(%s)""", [table, pk_name])
row = cursor.fetchone()
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
def bulk_insert_sql(self, fields, placeholder_rows):
query = []
for row in placeholder_rows:
select = []
for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if fields[i]:
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
placeholder = BulkInsertMapper.types.get(internal_type, '%s') % placeholder
# Add columns aliases to the first select to avoid "ORA-00918:
# column ambiguously defined" when two or more columns in the
# first select have the same value.
if not query:
placeholder = '%s col_%s' % (placeholder, i)
select.append(placeholder)
query.append('SELECT %s FROM DUAL' % ', '.join(select))
# Bulk insert to tables with Oracle identity columns causes Oracle to
# add sequence.nextval to it. Sequence.nextval cannot be used with the
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
return 'SELECT * FROM (%s)' % ' UNION ALL '.join(query)
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), params
return super().subtract_temporals(internal_type, lhs, rhs)
def bulk_batch_size(self, fields, objs):
"""Oracle restricts the number of parameters in a query."""
if fields:
return self.connection.features.max_query_params // len(fields)
return len(objs)
def conditional_expression_supported_in_where_clause(self, expression):
"""
Oracle supports only EXISTS(...) or filters in the WHERE clause, others
must be compared with True.
"""
if isinstance(expression, Exists):
return True
if isinstance(expression, ExpressionWrapper) and isinstance(expression.expression, Q):
return True
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False

View File

@@ -0,0 +1,172 @@
import copy
import datetime
import re
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.utils import DatabaseError
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
sql_alter_column_null = "MODIFY %(column)s NULL"
sql_alter_column_not_null = "MODIFY %(column)s NOT NULL"
sql_alter_column_default = "MODIFY %(column)s DEFAULT %(default)s"
sql_alter_column_no_default = "MODIFY %(column)s DEFAULT NULL"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_create_column_inline_fk = 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
def quote_value(self, value):
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
return "'%s'" % value
elif isinstance(value, str):
return "'%s'" % value.replace("\'", "\'\'").replace('%', '%%')
elif isinstance(value, (bytes, bytearray, memoryview)):
return "'%s'" % value.hex()
elif isinstance(value, bool):
return "1" if value else "0"
else:
return str(value)
def remove_field(self, model, field):
# If the column is an identity column, drop the identity before
# removing the field.
if self._is_identity_column(model._meta.db_table, field.column):
self._drop_identity(model._meta.db_table, field.column)
super().remove_field(model, field)
def delete_model(self, model):
# Run superclass action
super().delete_model(model)
# Clean up manually created sequence.
self.execute("""
DECLARE
i INTEGER;
BEGIN
SELECT COUNT(1) INTO i FROM USER_SEQUENCES
WHERE SEQUENCE_NAME = '%(sq_name)s';
IF i = 1 THEN
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
END IF;
END;
/""" % {'sq_name': self.connection.ops._get_no_autofield_sequence_name(model._meta.db_table)})
def alter_field(self, model, old_field, new_field, strict=False):
try:
super().alter_field(model, old_field, new_field, strict)
except DatabaseError as e:
description = str(e)
# If we're changing type to an unsupported type we need a
# SQLite-ish workaround
if 'ORA-22858' in description or 'ORA-22859' in description:
self._alter_field_type_workaround(model, old_field, new_field)
# If an identity column is changing to a non-numeric type, drop the
# identity first.
elif 'ORA-30675' in description:
self._drop_identity(model._meta.db_table, old_field.column)
self.alter_field(model, old_field, new_field, strict)
# If a primary key column is changing to an identity column, drop
# the primary key first.
elif 'ORA-30673' in description and old_field.primary_key:
self._delete_primary_key(model, strict=True)
self._alter_field_type_workaround(model, old_field, new_field)
else:
raise
def _alter_field_type_workaround(self, model, old_field, new_field):
"""
Oracle refuses to change from some type to other type.
What we need to do instead is:
- Add a nullable version of the desired field with a temporary name. If
the new column is an auto field, then the temporary column can't be
nullable.
- Update the table to transfer values from old to new
- Drop old column
- Rename the new column and possibly drop the nullable property
"""
# Make a new field that's like the new one but with a temporary
# column name.
new_temp_field = copy.deepcopy(new_field)
new_temp_field.null = (new_field.get_internal_type() not in ('AutoField', 'BigAutoField', 'SmallAutoField'))
new_temp_field.column = self._generate_temp_name(new_field.column)
# Add it
self.add_field(model, new_temp_field)
# Explicit data type conversion
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
new_value = self.quote_name(old_field.column)
old_type = old_field.db_type(self.connection)
if re.match('^N?CLOB', old_type):
new_value = "TO_CHAR(%s)" % new_value
old_type = 'VARCHAR2'
if re.match('^N?VARCHAR2', old_type):
new_internal_type = new_field.get_internal_type()
if new_internal_type == 'DateField':
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
elif new_internal_type == 'DateTimeField':
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
elif new_internal_type == 'TimeField':
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
new_value = "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
# Transfer values across
self.execute("UPDATE %s set %s=%s" % (
self.quote_name(model._meta.db_table),
self.quote_name(new_temp_field.column),
new_value,
))
# Drop the old field
self.remove_field(model, old_field)
# Rename and possibly make the new field NOT NULL
super().alter_field(model, new_temp_field, new_field)
def normalize_name(self, name):
"""
Get the properly shortened and uppercased identifier as returned by
quote_name() but without the quotes.
"""
nn = self.quote_name(name)
if nn[0] == '"' and nn[-1] == '"':
nn = nn[1:-1]
return nn
def _generate_temp_name(self, for_name):
"""Generate temporary names for workarounds that need temp columns."""
suffix = hex(hash(for_name)).upper()[1:]
return self.normalize_name(for_name + "_" + suffix)
def prepare_default(self, value):
return self.quote_value(value)
def _field_should_be_indexed(self, model, field):
create_index = super()._field_should_be_indexed(model, field)
db_type = field.db_type(self.connection)
if db_type is not None and db_type.lower() in self.connection._limited_data_types:
return False
return create_index
def _unique_should_be_added(self, old_field, new_field):
return (
super()._unique_should_be_added(old_field, new_field) and
not self._field_became_primary_key(old_field, new_field)
)
def _is_identity_column(self, table_name, column_name):
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT
CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END
FROM user_tab_cols
WHERE table_name = %s AND
column_name = %s
""", [self.normalize_name(table_name), self.normalize_name(column_name)])
row = cursor.fetchone()
return row[0] if row else False
def _drop_identity(self, table_name, column_name):
self.execute('ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY' % {
'table': self.quote_name(table_name),
'column': self.quote_name(column_name),
})

View File

@@ -0,0 +1,76 @@
import datetime
from .base import Database
class InsertVar:
"""
A late-binding cursor variable that can be passed to Cursor.execute
as a parameter, in order to receive the id of the row created by an
insert statement.
"""
types = {
'AutoField': int,
'BigAutoField': int,
'SmallAutoField': int,
'IntegerField': int,
'BigIntegerField': int,
'SmallIntegerField': int,
'PositiveSmallIntegerField': int,
'PositiveIntegerField': int,
'FloatField': Database.NATIVE_FLOAT,
'DateTimeField': Database.TIMESTAMP,
'DateField': Database.Date,
'DecimalField': Database.NUMBER,
}
def __init__(self, field):
internal_type = getattr(field, 'target_field', field).get_internal_type()
self.db_type = self.types.get(internal_type, str)
def bind_parameter(self, cursor):
param = cursor.cursor.var(self.db_type)
cursor._insert_id_var = param
return param
class Oracle_datetime(datetime.datetime):
"""
A datetime object, with an additional class attribute
to tell cx_Oracle to save the microseconds too.
"""
input_size = Database.TIMESTAMP
@classmethod
def from_datetime(cls, dt):
return Oracle_datetime(
dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond,
)
class BulkInsertMapper:
BLOB = 'TO_BLOB(%s)'
CLOB = 'TO_CLOB(%s)'
DATE = 'TO_DATE(%s)'
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
NUMBER = 'TO_NUMBER(%s)'
TIMESTAMP = 'TO_TIMESTAMP(%s)'
types = {
'BigIntegerField': NUMBER,
'BinaryField': BLOB,
'BooleanField': NUMBER,
'DateField': DATE,
'DateTimeField': TIMESTAMP,
'DecimalField': NUMBER,
'DurationField': INTERVAL,
'FloatField': NUMBER,
'IntegerField': NUMBER,
'NullBooleanField': NUMBER,
'PositiveIntegerField': NUMBER,
'PositiveSmallIntegerField': NUMBER,
'SmallIntegerField': NUMBER,
'TextField': CLOB,
'TimeField': TIMESTAMP,
}

View File

@@ -0,0 +1,22 @@
from django.core import checks
from django.db.backends.base.validation import BaseDatabaseValidation
class DatabaseValidation(BaseDatabaseValidation):
def check_field_type(self, field, field_type):
"""Oracle doesn't support a database index on some data types."""
errors = []
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
'Oracle does not support a database index on %s columns.'
% field_type,
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id='fields.W162',
)
)
return errors

View File

@@ -0,0 +1,326 @@
"""
PostgreSQL database backend for Django.
Requires psycopg 2: https://www.psycopg.org/
"""
import asyncio
import threading
import warnings
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import (
CursorDebugWrapper as BaseCursorDebugWrapper,
)
from django.db.utils import DatabaseError as WrappedDatabaseError
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
from django.utils.version import get_version_tuple
try:
import psycopg2 as Database
import psycopg2.extensions
import psycopg2.extras
except ImportError as e:
raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
def psycopg2_version():
version = psycopg2.__version__.split(' ', 1)[0]
return get_version_tuple(version)
PSYCOPG2_VERSION = psycopg2_version()
if PSYCOPG2_VERSION < (2, 5, 4):
raise ImproperlyConfigured("psycopg2_version 2.5.4 or newer is required; you have %s" % psycopg2.__version__)
# Some of these import psycopg2, so import them after checking if it's installed.
from .client import DatabaseClient # NOQA isort:skip
from .creation import DatabaseCreation # NOQA isort:skip
from .features import DatabaseFeatures # NOQA isort:skip
from .introspection import DatabaseIntrospection # NOQA isort:skip
from .operations import DatabaseOperations # NOQA isort:skip
from .schema import DatabaseSchemaEditor # NOQA isort:skip
from .utils import utc_tzinfo_factory # NOQA isort:skip
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
psycopg2.extras.register_uuid()
# Register support for inet[] manually so we don't have to handle the Inet()
# object on load all the time.
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
'INETARRAY',
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql'
display_name = 'PostgreSQL'
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'serial',
'BigAutoField': 'bigserial',
'BinaryField': 'bytea',
'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'interval',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'inet',
'GenericIPAddressField': 'inet',
'NullBooleanField': 'boolean',
'OneToOneField': 'integer',
'PositiveIntegerField': 'integer',
'PositiveSmallIntegerField': 'smallint',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallserial',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'uuid',
}
data_type_check_constraints = {
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
operators = {
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': 'LIKE %s',
'icontains': 'LIKE UPPER(%s)',
'regex': '~ %s',
'iregex': '~* %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE %s',
'endswith': 'LIKE %s',
'istartswith': 'LIKE UPPER(%s)',
'iendswith': 'LIKE UPPER(%s)',
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
pattern_ops = {
'contains': "LIKE '%%' || {} || '%%'",
'icontains': "LIKE '%%' || UPPER({}) || '%%'",
'startswith': "LIKE {} || '%%'",
'istartswith': "LIKE UPPER({}) || '%%'",
'endswith': "LIKE '%%' || {}",
'iendswith': "LIKE '%%' || UPPER({})",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
# PostgreSQL backend-specific attributes.
_named_cursor_idx = 0
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
if settings_dict['NAME'] == '':
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value.")
if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
"in settings.DATABASES." % (
settings_dict['NAME'],
len(settings_dict['NAME']),
self.ops.max_name_length(),
)
)
conn_params = {
'database': settings_dict['NAME'] or 'postgres',
**settings_dict['OPTIONS'],
}
conn_params.pop('isolation_level', None)
if settings_dict['USER']:
conn_params['user'] = settings_dict['USER']
if settings_dict['PASSWORD']:
conn_params['password'] = settings_dict['PASSWORD']
if settings_dict['HOST']:
conn_params['host'] = settings_dict['HOST']
if settings_dict['PORT']:
conn_params['port'] = settings_dict['PORT']
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
connection = Database.connect(**conn_params)
# self.isolation_level must be set:
# - after connecting to the database in order to obtain the database's
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
options = self.settings_dict['OPTIONS']
try:
self.isolation_level = options['isolation_level']
except KeyError:
self.isolation_level = connection.isolation_level
else:
# Set the isolation level to the value from OPTIONS.
if self.isolation_level != connection.isolation_level:
connection.set_session(isolation_level=self.isolation_level)
return connection
def ensure_timezone(self):
if self.connection is None:
return False
conn_timezone_name = self.connection.get_parameter_status('TimeZone')
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
return True
return False
def init_connection_state(self):
self.connection.set_client_encoding('UTF8')
timezone_changed = self.ensure_timezone()
if timezone_changed:
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():
self.connection.commit()
@async_unsafe
def create_cursor(self, name=None):
if name:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
else:
cursor = self.connection.cursor()
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
return cursor
@async_unsafe
def chunked_cursor(self):
self._named_cursor_idx += 1
# Get the current async task
# Note that right now this is behind @async_unsafe, so this is
# unreachable, but in future we'll start loosening this restriction.
# For now, it's here so that every use of "threading" is
# also async-compatible.
try:
if hasattr(asyncio, 'current_task'):
# Python 3.7 and up
current_task = asyncio.current_task()
else:
# Python 3.6
current_task = asyncio.Task.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = 'sync'
# Use that and the thread ident to get a unique name
return self._cursor(
name='_django_curs_%d_%s_%d' % (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
self._named_cursor_idx,
)
)
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
def check_constraints(self, table_names=None):
"""
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1")
except Database.Error:
return False
else:
return True
@property
def _nodb_connection(self):
nodb_connection = super()._nodb_connection
try:
nodb_connection.ensure_connection()
except (Database.DatabaseError, WrappedDatabaseError):
warnings.warn(
"Normally Django will use a connection to the 'postgres' database "
"to avoid running initialization queries against the production "
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
RuntimeWarning
)
for connection in connections.all():
if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
return self.__class__(
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias,
)
return nodb_connection
@cached_property
def pg_version(self):
with self.temporary_connection():
return self.connection.server_version
def make_debug_cursor(self, cursor):
return CursorDebugWrapper(cursor, self)
class CursorDebugWrapper(BaseCursorDebugWrapper):
def copy_expert(self, sql, file, *args):
with self.debug_sql(sql):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
with self.debug_sql(sql='COPY %s TO STDOUT' % table):
return self.cursor.copy_to(file, table, *args, **kwargs)

View File

@@ -0,0 +1,54 @@
import os
import signal
import subprocess
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql'
@classmethod
def runshell_db(cls, conn_params):
args = [cls.executable_name]
host = conn_params.get('host', '')
port = conn_params.get('port', '')
dbname = conn_params.get('database', '')
user = conn_params.get('user', '')
passwd = conn_params.get('password', '')
sslmode = conn_params.get('sslmode', '')
sslrootcert = conn_params.get('sslrootcert', '')
sslcert = conn_params.get('sslcert', '')
sslkey = conn_params.get('sslkey', '')
if user:
args += ['-U', user]
if host:
args += ['-h', host]
if port:
args += ['-p', str(port)]
args += [dbname]
sigint_handler = signal.getsignal(signal.SIGINT)
subprocess_env = os.environ.copy()
if passwd:
subprocess_env['PGPASSWORD'] = str(passwd)
if sslmode:
subprocess_env['PGSSLMODE'] = str(sslmode)
if sslrootcert:
subprocess_env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert:
subprocess_env['PGSSLCERT'] = str(sslcert)
if sslkey:
subprocess_env['PGSSLKEY'] = str(sslkey)
try:
# Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
subprocess.run(args, check=True, env=subprocess_env)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)
def runshell(self):
DatabaseClient.runshell_db(self.connection.get_connection_params())

View File

@@ -0,0 +1,77 @@
import sys
from psycopg2 import errorcodes
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
def _get_database_create_suffix(self, encoding=None, template=None):
suffix = ""
if encoding:
suffix += " ENCODING '{}'".format(encoding)
if template:
suffix += " TEMPLATE {}".format(self._quote_name(template))
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
test_settings = self.connection.settings_dict['TEST']
assert test_settings['COLLATION'] is None, (
"PostgreSQL does not support collation setting at database creation time."
)
return self._get_database_create_suffix(
encoding=test_settings['CHARSET'],
template=test_settings.get('TEMPLATE'),
)
def _database_exists(self, cursor, database_name):
cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
if keepdb and self._database_exists(cursor, parameters['dbname']):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
# All errors except "database already exists" cancel tests.
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
# exists".
raise e
def _clone_test_db(self, suffix, verbosity, keepdb=False):
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
# to the template database.
self.connection.close()
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
'dbname': self._quote_name(target_database_name),
'suffix': self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_connection.cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)

View File

@@ -0,0 +1,70 @@
import operator
from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.utils import InterfaceError
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True
can_return_columns_from_insert = True
can_return_multiple_columns_from_insert = True
can_return_rows_from_bulk_insert = True
has_real_datatype = True
has_native_uuid_field = True
has_native_duration_field = True
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_of = True
has_select_for_update_skip_locked = True
can_release_savepoints = True
supports_tablespaces = True
supports_transactions = True
can_introspect_autofield = True
can_introspect_ip_address_field = True
can_introspect_materialized_views = True
can_introspect_small_integer_field = True
can_distinct_on_fields = True
can_rollback_ddl = True
supports_combined_alters = True
nulls_order_largest = True
closed_cursor_error_class = InterfaceError
has_case_insensitive_like = False
greatest_least_ignores_nulls = True
can_clone_databases = True
supports_temporal_subtraction = True
supports_slicing_ordering_in_compound = True
create_test_procedure_without_params_sql = """
CREATE FUNCTION test_procedure () RETURNS void AS $$
DECLARE
V_I INTEGER;
BEGIN
V_I := 1;
END;
$$ LANGUAGE plpgsql;"""
create_test_procedure_with_int_param_sql = """
CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$
DECLARE
V_I INTEGER;
BEGIN
V_I := P_I;
END;
$$ LANGUAGE plpgsql;"""
requires_casted_case_in_updates = True
supports_over_clause = True
supports_aggregate_filter_clause = True
supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
validates_explain_options = False # A query will error on invalid options.
@cached_property
def is_postgresql_9_6(self):
return self.connection.pg_version >= 90600
@cached_property
def is_postgresql_10(self):
return self.connection.pg_version >= 100000
has_brin_autosummarize = property(operator.attrgetter('is_postgresql_10'))
has_phraseto_tsquery = property(operator.attrgetter('is_postgresql_9_6'))
supports_table_partitions = property(operator.attrgetter('is_postgresql_10'))

View File

@@ -0,0 +1,224 @@
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)
from django.db.models.indexes import Index
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: 'BooleanField',
17: 'BinaryField',
20: 'BigIntegerField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
700: 'FloatField',
701: 'FloatField',
869: 'GenericIPAddressField',
1042: 'CharField', # blank-padded
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1186: 'DurationField',
1266: 'TimeField',
1700: 'DecimalField',
2950: 'UUIDField',
}
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.default and 'nextval' in description.default:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("""
SELECT c.relname,
CASE WHEN {} THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""".format('c.relispartition' if self.connection.features.supports_table_partitions else 'FALSE'))
return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
cursor.execute("""
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default
FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
JOIN pg_type t ON a.atttypid = t.oid
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""", [table_name])
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return [
FieldInfo(
line.name,
line.type_code,
line.display_size,
line.internal_size,
line.precision,
line.scale,
*field_map[line.name],
)
for line in cursor.description
]
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute("""
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid = 'pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
""", [table_name])
return [
{'name': row[0], 'table': table_name, 'column': row[1]}
for row in cursor.fetchall()
]
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
return {row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)}
def get_key_columns(self, cursor, table_name):
cursor.execute("""
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
WHERE
c1.relname = %s AND
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
""", [table_name])
return cursor.fetchall()
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns. Also retrieve the definition of expression-based
indexes.
"""
constraints = {}
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
cursor.execute("""
SELECT
c.conname,
array(
SELECT attname
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
WHERE ca.attrelid = c.conrelid
ORDER BY cols.arridx
),
c.contype,
(SELECT fkc.relname || '.' || fka.attname
FROM pg_attribute AS fka
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
cl.reloptions
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
""", [table_name])
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
"primary_key": kind == "p",
"unique": kind in ["p", "u"],
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
"check": kind == "c",
"index": False,
"definition": None,
"options": options,
}
# Now get indexes
cursor.execute("""
SELECT
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
FROM (
SELECT
c2.relname as indexname, idx.*, attr.attname, am.amname,
CASE
WHEN idx.indexprs IS NOT NULL THEN
pg_get_indexdef(idx.indexrelid)
END AS exprdef,
CASE am.amname
WHEN 'btree' THEN
CASE (option & 1)
WHEN 1 THEN 'DESC' ELSE 'ASC'
END
END as ordering,
c2.reloptions as attoptions
FROM (
SELECT *
FROM pg_index i, unnest(i.indkey, i.indoption) WITH ORDINALITY koi(key, option, arridx)
) idx
LEFT JOIN pg_class c ON idx.indrelid = c.oid
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
LEFT JOIN pg_am am ON c2.relam = am.oid
LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
""", [table_name])
for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
if index not in constraints:
basic_index = type_ == 'btree' and not index.endswith('_btree') and options is None
constraints[index] = {
"columns": columns if columns != [None] else [],
"orders": orders if orders != [None] else [],
"primary_key": primary,
"unique": unique,
"foreign_key": None,
"check": False,
"index": True,
"type": Index.suffix if basic_index else type_,
"definition": definition,
"options": options,
}
return constraints

View File

@@ -0,0 +1,300 @@
from psycopg2.extras import Inet
from django.conf import settings
from django.db import NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = 'varchar'
explain_prefix = 'EXPLAIN'
cast_data_types = {
'AutoField': 'integer',
'BigAutoField': 'bigint',
'SmallAutoField': 'smallint',
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
# These fields cannot be implicitly cast back in the default
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
return '%s'
def date_extract_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
if lookup_type == 'week_day':
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
elif lookup_type == 'iso_year':
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
def date_trunc_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def _prepare_tzname_delta(self, tzname):
if '+' in tzname:
return tzname.replace('+', '-')
elif '-' in tzname:
return tzname.replace('-', '+')
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return '(%s)::date' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return '(%s)::time' % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def time_trunc_sql(self, lookup_type, field_name):
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
if internal_type in ('IPAddressField', 'GenericIPAddressField'):
lookup = "HOST(%s)"
elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
lookup = '%s::citext'
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
lookup = 'UPPER(%s)' % lookup
return lookup
def no_limit_value(self):
return None
def prepare_sql_script(self, sql):
return [sql]
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return '"%s"' % name
def set_time_zone_sql(self):
return "SET TIME ZONE %s"
def sql_flush(self, style, tables, sequences, allow_cascade=False):
if tables:
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows
# us to truncate tables referenced by a foreign key in any other
# table.
tables_sql = ', '.join(
style.SQL_FIELD(self.quote_name(table)) for table in tables)
if allow_cascade:
sql = ['%s %s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
tables_sql,
style.SQL_KEYWORD('CASCADE'),
)]
else:
sql = ['%s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
tables_sql,
)]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
else:
return []
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
table_name = sequence_info['table']
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
column_name = sequence_info['column'] or 'id'
sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
))
return sql
def tablespace_sql(self, tablespace, inline=False):
if inline:
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
else:
return "TABLESPACE %s" % self.quote_name(tablespace)
def sequence_reset_sql(self, style, model_list):
from django.db import models
output = []
qn = self.quote_name
for model in model_list:
# Use `coalesce` to set the sequence for each model to the max pk value if there are records,
# or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
# if there are records (as the max pk value is already in use), otherwise set it to false.
# Use pg_get_serial_sequence to get the underlying sequence name from the table name
# and column name (available since PostgreSQL 8)
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
style.SQL_KEYWORD('IS NOT'),
style.SQL_KEYWORD('FROM'),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.many_to_many:
if not f.remote_field.through:
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(f.m2m_db_table())),
style.SQL_FIELD('id'),
style.SQL_FIELD(qn('id')),
style.SQL_FIELD(qn('id')),
style.SQL_KEYWORD('IS NOT'),
style.SQL_KEYWORD('FROM'),
style.SQL_TABLE(qn(f.m2m_db_table()))
)
)
return output
def prep_for_iexact_query(self, x):
return x
def max_name_length(self):
"""
Return the maximum length of an identifier.
The maximum length of an identifier is 63 by default, but can be
changed by recompiling PostgreSQL after editing the NAMEDATALEN
macro in src/include/pg_config_manual.h.
This implementation returns 63, but can be overridden by a custom
database backend that inherits most of its behavior from this one.
"""
return 63
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
else:
return ['DISTINCT'], []
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
# The query attribute is a Psycopg extension to the DB API 2.0.
if cursor.query is not None:
return cursor.query.decode()
return None
def return_insert_columns(self, fields):
if not fields:
return '', ()
columns = [
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def adapt_datefield_value(self, value):
return value
def adapt_datetimefield_value(self, value):
return value
def adapt_timefield_value(self, value):
return value
def adapt_ipaddressfield_value(self, value):
if value:
return Inet(value)
return None
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params
return super().subtract_temporals(internal_type, lhs, rhs)
def window_frame_range_start_end(self, start=None, end=None):
start_, end_ = super().window_frame_range_start_end(start, end)
if (start and start < 0) or (end and end > 0):
raise NotSupportedError(
'PostgreSQL only supports UNBOUNDED together with PRECEDING '
'and FOLLOWING.'
)
return start_, end_
def explain_query_prefix(self, format=None, **options):
prefix = super().explain_query_prefix(format)
extra = {}
if format:
extra['FORMAT'] = format
if options:
extra.update({
name.upper(): 'true' if value else 'false'
for name, value in options.items()
})
if extra:
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
return prefix
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)

View File

@@ -0,0 +1,192 @@
import psycopg2
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import IndexColumns
from django.db.backends.utils import strip_quotes
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
sql_create_index = "CREATE INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s%(condition)s"
sql_create_index_concurrently = (
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s%(condition)s"
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
'; SET CONSTRAINTS %(name)s IMMEDIATE'
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
def quote_value(self, value):
if isinstance(value, str):
value = value.replace('%', '%%')
# getquoted() returns a quoted bytestring of the adapted value.
return psycopg2.extensions.adapt(value).getquoted().decode()
def _field_indexes_sql(self, model, field):
output = super()._field_indexes_sql(model, field)
like_index_statement = self._create_like_index_sql(model, field)
if like_index_statement is not None:
output.append(like_index_statement)
return output
def _field_data_type(self, field):
if field.is_relation:
return field.rel_db_type(self.connection)
return self.connection.data_types.get(
field.get_internal_type(),
field.db_type(self.connection),
)
def _create_like_index_sql(self, model, field):
"""
Return the statement to create an index with varchar operator pattern
when the column type is 'varchar' or 'text', otherwise return None.
"""
db_type = field.db_type(connection=self.connection)
if db_type is not None and (field.db_index or field.unique):
# Fields with database column types of `varchar` and `text` need
# a second index that specifies their operator class, which is
# needed when performing correct LIKE queries outside the
# C locale. See #12234.
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
if '[' in db_type:
return None
if db_type.startswith('varchar'):
return self._create_index_sql(model, [field], suffix='_like', opclasses=['varchar_pattern_ops'])
elif db_type.startswith('text'):
return self._create_index_sql(model, [field], suffix='_like', opclasses=['text_pattern_ops'])
return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
# Cast when data type changed.
if self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += ' USING %(column)s::%(type)s'
# Make ALTER TYPE with SERIAL make sense.
table = strip_quotes(model._meta.db_table)
serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
if new_type.lower() in serial_fields_map:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type % {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
},
[],
),
[
(
self.sql_delete_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_create_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_alter_column % {
"table": self.quote_name(table),
"changes": self.sql_alter_column_default % {
"column": self.quote_name(column),
"default": "nextval('%s')" % self.quote_name(sequence_name),
}
},
[],
),
(
self.sql_set_sequence_max % {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_set_sequence_owner % {
'table': self.quote_name(table),
'column': self.quote_name(column),
'sequence': self.quote_name(sequence_name),
},
[],
),
],
)
else:
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
if (old_field.db_index or old_field.unique) and (
(old_type.startswith('varchar') and not new_type.startswith('varchar')) or
(old_type.startswith('text') and not new_type.startswith('text')) or
(old_type.startswith('citext') and not new_type.startswith('citext'))
):
index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_name))
super()._alter_field(
model, old_field, new_field, old_type, new_type, old_db_params,
new_db_params, strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
(not old_field.unique and new_field.unique)):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
self, model, fields, *, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
condition=None, concurrently=False,
):
sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
return super()._create_index_sql(
model, fields, name=name, suffix=suffix, using=using, db_tablespace=db_tablespace,
col_suffixes=col_suffixes, sql=sql, opclasses=opclasses, condition=condition,
)

View File

@@ -0,0 +1,7 @@
from django.utils.timezone import utc
def utc_tzinfo_factory(offset):
if offset != 0:
raise AssertionError("database connection isn't set to UTC")
return utc

View File

@@ -0,0 +1,3 @@
from django.dispatch import Signal
connection_created = Signal(providing_args=["connection"])

Some files were not shown because too many files have changed in this diff Show More