|
|
@ -13,7 +13,6 @@ |
|
|
|
# under the License. |
|
|
|
from __future__ import absolute_import |
|
|
|
|
|
|
|
import abc |
|
|
|
import functools |
|
|
|
import inspect |
|
|
|
import typing |
|
|
@ -21,23 +20,17 @@ import typing |
|
|
|
import decorator |
|
|
|
|
|
|
|
|
|
|
|
class ProtocolMeta(abc.ABCMeta): |
|
|
|
|
|
|
|
def __new__(mcls, name, bases, namespace, **kwargs): |
|
|
|
cls = super().__new__(mcls, name, bases, namespace, **kwargs) |
|
|
|
cls._is_protocol = True |
|
|
|
return cls |
|
|
|
|
|
|
|
|
|
|
|
class Protocol(abc.ABC, metaclass=ProtocolMeta): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
T = typing.TypeVar('T') |
|
|
|
def protocol(cls: type) -> type: |
|
|
|
name = cls.__name__ |
|
|
|
bases = inspect.getmro(cls)[1:] |
|
|
|
namespace = dict(cls.__dict__, |
|
|
|
_is_protocol=True, |
|
|
|
__module__=cls.__module__) |
|
|
|
return type(name, bases, namespace) |
|
|
|
|
|
|
|
|
|
|
|
def is_protocol_class(cls): |
|
|
|
return cls.__dict__.get('_is_protocol', False) |
|
|
|
return inspect.isclass(cls) and cls.__dict__.get('_is_protocol', False) |
|
|
|
|
|
|
|
|
|
|
|
def is_public_function(obj): |
|
|
@ -45,60 +38,85 @@ def is_public_function(obj): |
|
|
|
getattr(obj, '__name__', '_')[0] != '_') |
|
|
|
|
|
|
|
|
|
|
|
class CallHandler(abc.ABC): |
|
|
|
|
|
|
|
def _handle_call(self, method: typing.Callable, *args, **kwargs): |
|
|
|
raise NotImplementedError |
|
|
|
T = typing.TypeVar('T') |
|
|
|
|
|
|
|
|
|
|
|
class CallProxy(CallHandler): |
|
|
|
class CallHandlerMeta(type): |
|
|
|
|
|
|
|
_handle_call: typing.Callable |
|
|
|
def __new__(mcls, name, bases, namespace, **kwargs): |
|
|
|
protocol_class = namespace.get('protocol_class') |
|
|
|
if protocol_class is not None: |
|
|
|
proxy_class = call_proxy_class(protocol_class) |
|
|
|
bases += proxy_class, |
|
|
|
return super().__new__(mcls, name, bases, namespace, **kwargs) |
|
|
|
|
|
|
|
def __init__(self, handle_call: typing.Callable): |
|
|
|
setattr(self, '_handle_call', handle_call) |
|
|
|
|
|
|
|
class CallHandler(metaclass=CallHandlerMeta): |
|
|
|
|
|
|
|
@functools.lru_cache() |
|
|
|
def call_proxy_class(protocol_class: type, |
|
|
|
class_name: typing.Optional[str] = None, |
|
|
|
handler_class: typing.Type[CallHandler] = CallProxy) \ |
|
|
|
-> type: |
|
|
|
if not is_protocol_class(protocol_class): |
|
|
|
raise TypeError(f"{protocol_class} is not a subclass of {Protocol}") |
|
|
|
if class_name is None: |
|
|
|
class_name = protocol_class.__name__ + 'Proxy' |
|
|
|
namespace: typing.Dict[str, typing.Any] = {} |
|
|
|
for name, member in protocol_class.__dict__.items(): |
|
|
|
if is_public_function(member): |
|
|
|
method = call_proxy_method(member) |
|
|
|
namespace[name] = method |
|
|
|
protocol_class: type |
|
|
|
|
|
|
|
return type(class_name, (handler_class, protocol_class), namespace) |
|
|
|
def __init__(self, |
|
|
|
handle_call: typing.Optional[typing.Callable] = None): |
|
|
|
if handle_call is not None: |
|
|
|
assert callable(handle_call) |
|
|
|
setattr(self, '_handle_call', handle_call) |
|
|
|
|
|
|
|
def _handle_call(self, method: typing.Callable, *args, **kwargs): |
|
|
|
pass |
|
|
|
|
|
|
|
def call_proxy(protocol_class: typing.Type[T], handle_call: typing.Callable) \ |
|
|
|
-> T: |
|
|
|
proxy_class = call_proxy_class(typing.cast(type, protocol_class)) |
|
|
|
return proxy_class(handle_call) |
|
|
|
def use_as(self, cls: typing.Type[T]) -> T: |
|
|
|
assert isinstance(self, cls) |
|
|
|
return typing.cast(T, self) |
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache() |
|
|
|
def stack_classes(name: str, cls: type, *classes) -> type: |
|
|
|
return type(name, (cls,) + classes, {}) |
|
|
|
def call_proxy_class( |
|
|
|
cls: type, |
|
|
|
*bases: type, |
|
|
|
class_name: typing.Optional[str] = None, |
|
|
|
namespace: typing.Optional[dict] = None) \ |
|
|
|
-> type: |
|
|
|
if not inspect.isclass(cls): |
|
|
|
raise TypeError(f"Object {cls} is not a class") |
|
|
|
if class_name is None: |
|
|
|
class_name = cls.__name__ + 'Proxy' |
|
|
|
protocol_classes = list_protocols(cls) |
|
|
|
if not protocol_classes: |
|
|
|
raise TypeError(f"Class {cls} doesn't implement any protocol") |
|
|
|
if namespace is None: |
|
|
|
namespace = {} |
|
|
|
for protocol_class in reversed(protocol_classes): |
|
|
|
for name, member in protocol_class.__dict__.items(): |
|
|
|
if is_public_function(member): |
|
|
|
method = call_proxy_method(member) |
|
|
|
namespace[name] = method |
|
|
|
# Skip empty protocols |
|
|
|
if not namespace: |
|
|
|
raise TypeError(f"Class {cls} has any protocol specification") |
|
|
|
namespace['__module__'] = cls.__module__ |
|
|
|
proxy_class = type(class_name, bases + protocol_classes, namespace) |
|
|
|
assert not is_protocol_class(proxy_class) |
|
|
|
assert not is_protocol_class(proxy_class) |
|
|
|
return proxy_class |
|
|
|
|
|
|
|
|
|
|
|
def call_proxy(cls: type, handle_call: typing.Callable) -> CallHandler: |
|
|
|
proxy_class = call_proxy_class(cls, CallHandler) |
|
|
|
return proxy_class(handle_call) |
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache() |
|
|
|
def list_protocols(cls: type) -> typing.Tuple[typing.Type[Protocol], ...]: |
|
|
|
def list_protocols(cls: type) -> typing.Tuple[type, ...]: |
|
|
|
subclasses = inspect.getmro(cls) |
|
|
|
protocols = tuple(typing.cast(typing.Type[Protocol], cls) |
|
|
|
protocols = tuple(cls |
|
|
|
for cls in subclasses |
|
|
|
if is_protocol_class(cls)) |
|
|
|
return tuple(protocols) |
|
|
|
|
|
|
|
|
|
|
|
def call_proxy_method(func: typing.Callable) -> typing.Callable: |
|
|
|
return decorator.decorate(func, _call_proxy_method) |
|
|
|
method = decorator.decorate(func, _call_proxy_method) |
|
|
|
assert method is not func |
|
|
|
return method |
|
|
|
|
|
|
|
|
|
|
|
def _call_proxy_method(func, self: CallHandler, *args, **kwargs): |
|
|
|