| #!/usr/bin/env python3 |
| # |
| # Copyright 2023 The Fuchsia Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from __future__ import annotations |
| |
| from typing import Collection, Literal, Mapping, TypeGuard, TypeVar, overload |
| |
| from mobly import signals |
| |
| |
| class ValidatorError(signals.TestAbortClass): |
| pass |
| |
| |
| class FieldNotFoundError(ValidatorError): |
| pass |
| |
| |
| class FieldTypeError(ValidatorError): |
| pass |
| |
| |
| T = TypeVar("T") |
| |
| |
| class _NO_DEFAULT: |
| pass |
| |
| |
| class MapValidator: |
| def __init__(self, map: Mapping[str, object]) -> None: |
| self.map = map |
| |
| @overload |
| def get(self, type: type[T], key: str, default: None) -> T | None: |
| ... |
| |
| @overload |
| def get( |
| self, type: type[T], key: str, default: T | _NO_DEFAULT = _NO_DEFAULT() |
| ) -> T: |
| ... |
| |
| def get( |
| self, type: type[T], key: str, default: T | None | _NO_DEFAULT = _NO_DEFAULT() |
| ) -> T | None: |
| """Access the map requiring a value type at the specified key. |
| |
| If default is set and the map does not contain the specified key, the |
| default will be returned. |
| |
| Args: |
| type: Expected type of the value |
| key: Key to index into the map with |
| default: Default value when the map does not contain key |
| |
| Returns: |
| Value of the expected type, or None if default is None. |
| |
| Raises: |
| FieldNotFound: when default is not set and the map does not contain |
| the specified key |
| FieldTypeError: when the value at the specified key is not the |
| expected type |
| """ |
| if key not in self.map: |
| if isinstance(default, type) or default is None: |
| return default |
| raise FieldNotFoundError( |
| f'Required field "{key}" is missing; expected {type.__name__}' |
| ) |
| val = self.map[key] |
| if val is None and default is None: |
| return None |
| if not isinstance(val, type): |
| raise FieldTypeError( |
| f'Expected "{key}" to be {type.__name__}, got {describe_type(val)}' |
| ) |
| return val |
| |
| @overload |
| def list(self, key: str) -> ListValidator: |
| ... |
| |
| @overload |
| def list(self, key: str, optional: Literal[False]) -> ListValidator: |
| ... |
| |
| @overload |
| def list(self, key: str, optional: Literal[True]) -> ListValidator | None: |
| ... |
| |
| def list(self, key: str, optional: bool = False) -> ListValidator | None: |
| """Access the map requiring a list at the specified key. |
| |
| If optional is True and the map does not contain the specified key, None |
| will be returned. |
| |
| Args: |
| key: Key to index into the map with |
| optional: If True, will return None if the map does not contain key |
| |
| Returns: |
| ListValidator or None if optional is True. |
| |
| Raises: |
| FieldNotFound: when optional is False and the map does not contain |
| the specified key |
| FieldTypeError: when the value at the specified key is not a list |
| """ |
| if key not in self.map: |
| if optional: |
| return None |
| raise FieldNotFoundError( |
| f'Required field "{key}" is missing; expected list' |
| ) |
| return ListValidator(key, self.get(list, key)) |
| |
| |
| class ListValidator: |
| def __init__(self, name: str, val: list[object]) -> None: |
| self.name = name |
| self.val = val |
| |
| def all(self, type: type[T]) -> list[T]: |
| """Access the list requiring all elements to be the specified type. |
| |
| Args: |
| type: Expected type of all elements |
| |
| Raises: |
| FieldTypeError: when an element is not the expected type |
| """ |
| if not is_list_of(self.val, type): |
| raise FieldTypeError( |
| f'Expected "{self.name}" to be list[{type.__name__}], ' |
| f"got {describe_type(self.val)}" |
| ) |
| return self.val |
| |
| |
| def describe_type(o: object) -> str: |
| """Describe the complete type of the object. |
| |
| Different from type() by recursing when a mapping or collection is found. |
| """ |
| if isinstance(o, Mapping): |
| keys = set([describe_type(k) for k in o.keys()]) |
| values = set([describe_type(v) for v in o.values()]) |
| return f'dict[{" | ".join(keys)}, {" | ".join(values)}]' |
| if isinstance(o, Collection) and not isinstance(o, str): |
| elements = set([describe_type(x) for x in o]) |
| return f'list[{" | ".join(elements)}]' |
| return type(o).__name__ |
| |
| |
| def is_list_of(val: list[object], type: type[T]) -> TypeGuard[list[T]]: |
| return all(isinstance(x, type) for x in val) |