Initial commit. Basic models mostly done.
This commit is contained in:
18
venv/lib/python3.8/site-packages/django/test/__init__.py
Normal file
18
venv/lib/python3.8/site-packages/django/test/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Django Unit Test framework."""
|
||||
|
||||
from django.test.client import Client, RequestFactory
|
||||
from django.test.testcases import (
|
||||
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
|
||||
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
||||
)
|
||||
from django.test.utils import (
|
||||
ignore_warnings, modify_settings, override_settings,
|
||||
override_system_checks, tag,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase',
|
||||
'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature',
|
||||
'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings',
|
||||
'modify_settings', 'override_settings', 'override_system_checks', 'tag',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
707
venv/lib/python3.8/site-packages/django/test/client.py
Normal file
707
venv/lib/python3.8/site-packages/django/test/client.py
Normal file
@@ -0,0 +1,707 @@
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from importlib import import_module
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.base import BaseHandler
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.core.signals import (
|
||||
got_request_exception, request_finished, request_started,
|
||||
)
|
||||
from django.db import close_old_connections
|
||||
from django.http import HttpRequest, QueryDict, SimpleCookie
|
||||
from django.test import signals
|
||||
from django.test.utils import ContextList
|
||||
from django.urls import resolve
|
||||
from django.utils.encoding import force_bytes
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
from django.utils.http import urlencode
|
||||
from django.utils.itercompat import is_iterable
|
||||
|
||||
__all__ = ('Client', 'RedirectCycleError', 'RequestFactory', 'encode_file', 'encode_multipart')
|
||||
|
||||
|
||||
BOUNDARY = 'BoUnDaRyStRiNg'
|
||||
MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
|
||||
CONTENT_TYPE_RE = re.compile(r'.*; charset=([\w\d-]+);?')
|
||||
# Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
|
||||
JSON_CONTENT_TYPE_RE = re.compile(r'^application\/(.+\+)?json')
|
||||
|
||||
|
||||
class RedirectCycleError(Exception):
|
||||
"""The test client has been asked to follow a redirect loop."""
|
||||
def __init__(self, message, last_response):
|
||||
super().__init__(message)
|
||||
self.last_response = last_response
|
||||
self.redirect_chain = last_response.redirect_chain
|
||||
|
||||
|
||||
class FakePayload:
|
||||
"""
|
||||
A wrapper around BytesIO that restricts what can be read since data from
|
||||
the network can't be sought and cannot be read outside of its content
|
||||
length. This makes sure that views can't do anything under the test client
|
||||
that wouldn't work in real life.
|
||||
"""
|
||||
def __init__(self, content=None):
|
||||
self.__content = BytesIO()
|
||||
self.__len = 0
|
||||
self.read_started = False
|
||||
if content is not None:
|
||||
self.write(content)
|
||||
|
||||
def __len__(self):
|
||||
return self.__len
|
||||
|
||||
def read(self, num_bytes=None):
|
||||
if not self.read_started:
|
||||
self.__content.seek(0)
|
||||
self.read_started = True
|
||||
if num_bytes is None:
|
||||
num_bytes = self.__len or 0
|
||||
assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
|
||||
content = self.__content.read(num_bytes)
|
||||
self.__len -= num_bytes
|
||||
return content
|
||||
|
||||
def write(self, content):
|
||||
if self.read_started:
|
||||
raise ValueError("Unable to write a payload after it's been read")
|
||||
content = force_bytes(content)
|
||||
self.__content.write(content)
|
||||
self.__len += len(content)
|
||||
|
||||
|
||||
def closing_iterator_wrapper(iterable, close):
|
||||
try:
|
||||
yield from iterable
|
||||
finally:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
|
||||
def conditional_content_removal(request, response):
|
||||
"""
|
||||
Simulate the behavior of most Web servers by removing the content of
|
||||
responses for HEAD requests, 1xx, 204, and 304 responses. Ensure
|
||||
compliance with RFC 7230, section 3.3.3.
|
||||
"""
|
||||
if 100 <= response.status_code < 200 or response.status_code in (204, 304):
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b''
|
||||
if request.method == 'HEAD':
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b''
|
||||
return response
|
||||
|
||||
|
||||
class ClientHandler(BaseHandler):
|
||||
"""
|
||||
A HTTP Handler that can be used for testing purposes. Use the WSGI
|
||||
interface to compose requests, but return the raw HttpResponse object with
|
||||
the originating WSGIRequest attached to its ``wsgi_request`` attribute.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, environ):
|
||||
# Set up middleware if needed. We couldn't do this earlier, because
|
||||
# settings weren't available.
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware()
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
request_started.send(sender=self.__class__, environ=environ)
|
||||
request_started.connect(close_old_connections)
|
||||
request = WSGIRequest(environ)
|
||||
# sneaky little hack so that we can easily get round
|
||||
# CsrfViewMiddleware. This makes life easier, and is probably
|
||||
# required for backwards compatibility with external tests against
|
||||
# admin views.
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
|
||||
# Request goes through middleware.
|
||||
response = self.get_response(request)
|
||||
|
||||
# Simulate behaviors of most Web servers.
|
||||
conditional_content_removal(request, response)
|
||||
|
||||
# Attach the originating request to the response so that it could be
|
||||
# later retrieved.
|
||||
response.wsgi_request = request
|
||||
|
||||
# Emulate a WSGI server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = closing_iterator_wrapper(
|
||||
response.streaming_content, response.close)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
response.close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
|
||||
"""
|
||||
Store templates and contexts that are rendered.
|
||||
|
||||
The context is copied so that it is an accurate representation at the time
|
||||
of rendering.
|
||||
"""
|
||||
store.setdefault('templates', []).append(template)
|
||||
if 'context' not in store:
|
||||
store['context'] = ContextList()
|
||||
store['context'].append(copy(context))
|
||||
|
||||
|
||||
def encode_multipart(boundary, data):
|
||||
"""
|
||||
Encode multipart POST data from a dictionary of form values.
|
||||
|
||||
The key will be used as the form data name; the value will be transmitted
|
||||
as content. If the value is a file, the contents of the file will be sent
|
||||
as an application/octet-stream; otherwise, str(value) will be sent.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
def to_bytes(s):
|
||||
return force_bytes(s, settings.DEFAULT_CHARSET)
|
||||
|
||||
# Not by any means perfect, but good enough for our purposes.
|
||||
def is_file(thing):
|
||||
return hasattr(thing, "read") and callable(thing.read)
|
||||
|
||||
# Each bit of the multipart form data could be either a form value or a
|
||||
# file, or a *list* of form values and/or files. Remember that HTTP field
|
||||
# names can be duplicated!
|
||||
for (key, value) in data.items():
|
||||
if value is None:
|
||||
raise TypeError(
|
||||
"Cannot encode None for key '%s' as POST data. Did you mean "
|
||||
"to pass an empty string or omit the value?" % key
|
||||
)
|
||||
elif is_file(value):
|
||||
lines.extend(encode_file(boundary, key, value))
|
||||
elif not isinstance(value, str) and is_iterable(value):
|
||||
for item in value:
|
||||
if is_file(item):
|
||||
lines.extend(encode_file(boundary, key, item))
|
||||
else:
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
item
|
||||
])
|
||||
else:
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
value
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
to_bytes('--%s--' % boundary),
|
||||
b'',
|
||||
])
|
||||
return b'\r\n'.join(lines)
|
||||
|
||||
|
||||
def encode_file(boundary, key, file):
|
||||
def to_bytes(s):
|
||||
return force_bytes(s, settings.DEFAULT_CHARSET)
|
||||
|
||||
# file.name might not be a string. For example, it's an int for
|
||||
# tempfile.TemporaryFile().
|
||||
file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str)
|
||||
filename = os.path.basename(file.name) if file_has_string_name else ''
|
||||
|
||||
if hasattr(file, 'content_type'):
|
||||
content_type = file.content_type
|
||||
elif filename:
|
||||
content_type = mimetypes.guess_type(filename)[0]
|
||||
else:
|
||||
content_type = None
|
||||
|
||||
if content_type is None:
|
||||
content_type = 'application/octet-stream'
|
||||
filename = filename or key
|
||||
return [
|
||||
to_bytes('--%s' % boundary),
|
||||
to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"'
|
||||
% (key, filename)),
|
||||
to_bytes('Content-Type: %s' % content_type),
|
||||
b'',
|
||||
to_bytes(file.read())
|
||||
]
|
||||
|
||||
|
||||
class RequestFactory:
|
||||
"""
|
||||
Class that lets you create mock Request objects for use in testing.
|
||||
|
||||
Usage:
|
||||
|
||||
rf = RequestFactory()
|
||||
get_request = rf.get('/hello/')
|
||||
post_request = rf.post('/submit/', {'foo': 'bar'})
|
||||
|
||||
Once you have a request object you can pass it to any view function,
|
||||
just as if that view had been hooked up using a URLconf.
|
||||
"""
|
||||
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
|
||||
self.json_encoder = json_encoder
|
||||
self.defaults = defaults
|
||||
self.cookies = SimpleCookie()
|
||||
self.errors = BytesIO()
|
||||
|
||||
def _base_environ(self, **request):
|
||||
"""
|
||||
The base environment for a request.
|
||||
"""
|
||||
# This is a minimal valid WSGI environ dictionary, plus:
|
||||
# - HTTP_COOKIE: for cookie support,
|
||||
# - REMOTE_ADDR: often useful, see #8551.
|
||||
# See https://www.python.org/dev/peps/pep-3333/#environ-variables
|
||||
return {
|
||||
'HTTP_COOKIE': '; '.join(sorted(
|
||||
'%s=%s' % (morsel.key, morsel.coded_value)
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
'PATH_INFO': '/',
|
||||
'REMOTE_ADDR': '127.0.0.1',
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'SCRIPT_NAME': '',
|
||||
'SERVER_NAME': 'testserver',
|
||||
'SERVER_PORT': '80',
|
||||
'SERVER_PROTOCOL': 'HTTP/1.1',
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': 'http',
|
||||
'wsgi.input': FakePayload(b''),
|
||||
'wsgi.errors': self.errors,
|
||||
'wsgi.multiprocess': True,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.run_once': False,
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
|
||||
def request(self, **request):
|
||||
"Construct a generic request object."
|
||||
return WSGIRequest(self._base_environ(**request))
|
||||
|
||||
def _encode_data(self, data, content_type):
|
||||
if content_type is MULTIPART_CONTENT:
|
||||
return encode_multipart(BOUNDARY, data)
|
||||
else:
|
||||
# Encode the content so that the byte representation is correct.
|
||||
match = CONTENT_TYPE_RE.match(content_type)
|
||||
if match:
|
||||
charset = match.group(1)
|
||||
else:
|
||||
charset = settings.DEFAULT_CHARSET
|
||||
return force_bytes(data, encoding=charset)
|
||||
|
||||
def _encode_json(self, data, content_type):
|
||||
"""
|
||||
Return encoded JSON if data is a dict, list, or tuple and content_type
|
||||
is application/json.
|
||||
"""
|
||||
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple))
|
||||
return json.dumps(data, cls=self.json_encoder) if should_encode else data
|
||||
|
||||
def _get_path(self, parsed):
|
||||
path = parsed.path
|
||||
# If there are parameters, add them
|
||||
if parsed.params:
|
||||
path += ";" + parsed.params
|
||||
path = unquote_to_bytes(path)
|
||||
# Replace the behavior where non-ASCII values in the WSGI environ are
|
||||
# arbitrarily decoded with ISO-8859-1.
|
||||
# Refs comment in `get_bytes_from_wsgi()`.
|
||||
return path.decode('iso-8859-1')
|
||||
|
||||
def get(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a GET request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic('GET', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
secure=False, **extra):
|
||||
"""Construct a POST request."""
|
||||
data = self._encode_json({} if data is None else data, content_type)
|
||||
post_data = self._encode_data(data, content_type)
|
||||
|
||||
return self.generic('POST', path, post_data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def head(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a HEAD request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic('HEAD', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def trace(self, path, secure=False, **extra):
|
||||
"""Construct a TRACE request."""
|
||||
return self.generic('TRACE', path, secure=secure, **extra)
|
||||
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"Construct an OPTIONS request."
|
||||
return self.generic('OPTIONS', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PUT request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('PUT', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PATCH request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('PATCH', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a DELETE request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('DELETE', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def generic(self, method, path, data='',
|
||||
content_type='application/octet-stream', secure=False,
|
||||
**extra):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
r = {
|
||||
'PATH_INFO': self._get_path(parsed),
|
||||
'REQUEST_METHOD': method,
|
||||
'SERVER_PORT': '443' if secure else '80',
|
||||
'wsgi.url_scheme': 'https' if secure else 'http',
|
||||
}
|
||||
if data:
|
||||
r.update({
|
||||
'CONTENT_LENGTH': str(len(data)),
|
||||
'CONTENT_TYPE': content_type,
|
||||
'wsgi.input': FakePayload(data),
|
||||
})
|
||||
r.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
|
||||
if not r.get('QUERY_STRING'):
|
||||
# WSGI requires latin-1 encoded strings. See get_path_info().
|
||||
query_string = parsed[4].encode().decode('iso-8859-1')
|
||||
r['QUERY_STRING'] = query_string
|
||||
return self.request(**r)
|
||||
|
||||
|
||||
class Client(RequestFactory):
|
||||
"""
|
||||
A class that can act as a client for testing purposes.
|
||||
|
||||
It allows the user to compose GET and POST requests, and
|
||||
obtain the response that the server gave to those requests.
|
||||
The server Response objects are annotated with the details
|
||||
of the contexts and templates that were rendered during the
|
||||
process of serving the request.
|
||||
|
||||
Client objects are stateful - they will retain cookie (and
|
||||
thus session) details for the lifetime of the Client instance.
|
||||
|
||||
This is not intended as a replacement for Twill/Selenium or
|
||||
the like - it is here to allow testing against the
|
||||
contexts and templates produced by a view, rather than the
|
||||
HTML rendered to the end-user.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = ClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
self.exc_info = sys.exc_info()
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Return the current session variables."""
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if cookie:
|
||||
return engine.SessionStore(cookie.value)
|
||||
|
||||
session = engine.SessionStore()
|
||||
session.save()
|
||||
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
return session
|
||||
|
||||
def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the environment dictionary and pass
|
||||
to the handler, return the result of the handler. Assume defaults for
|
||||
the query environment, which can be overridden using the arguments to
|
||||
the request.
|
||||
"""
|
||||
environ = self._base_environ(**request)
|
||||
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = "template-render-%s" % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = "request-exception-%s" % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = self.handler(environ)
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Look for a signaled exception, clear the current context exception
|
||||
# data, then re-raise the signaled exception. Also clear the signaled
|
||||
# exception from the local cache.
|
||||
response.exc_info = self.exc_info
|
||||
if self.exc_info:
|
||||
_, exc_value, _ = self.exc_info
|
||||
self.exc_info = None
|
||||
if self.raise_request_exception:
|
||||
raise exc_value
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO']))
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
if response.context and len(response.context) == 1:
|
||||
response.context = response.context[0]
|
||||
# Update persistent cookie data.
|
||||
if response.cookies:
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
||||
|
||||
def get(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using GET."""
|
||||
response = super().get(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using POST."""
|
||||
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def head(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using HEAD."""
|
||||
response = super().head(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using OPTIONS."""
|
||||
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PUT."""
|
||||
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PATCH."""
|
||||
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a DELETE request to the server."""
|
||||
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def trace(self, path, data='', follow=False, secure=False, **extra):
|
||||
"""Send a TRACE request to the server."""
|
||||
response = super().trace(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def login(self, **credentials):
|
||||
"""
|
||||
Set the Factory to appear as if it has successfully logged into a site.
|
||||
|
||||
Return True if login is possible; False if the provided credentials
|
||||
are incorrect.
|
||||
"""
|
||||
from django.contrib.auth import authenticate
|
||||
user = authenticate(**credentials)
|
||||
if user:
|
||||
self._login(user)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def force_login(self, user, backend=None):
|
||||
def get_backend():
|
||||
from django.contrib.auth import load_backend
|
||||
for backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
if hasattr(backend, 'get_user'):
|
||||
return backend_path
|
||||
if backend is None:
|
||||
backend = get_backend()
|
||||
user.backend = backend
|
||||
self._login(user, backend)
|
||||
|
||||
def _login(self, user, backend=None):
|
||||
from django.contrib.auth import login
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
|
||||
# Create a fake request to store login details.
|
||||
request = HttpRequest()
|
||||
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
else:
|
||||
request.session = engine.SessionStore()
|
||||
login(request, user, backend)
|
||||
|
||||
# Save the session values.
|
||||
request.session.save()
|
||||
|
||||
# Set the cookie to represent the session.
|
||||
session_cookie = settings.SESSION_COOKIE_NAME
|
||||
self.cookies[session_cookie] = request.session.session_key
|
||||
cookie_data = {
|
||||
'max-age': None,
|
||||
'path': '/',
|
||||
'domain': settings.SESSION_COOKIE_DOMAIN,
|
||||
'secure': settings.SESSION_COOKIE_SECURE or None,
|
||||
'expires': None,
|
||||
}
|
||||
self.cookies[session_cookie].update(cookie_data)
|
||||
|
||||
def logout(self):
|
||||
"""Log out the user by removing the cookies and session object."""
|
||||
from django.contrib.auth import get_user, logout
|
||||
|
||||
request = HttpRequest()
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
request.user = get_user(request)
|
||||
else:
|
||||
request.session = engine.SessionStore()
|
||||
logout(request)
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _parse_json(self, response, **extra):
|
||||
if not hasattr(response, '_json'):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
||||
raise ValueError(
|
||||
'Content-Type header is "{0}", not "application/json"'
|
||||
.format(response.get('Content-Type'))
|
||||
)
|
||||
response._json = json.loads(response.content.decode(response.charset), **extra)
|
||||
return response._json
|
||||
|
||||
def _handle_redirects(self, response, data='', content_type='', **extra):
|
||||
"""
|
||||
Follow any redirects by requesting responses from the server using GET.
|
||||
"""
|
||||
response.redirect_chain = []
|
||||
redirect_status_codes = (
|
||||
HTTPStatus.MOVED_PERMANENTLY,
|
||||
HTTPStatus.FOUND,
|
||||
HTTPStatus.SEE_OTHER,
|
||||
HTTPStatus.TEMPORARY_REDIRECT,
|
||||
HTTPStatus.PERMANENT_REDIRECT,
|
||||
)
|
||||
while response.status_code in redirect_status_codes:
|
||||
response_url = response.url
|
||||
redirect_chain = response.redirect_chain
|
||||
redirect_chain.append((response_url, response.status_code))
|
||||
|
||||
url = urlsplit(response_url)
|
||||
if url.scheme:
|
||||
extra['wsgi.url_scheme'] = url.scheme
|
||||
if url.hostname:
|
||||
extra['SERVER_NAME'] = url.hostname
|
||||
if url.port:
|
||||
extra['SERVER_PORT'] = str(url.port)
|
||||
|
||||
# Prepend the request path to handle relative path redirects
|
||||
path = url.path
|
||||
if not path.startswith('/'):
|
||||
path = urljoin(response.request['PATH_INFO'], path)
|
||||
|
||||
if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
|
||||
# Preserve request method post-redirect for 307/308 responses.
|
||||
request_method = getattr(self, response.request['REQUEST_METHOD'].lower())
|
||||
else:
|
||||
request_method = self.get
|
||||
data = QueryDict(url.query)
|
||||
content_type = None
|
||||
|
||||
response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
|
||||
response.redirect_chain = redirect_chain
|
||||
|
||||
if redirect_chain[-1] in redirect_chain[:-1]:
|
||||
# Check that we're not redirecting to somewhere we've already
|
||||
# been to, to prevent loops.
|
||||
raise RedirectCycleError("Redirect loop detected.", last_response=response)
|
||||
if len(redirect_chain) > 20:
|
||||
# Such a lengthy chain likely also means a loop, but one with
|
||||
# a growing path, changing view, or changing query argument;
|
||||
# 20 is the value of "network.http.redirection-limit" from Firefox.
|
||||
raise RedirectCycleError("Too many redirects.", last_response=response)
|
||||
|
||||
return response
|
||||
228
venv/lib/python3.8/site-packages/django/test/html.py
Normal file
228
venv/lib/python3.8/site-packages/django/test/html.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Compare two HTML documents."""
|
||||
|
||||
import re
|
||||
from html.parser import HTMLParser
|
||||
|
||||
# ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020
|
||||
# SPACE.
|
||||
# https://infra.spec.whatwg.org/#ascii-whitespace
|
||||
ASCII_WHITESPACE = re.compile(r'[\t\n\f\r ]+')
|
||||
|
||||
|
||||
def normalize_whitespace(string):
|
||||
return ASCII_WHITESPACE.sub(' ', string)
|
||||
|
||||
|
||||
class Element:
|
||||
def __init__(self, name, attributes):
|
||||
self.name = name
|
||||
self.attributes = sorted(attributes)
|
||||
self.children = []
|
||||
|
||||
def append(self, element):
|
||||
if isinstance(element, str):
|
||||
element = normalize_whitespace(element)
|
||||
if self.children:
|
||||
if isinstance(self.children[-1], str):
|
||||
self.children[-1] += element
|
||||
self.children[-1] = normalize_whitespace(self.children[-1])
|
||||
return
|
||||
elif self.children:
|
||||
# removing last children if it is only whitespace
|
||||
# this can result in incorrect dom representations since
|
||||
# whitespace between inline tags like <span> is significant
|
||||
if isinstance(self.children[-1], str):
|
||||
if self.children[-1].isspace():
|
||||
self.children.pop()
|
||||
if element:
|
||||
self.children.append(element)
|
||||
|
||||
def finalize(self):
|
||||
def rstrip_last_element(children):
|
||||
if children:
|
||||
if isinstance(children[-1], str):
|
||||
children[-1] = children[-1].rstrip()
|
||||
if not children[-1]:
|
||||
children.pop()
|
||||
children = rstrip_last_element(children)
|
||||
return children
|
||||
|
||||
rstrip_last_element(self.children)
|
||||
for i, child in enumerate(self.children):
|
||||
if isinstance(child, str):
|
||||
self.children[i] = child.strip()
|
||||
elif hasattr(child, 'finalize'):
|
||||
child.finalize()
|
||||
|
||||
def __eq__(self, element):
|
||||
if not hasattr(element, 'name') or self.name != element.name:
|
||||
return False
|
||||
if len(self.attributes) != len(element.attributes):
|
||||
return False
|
||||
if self.attributes != element.attributes:
|
||||
# attributes without a value is same as attribute with value that
|
||||
# equals the attributes name:
|
||||
# <input checked> == <input checked="checked">
|
||||
for i in range(len(self.attributes)):
|
||||
attr, value = self.attributes[i]
|
||||
other_attr, other_value = element.attributes[i]
|
||||
if value is None:
|
||||
value = attr
|
||||
if other_value is None:
|
||||
other_value = other_attr
|
||||
if attr != other_attr or value != other_value:
|
||||
return False
|
||||
return self.children == element.children
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name, *self.attributes))
|
||||
|
||||
def _count(self, element, count=True):
|
||||
if not isinstance(element, str):
|
||||
if self == element:
|
||||
return 1
|
||||
if isinstance(element, RootElement):
|
||||
if self.children == element.children:
|
||||
return 1
|
||||
i = 0
|
||||
for child in self.children:
|
||||
# child is text content and element is also text content, then
|
||||
# make a simple "text" in "text"
|
||||
if isinstance(child, str):
|
||||
if isinstance(element, str):
|
||||
if count:
|
||||
i += child.count(element)
|
||||
elif element in child:
|
||||
return 1
|
||||
else:
|
||||
i += child._count(element, count=count)
|
||||
if not count and i:
|
||||
return i
|
||||
return i
|
||||
|
||||
def __contains__(self, element):
|
||||
return self._count(element, count=False) > 0
|
||||
|
||||
def count(self, element):
|
||||
return self._count(element, count=True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.children[key]
|
||||
|
||||
def __str__(self):
|
||||
output = '<%s' % self.name
|
||||
for key, value in self.attributes:
|
||||
if value:
|
||||
output += ' %s="%s"' % (key, value)
|
||||
else:
|
||||
output += ' %s' % key
|
||||
if self.children:
|
||||
output += '>\n'
|
||||
output += ''.join(str(c) for c in self.children)
|
||||
output += '\n</%s>' % self.name
|
||||
else:
|
||||
output += '>'
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class RootElement(Element):
|
||||
def __init__(self):
|
||||
super().__init__(None, ())
|
||||
|
||||
def __str__(self):
|
||||
return ''.join(str(c) for c in self.children)
|
||||
|
||||
|
||||
class HTMLParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Parser(HTMLParser):
|
||||
# https://html.spec.whatwg.org/#void-elements
|
||||
SELF_CLOSING_TAGS = {
|
||||
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta',
|
||||
'param', 'source', 'track', 'wbr',
|
||||
# Deprecated tags
|
||||
'frame', 'spacer',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.root = RootElement()
|
||||
self.open_tags = []
|
||||
self.element_positions = {}
|
||||
|
||||
def error(self, msg):
|
||||
raise HTMLParseError(msg, self.getpos())
|
||||
|
||||
def format_position(self, position=None, element=None):
|
||||
if not position and element:
|
||||
position = self.element_positions[element]
|
||||
if position is None:
|
||||
position = self.getpos()
|
||||
if hasattr(position, 'lineno'):
|
||||
position = position.lineno, position.offset
|
||||
return 'Line %d, Column %d' % position
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
if self.open_tags:
|
||||
return self.open_tags[-1]
|
||||
else:
|
||||
return self.root
|
||||
|
||||
def handle_startendtag(self, tag, attrs):
|
||||
self.handle_starttag(tag, attrs)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
self.handle_endtag(tag)
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
# Special case handling of 'class' attribute, so that comparisons of DOM
|
||||
# instances are not sensitive to ordering of classes.
|
||||
attrs = [
|
||||
(name, ' '.join(sorted(value for value in ASCII_WHITESPACE.split(value) if value)))
|
||||
if name == "class"
|
||||
else (name, value)
|
||||
for name, value in attrs
|
||||
]
|
||||
element = Element(tag, attrs)
|
||||
self.current.append(element)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
self.open_tags.append(element)
|
||||
self.element_positions[element] = self.getpos()
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
while element.name != tag:
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
|
||||
def handle_data(self, data):
|
||||
self.current.append(data)
|
||||
|
||||
|
||||
def parse_html(html):
|
||||
"""
|
||||
Take a string that contains *valid* HTML and turn it into a Python object
|
||||
structure that can be easily compared against other HTML on semantic
|
||||
equivalence. Syntactical differences like which quotation is used on
|
||||
arguments will be ignored.
|
||||
"""
|
||||
parser = Parser()
|
||||
parser.feed(html)
|
||||
parser.close()
|
||||
document = parser.root
|
||||
document.finalize()
|
||||
# Removing ROOT element if it's not necessary
|
||||
if len(document.children) == 1:
|
||||
if not isinstance(document.children[0], str):
|
||||
document = document.children[0]
|
||||
return document
|
||||
798
venv/lib/python3.8/site-packages/django/test/runner.py
Normal file
798
venv/lib/python3.8/site-packages/django/test/runner.py
Normal file
@@ -0,0 +1,798 @@
|
||||
import ctypes
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import textwrap
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.db import connections
|
||||
from django.test import SimpleTestCase, TestCase
|
||||
from django.test.utils import (
|
||||
setup_databases as _setup_databases, setup_test_environment,
|
||||
teardown_databases as _teardown_databases, teardown_test_environment,
|
||||
)
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.version import PY37
|
||||
|
||||
try:
|
||||
import ipdb as pdb
|
||||
except ImportError:
|
||||
import pdb
|
||||
|
||||
try:
|
||||
import tblib.pickling_support
|
||||
except ImportError:
|
||||
tblib = None
|
||||
|
||||
|
||||
class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
def __init__(self, stream, descriptions, verbosity):
|
||||
self.logger = logging.getLogger('django.db.backends')
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
super().__init__(stream, descriptions, verbosity)
|
||||
|
||||
def startTest(self, test):
|
||||
self.debug_sql_stream = StringIO()
|
||||
self.handler = logging.StreamHandler(self.debug_sql_stream)
|
||||
self.logger.addHandler(self.handler)
|
||||
super().startTest(test)
|
||||
|
||||
def stopTest(self, test):
|
||||
super().stopTest(test)
|
||||
self.logger.removeHandler(self.handler)
|
||||
if self.showAll:
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.stream.write(self.debug_sql_stream.read())
|
||||
self.stream.writeln(self.separator2)
|
||||
|
||||
def addError(self, test, err):
|
||||
super().addError(test, err)
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.errors[-1] = self.errors[-1] + (self.debug_sql_stream.read(),)
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super().addFailure(test, err)
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
|
||||
|
||||
def addSubTest(self, test, subtest, err):
|
||||
super().addSubTest(test, subtest, err)
|
||||
if err is not None:
|
||||
self.debug_sql_stream.seek(0)
|
||||
errors = self.failures if issubclass(err[0], test.failureException) else self.errors
|
||||
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
|
||||
|
||||
def printErrorList(self, flavour, errors):
|
||||
for test, err, sql_debug in errors:
|
||||
self.stream.writeln(self.separator1)
|
||||
self.stream.writeln("%s: %s" % (flavour, self.getDescription(test)))
|
||||
self.stream.writeln(self.separator2)
|
||||
self.stream.writeln(err)
|
||||
self.stream.writeln(self.separator2)
|
||||
self.stream.writeln(sql_debug)
|
||||
|
||||
|
||||
class PDBDebugResult(unittest.TextTestResult):
|
||||
"""
|
||||
Custom result class that triggers a PDB session when an error or failure
|
||||
occurs.
|
||||
"""
|
||||
|
||||
def addError(self, test, err):
|
||||
super().addError(test, err)
|
||||
self.debug(err)
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super().addFailure(test, err)
|
||||
self.debug(err)
|
||||
|
||||
def debug(self, error):
|
||||
exc_type, exc_value, traceback = error
|
||||
print("\nOpening PDB: %r" % exc_value)
|
||||
pdb.post_mortem(traceback)
|
||||
|
||||
|
||||
class RemoteTestResult:
|
||||
"""
|
||||
Record information about which tests have succeeded and which have failed.
|
||||
|
||||
The sole purpose of this class is to record events in the child processes
|
||||
so they can be replayed in the master process. As a consequence it doesn't
|
||||
inherit unittest.TestResult and doesn't attempt to implement all its API.
|
||||
|
||||
The implementation matches the unpythonic coding style of unittest2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if tblib is not None:
|
||||
tblib.pickling_support.install()
|
||||
|
||||
self.events = []
|
||||
self.failfast = False
|
||||
self.shouldStop = False
|
||||
self.testsRun = 0
|
||||
|
||||
@property
|
||||
def test_index(self):
|
||||
return self.testsRun - 1
|
||||
|
||||
def _confirm_picklable(self, obj):
|
||||
"""
|
||||
Confirm that obj can be pickled and unpickled as multiprocessing will
|
||||
need to pickle the exception in the child process and unpickle it in
|
||||
the parent process. Let the exception rise, if not.
|
||||
"""
|
||||
pickle.loads(pickle.dumps(obj))
|
||||
|
||||
def _print_unpicklable_subtest(self, test, subtest, pickle_exc):
|
||||
print("""
|
||||
Subtest failed:
|
||||
|
||||
test: {}
|
||||
subtest: {}
|
||||
|
||||
Unfortunately, the subtest that failed cannot be pickled, so the parallel
|
||||
test runner cannot handle it cleanly. Here is the pickling error:
|
||||
|
||||
> {}
|
||||
|
||||
You should re-run this test with --parallel=1 to reproduce the failure
|
||||
with a cleaner failure message.
|
||||
""".format(test, subtest, pickle_exc))
|
||||
|
||||
def check_picklable(self, test, err):
|
||||
# Ensure that sys.exc_info() tuples are picklable. This displays a
|
||||
# clear multiprocessing.pool.RemoteTraceback generated in the child
|
||||
# process instead of a multiprocessing.pool.MaybeEncodingError, making
|
||||
# the root cause easier to figure out for users who aren't familiar
|
||||
# with the multiprocessing module. Since we're in a forked process,
|
||||
# our best chance to communicate with them is to print to stdout.
|
||||
try:
|
||||
self._confirm_picklable(err)
|
||||
except Exception as exc:
|
||||
original_exc_txt = repr(err[1])
|
||||
original_exc_txt = textwrap.fill(original_exc_txt, 75, initial_indent=' ', subsequent_indent=' ')
|
||||
pickle_exc_txt = repr(exc)
|
||||
pickle_exc_txt = textwrap.fill(pickle_exc_txt, 75, initial_indent=' ', subsequent_indent=' ')
|
||||
if tblib is None:
|
||||
print("""
|
||||
|
||||
{} failed:
|
||||
|
||||
{}
|
||||
|
||||
Unfortunately, tracebacks cannot be pickled, making it impossible for the
|
||||
parallel test runner to handle this exception cleanly.
|
||||
|
||||
In order to see the traceback, you should install tblib:
|
||||
|
||||
python -m pip install tblib
|
||||
""".format(test, original_exc_txt))
|
||||
else:
|
||||
print("""
|
||||
|
||||
{} failed:
|
||||
|
||||
{}
|
||||
|
||||
Unfortunately, the exception it raised cannot be pickled, making it impossible
|
||||
for the parallel test runner to handle it cleanly.
|
||||
|
||||
Here's the error encountered while trying to pickle the exception:
|
||||
|
||||
{}
|
||||
|
||||
You should re-run this test with the --parallel=1 option to reproduce the
|
||||
failure and get a correct traceback.
|
||||
""".format(test, original_exc_txt, pickle_exc_txt))
|
||||
raise
|
||||
|
||||
def check_subtest_picklable(self, test, subtest):
|
||||
try:
|
||||
self._confirm_picklable(subtest)
|
||||
except Exception as exc:
|
||||
self._print_unpicklable_subtest(test, subtest, exc)
|
||||
raise
|
||||
|
||||
def stop_if_failfast(self):
|
||||
if self.failfast:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
self.shouldStop = True
|
||||
|
||||
def startTestRun(self):
|
||||
self.events.append(('startTestRun',))
|
||||
|
||||
def stopTestRun(self):
|
||||
self.events.append(('stopTestRun',))
|
||||
|
||||
def startTest(self, test):
|
||||
self.testsRun += 1
|
||||
self.events.append(('startTest', self.test_index))
|
||||
|
||||
def stopTest(self, test):
|
||||
self.events.append(('stopTest', self.test_index))
|
||||
|
||||
def addError(self, test, err):
|
||||
self.check_picklable(test, err)
|
||||
self.events.append(('addError', self.test_index, err))
|
||||
self.stop_if_failfast()
|
||||
|
||||
def addFailure(self, test, err):
|
||||
self.check_picklable(test, err)
|
||||
self.events.append(('addFailure', self.test_index, err))
|
||||
self.stop_if_failfast()
|
||||
|
||||
def addSubTest(self, test, subtest, err):
|
||||
# Follow Python 3.5's implementation of unittest.TestResult.addSubTest()
|
||||
# by not doing anything when a subtest is successful.
|
||||
if err is not None:
|
||||
# Call check_picklable() before check_subtest_picklable() since
|
||||
# check_picklable() performs the tblib check.
|
||||
self.check_picklable(test, err)
|
||||
self.check_subtest_picklable(test, subtest)
|
||||
self.events.append(('addSubTest', self.test_index, subtest, err))
|
||||
self.stop_if_failfast()
|
||||
|
||||
def addSuccess(self, test):
|
||||
self.events.append(('addSuccess', self.test_index))
|
||||
|
||||
def addSkip(self, test, reason):
|
||||
self.events.append(('addSkip', self.test_index, reason))
|
||||
|
||||
def addExpectedFailure(self, test, err):
|
||||
# If tblib isn't installed, pickling the traceback will always fail.
|
||||
# However we don't want tblib to be required for running the tests
|
||||
# when they pass or fail as expected. Drop the traceback when an
|
||||
# expected failure occurs.
|
||||
if tblib is None:
|
||||
err = err[0], err[1], None
|
||||
self.check_picklable(test, err)
|
||||
self.events.append(('addExpectedFailure', self.test_index, err))
|
||||
|
||||
def addUnexpectedSuccess(self, test):
|
||||
self.events.append(('addUnexpectedSuccess', self.test_index))
|
||||
self.stop_if_failfast()
|
||||
|
||||
|
||||
class RemoteTestRunner:
|
||||
"""
|
||||
Run tests and record everything but don't display anything.
|
||||
|
||||
The implementation matches the unpythonic coding style of unittest2.
|
||||
"""
|
||||
|
||||
resultclass = RemoteTestResult
|
||||
|
||||
def __init__(self, failfast=False, resultclass=None):
|
||||
self.failfast = failfast
|
||||
if resultclass is not None:
|
||||
self.resultclass = resultclass
|
||||
|
||||
def run(self, test):
|
||||
result = self.resultclass()
|
||||
unittest.registerResult(result)
|
||||
result.failfast = self.failfast
|
||||
test(result)
|
||||
return result
|
||||
|
||||
|
||||
def default_test_processes():
|
||||
"""Default number of test processes when using the --parallel option."""
|
||||
# The current implementation of the parallel test runner requires
|
||||
# multiprocessing to start subprocesses with fork().
|
||||
if multiprocessing.get_start_method() != 'fork':
|
||||
return 1
|
||||
try:
|
||||
return int(os.environ['DJANGO_TEST_PROCESSES'])
|
||||
except KeyError:
|
||||
return multiprocessing.cpu_count()
|
||||
|
||||
|
||||
_worker_id = 0
|
||||
|
||||
|
||||
def _init_worker(counter):
|
||||
"""
|
||||
Switch to databases dedicated to this worker.
|
||||
|
||||
This helper lives at module-level because of the multiprocessing module's
|
||||
requirements.
|
||||
"""
|
||||
|
||||
global _worker_id
|
||||
|
||||
with counter.get_lock():
|
||||
counter.value += 1
|
||||
_worker_id = counter.value
|
||||
|
||||
for alias in connections:
|
||||
connection = connections[alias]
|
||||
settings_dict = connection.creation.get_test_db_clone_settings(str(_worker_id))
|
||||
# connection.settings_dict must be updated in place for changes to be
|
||||
# reflected in django.db.connections. If the following line assigned
|
||||
# connection.settings_dict = settings_dict, new threads would connect
|
||||
# to the default database instead of the appropriate clone.
|
||||
connection.settings_dict.update(settings_dict)
|
||||
connection.close()
|
||||
|
||||
|
||||
def _run_subsuite(args):
|
||||
"""
|
||||
Run a suite of tests with a RemoteTestRunner and return a RemoteTestResult.
|
||||
|
||||
This helper lives at module-level and its arguments are wrapped in a tuple
|
||||
because of the multiprocessing module's requirements.
|
||||
"""
|
||||
runner_class, subsuite_index, subsuite, failfast = args
|
||||
runner = runner_class(failfast=failfast)
|
||||
result = runner.run(subsuite)
|
||||
return subsuite_index, result.events
|
||||
|
||||
|
||||
class ParallelTestSuite(unittest.TestSuite):
|
||||
"""
|
||||
Run a series of tests in parallel in several processes.
|
||||
|
||||
While the unittest module's documentation implies that orchestrating the
|
||||
execution of tests is the responsibility of the test runner, in practice,
|
||||
it appears that TestRunner classes are more concerned with formatting and
|
||||
displaying test results.
|
||||
|
||||
Since there are fewer use cases for customizing TestSuite than TestRunner,
|
||||
implementing parallelization at the level of the TestSuite improves
|
||||
interoperability with existing custom test runners. A single instance of a
|
||||
test runner can still collect results from all tests without being aware
|
||||
that they have been run in parallel.
|
||||
"""
|
||||
|
||||
# In case someone wants to modify these in a subclass.
|
||||
init_worker = _init_worker
|
||||
run_subsuite = _run_subsuite
|
||||
runner_class = RemoteTestRunner
|
||||
|
||||
def __init__(self, suite, processes, failfast=False):
|
||||
self.subsuites = partition_suite_by_case(suite)
|
||||
self.processes = processes
|
||||
self.failfast = failfast
|
||||
super().__init__()
|
||||
|
||||
def run(self, result):
|
||||
"""
|
||||
Distribute test cases across workers.
|
||||
|
||||
Return an identifier of each test case with its result in order to use
|
||||
imap_unordered to show results as soon as they're available.
|
||||
|
||||
To minimize pickling errors when getting results from workers:
|
||||
|
||||
- pass back numeric indexes in self.subsuites instead of tests
|
||||
- make tracebacks picklable with tblib, if available
|
||||
|
||||
Even with tblib, errors may still occur for dynamically created
|
||||
exception classes which cannot be unpickled.
|
||||
"""
|
||||
counter = multiprocessing.Value(ctypes.c_int, 0)
|
||||
pool = multiprocessing.Pool(
|
||||
processes=self.processes,
|
||||
initializer=self.init_worker.__func__,
|
||||
initargs=[counter],
|
||||
)
|
||||
args = [
|
||||
(self.runner_class, index, subsuite, self.failfast)
|
||||
for index, subsuite in enumerate(self.subsuites)
|
||||
]
|
||||
test_results = pool.imap_unordered(self.run_subsuite.__func__, args)
|
||||
|
||||
while True:
|
||||
if result.shouldStop:
|
||||
pool.terminate()
|
||||
break
|
||||
|
||||
try:
|
||||
subsuite_index, events = test_results.next(timeout=0.1)
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
except StopIteration:
|
||||
pool.close()
|
||||
break
|
||||
|
||||
tests = list(self.subsuites[subsuite_index])
|
||||
for event in events:
|
||||
event_name = event[0]
|
||||
handler = getattr(result, event_name, None)
|
||||
if handler is None:
|
||||
continue
|
||||
test = tests[event[1]]
|
||||
args = event[2:]
|
||||
handler(test, *args)
|
||||
|
||||
pool.join()
|
||||
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.subsuites)
|
||||
|
||||
|
||||
class DiscoverRunner:
|
||||
"""A Django test runner that uses unittest2 test discovery."""
|
||||
|
||||
test_suite = unittest.TestSuite
|
||||
parallel_test_suite = ParallelTestSuite
|
||||
test_runner = unittest.TextTestRunner
|
||||
test_loader = unittest.defaultTestLoader
|
||||
reorder_by = (TestCase, SimpleTestCase)
|
||||
|
||||
def __init__(self, pattern=None, top_level=None, verbosity=1,
|
||||
interactive=True, failfast=False, keepdb=False,
|
||||
reverse=False, debug_mode=False, debug_sql=False, parallel=0,
|
||||
tags=None, exclude_tags=None, test_name_patterns=None,
|
||||
pdb=False, **kwargs):
|
||||
|
||||
self.pattern = pattern
|
||||
self.top_level = top_level
|
||||
self.verbosity = verbosity
|
||||
self.interactive = interactive
|
||||
self.failfast = failfast
|
||||
self.keepdb = keepdb
|
||||
self.reverse = reverse
|
||||
self.debug_mode = debug_mode
|
||||
self.debug_sql = debug_sql
|
||||
self.parallel = parallel
|
||||
self.tags = set(tags or [])
|
||||
self.exclude_tags = set(exclude_tags or [])
|
||||
self.pdb = pdb
|
||||
if self.pdb and self.parallel > 1:
|
||||
raise ValueError('You cannot use --pdb with parallel tests; pass --parallel=1 to use it.')
|
||||
self.test_name_patterns = None
|
||||
if test_name_patterns:
|
||||
# unittest does not export the _convert_select_pattern function
|
||||
# that converts command-line arguments to patterns.
|
||||
self.test_name_patterns = {
|
||||
pattern if '*' in pattern else '*%s*' % pattern
|
||||
for pattern in test_name_patterns
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
parser.add_argument(
|
||||
'-t', '--top-level-directory', dest='top_level',
|
||||
help='Top level of project for unittest discovery.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-p', '--pattern', default="test*.py",
|
||||
help='The test matching pattern. Defaults to test*.py.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--keepdb', action='store_true',
|
||||
help='Preserves the test DB between runs.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-r', '--reverse', action='store_true',
|
||||
help='Reverses test cases order.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug-mode', action='store_true',
|
||||
help='Sets settings.DEBUG to True.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-d', '--debug-sql', action='store_true',
|
||||
help='Prints logged SQL queries on failure.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--parallel', nargs='?', default=1, type=int,
|
||||
const=default_test_processes(), metavar='N',
|
||||
help='Run tests using up to N parallel processes.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tag', action='append', dest='tags',
|
||||
help='Run only tests with the specified tag. Can be used multiple times.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--exclude-tag', action='append', dest='exclude_tags',
|
||||
help='Do not run tests with the specified tag. Can be used multiple times.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdb', action='store_true',
|
||||
help='Runs a debugger (pdb, or ipdb if installed) on error or failure.'
|
||||
)
|
||||
if PY37:
|
||||
parser.add_argument(
|
||||
'-k', action='append', dest='test_name_patterns',
|
||||
help=(
|
||||
'Only run test methods and classes that match the pattern '
|
||||
'or substring. Can be used multiple times. Same as '
|
||||
'unittest -k option.'
|
||||
),
|
||||
)
|
||||
|
||||
def setup_test_environment(self, **kwargs):
|
||||
setup_test_environment(debug=self.debug_mode)
|
||||
unittest.installHandler()
|
||||
|
||||
def build_suite(self, test_labels=None, extra_tests=None, **kwargs):
|
||||
suite = self.test_suite()
|
||||
test_labels = test_labels or ['.']
|
||||
extra_tests = extra_tests or []
|
||||
self.test_loader.testNamePatterns = self.test_name_patterns
|
||||
|
||||
discover_kwargs = {}
|
||||
if self.pattern is not None:
|
||||
discover_kwargs['pattern'] = self.pattern
|
||||
if self.top_level is not None:
|
||||
discover_kwargs['top_level_dir'] = self.top_level
|
||||
|
||||
for label in test_labels:
|
||||
kwargs = discover_kwargs.copy()
|
||||
tests = None
|
||||
|
||||
label_as_path = os.path.abspath(label)
|
||||
|
||||
# if a module, or "module.ClassName[.method_name]", just run those
|
||||
if not os.path.exists(label_as_path):
|
||||
tests = self.test_loader.loadTestsFromName(label)
|
||||
elif os.path.isdir(label_as_path) and not self.top_level:
|
||||
# Try to be a bit smarter than unittest about finding the
|
||||
# default top-level for a given directory path, to avoid
|
||||
# breaking relative imports. (Unittest's default is to set
|
||||
# top-level equal to the path, which means relative imports
|
||||
# will result in "Attempted relative import in non-package.").
|
||||
|
||||
# We'd be happy to skip this and require dotted module paths
|
||||
# (which don't cause this problem) instead of file paths (which
|
||||
# do), but in the case of a directory in the cwd, which would
|
||||
# be equally valid if considered as a top-level module or as a
|
||||
# directory path, unittest unfortunately prefers the latter.
|
||||
|
||||
top_level = label_as_path
|
||||
while True:
|
||||
init_py = os.path.join(top_level, '__init__.py')
|
||||
if os.path.exists(init_py):
|
||||
try_next = os.path.dirname(top_level)
|
||||
if try_next == top_level:
|
||||
# __init__.py all the way down? give up.
|
||||
break
|
||||
top_level = try_next
|
||||
continue
|
||||
break
|
||||
kwargs['top_level_dir'] = top_level
|
||||
|
||||
if not (tests and tests.countTestCases()) and is_discoverable(label):
|
||||
# Try discovery if path is a package or directory
|
||||
tests = self.test_loader.discover(start_dir=label, **kwargs)
|
||||
|
||||
# Make unittest forget the top-level dir it calculated from this
|
||||
# run, to support running tests from two different top-levels.
|
||||
self.test_loader._top_level_dir = None
|
||||
|
||||
suite.addTests(tests)
|
||||
|
||||
for test in extra_tests:
|
||||
suite.addTest(test)
|
||||
|
||||
if self.tags or self.exclude_tags:
|
||||
if self.verbosity >= 2:
|
||||
if self.tags:
|
||||
print('Including test tag(s): %s.' % ', '.join(sorted(self.tags)))
|
||||
if self.exclude_tags:
|
||||
print('Excluding test tag(s): %s.' % ', '.join(sorted(self.exclude_tags)))
|
||||
suite = filter_tests_by_tags(suite, self.tags, self.exclude_tags)
|
||||
suite = reorder_suite(suite, self.reorder_by, self.reverse)
|
||||
|
||||
if self.parallel > 1:
|
||||
parallel_suite = self.parallel_test_suite(suite, self.parallel, self.failfast)
|
||||
|
||||
# Since tests are distributed across processes on a per-TestCase
|
||||
# basis, there's no need for more processes than TestCases.
|
||||
parallel_units = len(parallel_suite.subsuites)
|
||||
self.parallel = min(self.parallel, parallel_units)
|
||||
|
||||
# If there's only one TestCase, parallelization isn't needed.
|
||||
if self.parallel > 1:
|
||||
suite = parallel_suite
|
||||
|
||||
return suite
|
||||
|
||||
def setup_databases(self, **kwargs):
|
||||
return _setup_databases(
|
||||
self.verbosity, self.interactive, self.keepdb, self.debug_sql,
|
||||
self.parallel, **kwargs
|
||||
)
|
||||
|
||||
def get_resultclass(self):
|
||||
if self.debug_sql:
|
||||
return DebugSQLTextTestResult
|
||||
elif self.pdb:
|
||||
return PDBDebugResult
|
||||
|
||||
def get_test_runner_kwargs(self):
|
||||
return {
|
||||
'failfast': self.failfast,
|
||||
'resultclass': self.get_resultclass(),
|
||||
'verbosity': self.verbosity,
|
||||
}
|
||||
|
||||
def run_checks(self):
|
||||
# Checks are run after database creation since some checks require
|
||||
# database access.
|
||||
call_command('check', verbosity=self.verbosity)
|
||||
|
||||
def run_suite(self, suite, **kwargs):
|
||||
kwargs = self.get_test_runner_kwargs()
|
||||
runner = self.test_runner(**kwargs)
|
||||
return runner.run(suite)
|
||||
|
||||
def teardown_databases(self, old_config, **kwargs):
|
||||
"""Destroy all the non-mirror databases."""
|
||||
_teardown_databases(
|
||||
old_config,
|
||||
verbosity=self.verbosity,
|
||||
parallel=self.parallel,
|
||||
keepdb=self.keepdb,
|
||||
)
|
||||
|
||||
def teardown_test_environment(self, **kwargs):
|
||||
unittest.removeHandler()
|
||||
teardown_test_environment()
|
||||
|
||||
def suite_result(self, suite, result, **kwargs):
|
||||
return len(result.failures) + len(result.errors)
|
||||
|
||||
def _get_databases(self, suite):
|
||||
databases = set()
|
||||
for test in suite:
|
||||
if isinstance(test, unittest.TestCase):
|
||||
test_databases = getattr(test, 'databases', None)
|
||||
if test_databases == '__all__':
|
||||
return set(connections)
|
||||
if test_databases:
|
||||
databases.update(test_databases)
|
||||
else:
|
||||
databases.update(self._get_databases(test))
|
||||
return databases
|
||||
|
||||
def get_databases(self, suite):
|
||||
databases = self._get_databases(suite)
|
||||
if self.verbosity >= 2:
|
||||
unused_databases = [alias for alias in connections if alias not in databases]
|
||||
if unused_databases:
|
||||
print('Skipping setup of unused database(s): %s.' % ', '.join(sorted(unused_databases)))
|
||||
return databases
|
||||
|
||||
def run_tests(self, test_labels, extra_tests=None, **kwargs):
|
||||
"""
|
||||
Run the unit tests for all the test labels in the provided list.
|
||||
|
||||
Test labels should be dotted Python paths to test modules, test
|
||||
classes, or test methods.
|
||||
|
||||
A list of 'extra' tests may also be provided; these tests
|
||||
will be added to the test suite.
|
||||
|
||||
Return the number of tests that failed.
|
||||
"""
|
||||
self.setup_test_environment()
|
||||
suite = self.build_suite(test_labels, extra_tests)
|
||||
databases = self.get_databases(suite)
|
||||
old_config = self.setup_databases(aliases=databases)
|
||||
run_failed = False
|
||||
try:
|
||||
self.run_checks()
|
||||
result = self.run_suite(suite)
|
||||
except Exception:
|
||||
run_failed = True
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
self.teardown_databases(old_config)
|
||||
self.teardown_test_environment()
|
||||
except Exception:
|
||||
# Silence teardown exceptions if an exception was raised during
|
||||
# runs to avoid shadowing it.
|
||||
if not run_failed:
|
||||
raise
|
||||
return self.suite_result(suite, result)
|
||||
|
||||
|
||||
def is_discoverable(label):
|
||||
"""
|
||||
Check if a test label points to a Python package or file directory.
|
||||
|
||||
Relative labels like "." and ".." are seen as directories.
|
||||
"""
|
||||
try:
|
||||
mod = import_module(label)
|
||||
except (ImportError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return hasattr(mod, '__path__')
|
||||
|
||||
return os.path.isdir(os.path.abspath(label))
|
||||
|
||||
|
||||
def reorder_suite(suite, classes, reverse=False):
|
||||
"""
|
||||
Reorder a test suite by test type.
|
||||
|
||||
`classes` is a sequence of types
|
||||
|
||||
All tests of type classes[0] are placed first, then tests of type
|
||||
classes[1], etc. Tests with no match in classes are placed last.
|
||||
|
||||
If `reverse` is True, sort tests within classes in opposite order but
|
||||
don't reverse test classes.
|
||||
"""
|
||||
class_count = len(classes)
|
||||
suite_class = type(suite)
|
||||
bins = [OrderedSet() for i in range(class_count + 1)]
|
||||
partition_suite_by_type(suite, classes, bins, reverse=reverse)
|
||||
reordered_suite = suite_class()
|
||||
for i in range(class_count + 1):
|
||||
reordered_suite.addTests(bins[i])
|
||||
return reordered_suite
|
||||
|
||||
|
||||
def partition_suite_by_type(suite, classes, bins, reverse=False):
|
||||
"""
|
||||
Partition a test suite by test type. Also prevent duplicated tests.
|
||||
|
||||
classes is a sequence of types
|
||||
bins is a sequence of TestSuites, one more than classes
|
||||
reverse changes the ordering of tests within bins
|
||||
|
||||
Tests of type classes[i] are added to bins[i],
|
||||
tests with no match found in classes are place in bins[-1]
|
||||
"""
|
||||
suite_class = type(suite)
|
||||
if reverse:
|
||||
suite = reversed(tuple(suite))
|
||||
for test in suite:
|
||||
if isinstance(test, suite_class):
|
||||
partition_suite_by_type(test, classes, bins, reverse=reverse)
|
||||
else:
|
||||
for i in range(len(classes)):
|
||||
if isinstance(test, classes[i]):
|
||||
bins[i].add(test)
|
||||
break
|
||||
else:
|
||||
bins[-1].add(test)
|
||||
|
||||
|
||||
def partition_suite_by_case(suite):
|
||||
"""Partition a test suite by test case, preserving the order of tests."""
|
||||
groups = []
|
||||
suite_class = type(suite)
|
||||
for test_type, test_group in itertools.groupby(suite, type):
|
||||
if issubclass(test_type, unittest.TestCase):
|
||||
groups.append(suite_class(test_group))
|
||||
else:
|
||||
for item in test_group:
|
||||
groups.extend(partition_suite_by_case(item))
|
||||
return groups
|
||||
|
||||
|
||||
def filter_tests_by_tags(suite, tags, exclude_tags):
|
||||
suite_class = type(suite)
|
||||
filtered_suite = suite_class()
|
||||
|
||||
for test in suite:
|
||||
if isinstance(test, suite_class):
|
||||
filtered_suite.addTests(filter_tests_by_tags(test, tags, exclude_tags))
|
||||
else:
|
||||
test_tags = set(getattr(test, 'tags', set()))
|
||||
test_fn_name = getattr(test, '_testMethodName', str(test))
|
||||
test_fn = getattr(test, test_fn_name, test)
|
||||
test_fn_tags = set(getattr(test_fn, 'tags', set()))
|
||||
all_tags = test_tags.union(test_fn_tags)
|
||||
matched_tags = all_tags.intersection(tags)
|
||||
if (matched_tags or not tags) and not all_tags.intersection(exclude_tags):
|
||||
filtered_suite.addTest(test)
|
||||
|
||||
return filtered_suite
|
||||
130
venv/lib/python3.8/site-packages/django/test/selenium.py
Normal file
130
venv/lib/python3.8/site-packages/django/test/selenium.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.test import LiveServerTestCase, tag
|
||||
from django.utils.decorators import classproperty
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.text import capfirst
|
||||
|
||||
|
||||
class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
# List of browsers to dynamically create test classes for.
|
||||
browsers = []
|
||||
# A selenium hub URL to test against.
|
||||
selenium_hub = None
|
||||
# The external host Selenium Hub can reach.
|
||||
external_host = None
|
||||
# Sentinel value to differentiate browser-specific instances.
|
||||
browser = None
|
||||
# Run browsers in headless mode.
|
||||
headless = False
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
"""
|
||||
Dynamically create new classes and add them to the test module when
|
||||
multiple browsers specs are provided (e.g. --selenium=firefox,chrome).
|
||||
"""
|
||||
test_class = super().__new__(cls, name, bases, attrs)
|
||||
# If the test class is either browser-specific or a test base, return it.
|
||||
if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()):
|
||||
return test_class
|
||||
elif test_class.browsers:
|
||||
# Reuse the created test class to make it browser-specific.
|
||||
# We can't rename it to include the browser name or create a
|
||||
# subclass like we do with the remaining browsers as it would
|
||||
# either duplicate tests or prevent pickling of its instances.
|
||||
first_browser = test_class.browsers[0]
|
||||
test_class.browser = first_browser
|
||||
# Listen on an external interface if using a selenium hub.
|
||||
host = test_class.host if not test_class.selenium_hub else '0.0.0.0'
|
||||
test_class.host = host
|
||||
test_class.external_host = cls.external_host
|
||||
# Create subclasses for each of the remaining browsers and expose
|
||||
# them through the test's module namespace.
|
||||
module = sys.modules[test_class.__module__]
|
||||
for browser in test_class.browsers[1:]:
|
||||
browser_test_class = cls.__new__(
|
||||
cls,
|
||||
"%s%s" % (capfirst(browser), name),
|
||||
(test_class,),
|
||||
{
|
||||
'browser': browser,
|
||||
'host': host,
|
||||
'external_host': cls.external_host,
|
||||
'__module__': test_class.__module__,
|
||||
}
|
||||
)
|
||||
setattr(module, browser_test_class.__name__, browser_test_class)
|
||||
return test_class
|
||||
# If no browsers were specified, skip this class (it'll still be discovered).
|
||||
return unittest.skip('No browsers specified.')(test_class)
|
||||
|
||||
@classmethod
|
||||
def import_webdriver(cls, browser):
|
||||
return import_string("selenium.webdriver.%s.webdriver.WebDriver" % browser)
|
||||
|
||||
@classmethod
|
||||
def import_options(cls, browser):
|
||||
return import_string('selenium.webdriver.%s.options.Options' % browser)
|
||||
|
||||
@classmethod
|
||||
def get_capability(cls, browser):
|
||||
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
|
||||
return getattr(DesiredCapabilities, browser.upper())
|
||||
|
||||
def create_options(self):
|
||||
options = self.import_options(self.browser)()
|
||||
if self.headless:
|
||||
try:
|
||||
options.headless = True
|
||||
except AttributeError:
|
||||
pass # Only Chrome and Firefox support the headless mode.
|
||||
return options
|
||||
|
||||
def create_webdriver(self):
|
||||
if self.selenium_hub:
|
||||
from selenium import webdriver
|
||||
return webdriver.Remote(
|
||||
command_executor=self.selenium_hub,
|
||||
desired_capabilities=self.get_capability(self.browser),
|
||||
)
|
||||
return self.import_webdriver(self.browser)(options=self.create_options())
|
||||
|
||||
|
||||
@tag('selenium')
|
||||
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
|
||||
implicit_wait = 10
|
||||
external_host = None
|
||||
|
||||
@classproperty
|
||||
def live_server_url(cls):
|
||||
return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port)
|
||||
|
||||
@classproperty
|
||||
def allowed_host(cls):
|
||||
return cls.external_host or cls.host
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.selenium = cls.create_webdriver()
|
||||
cls.selenium.implicitly_wait(cls.implicit_wait)
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def _tearDownClassInternal(cls):
|
||||
# quit() the WebDriver before attempting to terminate and join the
|
||||
# single-threaded LiveServerThread to avoid a dead lock if the browser
|
||||
# kept a connection alive.
|
||||
if hasattr(cls, 'selenium'):
|
||||
cls.selenium.quit()
|
||||
super()._tearDownClassInternal()
|
||||
|
||||
@contextmanager
|
||||
def disable_implicit_wait(self):
|
||||
"""Disable the default implicit wait."""
|
||||
self.selenium.implicitly_wait(0)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.selenium.implicitly_wait(self.implicit_wait)
|
||||
206
venv/lib/python3.8/site-packages/django/test/signals.py
Normal file
206
venv/lib/python3.8/site-packages/django/test/signals.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from asgiref.local import Local
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.signals import setting_changed
|
||||
from django.db import connections, router
|
||||
from django.db.utils import ConnectionRouter
|
||||
from django.dispatch import Signal, receiver
|
||||
from django.utils import timezone
|
||||
from django.utils.formats import FORMAT_SETTINGS, reset_format_cache
|
||||
from django.utils.functional import empty
|
||||
|
||||
template_rendered = Signal(providing_args=["template", "context"])
|
||||
|
||||
# Most setting_changed receivers are supposed to be added below,
|
||||
# except for cases where the receiver is related to a contrib app.
|
||||
|
||||
# Settings that may not work well when using 'override_settings' (#19031)
|
||||
COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_cache_handlers(**kwargs):
|
||||
if kwargs['setting'] == 'CACHES':
|
||||
from django.core.cache import caches
|
||||
caches._caches = Local()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_installed_apps(**kwargs):
|
||||
if kwargs['setting'] == 'INSTALLED_APPS':
|
||||
# Rebuild any AppDirectoriesFinder instance.
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
get_finder.cache_clear()
|
||||
# Rebuild management commands cache
|
||||
from django.core.management import get_commands
|
||||
get_commands.cache_clear()
|
||||
# Rebuild get_app_template_dirs cache.
|
||||
from django.template.utils import get_app_template_dirs
|
||||
get_app_template_dirs.cache_clear()
|
||||
# Rebuild translations cache.
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._translations = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_connections_time_zone(**kwargs):
|
||||
if kwargs['setting'] == 'TIME_ZONE':
|
||||
# Reset process time zone
|
||||
if hasattr(time, 'tzset'):
|
||||
if kwargs['value']:
|
||||
os.environ['TZ'] = kwargs['value']
|
||||
else:
|
||||
os.environ.pop('TZ', None)
|
||||
time.tzset()
|
||||
|
||||
# Reset local time zone cache
|
||||
timezone.get_default_timezone.cache_clear()
|
||||
|
||||
# Reset the database connections' time zone
|
||||
if kwargs['setting'] in {'TIME_ZONE', 'USE_TZ'}:
|
||||
for conn in connections.all():
|
||||
try:
|
||||
del conn.timezone
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
del conn.timezone_name
|
||||
except AttributeError:
|
||||
pass
|
||||
conn.ensure_timezone()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_routers_cache(**kwargs):
|
||||
if kwargs['setting'] == 'DATABASE_ROUTERS':
|
||||
router.routers = ConnectionRouter().routers
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def reset_template_engines(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'TEMPLATES',
|
||||
'DEBUG',
|
||||
'FILE_CHARSET',
|
||||
'INSTALLED_APPS',
|
||||
}:
|
||||
from django.template import engines
|
||||
try:
|
||||
del engines.templates
|
||||
except AttributeError:
|
||||
pass
|
||||
engines._templates = None
|
||||
engines._engines = {}
|
||||
from django.template.engine import Engine
|
||||
Engine.get_default.cache_clear()
|
||||
from django.forms.renderers import get_default_renderer
|
||||
get_default_renderer.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_serializers_cache(**kwargs):
|
||||
if kwargs['setting'] == 'SERIALIZATION_MODULES':
|
||||
from django.core import serializers
|
||||
serializers._serializers = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def language_changed(**kwargs):
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._default = None
|
||||
trans_real._active = Local()
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._translations = {}
|
||||
trans_real.check_for_language.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def localize_settings_changed(**kwargs):
|
||||
if kwargs['setting'] in FORMAT_SETTINGS or kwargs['setting'] == 'USE_THOUSAND_SEPARATOR':
|
||||
reset_format_cache()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def file_storage_changed(**kwargs):
|
||||
if kwargs['setting'] == 'DEFAULT_FILE_STORAGE':
|
||||
from django.core.files.storage import default_storage
|
||||
default_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def complex_setting_changed(**kwargs):
|
||||
if kwargs['enter'] and kwargs['setting'] in COMPLEX_OVERRIDE_SETTINGS:
|
||||
# Considering the current implementation of the signals framework,
|
||||
# this stacklevel shows the line containing the override_settings call.
|
||||
warnings.warn("Overriding setting %s can lead to unexpected behavior."
|
||||
% kwargs['setting'], stacklevel=6)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def root_urlconf_changed(**kwargs):
|
||||
if kwargs['setting'] == 'ROOT_URLCONF':
|
||||
from django.urls import clear_url_caches, set_urlconf
|
||||
clear_url_caches()
|
||||
set_urlconf(None)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_storage_changed(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_STORAGE',
|
||||
'STATIC_ROOT',
|
||||
'STATIC_URL',
|
||||
}:
|
||||
from django.contrib.staticfiles.storage import staticfiles_storage
|
||||
staticfiles_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_finders_changed(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_DIRS',
|
||||
'STATIC_ROOT',
|
||||
}:
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
get_finder.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def auth_password_validators_changed(**kwargs):
|
||||
if kwargs['setting'] == 'AUTH_PASSWORD_VALIDATORS':
|
||||
from django.contrib.auth.password_validation import get_default_password_validators
|
||||
get_default_password_validators.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def user_model_swapped(**kwargs):
|
||||
if kwargs['setting'] == 'AUTH_USER_MODEL':
|
||||
apps.clear_cache()
|
||||
try:
|
||||
from django.contrib.auth import get_user_model
|
||||
UserModel = get_user_model()
|
||||
except ImproperlyConfigured:
|
||||
# Some tests set an invalid AUTH_USER_MODEL.
|
||||
pass
|
||||
else:
|
||||
from django.contrib.auth import backends
|
||||
backends.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import forms
|
||||
forms.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.handlers import modwsgi
|
||||
modwsgi.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.management.commands import changepassword
|
||||
changepassword.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import views
|
||||
views.UserModel = UserModel
|
||||
1516
venv/lib/python3.8/site-packages/django/test/testcases.py
Normal file
1516
venv/lib/python3.8/site-packages/django/test/testcases.py
Normal file
File diff suppressed because it is too large
Load Diff
852
venv/lib/python3.8/site-packages/django/test/utils.py
Normal file
852
venv/lib/python3.8/site-packages/django/test/utils.py
Normal file
@@ -0,0 +1,852 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from io import StringIO
|
||||
from itertools import chain
|
||||
from types import SimpleNamespace
|
||||
from unittest import TestCase, skipIf, skipUnless
|
||||
from xml.dom.minidom import Node, parseString
|
||||
|
||||
from django.apps import apps
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import UserSettingsHolder, settings
|
||||
from django.core import mail
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.signals import request_started
|
||||
from django.db import DEFAULT_DB_ALIAS, connections, reset_queries
|
||||
from django.db.models.options import Options
|
||||
from django.template import Template
|
||||
from django.test.signals import setting_changed, template_rendered
|
||||
from django.urls import get_script_prefix, set_script_prefix
|
||||
from django.utils.translation import deactivate
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError:
|
||||
jinja2 = None
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
|
||||
'modify_settings', 'override_settings',
|
||||
'requires_tz_support',
|
||||
'setup_test_environment', 'teardown_test_environment',
|
||||
)
|
||||
|
||||
TZ_SUPPORT = hasattr(time, 'tzset')
|
||||
|
||||
|
||||
class Approximate:
|
||||
def __init__(self, val, places=7):
|
||||
self.val = val
|
||||
self.places = places
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.val)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.val == other or round(abs(self.val - other), self.places) == 0
|
||||
|
||||
|
||||
class ContextList(list):
|
||||
"""
|
||||
A wrapper that provides direct key access to context items contained
|
||||
in a list of context objects.
|
||||
"""
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, str):
|
||||
for subcontext in self:
|
||||
if key in subcontext:
|
||||
return subcontext[key]
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self.__getitem__(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def keys(self):
|
||||
"""
|
||||
Flattened keys of subcontexts.
|
||||
"""
|
||||
return set(chain.from_iterable(d for subcontext in self for d in subcontext))
|
||||
|
||||
|
||||
def instrumented_test_render(self, context):
|
||||
"""
|
||||
An instrumented Template render method, providing a signal that can be
|
||||
intercepted by the test Client.
|
||||
"""
|
||||
template_rendered.send(sender=self, template=self, context=context)
|
||||
return self.nodelist.render(context)
|
||||
|
||||
|
||||
class _TestState:
|
||||
pass
|
||||
|
||||
|
||||
def setup_test_environment(debug=None):
|
||||
"""
|
||||
Perform global pre-test setup, such as installing the instrumented template
|
||||
renderer and setting the email backend to the locmem email backend.
|
||||
"""
|
||||
if hasattr(_TestState, 'saved_data'):
|
||||
# Executing this function twice would overwrite the saved values.
|
||||
raise RuntimeError(
|
||||
"setup_test_environment() was already called and can't be called "
|
||||
"again without first calling teardown_test_environment()."
|
||||
)
|
||||
|
||||
if debug is None:
|
||||
debug = settings.DEBUG
|
||||
|
||||
saved_data = SimpleNamespace()
|
||||
_TestState.saved_data = saved_data
|
||||
|
||||
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
|
||||
# Add the default host of the test client.
|
||||
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
|
||||
|
||||
saved_data.debug = settings.DEBUG
|
||||
settings.DEBUG = debug
|
||||
|
||||
saved_data.email_backend = settings.EMAIL_BACKEND
|
||||
settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
|
||||
|
||||
saved_data.template_render = Template._render
|
||||
Template._render = instrumented_test_render
|
||||
|
||||
mail.outbox = []
|
||||
|
||||
deactivate()
|
||||
|
||||
|
||||
def teardown_test_environment():
|
||||
"""
|
||||
Perform any global post-test teardown, such as restoring the original
|
||||
template renderer and restoring the email sending functions.
|
||||
"""
|
||||
saved_data = _TestState.saved_data
|
||||
|
||||
settings.ALLOWED_HOSTS = saved_data.allowed_hosts
|
||||
settings.DEBUG = saved_data.debug
|
||||
settings.EMAIL_BACKEND = saved_data.email_backend
|
||||
Template._render = saved_data.template_render
|
||||
|
||||
del _TestState.saved_data
|
||||
del mail.outbox
|
||||
|
||||
|
||||
def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, parallel=0, aliases=None, **kwargs):
|
||||
"""Create the test databases."""
|
||||
test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases)
|
||||
|
||||
old_names = []
|
||||
|
||||
for db_name, aliases in test_databases.values():
|
||||
first_alias = None
|
||||
for alias in aliases:
|
||||
connection = connections[alias]
|
||||
old_names.append((connection, db_name, first_alias is None))
|
||||
|
||||
# Actually create the database for the first connection
|
||||
if first_alias is None:
|
||||
first_alias = alias
|
||||
connection.creation.create_test_db(
|
||||
verbosity=verbosity,
|
||||
autoclobber=not interactive,
|
||||
keepdb=keepdb,
|
||||
serialize=connection.settings_dict.get('TEST', {}).get('SERIALIZE', True),
|
||||
)
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
connection.creation.clone_test_db(
|
||||
suffix=str(index + 1),
|
||||
verbosity=verbosity,
|
||||
keepdb=keepdb,
|
||||
)
|
||||
# Configure all other connections as mirrors of the first one
|
||||
else:
|
||||
connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
|
||||
|
||||
# Configure the test mirrors.
|
||||
for alias, mirror_alias in mirrored_aliases.items():
|
||||
connections[alias].creation.set_as_test_mirror(
|
||||
connections[mirror_alias].settings_dict)
|
||||
|
||||
if debug_sql:
|
||||
for alias in connections:
|
||||
connections[alias].force_debug_cursor = True
|
||||
|
||||
return old_names
|
||||
|
||||
|
||||
def dependency_ordered(test_databases, dependencies):
|
||||
"""
|
||||
Reorder test_databases into an order that honors the dependencies
|
||||
described in TEST[DEPENDENCIES].
|
||||
"""
|
||||
ordered_test_databases = []
|
||||
resolved_databases = set()
|
||||
|
||||
# Maps db signature to dependencies of all its aliases
|
||||
dependencies_map = {}
|
||||
|
||||
# Check that no database depends on its own alias
|
||||
for sig, (_, aliases) in test_databases:
|
||||
all_deps = set()
|
||||
for alias in aliases:
|
||||
all_deps.update(dependencies.get(alias, []))
|
||||
if not all_deps.isdisjoint(aliases):
|
||||
raise ImproperlyConfigured(
|
||||
"Circular dependency: databases %r depend on each other, "
|
||||
"but are aliases." % aliases
|
||||
)
|
||||
dependencies_map[sig] = all_deps
|
||||
|
||||
while test_databases:
|
||||
changed = False
|
||||
deferred = []
|
||||
|
||||
# Try to find a DB that has all its dependencies met
|
||||
for signature, (db_name, aliases) in test_databases:
|
||||
if dependencies_map[signature].issubset(resolved_databases):
|
||||
resolved_databases.update(aliases)
|
||||
ordered_test_databases.append((signature, (db_name, aliases)))
|
||||
changed = True
|
||||
else:
|
||||
deferred.append((signature, (db_name, aliases)))
|
||||
|
||||
if not changed:
|
||||
raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]")
|
||||
test_databases = deferred
|
||||
return ordered_test_databases
|
||||
|
||||
|
||||
def get_unique_databases_and_mirrors(aliases=None):
|
||||
"""
|
||||
Figure out which databases actually need to be created.
|
||||
|
||||
Deduplicate entries in DATABASES that correspond the same database or are
|
||||
configured as test mirrors.
|
||||
|
||||
Return two values:
|
||||
- test_databases: ordered mapping of signatures to (name, list of aliases)
|
||||
where all aliases share the same underlying database.
|
||||
- mirrored_aliases: mapping of mirror aliases to original aliases.
|
||||
"""
|
||||
if aliases is None:
|
||||
aliases = connections
|
||||
mirrored_aliases = {}
|
||||
test_databases = {}
|
||||
dependencies = {}
|
||||
default_sig = connections[DEFAULT_DB_ALIAS].creation.test_db_signature()
|
||||
|
||||
for alias in connections:
|
||||
connection = connections[alias]
|
||||
test_settings = connection.settings_dict['TEST']
|
||||
|
||||
if test_settings['MIRROR']:
|
||||
# If the database is marked as a test mirror, save the alias.
|
||||
mirrored_aliases[alias] = test_settings['MIRROR']
|
||||
elif alias in aliases:
|
||||
# Store a tuple with DB parameters that uniquely identify it.
|
||||
# If we have two aliases with the same values for that tuple,
|
||||
# we only need to create the test database once.
|
||||
item = test_databases.setdefault(
|
||||
connection.creation.test_db_signature(),
|
||||
(connection.settings_dict['NAME'], set())
|
||||
)
|
||||
item[1].add(alias)
|
||||
|
||||
if 'DEPENDENCIES' in test_settings:
|
||||
dependencies[alias] = test_settings['DEPENDENCIES']
|
||||
else:
|
||||
if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
|
||||
dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||
|
||||
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
|
||||
return test_databases, mirrored_aliases
|
||||
|
||||
|
||||
def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
|
||||
"""Destroy all the non-mirror databases."""
|
||||
for connection, old_name, destroy in old_config:
|
||||
if destroy:
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
connection.creation.destroy_test_db(
|
||||
suffix=str(index + 1),
|
||||
verbosity=verbosity,
|
||||
keepdb=keepdb,
|
||||
)
|
||||
connection.creation.destroy_test_db(old_name, verbosity, keepdb)
|
||||
|
||||
|
||||
def get_runner(settings, test_runner_class=None):
|
||||
test_runner_class = test_runner_class or settings.TEST_RUNNER
|
||||
test_path = test_runner_class.split('.')
|
||||
# Allow for relative paths
|
||||
if len(test_path) > 1:
|
||||
test_module_name = '.'.join(test_path[:-1])
|
||||
else:
|
||||
test_module_name = '.'
|
||||
test_module = __import__(test_module_name, {}, {}, test_path[-1])
|
||||
return getattr(test_module, test_path[-1])
|
||||
|
||||
|
||||
class TestContextDecorator:
|
||||
"""
|
||||
A base class that can either be used as a context manager during tests
|
||||
or as a test function or unittest.TestCase subclass decorator to perform
|
||||
temporary alterations.
|
||||
|
||||
`attr_name`: attribute assigned the return value of enable() if used as
|
||||
a class decorator.
|
||||
|
||||
`kwarg_name`: keyword argument passing the return value of enable() if
|
||||
used as a function decorator.
|
||||
"""
|
||||
def __init__(self, attr_name=None, kwarg_name=None):
|
||||
self.attr_name = attr_name
|
||||
self.kwarg_name = kwarg_name
|
||||
|
||||
def enable(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def disable(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
return self.enable()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.disable()
|
||||
|
||||
def decorate_class(self, cls):
|
||||
if issubclass(cls, TestCase):
|
||||
decorated_setUp = cls.setUp
|
||||
decorated_tearDown = cls.tearDown
|
||||
|
||||
def setUp(inner_self):
|
||||
context = self.enable()
|
||||
if self.attr_name:
|
||||
setattr(inner_self, self.attr_name, context)
|
||||
try:
|
||||
decorated_setUp(inner_self)
|
||||
except Exception:
|
||||
self.disable()
|
||||
raise
|
||||
|
||||
def tearDown(inner_self):
|
||||
decorated_tearDown(inner_self)
|
||||
self.disable()
|
||||
|
||||
cls.setUp = setUp
|
||||
cls.tearDown = tearDown
|
||||
return cls
|
||||
raise TypeError('Can only decorate subclasses of unittest.TestCase')
|
||||
|
||||
def decorate_callable(self, func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return func(*args, **kwargs)
|
||||
return inner
|
||||
|
||||
def __call__(self, decorated):
|
||||
if isinstance(decorated, type):
|
||||
return self.decorate_class(decorated)
|
||||
elif callable(decorated):
|
||||
return self.decorate_callable(decorated)
|
||||
raise TypeError('Cannot decorate object of type %s' % type(decorated))
|
||||
|
||||
|
||||
class override_settings(TestContextDecorator):
|
||||
"""
|
||||
Act as either a decorator or a context manager. If it's a decorator, take a
|
||||
function and return a wrapped function. If it's a contextmanager, use it
|
||||
with the ``with`` statement. In either event, entering/exiting are called
|
||||
before and after, respectively, the function/block is executed.
|
||||
"""
|
||||
enable_exception = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.options = kwargs
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
# Keep this code at the beginning to leave the settings unchanged
|
||||
# in case it raises an exception because INSTALLED_APPS is invalid.
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
try:
|
||||
apps.set_installed_apps(self.options['INSTALLED_APPS'])
|
||||
except Exception:
|
||||
apps.unset_installed_apps()
|
||||
raise
|
||||
override = UserSettingsHolder(settings._wrapped)
|
||||
for key, new_value in self.options.items():
|
||||
setattr(override, key, new_value)
|
||||
self.wrapped = settings._wrapped
|
||||
settings._wrapped = override
|
||||
for key, new_value in self.options.items():
|
||||
try:
|
||||
setting_changed.send(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key, value=new_value, enter=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.enable_exception = exc
|
||||
self.disable()
|
||||
|
||||
def disable(self):
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
apps.unset_installed_apps()
|
||||
settings._wrapped = self.wrapped
|
||||
del self.wrapped
|
||||
responses = []
|
||||
for key in self.options:
|
||||
new_value = getattr(settings, key, None)
|
||||
responses_for_setting = setting_changed.send_robust(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key, value=new_value, enter=False,
|
||||
)
|
||||
responses.extend(responses_for_setting)
|
||||
if self.enable_exception is not None:
|
||||
exc = self.enable_exception
|
||||
self.enable_exception = None
|
||||
raise exc
|
||||
for _, response in responses:
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
|
||||
def save_options(self, test_func):
|
||||
if test_func._overridden_settings is None:
|
||||
test_func._overridden_settings = self.options
|
||||
else:
|
||||
# Duplicate dict to prevent subclasses from altering their parent.
|
||||
test_func._overridden_settings = {
|
||||
**test_func._overridden_settings,
|
||||
**self.options,
|
||||
}
|
||||
|
||||
def decorate_class(self, cls):
|
||||
from django.test import SimpleTestCase
|
||||
if not issubclass(cls, SimpleTestCase):
|
||||
raise ValueError(
|
||||
"Only subclasses of Django SimpleTestCase can be decorated "
|
||||
"with override_settings")
|
||||
self.save_options(cls)
|
||||
return cls
|
||||
|
||||
|
||||
class modify_settings(override_settings):
|
||||
"""
|
||||
Like override_settings, but makes it possible to append, prepend, or remove
|
||||
items instead of redefining the entire list.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args:
|
||||
# Hack used when instantiating from SimpleTestCase.setUpClass.
|
||||
assert not kwargs
|
||||
self.operations = args[0]
|
||||
else:
|
||||
assert not args
|
||||
self.operations = list(kwargs.items())
|
||||
super(override_settings, self).__init__()
|
||||
|
||||
def save_options(self, test_func):
|
||||
if test_func._modified_settings is None:
|
||||
test_func._modified_settings = self.operations
|
||||
else:
|
||||
# Duplicate list to prevent subclasses from altering their parent.
|
||||
test_func._modified_settings = list(
|
||||
test_func._modified_settings) + self.operations
|
||||
|
||||
def enable(self):
|
||||
self.options = {}
|
||||
for name, operations in self.operations:
|
||||
try:
|
||||
# When called from SimpleTestCase.setUpClass, values may be
|
||||
# overridden several times; cumulate changes.
|
||||
value = self.options[name]
|
||||
except KeyError:
|
||||
value = list(getattr(settings, name, []))
|
||||
for action, items in operations.items():
|
||||
# items my be a single value or an iterable.
|
||||
if isinstance(items, str):
|
||||
items = [items]
|
||||
if action == 'append':
|
||||
value = value + [item for item in items if item not in value]
|
||||
elif action == 'prepend':
|
||||
value = [item for item in items if item not in value] + value
|
||||
elif action == 'remove':
|
||||
value = [item for item in value if item not in items]
|
||||
else:
|
||||
raise ValueError("Unsupported action: %s" % action)
|
||||
self.options[name] = value
|
||||
super().enable()
|
||||
|
||||
|
||||
class override_system_checks(TestContextDecorator):
|
||||
"""
|
||||
Act as a decorator. Override list of registered system checks.
|
||||
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
|
||||
you also need to exclude its system checks.
|
||||
"""
|
||||
def __init__(self, new_checks, deployment_checks=None):
|
||||
from django.core.checks.registry import registry
|
||||
self.registry = registry
|
||||
self.new_checks = new_checks
|
||||
self.deployment_checks = deployment_checks
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.old_checks = self.registry.registered_checks
|
||||
self.registry.registered_checks = set()
|
||||
for check in self.new_checks:
|
||||
self.registry.register(check, *getattr(check, 'tags', ()))
|
||||
self.old_deployment_checks = self.registry.deployment_checks
|
||||
if self.deployment_checks is not None:
|
||||
self.registry.deployment_checks = set()
|
||||
for check in self.deployment_checks:
|
||||
self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
|
||||
|
||||
def disable(self):
|
||||
self.registry.registered_checks = self.old_checks
|
||||
self.registry.deployment_checks = self.old_deployment_checks
|
||||
|
||||
|
||||
def compare_xml(want, got):
|
||||
"""
|
||||
Try to do a 'xml-comparison' of want and got. Plain string comparison
|
||||
doesn't always work because, for example, attribute ordering should not be
|
||||
important. Ignore comment nodes, document type node, and leading and
|
||||
trailing whitespaces.
|
||||
|
||||
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
|
||||
"""
|
||||
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
|
||||
|
||||
def norm_whitespace(v):
|
||||
return _norm_whitespace_re.sub(' ', v)
|
||||
|
||||
def child_text(element):
|
||||
return ''.join(c.data for c in element.childNodes
|
||||
if c.nodeType == Node.TEXT_NODE)
|
||||
|
||||
def children(element):
|
||||
return [c for c in element.childNodes
|
||||
if c.nodeType == Node.ELEMENT_NODE]
|
||||
|
||||
def norm_child_text(element):
|
||||
return norm_whitespace(child_text(element))
|
||||
|
||||
def attrs_dict(element):
|
||||
return dict(element.attributes.items())
|
||||
|
||||
def check_element(want_element, got_element):
|
||||
if want_element.tagName != got_element.tagName:
|
||||
return False
|
||||
if norm_child_text(want_element) != norm_child_text(got_element):
|
||||
return False
|
||||
if attrs_dict(want_element) != attrs_dict(got_element):
|
||||
return False
|
||||
want_children = children(want_element)
|
||||
got_children = children(got_element)
|
||||
if len(want_children) != len(got_children):
|
||||
return False
|
||||
return all(check_element(want, got) for want, got in zip(want_children, got_children))
|
||||
|
||||
def first_node(document):
|
||||
for node in document.childNodes:
|
||||
if node.nodeType not in (Node.COMMENT_NODE, Node.DOCUMENT_TYPE_NODE):
|
||||
return node
|
||||
|
||||
want = want.strip().replace('\\n', '\n')
|
||||
got = got.strip().replace('\\n', '\n')
|
||||
|
||||
# If the string is not a complete xml document, we may need to add a
|
||||
# root element. This allow us to compare fragments, like "<foo/><bar/>"
|
||||
if not want.startswith('<?xml'):
|
||||
wrapper = '<root>%s</root>'
|
||||
want = wrapper % want
|
||||
got = wrapper % got
|
||||
|
||||
# Parse the want and got strings, and compare the parsings.
|
||||
want_root = first_node(parseString(want))
|
||||
got_root = first_node(parseString(got))
|
||||
|
||||
return check_element(want_root, got_root)
|
||||
|
||||
|
||||
class CaptureQueriesContext:
|
||||
"""
|
||||
Context manager that captures queries executed by the specified connection.
|
||||
"""
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.captured_queries)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.captured_queries[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captured_queries)
|
||||
|
||||
@property
|
||||
def captured_queries(self):
|
||||
return self.connection.queries[self.initial_queries:self.final_queries]
|
||||
|
||||
def __enter__(self):
|
||||
self.force_debug_cursor = self.connection.force_debug_cursor
|
||||
self.connection.force_debug_cursor = True
|
||||
# Run any initialization queries if needed so that they won't be
|
||||
# included as part of the count.
|
||||
self.connection.ensure_connection()
|
||||
self.initial_queries = len(self.connection.queries_log)
|
||||
self.final_queries = None
|
||||
request_started.disconnect(reset_queries)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.force_debug_cursor = self.force_debug_cursor
|
||||
request_started.connect(reset_queries)
|
||||
if exc_type is not None:
|
||||
return
|
||||
self.final_queries = len(self.connection.queries_log)
|
||||
|
||||
|
||||
class ignore_warnings(TestContextDecorator):
|
||||
def __init__(self, **kwargs):
|
||||
self.ignore_kwargs = kwargs
|
||||
if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
|
||||
self.filter_func = warnings.filterwarnings
|
||||
else:
|
||||
self.filter_func = warnings.simplefilter
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.catch_warnings = warnings.catch_warnings()
|
||||
self.catch_warnings.__enter__()
|
||||
self.filter_func('ignore', **self.ignore_kwargs)
|
||||
|
||||
def disable(self):
|
||||
self.catch_warnings.__exit__(*sys.exc_info())
|
||||
|
||||
|
||||
# On OSes that don't provide tzset (Windows), we can't set the timezone
|
||||
# in which the program runs. As a consequence, we must skip tests that
|
||||
# don't enforce a specific timezone (with timezone.override or equivalent),
|
||||
# or attempt to interpret naive datetimes in the default timezone.
|
||||
|
||||
requires_tz_support = skipUnless(
|
||||
TZ_SUPPORT,
|
||||
"This test relies on the ability to run a program in an arbitrary "
|
||||
"time zone, but your operating system isn't able to do that."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def extend_sys_path(*paths):
|
||||
"""Context manager to temporarily add paths to sys.path."""
|
||||
_orig_sys_path = sys.path[:]
|
||||
sys.path.extend(paths)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path = _orig_sys_path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def isolate_lru_cache(lru_cache_object):
|
||||
"""Clear the cache of an LRU cache object on entering and exiting."""
|
||||
lru_cache_object.cache_clear()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lru_cache_object.cache_clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output(stream_name):
|
||||
"""Return a context manager used by captured_stdout/stdin/stderr
|
||||
that temporarily replaces the sys stream *stream_name* with a StringIO.
|
||||
|
||||
Note: This function and the following ``captured_std*`` are copied
|
||||
from CPython's ``test.support`` module."""
|
||||
orig_stdout = getattr(sys, stream_name)
|
||||
setattr(sys, stream_name, StringIO())
|
||||
try:
|
||||
yield getattr(sys, stream_name)
|
||||
finally:
|
||||
setattr(sys, stream_name, orig_stdout)
|
||||
|
||||
|
||||
def captured_stdout():
|
||||
"""Capture the output of sys.stdout:
|
||||
|
||||
with captured_stdout() as stdout:
|
||||
print("hello")
|
||||
self.assertEqual(stdout.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stdout")
|
||||
|
||||
|
||||
def captured_stderr():
|
||||
"""Capture the output of sys.stderr:
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
print("hello", file=sys.stderr)
|
||||
self.assertEqual(stderr.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stderr")
|
||||
|
||||
|
||||
def captured_stdin():
|
||||
"""Capture the input to sys.stdin:
|
||||
|
||||
with captured_stdin() as stdin:
|
||||
stdin.write('hello\n')
|
||||
stdin.seek(0)
|
||||
# call test code that consumes from sys.stdin
|
||||
captured = input()
|
||||
self.assertEqual(captured, "hello")
|
||||
"""
|
||||
return captured_output("stdin")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_time(t):
|
||||
"""
|
||||
Context manager to temporarily freeze time.time(). This temporarily
|
||||
modifies the time function of the time module. Modules which import the
|
||||
time function directly (e.g. `from time import time`) won't be affected
|
||||
This isn't meant as a public API, but helps reduce some repetitive code in
|
||||
Django's test suite.
|
||||
"""
|
||||
_real_time = time.time
|
||||
time.time = lambda: t
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
time.time = _real_time
|
||||
|
||||
|
||||
def require_jinja2(test_func):
|
||||
"""
|
||||
Decorator to enable a Jinja2 template engine in addition to the regular
|
||||
Django template engine for a test or skip it if Jinja2 isn't available.
|
||||
"""
|
||||
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
|
||||
return override_settings(TEMPLATES=[{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'APP_DIRS': True,
|
||||
}, {
|
||||
'BACKEND': 'django.template.backends.jinja2.Jinja2',
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {'keep_trailing_newline': True},
|
||||
}])(test_func)
|
||||
|
||||
|
||||
class override_script_prefix(TestContextDecorator):
|
||||
"""Decorator or context manager to temporary override the script prefix."""
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.old_prefix = get_script_prefix()
|
||||
set_script_prefix(self.prefix)
|
||||
|
||||
def disable(self):
|
||||
set_script_prefix(self.old_prefix)
|
||||
|
||||
|
||||
class LoggingCaptureMixin:
|
||||
"""
|
||||
Capture the output from the 'django' logger and store it on the class's
|
||||
logger_output attribute.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.logger = logging.getLogger('django')
|
||||
self.old_stream = self.logger.handlers[0].stream
|
||||
self.logger_output = StringIO()
|
||||
self.logger.handlers[0].stream = self.logger_output
|
||||
|
||||
def tearDown(self):
|
||||
self.logger.handlers[0].stream = self.old_stream
|
||||
|
||||
|
||||
class isolate_apps(TestContextDecorator):
|
||||
"""
|
||||
Act as either a decorator or a context manager to register models defined
|
||||
in its wrapped context to an isolated registry.
|
||||
|
||||
The list of installed apps the isolated registry should contain must be
|
||||
passed as arguments.
|
||||
|
||||
Two optional keyword arguments can be specified:
|
||||
|
||||
`attr_name`: attribute assigned the isolated registry if used as a class
|
||||
decorator.
|
||||
|
||||
`kwarg_name`: keyword argument passing the isolated registry if used as a
|
||||
function decorator.
|
||||
"""
|
||||
def __init__(self, *installed_apps, **kwargs):
|
||||
self.installed_apps = installed_apps
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def enable(self):
|
||||
self.old_apps = Options.default_apps
|
||||
apps = Apps(self.installed_apps)
|
||||
setattr(Options, 'default_apps', apps)
|
||||
return apps
|
||||
|
||||
def disable(self):
|
||||
setattr(Options, 'default_apps', self.old_apps)
|
||||
|
||||
|
||||
def tag(*tags):
|
||||
"""Decorator to add tags to a test class or method."""
|
||||
def decorator(obj):
|
||||
if hasattr(obj, 'tags'):
|
||||
obj.tags = obj.tags.union(tags)
|
||||
else:
|
||||
setattr(obj, 'tags', set(tags))
|
||||
return obj
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def register_lookup(field, *lookups, lookup_name=None):
|
||||
"""
|
||||
Context manager to temporarily register lookups on a model field using
|
||||
lookup_name (or the lookup's lookup_name if not provided).
|
||||
"""
|
||||
try:
|
||||
for lookup in lookups:
|
||||
field.register_lookup(lookup, lookup_name)
|
||||
yield
|
||||
finally:
|
||||
for lookup in lookups:
|
||||
field._unregister_lookup(lookup, lookup_name)
|
||||
Reference in New Issue
Block a user