# Copyright (c) 2017 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from types import FunctionType, MethodType
from typing import Any, Callable, Iterable, Optional, List, Tuple, TypeVar
#: :meta private:
Method = TypeVar("Method", bound=Callable[..., Any])
# pylint: disable=invalid-name
[docs]
class overrides(object):
"""
A decorator for indicating that a method overrides another method in
a superclass. This checks that the method does actually exist,
copies the doc-string for the method, and enforces that the method
overridden is specified, making maintenance easier.
"""
# This near constant is changed by unit tests to check our code
# Github actions sets TYPE_OVERRIDES as True
__CHECK_TYPES: Optional[Any] = os.getenv("TYPE_OVERRIDES")
__slots__ = [
# The method in the superclass that this method overrides
"_superclass_method",
# True if the doc string is to be extended, False to set if not set
"_extend_doc",
# Any additional arguments required by the subclass method
"_additional_arguments",
# True if the subclass method may have additional defaults
"_extend_defaults",
# True if the name check is relaxed
"_relax_name_check",
# The name of the thing being overridden for error messages
"_override_name",
# True if this method adds typing info not in the original
"_adds_typing"
]
def __init__(
self, super_class_method: Callable, *, extend_doc: bool = True,
additional_arguments: Optional[Iterable[str]] = None,
extend_defaults: bool = False, adds_typing: bool = False,):
"""
:param super_class_method: The method to override in the superclass
:param bool extend_doc:
True the method doc string should be appended to the super-method
doc string, False if the documentation should be set to the
super-method doc string only if there isn't a doc string already
:param iterable(str) additional_arguments:
Additional arguments taken by the subclass method over the
superclass method, e.g., that are to be injected
:param bool extend_defaults:
Whether the subclass may specify extra defaults for the parameters
:param adds_typing:
Allows more typing (of non additional) than in the subclass.
Should only be used for built in super classes
"""
if isinstance(super_class_method, property):
super_class_method = super_class_method.fget
if not isinstance(super_class_method, (FunctionType, MethodType)):
raise TypeError("may only decorate method declarations; "
f"this is a {type(super_class_method)}")
self._superclass_method = super_class_method
self._extend_doc = bool(extend_doc)
self._extend_defaults = bool(extend_defaults)
self._relax_name_check = False
self._override_name = "super class method"
if additional_arguments is not None:
self._additional_arguments = frozenset(additional_arguments)
else:
self._additional_arguments = frozenset()
self._adds_typing = adds_typing
@staticmethod
def __match_defaults(default_args: Optional[List[Any]],
super_defaults: Optional[Tuple[Any]],
extend_ok: bool) -> bool:
if default_args is None:
return super_defaults is None
elif super_defaults is None:
return extend_ok
if extend_ok:
return len(default_args) >= len(super_defaults)
return len(default_args) == len(super_defaults)
def _verify_types(self, method_args: inspect.FullArgSpec,
super_args: inspect.FullArgSpec,
all_args: List[str]) -> None:
"""
Check that the arguments match.
"""
if not self.__CHECK_TYPES:
return
if "self" in all_args:
all_args.remove("self")
elif "cls" in all_args:
all_args.remove("cls")
method_types = method_args.annotations
super_types = super_args.annotations
for arg in all_args:
if arg not in super_types and not self._adds_typing:
raise AttributeError(
f"Super Method {self._superclass_method.__name__} "
f"has untyped arguments including {arg}")
if arg not in method_types:
raise AttributeError(
f"Method {self._superclass_method.__name__} "
f"has untyped arguments including {arg}")
if len(all_args) == 0:
if "return" not in super_types and not self._adds_typing and \
not method_args.varkw and not method_args.varargs:
raise AttributeError(
f"Super Method {self._superclass_method.__name__} "
f"has no arguments so should declare a return type")
if "return" not in method_types and \
not method_args.varkw and not method_args.varargs:
raise AttributeError(
f"Method {self._superclass_method.__name__} "
f"has no arguments so should declare a return type")
if "return" in super_types:
if "return" not in method_types:
if super_types["return"] is not None:
raise AttributeError(
f"Method {self._superclass_method.__name__} "
f"has no return type, while super does")
else:
if "return" in method_types and not self._adds_typing:
if method_types["return"] is not None:
raise AttributeError(
f"Super Method {self._superclass_method.__name__} "
f"has no return type, while this does")
def __verify_method_arguments(self, method: Method) -> None:
"""
Check that the arguments match.
"""
method_args = inspect.getfullargspec(method)
super_args = inspect.getfullargspec(self._superclass_method)
all_args = [
arg for arg in method_args.args
if arg not in self._additional_arguments]
default_args = None
if method_args.defaults is not None:
default_args = [
arg for arg in method_args.defaults
if arg not in self._additional_arguments]
if len(all_args) != len(super_args.args):
raise AttributeError(
f"Method has {len(method_args.args)} arguments but "
f"{self._override_name} has {len(super_args.args)} arguments")
for arg, super_arg in zip(all_args, super_args.args):
if arg != super_arg:
raise AttributeError(f"Missing argument {super_arg}")
if not self.__match_defaults(
default_args, super_args.defaults, self._extend_defaults):
raise AttributeError(
f"Default arguments don't match {self._override_name}")
self._verify_types(method_args, super_args, all_args)
def __call__(self, method: Method) -> Method:
"""
Apply the decorator to the given method.
"""
# Check and fail if this is a property
if isinstance(method, property):
raise AttributeError(
f"Please ensure that the {self.__class__.__name__} decorator "
"is the last decorator before the method declaration")
# Check that the name matches
if (not self._relax_name_check and
method.__name__ != self._superclass_method.__name__):
raise AttributeError(
f"{self._override_name} name "
f"{self._superclass_method.__name__} does not match "
f"{method.__name__}. Ensure {self.__class__.__name__} is the "
"last decorator before the method declaration")
# Check that the arguments match (except for __init__ as this might
# take extra arguments or pass arguments not specified)
if method.__name__ != "__init__":
self.__verify_method_arguments(method)
if (self._superclass_method.__doc__ is not None and
method.__doc__ is None):
method.__doc__ = self._superclass_method.__doc__
elif (self._extend_doc and
self._superclass_method.__doc__ is not None):
method.__doc__ = (
self._superclass_method.__doc__ + (method.__doc__ or ""))
return method
[docs]
@classmethod
def check_types(cls) -> None:
"""
If called will trigger check that all parameters are checked.
Used for testing, to avoid users being affected by the strict checks
"""
cls.__CHECK_TYPES = True