Initial commit. Basic models mostly done.
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from .array import * # NOQA
|
||||
from .citext import * # NOQA
|
||||
from .hstore import * # NOQA
|
||||
from .jsonb import * # NOQA
|
||||
from .ranges import * # NOQA
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,300 @@
|
||||
import json
|
||||
|
||||
from django.contrib.postgres import lookups
|
||||
from django.contrib.postgres.forms import SimpleArrayField
|
||||
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
||||
from django.core import checks, exceptions
|
||||
from django.db.models import Field, IntegerField, Transform
|
||||
from django.db.models.lookups import Exact, In
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from ..utils import prefix_validation_error
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
from .utils import AttributeSetter
|
||||
|
||||
__all__ = ['ArrayField']
|
||||
|
||||
|
||||
class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
default_error_messages = {
|
||||
'item_invalid': _('Item %(nth)s in the array did not validate:'),
|
||||
'nested_array_mismatch': _('Nested arrays must have the same length.'),
|
||||
}
|
||||
_default_hint = ('list', '[]')
|
||||
|
||||
def __init__(self, base_field, size=None, **kwargs):
|
||||
self.base_field = base_field
|
||||
self.size = size
|
||||
if self.size:
|
||||
self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
|
||||
# For performance, only add a from_db_value() method if the base field
|
||||
# implements it.
|
||||
if hasattr(self.base_field, 'from_db_value'):
|
||||
self.from_db_value = self._from_db_value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
try:
|
||||
return self.__dict__['model']
|
||||
except KeyError:
|
||||
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
|
||||
|
||||
@model.setter
|
||||
def model(self, model):
|
||||
self.__dict__['model'] = model
|
||||
self.base_field.model = model
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
if self.base_field.remote_field:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
'Base field for array cannot be a related field.',
|
||||
obj=self,
|
||||
id='postgres.E002'
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Remove the field name checks as they are not needed here.
|
||||
base_errors = self.base_field.check()
|
||||
if base_errors:
|
||||
messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
|
||||
errors.append(
|
||||
checks.Error(
|
||||
'Base field for array has errors:\n %s' % messages,
|
||||
obj=self,
|
||||
id='postgres.E001'
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def set_attributes_from_name(self, name):
|
||||
super().set_attributes_from_name(name)
|
||||
self.base_field.set_attributes_from_name(name)
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return 'Array of %s' % self.base_field.description
|
||||
|
||||
def db_type(self, connection):
|
||||
size = self.size or ''
|
||||
return '%s[%s]' % (self.base_field.db_type(connection), size)
|
||||
|
||||
def cast_db_type(self, connection):
|
||||
size = self.size or ''
|
||||
return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
|
||||
|
||||
def get_placeholder(self, value, compiler, connection):
|
||||
return '%s::{}'.format(self.db_type(connection))
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
|
||||
return value
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if path == 'django.contrib.postgres.fields.array.ArrayField':
|
||||
path = 'django.contrib.postgres.fields.ArrayField'
|
||||
kwargs.update({
|
||||
'base_field': self.base_field.clone(),
|
||||
'size': self.size,
|
||||
})
|
||||
return name, path, args, kwargs
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing
|
||||
vals = json.loads(value)
|
||||
value = [self.base_field.to_python(val) for val in vals]
|
||||
return value
|
||||
|
||||
def _from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return [
|
||||
self.base_field.from_db_value(item, expression, connection)
|
||||
for item in value
|
||||
]
|
||||
|
||||
def value_to_string(self, obj):
|
||||
values = []
|
||||
vals = self.value_from_object(obj)
|
||||
base_field = self.base_field
|
||||
|
||||
for val in vals:
|
||||
if val is None:
|
||||
values.append(None)
|
||||
else:
|
||||
obj = AttributeSetter(base_field.attname, val)
|
||||
values.append(base_field.value_to_string(obj))
|
||||
return json.dumps(values)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
if '_' not in name:
|
||||
try:
|
||||
index = int(name)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
index += 1 # postgres uses 1-indexing
|
||||
return IndexTransformFactory(index, self.base_field)
|
||||
try:
|
||||
start, end = name.split('_')
|
||||
start = int(start) + 1
|
||||
end = int(end) # don't add one here because postgres slices are weird
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
return SliceTransformFactory(start, end)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
for index, part in enumerate(value):
|
||||
try:
|
||||
self.base_field.validate(part, model_instance)
|
||||
except exceptions.ValidationError as error:
|
||||
raise prefix_validation_error(
|
||||
error,
|
||||
prefix=self.error_messages['item_invalid'],
|
||||
code='item_invalid',
|
||||
params={'nth': index + 1},
|
||||
)
|
||||
if isinstance(self.base_field, ArrayField):
|
||||
if len({len(i) for i in value}) > 1:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages['nested_array_mismatch'],
|
||||
code='nested_array_mismatch',
|
||||
)
|
||||
|
||||
def run_validators(self, value):
|
||||
super().run_validators(value)
|
||||
for index, part in enumerate(value):
|
||||
try:
|
||||
self.base_field.run_validators(part)
|
||||
except exceptions.ValidationError as error:
|
||||
raise prefix_validation_error(
|
||||
error,
|
||||
prefix=self.error_messages['item_invalid'],
|
||||
code='item_invalid',
|
||||
params={'nth': index + 1},
|
||||
)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(**{
|
||||
'form_class': SimpleArrayField,
|
||||
'base_field': self.base_field.formfield(),
|
||||
'max_length': self.size,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
class ArrayCastRHSMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
cast_type = self.lhs.output_field.cast_db_type(connection)
|
||||
return '%s::%s' % (rhs, cast_type), rhs_params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayContains(ArrayCastRHSMixin, lookups.DataContains):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayExact(ArrayCastRHSMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayLenTransform(Transform):
|
||||
lookup_name = 'len'
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
# Distinguish NULL and empty arrays
|
||||
return (
|
||||
'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
|
||||
'coalesce(array_length(%(lhs)s, 1), 0) END'
|
||||
) % {'lhs': lhs}, params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayInLookup(In):
|
||||
def get_prep_lookup(self):
|
||||
values = super().get_prep_lookup()
|
||||
if hasattr(values, 'resolve_expression'):
|
||||
return values
|
||||
# In.process_rhs() expects values to be hashable, so convert lists
|
||||
# to tuples.
|
||||
prepared_values = []
|
||||
for value in values:
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
prepared_values.append(value)
|
||||
else:
|
||||
prepared_values.append(tuple(value))
|
||||
return prepared_values
|
||||
|
||||
|
||||
class IndexTransform(Transform):
|
||||
|
||||
def __init__(self, index, base_field, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index = index
|
||||
self.base_field = base_field
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return '%s[%%s]' % lhs, params + [self.index]
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.base_field
|
||||
|
||||
|
||||
class IndexTransformFactory:
|
||||
|
||||
def __init__(self, index, base_field):
|
||||
self.index = index
|
||||
self.base_field = base_field
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return IndexTransform(self.index, self.base_field, *args, **kwargs)
|
||||
|
||||
|
||||
class SliceTransform(Transform):
|
||||
|
||||
def __init__(self, start, end, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
|
||||
|
||||
|
||||
class SliceTransformFactory:
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return SliceTransform(self.start, self.end, *args, **kwargs)
|
||||
@@ -0,0 +1,24 @@
|
||||
from django.db.models import CharField, EmailField, TextField
|
||||
|
||||
__all__ = ['CICharField', 'CIEmailField', 'CIText', 'CITextField']
|
||||
|
||||
|
||||
class CIText:
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'CI' + super().get_internal_type()
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'citext'
|
||||
|
||||
|
||||
class CICharField(CIText, CharField):
|
||||
pass
|
||||
|
||||
|
||||
class CIEmailField(CIText, EmailField):
|
||||
pass
|
||||
|
||||
|
||||
class CITextField(CIText, TextField):
|
||||
pass
|
||||
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
|
||||
from django.contrib.postgres import forms, lookups
|
||||
from django.contrib.postgres.fields.array import ArrayField
|
||||
from django.core import exceptions
|
||||
from django.db.models import Field, TextField, Transform
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ['HStoreField']
|
||||
|
||||
|
||||
class HStoreField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _('Map of strings to strings/nulls')
|
||||
default_error_messages = {
|
||||
'not_a_string': _('The value of “%(key)s” is not a string or null.'),
|
||||
}
|
||||
_default_hint = ('dict', '{}')
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'hstore'
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
for key, val in value.items():
|
||||
if not isinstance(val, str) and val is not None:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages['not_a_string'],
|
||||
code='not_a_string',
|
||||
params={'key': key},
|
||||
)
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
return value
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return json.dumps(self.value_from_object(obj))
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(**{
|
||||
'form_class': forms.HStoreField,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
prep_value = {}
|
||||
for key, val in value.items():
|
||||
key = str(key)
|
||||
if val is not None:
|
||||
val = str(val)
|
||||
prep_value[key] = val
|
||||
value = prep_value
|
||||
|
||||
if isinstance(value, list):
|
||||
value = [str(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
HStoreField.register_lookup(lookups.DataContains)
|
||||
HStoreField.register_lookup(lookups.ContainedBy)
|
||||
HStoreField.register_lookup(lookups.HasKey)
|
||||
HStoreField.register_lookup(lookups.HasKeys)
|
||||
HStoreField.register_lookup(lookups.HasAnyKeys)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = key_name
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return '(%s -> %%s)' % lhs, tuple(params) + (self.key_name,)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class KeysTransform(Transform):
|
||||
lookup_name = 'keys'
|
||||
function = 'akeys'
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class ValuesTransform(Transform):
|
||||
lookup_name = 'values'
|
||||
function = 'avals'
|
||||
output_field = ArrayField(TextField())
|
||||
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from django.contrib.postgres import forms, lookups
|
||||
from django.core import exceptions
|
||||
from django.db.models import (
|
||||
Field, TextField, Transform, lookups as builtin_lookups,
|
||||
)
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ['JSONField']
|
||||
|
||||
|
||||
class JsonAdapter(Json):
|
||||
"""
|
||||
Customized psycopg2.extras.Json to allow for a custom encoder.
|
||||
"""
|
||||
def __init__(self, adapted, dumps=None, encoder=None):
|
||||
self.encoder = encoder
|
||||
super().__init__(adapted, dumps=dumps)
|
||||
|
||||
def dumps(self, obj):
|
||||
options = {'cls': self.encoder} if self.encoder else {}
|
||||
return json.dumps(obj, **options)
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _('A JSON object')
|
||||
default_error_messages = {
|
||||
'invalid': _("Value must be valid JSON."),
|
||||
}
|
||||
_default_hint = ('dict', '{}')
|
||||
|
||||
def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
self.encoder = encoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'jsonb'
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs['encoder'] = self.encoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if value is not None:
|
||||
return JsonAdapter(value, encoder=self.encoder)
|
||||
return value
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
options = {'cls': self.encoder} if self.encoder else {}
|
||||
try:
|
||||
json.dumps(value, **options)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages['invalid'],
|
||||
code='invalid',
|
||||
params={'value': value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(**{
|
||||
'form_class': forms.JSONField,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
JSONField.register_lookup(lookups.DataContains)
|
||||
JSONField.register_lookup(lookups.ContainedBy)
|
||||
JSONField.register_lookup(lookups.HasKey)
|
||||
JSONField.register_lookup(lookups.HasKeys)
|
||||
JSONField.register_lookup(lookups.HasAnyKeys)
|
||||
JSONField.register_lookup(lookups.JSONExact)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
operator = '->'
|
||||
nested_operator = '#>'
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = key_name
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if len(key_transforms) > 1:
|
||||
return '(%s %s %%s)' % (lhs, self.nested_operator), params + [key_transforms]
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return '(%s %s %%s)' % (lhs, self.operator), tuple(params) + (lookup,)
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
operator = '->>'
|
||||
nested_operator = '#>>'
|
||||
output_field = TextField()
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. Make use of the ->> operator instead of casting key values to
|
||||
text and performing the lookup on the resulting representation.
|
||||
"""
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
assert isinstance(key_transform, KeyTransform)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name, *key_transform.source_expressions, **key_transform.extra
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIExact(KeyTransformTextLookupMixin, builtin_lookups.IExact):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(KeyTransformTextLookupMixin, builtin_lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, builtin_lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(KeyTransformTextLookupMixin, builtin_lookups.IStartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, builtin_lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(KeyTransformTextLookupMixin, builtin_lookups.IEndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, builtin_lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(KeyTransformTextLookupMixin, builtin_lookups.IRegex):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
@@ -0,0 +1,29 @@
|
||||
from django.core import checks
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ('<valid default>', '<invalid default>')
|
||||
|
||||
def _check_default(self):
|
||||
if self.has_default() and self.default is not None and not callable(self.default):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance so "
|
||||
"that it's not shared between all field instances." % (
|
||||
self.__class__.__name__,
|
||||
),
|
||||
hint=(
|
||||
'Use a callable instead, e.g., use `%s` instead of '
|
||||
'`%s`.' % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id='postgres.E003',
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
||||
@@ -0,0 +1,299 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
|
||||
|
||||
from django.contrib.postgres import forms, lookups
|
||||
from django.db import models
|
||||
|
||||
from .utils import AttributeSetter
|
||||
|
||||
__all__ = [
|
||||
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
|
||||
'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
|
||||
'FloatRangeField',
|
||||
'RangeBoundary', 'RangeOperators',
|
||||
]
|
||||
|
||||
|
||||
class RangeBoundary(models.Expression):
|
||||
"""A class that represents range boundaries."""
|
||||
def __init__(self, inclusive_lower=True, inclusive_upper=False):
|
||||
self.lower = '[' if inclusive_lower else '('
|
||||
self.upper = ']' if inclusive_upper else ')'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "'%s%s'" % (self.lower, self.upper), []
|
||||
|
||||
|
||||
class RangeOperators:
|
||||
# https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
|
||||
EQUAL = '='
|
||||
NOT_EQUAL = '<>'
|
||||
CONTAINS = '@>'
|
||||
CONTAINED_BY = '<@'
|
||||
OVERLAPS = '&&'
|
||||
FULLY_LT = '<<'
|
||||
FULLY_GT = '>>'
|
||||
NOT_LT = '&>'
|
||||
NOT_GT = '&<'
|
||||
ADJACENT_TO = '-|-'
|
||||
|
||||
|
||||
class RangeField(models.Field):
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Initializing base_field here ensures that its model matches the model for self.
|
||||
if hasattr(self, 'base_field'):
|
||||
self.base_field = self.base_field()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
try:
|
||||
return self.__dict__['model']
|
||||
except KeyError:
|
||||
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
|
||||
|
||||
@model.setter
|
||||
def model(self, model):
|
||||
self.__dict__['model'] = model
|
||||
self.base_field.model = model
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, Range):
|
||||
return value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return self.range_type(value[0], value[1])
|
||||
return value
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing
|
||||
vals = json.loads(value)
|
||||
for end in ('lower', 'upper'):
|
||||
if end in vals:
|
||||
vals[end] = self.base_field.to_python(vals[end])
|
||||
value = self.range_type(**vals)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value = self.range_type(value[0], value[1])
|
||||
return value
|
||||
|
||||
def set_attributes_from_name(self, name):
|
||||
super().set_attributes_from_name(name)
|
||||
self.base_field.set_attributes_from_name(name)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
value = self.value_from_object(obj)
|
||||
if value is None:
|
||||
return None
|
||||
if value.isempty:
|
||||
return json.dumps({"empty": True})
|
||||
base_field = self.base_field
|
||||
result = {"bounds": value._bounds}
|
||||
for end in ('lower', 'upper'):
|
||||
val = getattr(value, end)
|
||||
if val is None:
|
||||
result[end] = None
|
||||
else:
|
||||
obj = AttributeSetter(base_field.attname, val)
|
||||
result[end] = base_field.value_to_string(obj)
|
||||
return json.dumps(result)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
kwargs.setdefault('form_class', self.form_field)
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
|
||||
class IntegerRangeField(RangeField):
|
||||
base_field = models.IntegerField
|
||||
range_type = NumericRange
|
||||
form_field = forms.IntegerRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'int4range'
|
||||
|
||||
|
||||
class BigIntegerRangeField(RangeField):
|
||||
base_field = models.BigIntegerField
|
||||
range_type = NumericRange
|
||||
form_field = forms.IntegerRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'int8range'
|
||||
|
||||
|
||||
class DecimalRangeField(RangeField):
|
||||
base_field = models.DecimalField
|
||||
range_type = NumericRange
|
||||
form_field = forms.DecimalRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'numrange'
|
||||
|
||||
|
||||
class FloatRangeField(RangeField):
|
||||
system_check_deprecated_details = {
|
||||
'msg': (
|
||||
'FloatRangeField is deprecated and will be removed in Django 3.1.'
|
||||
),
|
||||
'hint': 'Use DecimalRangeField instead.',
|
||||
'id': 'fields.W902',
|
||||
}
|
||||
base_field = models.FloatField
|
||||
range_type = NumericRange
|
||||
form_field = forms.FloatRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'numrange'
|
||||
|
||||
|
||||
class DateTimeRangeField(RangeField):
|
||||
base_field = models.DateTimeField
|
||||
range_type = DateTimeTZRange
|
||||
form_field = forms.DateTimeRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'tstzrange'
|
||||
|
||||
|
||||
class DateRangeField(RangeField):
|
||||
base_field = models.DateField
|
||||
range_type = DateRange
|
||||
form_field = forms.DateRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'daterange'
|
||||
|
||||
|
||||
RangeField.register_lookup(lookups.DataContains)
|
||||
RangeField.register_lookup(lookups.ContainedBy)
|
||||
RangeField.register_lookup(lookups.Overlap)
|
||||
|
||||
|
||||
class DateTimeRangeContains(lookups.PostgresSimpleLookup):
|
||||
"""
|
||||
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
|
||||
type.
|
||||
"""
|
||||
lookup_name = 'contains'
|
||||
operator = RangeOperators.CONTAINS
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
# Transform rhs value for db lookup.
|
||||
if isinstance(self.rhs, datetime.date):
|
||||
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
|
||||
value = models.Value(self.rhs, output_field=output_field)
|
||||
self.rhs = value.resolve_expression(compiler.query)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = super().as_sql(compiler, connection)
|
||||
# Cast the rhs if needed.
|
||||
cast_sql = ''
|
||||
if (
|
||||
isinstance(self.rhs, models.Expression) and
|
||||
self.rhs._output_field_or_none and
|
||||
# Skip cast if rhs has a matching range type.
|
||||
not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
|
||||
):
|
||||
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
|
||||
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
|
||||
return '%s%s' % (sql, cast_sql), params
|
||||
|
||||
|
||||
DateRangeField.register_lookup(DateTimeRangeContains)
|
||||
DateTimeRangeField.register_lookup(DateTimeRangeContains)
|
||||
|
||||
|
||||
class RangeContainedBy(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'contained_by'
|
||||
type_mapping = {
|
||||
'integer': 'int4range',
|
||||
'bigint': 'int8range',
|
||||
'double precision': 'numrange',
|
||||
'date': 'daterange',
|
||||
'timestamp with time zone': 'tstzrange',
|
||||
}
|
||||
operator = RangeOperators.CONTAINED_BY
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
cast_type = self.type_mapping[self.lhs.output_field.db_type(connection)]
|
||||
return '%s::%s' % (rhs, cast_type), rhs_params
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if isinstance(self.lhs.output_field, models.FloatField):
|
||||
lhs = '%s::numeric' % lhs
|
||||
return lhs, lhs_params
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return RangeField().get_prep_value(self.rhs)
|
||||
|
||||
|
||||
models.DateField.register_lookup(RangeContainedBy)
|
||||
models.DateTimeField.register_lookup(RangeContainedBy)
|
||||
models.IntegerField.register_lookup(RangeContainedBy)
|
||||
models.BigIntegerField.register_lookup(RangeContainedBy)
|
||||
models.FloatField.register_lookup(RangeContainedBy)
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class FullyLessThan(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'fully_lt'
|
||||
operator = RangeOperators.FULLY_LT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class FullGreaterThan(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'fully_gt'
|
||||
operator = RangeOperators.FULLY_GT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class NotLessThan(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'not_lt'
|
||||
operator = RangeOperators.NOT_LT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class NotGreaterThan(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'not_gt'
|
||||
operator = RangeOperators.NOT_GT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class AdjacentToLookup(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'adjacent_to'
|
||||
operator = RangeOperators.ADJACENT_TO
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeStartsWith(models.Transform):
|
||||
lookup_name = 'startswith'
|
||||
function = 'lower'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.lhs.output_field.base_field
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeEndsWith(models.Transform):
|
||||
lookup_name = 'endswith'
|
||||
function = 'upper'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.lhs.output_field.base_field
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class IsEmpty(models.Transform):
|
||||
lookup_name = 'isempty'
|
||||
function = 'isempty'
|
||||
output_field = models.BooleanField()
|
||||
@@ -0,0 +1,3 @@
|
||||
class AttributeSetter:
|
||||
def __init__(self, name, value):
|
||||
setattr(self, name, value)
|
||||
Reference in New Issue
Block a user