diff --git a/.github/workflows/generator.yml b/.github/workflows/generator.yml index 2709ab6d..51a39e8c 100644 --- a/.github/workflows/generator.yml +++ b/.github/workflows/generator.yml @@ -21,4 +21,8 @@ jobs: cd flatdata-generator uv run --with pytest pytest -v pip install . - flatdata-generator --help \ No newline at end of file + flatdata-generator --help + - name: Type check + run: | + cd flatdata-generator + uv run --with mypy mypy flatdata/ diff --git a/.github/workflows/py.yml b/.github/workflows/py.yml index dd9b7117..0303cbaa 100644 --- a/.github/workflows/py.yml +++ b/.github/workflows/py.yml @@ -21,7 +21,11 @@ jobs: cd flatdata-py uv venv uv pip install ../flatdata-generator - uv pip install ".[inspector]" pytest + uv pip install ".[inspector]" pytest mypy .venv/bin/pytest -v .venv/bin/flatdata-inspector --help + - name: Type check + run: | + cd flatdata-py + .venv/bin/mypy flatdata/ diff --git a/flatdata-generator/flatdata/generator/app.py b/flatdata-generator/flatdata/generator/app.py index 1dc26016..4e13bfd8 100755 --- a/flatdata-generator/flatdata/generator/app.py +++ b/flatdata-generator/flatdata/generator/app.py @@ -21,7 +21,7 @@ from flatdata.generator.tree.errors import FlatdataSyntaxError -def _parse_command_line(): +def _parse_command_line() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Generates code for a given flatdata schema file.") parser.add_argument("-s", "--schema", type=str, required=True, @@ -39,7 +39,7 @@ def _parse_command_line(): return parser.parse_args() -def _setup_logging(args): +def _setup_logging(args: argparse.Namespace) -> None: level = logging.WARNING if args.debug: level = logging.DEBUG @@ -52,13 +52,13 @@ def _setup_logging(args): level=level) -def _check_args(args): +def _check_args(args: argparse.Namespace) -> None: if not os.path.isfile(args.schema): logging.fatal("Cannot find schema file at %s", args.schema) sys.exit(1) -def _run(args): +def _run(args: argparse.Namespace) -> None: _setup_logging(args) _check_args(args) @@ -86,6 +86,6 @@ def _run(args): logging.info("Code for %s is written to %s", args.gen, args.output_file) -def main(): +def main() -> None: """Entrypoint""" _run(_parse_command_line()) \ No newline at end of file diff --git a/flatdata-generator/flatdata/generator/engine.py b/flatdata-generator/flatdata/generator/engine.py index bf6a6d9e..34c8fa29 100644 --- a/flatdata-generator/flatdata/generator/engine.py +++ b/flatdata-generator/flatdata/generator/engine.py @@ -4,10 +4,12 @@ ''' import types +from typing import overload from flatdata.generator.tree.builder import build_ast from flatdata.generator.tree.nodes.trivial.namespace import Namespace from flatdata.generator.tree.nodes.node import Node +from flatdata.generator.tree.syntax_tree import SyntaxTree from .generators.cpp import CppGenerator from .generators.dot import DotGenerator @@ -15,6 +17,7 @@ from .generators.python import PythonGenerator from .generators.rust import RustGenerator from .generators.flatdata import FlatdataGenerator +from .generators import BaseGenerator class Engine: @@ -23,7 +26,7 @@ class Engine: Implements code generation from the given flatdata schema. """ - _GENERATORS = { + _GENERATORS: dict[str, type[BaseGenerator]] = { "cpp": CppGenerator, "dot": DotGenerator, "go": GoGenerator, @@ -33,13 +36,13 @@ class Engine: } @classmethod - def available_generators(cls): + def available_generators(cls) -> list[str]: """ Lists names of available code generators. """ return list(cls._GENERATORS.keys()) - def __init__(self, schema): + def __init__(self, schema: str) -> None: """ Instantiates generator engine for a given schema. :raises FlatdataSyntaxError @@ -47,7 +50,7 @@ def __init__(self, schema): self.schema = schema self.tree = build_ast(schema) - def render(self, generator_name): + def render(self, generator_name: str) -> str: """ Render schema with a given generator :param generator_name: @@ -60,9 +63,16 @@ def render(self, generator_name): ) output_content = generator.render(self.tree) - return output_content + return str(output_content) - def render_python_module(self, module_name=None, archive_name=None, root_namespace=None): + @overload + def render_python_module(self, module_name: str | None, archive_name: str, root_namespace: str | None = None) -> tuple[types.ModuleType, type]: ... + @overload + def render_python_module(self, *, archive_name: str, root_namespace: str | None = None) -> tuple[types.ModuleType, type]: ... + @overload + def render_python_module(self, module_name: str | None = None, archive_name: None = None, root_namespace: str | None = None) -> types.ModuleType: ... + + def render_python_module(self, module_name: str | None = None, archive_name: str | None = None, root_namespace: str | None = None) -> types.ModuleType | tuple[types.ModuleType, type]: """ Render python module. :param module_name: Module name to use. If none, root namespace name is used. @@ -70,28 +80,28 @@ def render_python_module(self, module_name=None, archive_name=None, root_namespa if specified, archive type is returned along with the model :param root_namespace: Root namespace to pick in case of multiple top level namespaces. """ - root_namespace = self._find_root_namespace(self.tree, archive_name, root_namespace) + ns = self._find_root_namespace(self.tree, archive_name, root_namespace) module_code = self.render("py") - module = types.ModuleType(module_name if module_name is not None else root_namespace.name) + module = types.ModuleType(module_name if module_name is not None else ns.name) #pylint: disable=exec-used exec(module_code, module.__dict__) if archive_name is None: return module - name = root_namespace.name + "_" + archive_name - archive_type = getattr(module, name) if archive_name else None + name = ns.name + "_" + archive_name + archive_type = getattr(module, name) return module, archive_type @classmethod - def _create_generator(cls, name): + def _create_generator(cls, name: str) -> BaseGenerator | None: generator_type = cls._GENERATORS.get(name, None) if generator_type is None: return None - return generator_type() + return generator_type() # type: ignore[call-arg] # dict values are concrete subclasses with zero-arg __init__ @staticmethod - def _find_root_namespace(tree, archive_name, root_namespace=None): + def _find_root_namespace(tree: SyntaxTree, archive_name: str | None, root_namespace: str | None = None) -> Namespace: root_children = tree.root.children root_namespaces = [ child for child in root_children diff --git a/flatdata-generator/flatdata/generator/generators/__init__.py b/flatdata-generator/flatdata/generator/generators/__init__.py index bbdb6c14..b4439375 100644 --- a/flatdata-generator/flatdata/generator/generators/__init__.py +++ b/flatdata-generator/flatdata/generator/generators/__init__.py @@ -4,10 +4,13 @@ ''' from abc import ABCMeta, abstractmethod +from typing import NoReturn + from jinja2 import Environment, PackageLoader from jinja2 import nodes from jinja2.ext import Extension from jinja2.exceptions import TemplateRuntimeError +from jinja2.parser import Parser from flatdata.generator.tree.nodes.archive import Archive from flatdata.generator.tree.nodes.trivial import Structure, Enumeration, Constant, Namespace @@ -21,21 +24,21 @@ class BaseGenerator(metaclass=ABCMeta): """Abstract base class for Flatdata generators""" - def __init__(self, template): + def __init__(self, template: str) -> None: self._template = template @abstractmethod - def supported_nodes(self): + def supported_nodes(self) -> list[type]: """List of supported nodes by this generator""" raise RuntimeError( "Derived generators must implement _supported_nodes") @abstractmethod - def _populate_environment(self, env): + def _populate_environment(self, env: Environment) -> None: raise RuntimeError( "Derived generators must implement _populate_filters") - def render(self, tree): + def render(self, tree: SyntaxTree) -> str: """Generate the language implementation from the AST""" env = Environment(loader=PackageLoader('flatdata.generator', 'templates'), lstrip_blocks=True, trim_blocks=True, autoescape=False, extensions=[RaiseExtension]) @@ -71,7 +74,7 @@ class RaiseExtension(Extension): tags = set(['raise']) - def parse(self, parser): + def parse(self, parser: Parser) -> nodes.CallBlock: """The first token is the line number, followed by the expression""" lineno = next(parser.stream).lineno message_node = parser.parse_expression() @@ -81,6 +84,6 @@ def parse(self, parser): ) #pylint: disable=no-self-use - def _raise(self, msg, caller): + def _raise(self, msg: str, caller: object) -> NoReturn: """Helper callback.""" raise TemplateRuntimeError(msg) diff --git a/flatdata-generator/flatdata/generator/generators/cpp.py b/flatdata-generator/flatdata/generator/generators/cpp.py index 3d9cc58f..3edce7fd 100644 --- a/flatdata-generator/flatdata/generator/generators/cpp.py +++ b/flatdata-generator/flatdata/generator/generators/cpp.py @@ -3,6 +3,12 @@ See the LICENSE file in the root of this project for license details. ''' +from jinja2 import Environment + +from flatdata.generator.tree.helpers.basictype import BasicType +from flatdata.generator.tree.helpers.enumtype import EnumType +from flatdata.generator.tree.nodes.node import Node +from flatdata.generator.tree.nodes.references import BuiltinStructureReference, StructureReference from flatdata.generator.tree.nodes.resources import Vector, Multivector, Instance, RawData, BoundResource, \ ResourceBase, Archive as ArchiveResource from flatdata.generator.tree.nodes.trivial import Structure, Enumeration, Constant, Field @@ -13,21 +19,21 @@ class CppGenerator(BaseGenerator): """Flatdata to C++ header file generator""" - def __init__(self): + def __init__(self) -> None: BaseGenerator.__init__(self, "cpp/cpp.jinja2") - def supported_nodes(self): + def supported_nodes(self) -> list[type]: return [Structure, Archive, Constant, Enumeration] - def _populate_environment(self, env): + def _populate_environment(self, env: Environment) -> None: env.filters["cpp_doc"] = lambda value: value - def _safe_cpp_string_line(value): + def _safe_cpp_string_line(value: str) -> str: return value.replace('\\', '\\\\').replace('"', r'\"') env.filters["safe_cpp_string_line"] = _safe_cpp_string_line - def _cpp_base_type(flatdata_type): + def _cpp_base_type(flatdata_type: BasicType | EnumType | Node) -> str: type_map = { "bool": "bool", "i8": "int8_t", @@ -41,28 +47,28 @@ def _cpp_base_type(flatdata_type): } if flatdata_type.name in type_map: return type_map[flatdata_type.name] - return flatdata_type.name.replace("@@", "::").replace("@", "::") + return str(flatdata_type.name.replace("@@", "::").replace("@", "::")) env.filters["cpp_base_type"] = _cpp_base_type - def _to_type_params(refs): + def _to_type_params(refs: list[BuiltinStructureReference | StructureReference]) -> str: return ', '.join([ref.node.path_with("::") for ref in refs]) env.filters["to_type_params"] = _to_type_params - def _snake_to_upper_camel_case(expr): + def _snake_to_upper_camel_case(expr: str) -> str: return ''.join(p.title() for p in expr.split('_')) env.filters["snake_to_upper_camel_case"] = _snake_to_upper_camel_case - def _typedef_name(entity, extra_suffix=""): + def _typedef_name(entity: Field | ResourceBase, extra_suffix: str = "") -> str: assert isinstance(entity, (Field, ResourceBase)), "Got: %s" % entity.__class__ return _snake_to_upper_camel_case(entity.name) + extra_suffix + "Type" env.filters["typedef_name"] = _typedef_name - def _optional_typedef_usage(resource, extra_suffix=""): - def _wrap_in_optional(declaration): + def _optional_typedef_usage(resource: ResourceBase, extra_suffix: str = "") -> str: + def _wrap_in_optional(declaration: str) -> str: if resource.optional: return "boost::optional< %s >" % declaration return declaration @@ -71,7 +77,7 @@ def _wrap_in_optional(declaration): env.filters["archive_typedef_usage"] = _optional_typedef_usage - def _resource_provides_incremental_builder(resource): + def _resource_provides_incremental_builder(resource: ResourceBase) -> bool: assert isinstance(resource, ResourceBase) if isinstance(resource, Instance): return False @@ -86,7 +92,7 @@ def _resource_provides_incremental_builder(resource): env.filters[ "resource_provides_incremental_builder"] = _resource_provides_incremental_builder - def provides_setter(resource): + def provides_setter(resource: ResourceBase) -> bool: assert isinstance(resource, ResourceBase) if isinstance(resource, Instance): return True diff --git a/flatdata-generator/flatdata/generator/generators/dot.py b/flatdata-generator/flatdata/generator/generators/dot.py index 02e97d5f..de190bd1 100644 --- a/flatdata-generator/flatdata/generator/generators/dot.py +++ b/flatdata-generator/flatdata/generator/generators/dot.py @@ -4,8 +4,11 @@ ''' from flatdata.generator.tree.nodes.archive import Archive +from flatdata.generator.tree.nodes.trivial import Field from . import BaseGenerator +from jinja2 import Environment + SCOPE_SEPARATOR = "__" DECORATION_BOUND = "__bound__" @@ -13,15 +16,18 @@ class DotGenerator(BaseGenerator): """Flatdata to DOT (graph description language) generator""" - def __init__(self): + def __init__(self) -> None: BaseGenerator.__init__(self, "dot/dot.jinja2") - def _populate_environment(self, env): + def _populate_environment(self, env: Environment) -> None: env.autoescape = True - def _field_value_type(field): - type_name = field.type.name.replace("@@", ".").replace("@", ".") - namespace_name = field.parent.parent.path + def _field_value_type(field: Field) -> str: + assert field.type is not None + assert field.parent is not None + assert field.parent.parent is not None + type_name = str(field.type.name).replace("@@", ".").replace("@", ".") + namespace_name = str(field.parent.parent.path) if type_name.startswith(namespace_name): type_name = type_name[len(namespace_name):] if type_name.startswith("."): @@ -31,5 +37,5 @@ def _field_value_type(field): env.filters["field_value_type"] = _field_value_type - def supported_nodes(self): + def supported_nodes(self) -> list[type]: return [Archive] diff --git a/flatdata-generator/flatdata/generator/generators/flatdata.py b/flatdata-generator/flatdata/generator/generators/flatdata.py index b5efbc79..8a522a3d 100644 --- a/flatdata-generator/flatdata/generator/generators/flatdata.py +++ b/flatdata-generator/flatdata/generator/generators/flatdata.py @@ -3,6 +3,10 @@ See the LICENSE file in the root of this project for license details. ''' +from jinja2 import Environment + +from flatdata.generator.tree.nodes.node import Node +from flatdata.generator.tree.nodes.references import BuiltinStructureReference, StructureReference from flatdata.generator.tree.nodes.resources import BoundResource from flatdata.generator.tree.nodes.trivial import Structure, Enumeration, Constant from flatdata.generator.tree.nodes.archive import Archive @@ -13,26 +17,26 @@ class FlatdataGenerator(BaseGenerator): """Flatdata to Flatdata generator, used for debugging/testing""" - def __init__(self): + def __init__(self) -> None: BaseGenerator.__init__(self, "flatdata/flatdata.jinja2") - def supported_nodes(self): + def supported_nodes(self) -> list[type]: return [Structure, Archive, Constant, Enumeration] - def _populate_environment(self, env): - def _is_builtin(node): + def _populate_environment(self, env: Environment) -> None: + def _is_builtin(node: Node) -> bool: for namespace in SyntaxTree.namespaces(node): if namespace.name == "_builtin": return True return False env.filters["filter_builtin"] = lambda l: [x for x in l if not _is_builtin(x)] - def _field_type(flatdata_type): + def _field_type(flatdata_type: str) -> str: return flatdata_type.replace("@@", ".").replace("@", ".") env.filters["field_type"] = _field_type - def _to_type_params(refs): + def _to_type_params(refs: list[BuiltinStructureReference | StructureReference]) -> str: return ', '.join([ref.node.path_with(".") for ref in refs if not _is_builtin(ref.node)]) env.filters["to_type_params"] = _to_type_params diff --git a/flatdata-generator/flatdata/generator/generators/go.py b/flatdata-generator/flatdata/generator/generators/go.py index e4e5a0c4..b526b11c 100644 --- a/flatdata-generator/flatdata/generator/generators/go.py +++ b/flatdata-generator/flatdata/generator/generators/go.py @@ -2,11 +2,15 @@ Copyright (c) 2017 HERE Europe B.V. See the LICENSE file in the root of this project for license details. ''' +from jinja2 import Environment + from flatdata.generator.tree.nodes.archive import Archive from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.resources import Instance, Vector, Multivector, RawData from flatdata.generator.tree.nodes.resources.archive import Archive as ArchiveResource +from flatdata.generator.tree.nodes.resources.base import ResourceBase from flatdata.generator.tree.nodes.trivial import Structure, Constant +from flatdata.generator.tree.syntax_tree import SyntaxTree from . import BaseGenerator @@ -14,32 +18,32 @@ class GoGenerator(BaseGenerator): """Flatdata to Go generator""" - def __init__(self): + def __init__(self) -> None: BaseGenerator.__init__(self, "go/go.jinja2") - def supported_nodes(self): + def supported_nodes(self) -> list[type]: return [Structure, Archive, Constant] - def _populate_environment(self, env): - def _decorate_archive_type(value): + def _populate_environment(self, env: Environment) -> None: + def _decorate_archive_type(value: Node) -> str: assert isinstance(value, Node) - return value.name + return str(value.name) - def to_go_doc(value): - lines = value.doc.splitlines() + def to_go_doc(value: object) -> str: + lines = value.doc.splitlines() # type: ignore[attr-defined] return '\n'.join(["// " + s for s in lines if len(s) != 0]) - def type_mapping(flatdata_type, _struct): + def type_mapping(flatdata_type: str, _struct: object) -> str: if is_bool(flatdata_type): return "uint8" return go_mapping(flatdata_type) - def type_mapping_with_bool(flatdata_type): + def type_mapping_with_bool(flatdata_type: str) -> str: if is_bool(flatdata_type): return "bool" return go_mapping(flatdata_type) - def go_mapping(flatdata_type): + def go_mapping(flatdata_type: str) -> str: return { "i8": "int8", "u8": "uint8", @@ -51,15 +55,15 @@ def go_mapping(flatdata_type): "i64": "int64" }[flatdata_type] - def is_bool(flatdata_type): + def is_bool(flatdata_type: str) -> bool: return flatdata_type == "bool" - def to_go_case(name, exported=True): + def to_go_case(name: str, exported: bool = True) -> str: if "_" in name: name = "".join(part.title() for part in name.split("_")) return (str.upper if exported else str.lower)(str(name[0])) + str(name[1:]) - def to_initializer(resource, tree): + def to_initializer(resource: ResourceBase, tree: SyntaxTree) -> str: if isinstance(resource, Instance): return _decorate_archive_type(resource.referenced_structures[0].node) if isinstance(resource, Vector): @@ -70,15 +74,15 @@ def to_initializer(resource, tree): for t in resource.referenced_structures] )) if isinstance(resource, ArchiveResource): - return _decorate_archive_type(resource.children[0].node) + return _decorate_archive_type(resource.children[0].node) # type: ignore[attr-defined] # child is an ArchiveReference which has .node if isinstance(resource, RawData): return "None" raise ValueError("Unknown resource type: %s" % (resource.__class__)) - def get_types_for_multivector(resource, _tree): + def get_types_for_multivector(resource: Multivector, _tree: SyntaxTree) -> list[str]: return [_decorate_archive_type(t.node) for t in resource.referenced_structures] - def contains_archive_resource(tree): + def contains_archive_resource(tree: SyntaxTree) -> bool: for child in tree.root.children[0].children: for res in child.children: if isinstance(res, ArchiveResource): diff --git a/flatdata-generator/flatdata/generator/generators/python.py b/flatdata-generator/flatdata/generator/generators/python.py index 93b69f4f..0f7b0c19 100644 --- a/flatdata-generator/flatdata/generator/generators/python.py +++ b/flatdata-generator/flatdata/generator/generators/python.py @@ -2,36 +2,41 @@ Copyright (c) 2017 HERE Europe B.V. See the LICENSE file in the root of this project for license details. ''' +from jinja2 import Environment + from flatdata.generator.tree.nodes.resources import Instance, Vector, Multivector, RawData from flatdata.generator.tree.nodes.resources.archive import Archive as ArchiveResource -from flatdata.generator.tree.nodes.trivial import Structure +from flatdata.generator.tree.nodes.resources.base import ResourceBase +from flatdata.generator.tree.nodes.trivial import Structure, Field +from flatdata.generator.tree.nodes.trivial.enumeration import Enumeration from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.archive import Archive +from flatdata.generator.tree.syntax_tree import SyntaxTree from . import BaseGenerator class PythonGenerator(BaseGenerator): """Flatdata to Python generator""" - def __init__(self): + def __init__(self) -> None: BaseGenerator.__init__(self, "py/python.jinja2") - def supported_nodes(self): + def supported_nodes(self) -> list[type]: return [Structure, Archive] - def _populate_environment(self, env): - def _decorate_archive_type(tree, value): + def _populate_environment(self, env: Environment) -> None: + def _decorate_archive_type(tree: SyntaxTree, value: Node) -> str: assert isinstance(value, Node) - return tree.namespace_path(value, "_") + "_" + value.name + return str(tree.namespace_path(value, "_") + "_" + value.name) - def to_python_doc(value): + def to_python_doc(value: str) -> str: return '\n'.join( ["# " + line.replace('/**', '', 1).replace('*/', '', 1).replace(" *", '', 1).replace("//", "", 1) for line in value.splitlines()]) - def to_container(resource): + def to_container(resource: ResourceBase) -> str: if isinstance(resource, Instance): return "flatdata.resources.Instance" if isinstance(resource, Vector): @@ -44,22 +49,22 @@ def to_container(resource): return "flatdata.archive.Archive" raise ValueError("Unknown resource type: %s" % (resource.__class__)) - def to_initializer(resource, tree): + def to_initializer(resource: ResourceBase, tree: SyntaxTree) -> str: if isinstance(resource, Instance): - return _decorate_archive_type(tree, resource.referenced_structures[0].node) + return str(_decorate_archive_type(tree, resource.referenced_structures[0].node)) if isinstance(resource, Vector): - return _decorate_archive_type(tree, resource.referenced_structures[0].node) + return str(_decorate_archive_type(tree, resource.referenced_structures[0].node)) if isinstance(resource, Multivector): return "[{}]".format( ','.join([_decorate_archive_type(tree, t.node) for t in resource.referenced_structures])) if isinstance(resource, ArchiveResource): - return _decorate_archive_type(tree, resource.children[0].node) + return str(_decorate_archive_type(tree, resource.children[0].node)) # type: ignore[attr-defined] # child is an ArchiveReference which has .node if isinstance(resource, RawData): return "None" raise ValueError("Unknown resource type: %s" % (resource.__class__)) - def to_dtype(field): + def to_dtype(field: Field) -> str: type_map = { "bool": "?", "i8": "b", @@ -71,11 +76,16 @@ def to_dtype(field): "u64": "u8", "i64": "i8" } + assert field.type is not None if field.type.name in type_map: return type_map[field.type.name] - return type_map[field.type_reference.node.type.name] + assert field.type_reference is not None + enum_node = field.type_reference.node + assert isinstance(enum_node, Enumeration) + assert enum_node.type is not None + return str(type_map[enum_node.type.name]) - def _safe_py_string_line(value): + def _safe_py_string_line(value: str) -> str: return value.replace('\\', '\\\\').replace('"', r'\"') env.filters["safe_py_string_line"] = _safe_py_string_line diff --git a/flatdata-generator/flatdata/generator/generators/rust.py b/flatdata-generator/flatdata/generator/generators/rust.py index 7751fee9..80c4fe00 100644 --- a/flatdata-generator/flatdata/generator/generators/rust.py +++ b/flatdata-generator/flatdata/generator/generators/rust.py @@ -4,9 +4,12 @@ ''' import re +from jinja2 import Environment + +from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.resources import (Vector, Multivector, Instance, RawData, BoundResource, Archive as ArchiveResource) -from flatdata.generator.tree.nodes.trivial import Structure, Constant, Enumeration +from flatdata.generator.tree.nodes.trivial import Structure, Constant, Enumeration, Field from flatdata.generator.tree.helpers.enumtype import EnumType from flatdata.generator.tree.nodes.archive import Archive from flatdata.generator.tree.syntax_tree import SyntaxTree @@ -23,14 +26,14 @@ class RustGenerator(BaseGenerator): "pure", "ref", "return", "self", "sizeof", "static", "struct", "super", "trait", "true", "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield"] - def __init__(self): + def __init__(self) -> None: BaseGenerator.__init__(self, "rust/rust.jinja2") - def supported_nodes(self): + def supported_nodes(self) -> list[type]: return [Structure, Archive, Constant, Enumeration] @staticmethod - def _format_numeric_literal(value): + def _format_numeric_literal(value: str) -> str: try: # only apply this to integer values number = int(value) @@ -40,19 +43,19 @@ def _format_numeric_literal(value): except ValueError: return value - def _populate_environment(self, env): - def _camel_to_snake_case(expr): + def _populate_environment(self, env: Environment) -> None: + def _camel_to_snake_case(expr: str) -> str: step1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', expr) return re.sub('([a-z0-9])(A-Z)', r'\1_\2', step1).lower() env.filters["camel_to_snake_case"] = _camel_to_snake_case - def _snake_to_upper_camel_case(expr): + def _snake_to_upper_camel_case(expr: str) -> str: return ''.join(p.title() for p in expr.split('_')) env.filters["snake_to_upper_camel_case"] = _snake_to_upper_camel_case - def _rust_doc(expr): + def _rust_doc(expr: str) -> str: lines = [ re.sub(r'^[ \t]*(/\*\*\s?|/\*\s?|\*/|\*\s?)(.*?)\s*(\*/)?$', r"/// \2", line).strip() @@ -68,24 +71,32 @@ def _rust_doc(expr): env.filters["rust_doc"] = _rust_doc - def _escape_rust_keywords(expr): + def _escape_rust_keywords(expr: str) -> str: if expr in self.RESERVED_KEYWORDS: return "{}_".format(expr) return expr - def _field_type(field): + def _field_type(field: Field) -> str: + assert field.type is not None if isinstance(field.type, EnumType): + assert field.type_reference is not None + assert field.parent is not None return "{}".format( _fully_qualified_name(field.parent, field.type_reference.node)) return "{}".format(field.type.name) - def _primitive_type(field): + def _primitive_type(field: Field) -> str: + assert field.type is not None if isinstance(field.type, EnumType): - return "{}".format(field.type_reference.node.type.name) + assert field.type_reference is not None + enum_node = field.type_reference.node + assert isinstance(enum_node, Enumeration) + assert enum_node.type is not None + return "{}".format(enum_node.type.name) return "{}".format(field.type.name) - def _fully_qualified_name(current, node): - return "::".join((current.path_depth() - 1) * ["super"]) + node.path_with("::") + def _fully_qualified_name(current: Node, node: Node) -> str: + return "::".join((current.path_depth() - 1) * ["super"]) + str(node.path_with("::")) env.globals["fully_qualified_name"] = _fully_qualified_name env.filters["escape_rust_keywords"] = _escape_rust_keywords diff --git a/flatdata-generator/flatdata/generator/grammar.py b/flatdata-generator/flatdata/generator/grammar.py index 034c5095..37f5ac06 100644 --- a/flatdata-generator/flatdata/generator/grammar.py +++ b/flatdata-generator/flatdata/generator/grammar.py @@ -8,7 +8,8 @@ from pyparsing import ( Word, alphas, alphanums, nums, cppStyleComment, Keyword, Group, Optional, Or, OneOrMore, delimitedList, ZeroOrMore, - hexnums, Combine, FollowedBy, ParseException as pyparsingParseException + hexnums, Combine, FollowedBy, ParseException as pyparsingParseException, + ParseResults ) ParseException = pyparsingParseException @@ -122,7 +123,7 @@ single_object("object") ) -def _combine_list(t): +def _combine_list(t: ParseResults) -> str: return "".join(t[0].asList()) explicit_field_reference_prefix = Group( diff --git a/flatdata-generator/flatdata/generator/py.typed b/flatdata-generator/flatdata/generator/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/flatdata-generator/flatdata/generator/tree/builder.py b/flatdata-generator/flatdata/generator/tree/builder.py index bc53a9de..9d623d72 100644 --- a/flatdata-generator/flatdata/generator/tree/builder.py +++ b/flatdata-generator/flatdata/generator/tree/builder.py @@ -28,12 +28,12 @@ from .resolver import resolve_references -def _create_nested_namespaces(path): +def _create_nested_namespaces(path: str) -> tuple[nodes.Namespace, nodes.Namespace]: assert not path.startswith(Node.PATH_SEPARATOR) splitpath = Node.splitpath(path) root = nodes.Namespace(name=splitpath[0]) - node = root + node: nodes.Namespace = root for name in splitpath[1:]: new_node = nodes.Namespace(name=name) node.insert(new_node) @@ -41,7 +41,7 @@ def _create_nested_namespaces(path): return root, node -def _ensure_namespace(root, path): +def _ensure_namespace(root: Root, path: str) -> nodes.Namespace: assert isinstance(root, Root) assert path.startswith( Node.PATH_SEPARATOR), "This method only works with root-level paths" @@ -52,13 +52,14 @@ def _ensure_namespace(root, path): return found last_common_parent = root.find_last(path) + assert last_common_parent is not None first, last = _create_nested_namespaces( path[len(last_common_parent.path) + 1:]) last_common_parent.insert(first) return last -def _innermost_namespace(root): +def _innermost_namespace(root: Node) -> nodes.Namespace | None: if not isinstance(root, nodes.Namespace): return None namespace = root @@ -68,10 +69,11 @@ def _innermost_namespace(root): return namespace -def _merge_roots(roots): +def _merge_roots(roots: list[nodes.Namespace]) -> Root: result = Root() for root in roots: innermost = _innermost_namespace(root) + assert innermost is not None target = _ensure_namespace( result, Node.PATH_SEPARATOR + innermost.path) for child in innermost.children: @@ -79,7 +81,7 @@ def _merge_roots(roots): return result -def _build_node_tree(definition): +def _build_node_tree(definition: str) -> Root: if not definition: return Root() @@ -89,7 +91,7 @@ def _build_node_tree(definition): except (ParseException, ParseSyntaxException) as err: raise ParsingError(err) - roots = [] + roots: list[nodes.Namespace] = [] for namespace in parsed.namespace: root_namespace, target_namespace = _create_nested_namespaces( @@ -104,7 +106,7 @@ def _build_node_tree(definition): for collection, cls in parsed_items: for item in collection: - target_namespace.insert(cls.create(properties=item, + target_namespace.insert(cls.create(properties=item, # type: ignore[attr-defined] # subclasses (Structure, Enumeration, Archive) define create() definition=definition)) roots.append(root_namespace) @@ -112,9 +114,10 @@ def _build_node_tree(definition): return _merge_roots(roots) -def _append_builtin_structures(root): +def _append_builtin_structures(root: Root) -> None: multivectors = list(root.iterate(Multivector)) for node in multivectors: + assert node.parent is not None and node.parent.parent is not None namespace = _ensure_namespace(root, node.parent.parent.path + "._builtin.multivector") for builtin in node.builtins: found = namespace.get_relative(builtin.name) @@ -124,7 +127,7 @@ def _append_builtin_structures(root): node.insert(BuiltinStructureReference(name=found.path)) -def _append_constant_references(root): +def _append_constant_references(root: Root) -> None: constants = [c for c in root.iterate(nodes.Constant)] constant_references = set(c.target for c in root.iterate(ConstantReference)) archives = [a for a in root.iterate(Archive)] @@ -134,20 +137,21 @@ def _append_constant_references(root): archive.insert(ConstantValueReference(constant.path)) -def _update_field_type_references(root): +def _update_field_type_references(root: Root) -> None: for field in root.iterate(nodes.Field): if field.type: continue reference = field.type_reference if isinstance(reference, EnumerationReference): + enum_node = reference.node # resolves to Enumeration at runtime field.type = EnumType(name=reference.name, basic_type=BasicType( - name=reference.node.type.name, width=reference.node.type.width)) - if reference.width and reference.width != reference.node.type.width: + name=enum_node.type.name, width=enum_node.type.width)) # type: ignore[attr-defined] # .node resolves to Enumeration which has .type + if reference.width and reference.width != enum_node.type.width: # type: ignore[attr-defined] raise InvalidEnumWidthError(enumeration_name=reference.name, - width=reference.node.type.width, provided_width=reference.width) + width=enum_node.type.width, provided_width=reference.width) # type: ignore[attr-defined] -def _compute_structure_sizes(root): +def _compute_structure_sizes(root: Root) -> None: # visit structs in the correct order. Not important right now, # but will make it very easy to add structs as fields in other structs for struct, _ in DfsTraversal(root).dependency_order(): @@ -159,26 +163,30 @@ def _compute_structure_sizes(root): if not isinstance(field, nodes.Field): continue field.offset = offset + assert field.type is not None offset += int(field.type.width) struct.size_in_bits = offset -def _compute_max_resource_size(root): +def _compute_max_resource_size(root: Root) -> None: # visit all explicit references and check how many bits they have available # the provides an upper bound on resource size for reference in root.iterate(ExplicitReference): - if reference.field.node.type.width == 64: + field_node = reference.field.node # resolves to Structure/Field at runtime + if field_node.type.width == 64: # type: ignore[attr-defined] # .node resolves to a node with .type continue - ref_limit = 2 ** reference.field.node.type.width - max_size = reference.destination.node.max_size - reference.destination.node.max_size = ref_limit if max_size is None or max_size > ref_limit else max_size + ref_limit = 2 ** field_node.type.width # type: ignore[attr-defined] + dest_node = reference.destination.node # resolves to a resource node at runtime + max_size = dest_node.max_size # type: ignore[attr-defined] # .node resolves to ResourceBase which has .max_size + dest_node.max_size = ref_limit if max_size is None or max_size > ref_limit else max_size # type: ignore[attr-defined] -def _check_ranges(root): +def _check_ranges(root: Root) -> None: # First check that names are unique for field in root.iterate(nodes.Field): name = field.range if not name: continue - for sibling in field.parent.fields: + assert field.parent is not None + for sibling in field.parent.fields: # type: ignore[attr-defined] # parent is a Structure which has .fields if sibling.name == name: raise InvalidRangeName(name) # Also check that the range is not optional @@ -191,26 +199,29 @@ def _check_ranges(root): and isinstance(reference.parent, ResourceBase) and not isinstance(reference.parent, Vector)): raise InvalidRangeReference(reference.target) -def _check_const_refs(root): +def _check_const_refs(root: Root) -> None: for field in root.iterate(nodes.Field): + assert field.type is not None for ref in field.children_like(ConstantReference): + const_node = ref.node # resolves to Constant at runtime # Check that type matches - if ref.node.type.name != field.type.name: - raise InvalidConstReference(ref.target, ref.node.type.name) + if const_node.type.name != field.type.name: # type: ignore[attr-defined] # .node resolves to Constant which has .type + raise InvalidConstReference(ref.target, const_node.type.name) # type: ignore[attr-defined] # Check that value fits into field - if ref.node.type.bits_required(ref.node.value) > field.type.width: + if const_node.type.bits_required(const_node.value) > field.type.width: # type: ignore[attr-defined] # Constant has .type and .value raise InvalidConstValueReference(ref.target, field.type.width) invalid_values = field.children_like(InvalidValueReference) if len(invalid_values) > 1: raise DuplicateInvalidValueReference(field.name, [x.target for x in invalid_values]) -def _check_explicit_references(root): +def _check_explicit_references(root: Root) -> None: for reference in root.iterate(ExplicitReference): for ref in reference.children_like(StructureReference): + assert reference.parent is not None if not ref.target in [x.target for x in reference.parent.children_like(StructureReference)]: raise InvalidStructInExplicitReference(ref.node.name, reference.parent.name) -def build_ast(definition): +def build_ast(definition: str) -> SyntaxTree: """Build the Flatdata syntax tree from a definition""" root = _build_node_tree(definition=definition) _append_builtin_structures(root) diff --git a/flatdata-generator/flatdata/generator/tree/errors.py b/flatdata-generator/flatdata/generator/tree/errors.py index 22c8d4fa..d0a392db 100644 --- a/flatdata-generator/flatdata/generator/tree/errors.py +++ b/flatdata-generator/flatdata/generator/tree/errors.py @@ -3,13 +3,23 @@ See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +from pyparsing import ParseBaseException + +if TYPE_CHECKING: + from .nodes.node import Node + class FlatdataSyntaxError(RuntimeError): pass class SymbolRedefinition(FlatdataSyntaxError): - def __init__(self, duplicate, existing): + def __init__(self, duplicate: Node, existing: Node) -> None: super().__init__( "Symbol redefined: {duplicate} already exists at {existing}".format( duplicate=duplicate, @@ -17,50 +27,50 @@ def __init__(self, duplicate, existing): class CircularReferencing(FlatdataSyntaxError): - def __init__(self, node, child): + def __init__(self, node: Node, child: Node) -> None: super().__init__( "Circular reference in schema: {node} -> {child}".format( node=node, child=child)) class MissingSymbol(FlatdataSyntaxError): - def __init__(self, name, options, node): + def __init__(self, name: str, options: Iterable[str], node: Node) -> None: message = "Missing symbol \"{name}\" in {path}.".format( name=name, path=node.path) try: import Levenshtein - options = sorted( + ranked = sorted( [(Levenshtein.distance(name, option.split('.')[-1]), option) for option in options], key=lambda x: x[0]) - if options: - message += " Did you mean \"{options}\"?".format( - options=options[0][1]) + if ranked: + message += " Did you mean \"{opt}\"?".format( + opt=ranked[0][1]) except ImportError: pass super().__init__(message) class IncorrectReferenceType(FlatdataSyntaxError): - def __init__(self, name, actual, expected): + def __init__(self, name: str, actual: type, expected: type) -> None: super().__init__( "{name} referring to incorrect type. Expected {expected}, actual {actual}".format( name=name, expected=expected, actual=actual)) class UnexpectedResourceType(FlatdataSyntaxError): - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__( "Unexpected resource type: {name}".format(name=name)) class ParsingError(FlatdataSyntaxError): - def __init__(self, pyparsing_error): + def __init__(self, pyparsing_error: ParseBaseException) -> None: super().__init__( self.create_message(pyparsing_error)) @staticmethod - def create_message(err): + def create_message(err: ParseBaseException) -> str: return "Failed to parse the schema. Details below:\n" \ " {line}\n" \ " {pointer}\n" \ @@ -69,89 +79,89 @@ def create_message(err): class InvalidWidthError(FlatdataSyntaxError): - def __init__(self, width, flatdata_type): + def __init__(self, width: int, flatdata_type: str) -> None: super().__init__( "Bit field of {width}bit width cannot fit in {type}".format(width=width, type=flatdata_type)) class InvalidSignError(FlatdataSyntaxError): - def __init__(self, value): + def __init__(self, value: int) -> None: super().__init__( "Value has wrong sign: {value}".format(value=value)) class DuplicateEnumValueError(FlatdataSyntaxError): - def __init__(self, enumeration_name, value): + def __init__(self, enumeration_name: str, value: int) -> None: super().__init__( "Enumeration {enumeration_name} has duplicate entries for value {value}" .format(enumeration_name=enumeration_name, value=value)) class SparseEnumError(FlatdataSyntaxError): - def __init__(self, enumeration_name, width): + def __init__(self, enumeration_name: str, width: int) -> None: super().__init__( "Enumeration {enumeration_name} has too many undefined values (2^{width}), please restrict bit width, or define more" .format(enumeration_name=enumeration_name, width=width)) class InvalidEnumValueError(FlatdataSyntaxError): - def __init__(self, enumeration_name, value): + def __init__(self, enumeration_name: str, value: int) -> None: super().__init__( "Enumeration {enumeration_name} has not enough bits for value {value}" .format(enumeration_name=enumeration_name, value=value)) class InvalidStructInExplicitReference(FlatdataSyntaxError): - def __init__(self, struct, resource): + def __init__(self, struct: str, resource: str) -> None: super().__init__( "Struct '{struct}' referenced, but not appearing in resource '{resource}'" .format(struct=struct, resource=resource)) class InvalidEnumWidthError(FlatdataSyntaxError): - def __init__(self, enumeration_name, width, provided_width): + def __init__(self, enumeration_name: str, width: int, provided_width: int) -> None: super().__init__( "Enumeration {enumeration_name} has {width} bits, but field specified {provided_width} bits" .format(enumeration_name=enumeration_name, width=width, provided_width=provided_width)) class InvalidConstantValueError(FlatdataSyntaxError): - def __init__(self, name, value): + def __init__(self, name: str, value: int) -> None: super().__init__( "Constant {name} has not enough bits for value {value}" .format(name=name, value=value)) class InvalidConstReference(FlatdataSyntaxError): - def __init__(self, name, type): + def __init__(self, name: str, type: str) -> None: super().__init__( "Referenced constant {name} has wrong type {type}" .format(name=name, type=type)) class InvalidConstValueReference(FlatdataSyntaxError): - def __init__(self, name, bits): + def __init__(self, name: str, bits: int) -> None: super().__init__( "Referenced constant {name} value does not fit into {bits} bits" .format(name=name, bits=bits)) class DuplicateInvalidValueReference(FlatdataSyntaxError): - def __init__(self, name, constants): + def __init__(self, name: str, constants: list[str]) -> None: super().__init__( "Multiple optional annotations {constants} for field {name}" .format(name=name, constants=constants)) class InvalidRangeName(FlatdataSyntaxError): - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__( "@range name {name} is already in use for a field" .format(name=name)) class InvalidRangeReference(FlatdataSyntaxError): - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__( "Structs with @range can only be used in vectors: {name}" .format(name=name)) class OptionalRange(FlatdataSyntaxError): - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__( "@range cannot be combined with @optional, store empty ranges instead: {name}" .format(name=name)) \ No newline at end of file diff --git a/flatdata-generator/flatdata/generator/tree/helpers/basictype.py b/flatdata-generator/flatdata/generator/tree/helpers/basictype.py index 53831131..646763ee 100644 --- a/flatdata-generator/flatdata/generator/tree/helpers/basictype.py +++ b/flatdata-generator/flatdata/generator/tree/helpers/basictype.py @@ -4,7 +4,7 @@ class BasicType: - _WIDTH = { + _WIDTH: dict[str, int] = { "bool": 1, "u8": 8, "i8": 8, @@ -16,7 +16,7 @@ class BasicType: "i64": 64 } - _TYPE_ANNOTATION = { + _TYPE_ANNOTATION: dict[str, str] = { "bool": "", "u8": "", "i8": "", @@ -29,35 +29,33 @@ class BasicType: } @staticmethod - def is_basic_type(name): + def is_basic_type(name: str) -> bool: return name in grammar.BASIC_TYPES - def __init__(self, name, width=None): + def __init__(self, name: str, width: int | None = None) -> None: assert self.is_basic_type(name) self._name = name - self._width = width - if width is None: - self._width = self._WIDTH[self._name] + self._width: int = width if width is not None else self._WIDTH[self._name] if self._width > self._WIDTH[self.name]: raise InvalidWidthError(self._width, self._name) @property - def name(self): + def name(self) -> str: return self._name @property - def width(self): + def width(self) -> int: return self._width @property - def is_signed(self): + def is_signed(self) -> bool: return self._name[0] == 'i' @property - def annotation(self): + def annotation(self) -> str: return self._TYPE_ANNOTATION[self._name] - def bits_required(self, value): + def bits_required(self, value: int) -> int: if self.is_signed: if value >= 0: # sign bit @@ -68,7 +66,7 @@ def bits_required(self, value): return value.bit_length() raise InvalidSignError(value=value) - def value_range(self): + def value_range(self) -> range: if self.is_signed: return range(-(2 ** (self.width - 1)), 2 ** (self.width - 1)) return range(2 ** self.width) diff --git a/flatdata-generator/flatdata/generator/tree/helpers/enumtype.py b/flatdata-generator/flatdata/generator/tree/helpers/enumtype.py index 51ba9e1c..e9586b44 100644 --- a/flatdata-generator/flatdata/generator/tree/helpers/enumtype.py +++ b/flatdata-generator/flatdata/generator/tree/helpers/enumtype.py @@ -2,23 +2,23 @@ class EnumType: - def __init__(self, name, basic_type): + def __init__(self, name: str, basic_type: BasicType) -> None: assert not BasicType.is_basic_type(name), "%r is no valid enum name" % name self._name = name self._type = basic_type @property - def name(self): + def name(self) -> str: return self._name @property - def width(self): + def width(self) -> int: return self._type.width @property - def annotation(self): + def annotation(self) -> str: return self._type.annotation @property - def is_signed(self): + def is_signed(self) -> bool: return self._type.is_signed diff --git a/flatdata-generator/flatdata/generator/tree/nodes/archive.py b/flatdata-generator/flatdata/generator/tree/nodes/archive.py index 88ebe85c..6a5ca176 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/archive.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/archive.py @@ -2,9 +2,12 @@ from flatdata.generator.tree.nodes.node import Node import flatdata.generator.tree.nodes.resources as resources +from pyparsing import ParseResults -def _create_resource(properties): + +def _create_resource(properties: ParseResults) -> resources.ResourceBase: resource_type = properties.type + cls: type[resources.ResourceBase] if 'vector' in resource_type: cls = resources.Vector elif 'multivector' in resource_type: @@ -26,12 +29,12 @@ def _create_resource(properties): class Archive(Node): - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: super().__init__(name=name, properties=properties) #pylint: disable=unused-argument @staticmethod - def create(properties, definition): + def create(properties: ParseResults, definition: str) -> 'Archive': result = Archive(name=properties.name, properties=properties) for resource in properties.resources: result.insert(_create_resource(resource)) @@ -45,9 +48,11 @@ def create(properties, definition): return result @property - def resources(self): - return self.children_like(resources.ResourceBase) + def resources(self) -> list[resources.ResourceBase]: + return self.children_like(resources.ResourceBase) # type: ignore[type-abstract] # isinstance() with ABC is valid @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + assert self._properties is not None + doc = self._properties.doc + return str(doc) if doc is not None else "" diff --git a/flatdata-generator/flatdata/generator/tree/nodes/explicit_reference.py b/flatdata-generator/flatdata/generator/tree/nodes/explicit_reference.py index 7f21d232..480bd610 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/explicit_reference.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/explicit_reference.py @@ -1,13 +1,15 @@ from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import ResourceReference, FieldReference, StructureReference +from pyparsing import ParseResults + class ExplicitReference(Node): - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: super().__init__(name=name, properties=properties) @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'ExplicitReference': destination = properties.destination field = Node.jointwo(properties.source_type, properties.source_field) result = ExplicitReference( @@ -22,19 +24,19 @@ def create(properties): @property - def destination(self): + def destination(self) -> ResourceReference: result = self.children_like(ResourceReference) assert len(result) == 1 return result[0] @property - def field(self): + def field(self) -> FieldReference: result = self.children_like(FieldReference) assert len(result) == 1 return result[0] @property - def structure(self): + def structure(self) -> StructureReference: result = self.children_like(StructureReference) assert len(result) == 1 return result[0] diff --git a/flatdata-generator/flatdata/generator/tree/nodes/node.py b/flatdata-generator/flatdata/generator/tree/nodes/node.py index f324d8ad..7263e94e 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/node.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/node.py @@ -2,11 +2,19 @@ Copyright (c) 2017 HERE Europe B.V. See the LICENSE file in the root of this project for license details. ''' -from copy import copy +from __future__ import annotations from collections import OrderedDict +from collections.abc import Iterator +from copy import copy +from typing import TypeVar, overload + +from pyparsing import ParseResults + from flatdata.generator.tree.errors import SymbolRedefinition +_T = TypeVar('_T', bound='Node') + class Node: """ @@ -18,65 +26,67 @@ class Node: PATH_SEPARATOR = '.' @staticmethod - def splitpath(path): + def splitpath(path: str) -> list[str]: """ Splits node path. """ return path.split(Node.PATH_SEPARATOR) @staticmethod - def jointwo(path, other): + def jointwo(path: str, other: str) -> str: """ Joins two node paths. """ return Node.PATH_SEPARATOR.join([path, other]) - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: assert self.PATH_SEPARATOR not in name assert name self._name = name self._properties = properties - self._children = OrderedDict() - self._parent = None + self._children: OrderedDict[str, Node] = OrderedDict() + self._parent: Node | None = None @property - def name(self): + def name(self) -> str: """ Returns the name of the node. """ return self._name @property - def children(self): + def children(self) -> list[Node]: """ Returns a list of children nodes. """ return list(self._children.values()) - def children_like(self, T): + def children_like(self, T: type[_T]) -> list[_T]: """ Returns a list of children nodes of a given type, if any. """ return [c for c in list(self._children.values()) if isinstance(c, T)] @property - def parent(self): + def parent(self) -> Node | None: """ Returns node's parent. """ return self._parent - def first_parent_like(self, T): + def first_parent_like(self, T: type[_T]) -> _T | None: """ Returns first available parent of a given type or None if none found. """ result = self while result.parent is not None and not isinstance(result.parent, T): result = result.parent - return result.parent + if isinstance(result.parent, T): + return result.parent + return None @property - def path(self): + def path(self) -> str: """ Returns nodes' path in a tree. """ @@ -84,13 +94,13 @@ def path(self): return self.name return Node.jointwo(self._parent.path, self.name) - def path_with(self, separator='_'): + def path_with(self, separator: str = '_') -> str: """ Returns nodes' path in a tree with a given characters as separator. """ return self.path.replace(self.PATH_SEPARATOR, separator) - def path_depth(self): + def path_depth(self) -> int: """ Returns nodes' depths in a tree """ @@ -98,7 +108,7 @@ def path_depth(self): return 0 return 1 + self._parent.path_depth() - def set_name(self, value): + def set_name(self, value: str) -> None: """ Sets the new name for the node. New name should not clash with any of siblings' names. :raises RuntimeError in case name is already in use @@ -107,7 +117,7 @@ def set_name(self, value): if self.name == value: return - if self.parent is not None and value in self._parent._children: + if self.parent is not None and value in self._parent._children: # type: ignore[union-attr] # self.parent property returns self._parent; mypy can't narrow backing field through property raise RuntimeError( "Cannot rename the node, name {value} is already in use".format(value=value)) @@ -115,7 +125,7 @@ def set_name(self, value): if self.parent is not None: self.parent.reindex() - def find(self, path): + def find(self, path: str) -> Node: """ Finds child node recursively by its path. :param path: Full path to the node up to the node search is started. @@ -137,7 +147,7 @@ def find(self, path): path=path, options=tuple(self.symbols()))) return target - def get(self, path, default=None): + def get(self, path: str, default: Node | None = None) -> Node | None: """ Returns the node like find() does, but allows default value specification. """ @@ -147,13 +157,13 @@ def get(self, path, default=None): return default return result - def find_relative(self, path): + def find_relative(self, path: str) -> Node: """ Finds a child node recursively via its path relative to the current node. """ return self.find(Node.jointwo(self.name, path)) - def find_last(self, path): + def find_last(self, path: str) -> Node | None: """ Finds a last node existing in the path. If no such node found, None is returned. """ @@ -172,23 +182,23 @@ def find_last(self, path): return target return target - def get_relative(self, path, default=None): + def get_relative(self, path: str, default: Node | None = None) -> Node | None: """ Finds a child node recursively via its path relative to the current node. """ return self.get(Node.jointwo(self.name, path), default) @property - def root(self): + def root(self) -> Node: """ Returns the root node of the tree """ result = self while result.parent is not None: - result = result._parent + result = result._parent # type: ignore[assignment] # guarded by while loop; mypy can't narrow backing field through property return result - def extract_subtree(self): + def extract_subtree(self) -> Node: """ Extract the subtree of node (some nodes are copied) Also copies the path to the root of the tree @@ -203,7 +213,7 @@ def extract_subtree(self): new_root = parent return new_root - def insert(self, *nodes): + def insert(self, *nodes: Node) -> Node: """ Inserts node into the tree. :raises: SymbolRedefinition in case node with the same name exists @@ -223,23 +233,27 @@ def insert(self, *nodes): node._parent = self return self - def erase(self, key): + def erase(self, key: str) -> None: """ Erase node with a given name from the tree. """ node = self._children.pop(key) node._parent = None - def reindex(self): + def reindex(self) -> None: """ Reindex the node. Produces no side effects if called externally. """ - new_children = OrderedDict() + new_children: OrderedDict[str, Node] = OrderedDict() for _key, node in self._children.items(): new_children[node.name] = node self._children = new_children - def iterate(self, node_type=None): + @overload + def iterate(self, node_type: type[_T]) -> Iterator[_T]: ... + @overload + def iterate(self, node_type: None = ...) -> Iterator[Node]: ... + def iterate(self, node_type: type | None = None) -> Iterator[Node]: """ Iterates the nodes in pre-order traversal fashion """ @@ -249,7 +263,7 @@ def iterate(self, node_type=None): for node in child.iterate(node_type): yield node - def parents(self): + def parents(self) -> Iterator[Node]: """ Returns all node's parents up to the root of the tree. """ @@ -258,7 +272,7 @@ def parents(self): yield par._parent par = par._parent - def detach(self): + def detach(self) -> Node: """ Detaches the node from its parent. """ @@ -268,12 +282,12 @@ def detach(self): self._parent = None return self - def symbols(self, include_types=False): + def symbols(self, include_types: bool = False) -> set[str] | dict[str, type]: """ Returns paths of all nodes available in the tree, optionally with node types. :param include_types: return types along with paths """ - result = dict() + result: dict[str, type] = dict() for node in self.iterate(): path = node.path if path: @@ -282,5 +296,5 @@ def symbols(self, include_types=False): return set(result.keys()) return result - def __repr__(self): + def __repr__(self) -> str: return "{type}{{{path}}}".format(type=type(self).__name__, path=self.path) diff --git a/flatdata-generator/flatdata/generator/tree/nodes/references.py b/flatdata-generator/flatdata/generator/tree/nodes/references.py index 4960900f..d7462042 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/references.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/references.py @@ -10,33 +10,33 @@ class Reference(Node): References participate in dependency resolution. """ - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__(name=Reference._referencify(name)) @property - def target(self): + def target(self) -> str: return Reference._dereferencify(self.name) - def update_reference(self, new_value): + def update_reference(self, new_value: str) -> None: assert new_value.endswith(self.target), \ "References can only be updated during resolution for the same symbol: %s -> %s" % \ (self.target, new_value) self.set_name(Reference._referencify(new_value)) @property - def node(self): + def node(self) -> Node: return self.root.find(self.target) @property - def is_qualified(self): + def is_qualified(self) -> bool: return self.name[:2] == "@@" @staticmethod - def _referencify(name): + def _referencify(name: str) -> str: return "@" + name.replace(".", "@") @staticmethod - def _dereferencify(name): + def _dereferencify(name: str) -> str: return name[1:].replace("@", ".") @@ -91,10 +91,10 @@ class EnumerationReference(TypeReference): EnumerationReference depict: - Field Type -> Enumeration """ - def __init__(self, name, width=None): + def __init__(self, name: str, width: int | None = None) -> None: super().__init__(name) self._width = width @property - def width(self): + def width(self) -> int | None: return self._width diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/archive.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/archive.py index 58477e56..9da6cb17 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/archive.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/archive.py @@ -1,23 +1,27 @@ +from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import ArchiveReference from .base import ResourceBase +from pyparsing import ParseResults + class Archive(ResourceBase): - def __init__(self, name, properties=None, target=None): + def __init__(self, name: str, properties: ParseResults | None = None, target: str | None = None) -> None: super().__init__(name=name, properties=properties) self._target = target @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'Archive': return Archive(name=properties.name, properties=properties, target=properties.type.archive.name) @property - def target(self): + def target(self) -> ArchiveReference: targets = self.children_like(ArchiveReference) assert len(targets) == 1 return targets[0] - def create_references(self): + def create_references(self) -> list[Node]: + assert self._target is not None return [ArchiveReference(name=self._target)] diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/base.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/base.py index e8e92ff5..b45475c4 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/base.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/base.py @@ -1,16 +1,18 @@ from abc import ABC, abstractmethod +from pyparsing import ParseResults + from flatdata.generator.tree.nodes.explicit_reference import ExplicitReference from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import BuiltinStructureReference, StructureReference class ResourceBase(Node, ABC): - def __init__(self, name, properties=None,): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: super().__init__(name=name, properties=properties) - self._decorations = [] - self._max_size = None + self._decorations: list[ParseResults] = [] + self._max_size: int | None = None if properties is not None and 'decorations' in properties: self._decorations = properties.decorations for decoration in self._decorations: @@ -18,33 +20,38 @@ def __init__(self, name, properties=None,): self.insert(ExplicitReference.create(properties=decoration.explicit_reference)) @abstractmethod - def create_references(self): + def create_references(self) -> list[Node]: pass @property - def optional(self): + def optional(self) -> bool: return any(['optional' in d for d in self.decorations]) @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + assert self._properties is not None + doc = self._properties.doc + return str(doc) if doc is not None else "" @property - def decorations(self): + def decorations(self) -> list[ParseResults]: return self._decorations @property - def explicit_references(self): + def explicit_references(self) -> list[ExplicitReference]: return self.children_like(ExplicitReference) @property - def referenced_structures(self): - return self.children_like(BuiltinStructureReference) + self.children_like(StructureReference) + def referenced_structures(self) -> list[BuiltinStructureReference | StructureReference]: + return [ + *self.children_like(BuiltinStructureReference), + *self.children_like(StructureReference), + ] @property - def max_size(self): + def max_size(self) -> int | None: return self._max_size @max_size.setter - def max_size(self, value): + def max_size(self, value: int | None) -> None: self._max_size = value diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/bound_resource.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/bound_resource.py index e29d8f9f..2267dcbd 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/bound_resource.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/bound_resource.py @@ -1,22 +1,26 @@ -from flatdata.generator.tree.nodes.references import ResourceReference +from flatdata.generator.tree.nodes.node import Node +from flatdata.generator.tree.nodes.references import BuiltinStructureReference, ResourceReference, StructureReference from .base import ResourceBase +from pyparsing import ParseResults + class BoundResource(ResourceBase): - def __init__(self, name, properties=None, resources=None): + def __init__(self, name: str, properties: ParseResults | None = None, resources: list[str] | None = None) -> None: super().__init__(name=name, properties=properties) self._resources = resources @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'BoundResource': return BoundResource(name=properties.name, properties=properties, resources=[r for r in properties.resources]) - def create_references(self): + def create_references(self) -> list[Node]: + assert self._resources is not None return [ResourceReference(name=r) for r in self._resources] @property - def referenced_structures(self): + def referenced_structures(self) -> list[BuiltinStructureReference | StructureReference]: return [s for r in self.children_like(ResourceReference) for s in - r.node.referenced_structures] + r.node.referenced_structures] # type: ignore[attr-defined] # .node resolves to a resource type with referenced_structures diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/instance.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/instance.py index 5e6c7020..64b7cb0f 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/instance.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/instance.py @@ -1,21 +1,25 @@ +from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import StructureReference from .base import ResourceBase +from pyparsing import ParseResults + class Instance(ResourceBase): - def __init__(self, name, properties=None, resource_type=None): + def __init__(self, name: str, properties: ParseResults | None = None, resource_type: str | None = None) -> None: super().__init__(name=name, properties=properties) self._type = resource_type @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'Instance': return Instance(name=properties.name, properties=properties, resource_type=properties.type.object.type) - def create_references(self): + def create_references(self) -> list[Node]: + assert self._type is not None return [StructureReference(name=self._type)] @property - def type(self): + def type(self) -> str | None: return self._type diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/multivector.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/multivector.py index cf4f1138..2857a4a0 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/multivector.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/multivector.py @@ -1,44 +1,49 @@ +from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import StructureReference, BuiltinStructureReference from flatdata.generator.tree.nodes.trivial import Structure from .base import ResourceBase +from typing import Any + +from pyparsing import ParseResults + class Multivector(ResourceBase): - def __init__(self, name, properties=None, types=None, width=None): + def __init__(self, name: str, properties: ParseResults | None = None, types: list[str] | None = None, width: int | None = None) -> None: super().__init__(name=name, properties=properties) - self._types = [] + self._types: list[str] = [] if types is not None: self._types = types self._width = width @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'Multivector': return Multivector(name=properties.name, properties=properties, types=[t for t in properties.type.multivector.type], width=int(properties.type.multivector.width)) - def create_references(self): + def create_references(self) -> list[Node]: return [StructureReference(name=t) for t in self._types] @property - def types(self): + def types(self) -> list[str]: return self._types @property - def width(self): + def width(self) -> int | None: return self._width @property - def index_reference(self): + def index_reference(self) -> BuiltinStructureReference: builtin_refs = [node for node in self.children if isinstance(node, BuiltinStructureReference)] assert len(builtin_refs) == 1, "multivector has exactly one builtin ref which is its index" return builtin_refs[0] @property - def builtins(self): - class MemberDict(dict): - def __getattr__(self, attr): + def builtins(self) -> list[Structure]: + class MemberDict(dict[str, Any]): + def __getattr__(self, attr: str) -> Any: return self.get(attr) decorations = [MemberDict({"range" : MemberDict({"name":"range"})})] field = MemberDict({"decorations":decorations, "name":"value", "width":self._width, "type":"u64"}) @@ -47,5 +52,5 @@ def __getattr__(self, attr): "schema":"struct IndexType%s { value : u64 : %s; }" % (self._width, self._width), "doc":"/** Builtin type to for MultiVector index */", "fields":[field]}) - index_type = Structure.create(properties=properties, definition="") + index_type = Structure.create(properties=properties, definition="") # type: ignore[arg-type] # MemberDict duck-types ParseResults return [index_type] diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/rawdata.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/rawdata.py index 73162855..b473de01 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/rawdata.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/rawdata.py @@ -1,13 +1,16 @@ +from flatdata.generator.tree.nodes.node import Node from .base import ResourceBase +from pyparsing import ParseResults + class RawData(ResourceBase): - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: super(RawData, self).__init__(name=name, properties=properties) @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'RawData': return RawData(name=properties.name, properties=properties) - def create_references(self): + def create_references(self) -> list[Node]: return [] diff --git a/flatdata-generator/flatdata/generator/tree/nodes/resources/vector.py b/flatdata-generator/flatdata/generator/tree/nodes/resources/vector.py index 79e40dcf..09b79788 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/resources/vector.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/resources/vector.py @@ -1,17 +1,21 @@ +from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import StructureReference from .base import ResourceBase +from pyparsing import ParseResults + class Vector(ResourceBase): - def __init__(self, name, properties=None, type=None): + def __init__(self, name: str, properties: ParseResults | None = None, type: str | None = None) -> None: super().__init__(name=name, properties=properties) self._type = type @staticmethod - def create(properties): + def create(properties: ParseResults) -> 'Vector': return Vector(name=properties.name, properties=properties, type=properties.type.vector.type) - def create_references(self): + def create_references(self) -> list[Node]: + assert self._type is not None return [StructureReference(name=self._type)] diff --git a/flatdata-generator/flatdata/generator/tree/nodes/root.py b/flatdata-generator/flatdata/generator/tree/nodes/root.py index b46d1aa1..2c21ac90 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/root.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/root.py @@ -2,6 +2,6 @@ class Root(Node): - def __init__(self): + def __init__(self) -> None: super().__init__(name="__root_node_name_is_empty__", properties=None) self._name = "" diff --git a/flatdata-generator/flatdata/generator/tree/nodes/trivial/constant.py b/flatdata-generator/flatdata/generator/tree/nodes/trivial/constant.py index 1f213a82..38ad45d7 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/trivial/constant.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/trivial/constant.py @@ -2,8 +2,11 @@ from flatdata.generator.tree.helpers.basictype import BasicType from flatdata.generator.tree.errors import InvalidConstantValueError +from pyparsing import ParseResults + + class Constant(Node): - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: super().__init__(name=name, properties=properties) if properties: self._value = int(properties.value, 0) @@ -11,18 +14,21 @@ def __init__(self, name, properties=None): raise InvalidConstantValueError(name=name, value=self.value) @staticmethod - def create(properties, definition): + def create(properties: ParseResults, definition: str) -> 'Constant': result = Constant(name=properties.name, properties=properties) return result @property - def type(self): + def type(self) -> BasicType: + assert self._properties is not None return BasicType(self._properties.type) @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + assert self._properties is not None + doc = self._properties.doc + return str(doc) if doc is not None else "" @property - def value(self): + def value(self) -> int: return self._value diff --git a/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration.py b/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration.py index 5f714b0b..f0d88316 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration.py @@ -3,24 +3,26 @@ from flatdata.generator.tree.nodes.node import Node from .enumeration_value import EnumerationValue +from pyparsing import ParseResults + class Enumeration(Node): - def __init__(self, name, properties=None, type=None, width=None): + def __init__(self, name: str, properties: ParseResults | None = None, type: str | None = None, width: int | None = None) -> None: super().__init__(name=name, properties=properties) - self._type = type + self._type: BasicType | None = None - if self._type is not None: - self._type = BasicType(name=self._type, width=width) + if type is not None: + self._type = BasicType(name=type, width=width) @staticmethod - def create(properties, definition): + def create(properties: ParseResults, definition: str) -> 'Enumeration': width = None if properties.width: width = int(properties.width) result = Enumeration(name=properties.name, properties=properties, type=properties.type, width=width) current_assigned_value = 0 - unique_values = set() + unique_values: set[int] = set() for value in properties.enum_values: if value.constant: current_assigned_value = int(value.constant, 0) @@ -32,30 +34,31 @@ def create(properties, definition): current_assigned_value += 1 # we do not want to genarate too many (exponential) values, so restrict to multiples of input size - if len(properties.enum_values) * 2 + 256 < 2 ** result.type.width: - raise SparseEnumError(enumeration_name=result._name, width=result.type.width) + assert result._type is not None + if len(properties.enum_values) * 2 + 256 < 2 ** result._type.width: + raise SparseEnumError(enumeration_name=result._name, width=result._type.width) - for missing_value in result.type.value_range(): + for missing_value in result._type.value_range(): if not missing_value in unique_values: value_node = EnumerationValue(name="UNKNOWN_VALUE_" + str(missing_value).replace("-", "MINUS_"), value=missing_value, auto_generated=True) result.insert(value_node) for value in unique_values: - bits_required = result.type.bits_required(value=value) - if bits_required > result.type.width: + bits_required = result._type.bits_required(value=value) + if bits_required > result._type.width: raise InvalidEnumValueError(enumeration_name=result._name, value=value) - result._bits_required = bits_required - return result @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + assert self._properties is not None + doc = self._properties.doc + return str(doc) if doc is not None else "" @property - def type(self): + def type(self) -> BasicType | None: return self._type @property - def values(self): + def values(self) -> list[EnumerationValue]: return self.children_like(EnumerationValue) diff --git a/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration_value.py b/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration_value.py index ac6ea24d..7c99c779 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration_value.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/trivial/enumeration_value.py @@ -1,24 +1,30 @@ from flatdata.generator.tree.nodes.node import Node +from pyparsing import ParseResults + + class EnumerationValue(Node): - def __init__(self, name, value, auto_generated, properties=None): + def __init__(self, name: str, value: int, auto_generated: bool, properties: ParseResults | None = None) -> None: super().__init__(name=name, properties=properties) self._value = value self._auto_generated = auto_generated @staticmethod - def create(properties, value): + def create(properties: ParseResults, value: int) -> 'EnumerationValue': result = EnumerationValue(name=properties.name, properties=properties, value=value, auto_generated=False) return result @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + if self._properties is None: + return "" + doc = self._properties.doc + return str(doc) if doc is not None else "" @property - def value(self): + def value(self) -> int: return self._value @property - def auto_generated(self): + def auto_generated(self) -> bool: return self._auto_generated diff --git a/flatdata-generator/flatdata/generator/tree/nodes/trivial/field.py b/flatdata-generator/flatdata/generator/tree/nodes/trivial/field.py index 2b23c85d..5af85fcb 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/trivial/field.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/trivial/field.py @@ -1,14 +1,19 @@ from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import EnumerationReference, ConstantValueReference, InvalidValueReference from flatdata.generator.tree.helpers.basictype import BasicType +from flatdata.generator.tree.helpers.enumtype import EnumType + +from pyparsing import ParseResults class Field(Node): - def __init__(self, name, properties=None, type=None, offset=None, width=None): + def __init__(self, name: str, properties: ParseResults | None = None, type: str | None = None, offset: int | None = None, width: int | None = None) -> None: super().__init__(name=name, properties=properties) self._offset = offset self._width = width - self._decorations = list() + self._type_reference: EnumerationReference | None = None + self._type: BasicType | EnumType | None = None + self._decorations: list[ParseResults] = list() if properties and 'decorations' in properties: self._decorations = properties.decorations @@ -27,7 +32,7 @@ def __init__(self, name, properties=None, type=None, offset=None, width=None): self._type = BasicType(name=type, width=self._width) @staticmethod - def create(properties, offset=None): + def create(properties: ParseResults, offset: int | None = None) -> 'Field': width = None if properties.width: width = int(properties.width) @@ -38,46 +43,48 @@ def create(properties, offset=None): width=width) @property - def range(self): + def range(self) -> str | None: for d in self.decorations: if "range" in d: - return d.range.name + return str(d.range.name) return None @property - def const_value_refs(self): + def const_value_refs(self) -> list[ConstantValueReference]: return self.children_like(ConstantValueReference) @property - def invalid_value(self): + def invalid_value(self) -> InvalidValueReference | None: for x in self.children_like(InvalidValueReference): return x return None @property - def decorations(self): + def decorations(self) -> list[ParseResults]: return self._decorations @property - def type(self): + def type(self) -> BasicType | EnumType | None: return self._type @type.setter - def type(self, value): + def type(self, value: BasicType | EnumType | None) -> None: self._type = value @property - def type_reference(self): + def type_reference(self) -> EnumerationReference | None: return self._type_reference @property - def offset(self): + def offset(self) -> int | None: return self._offset @offset.setter - def offset(self, value): + def offset(self, value: int) -> None: self._offset = value @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + assert self._properties is not None + doc = self._properties.doc + return str(doc) if doc is not None else "" diff --git a/flatdata-generator/flatdata/generator/tree/nodes/trivial/namespace.py b/flatdata-generator/flatdata/generator/tree/nodes/trivial/namespace.py index 0a53563b..86eda200 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/trivial/namespace.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/trivial/namespace.py @@ -3,9 +3,11 @@ See the LICENSE file in the root of this project for license details. ''' +from pyparsing import ParseResults + from flatdata.generator.tree.nodes.node import Node class Namespace(Node): - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: super().__init__(name=name, properties=properties) diff --git a/flatdata-generator/flatdata/generator/tree/nodes/trivial/structure.py b/flatdata-generator/flatdata/generator/tree/nodes/trivial/structure.py index 9623e3ff..782b53c6 100644 --- a/flatdata-generator/flatdata/generator/tree/nodes/trivial/structure.py +++ b/flatdata-generator/flatdata/generator/tree/nodes/trivial/structure.py @@ -1,9 +1,11 @@ from flatdata.generator.tree.nodes.node import Node from .field import Field +from pyparsing import ParseResults + class Structure(Node): - def __init__(self, name, properties=None): + def __init__(self, name: str, properties: ParseResults | None = None) -> None: """ Use to instantiate empty structure. No special properties are evaluated. @@ -14,7 +16,7 @@ def __init__(self, name, properties=None): super().__init__(name=name, properties=properties) @staticmethod - def create(properties, definition): + def create(properties: ParseResults, definition: str) -> 'Structure': result = Structure(name=properties.name, properties=properties) for field in properties.fields: @@ -22,25 +24,27 @@ def create(properties, definition): return result @property - def has_range(self): + def has_range(self) -> bool: return any(f for f in self.fields if f.range) @property - def doc(self): - return self._properties.doc + def doc(self) -> str: + assert self._properties is not None + doc = self._properties.doc + return str(doc) if doc is not None else "" @property - def size_in_bits(self): + def size_in_bits(self) -> int: return self._size_in_bits @size_in_bits.setter - def size_in_bits(self, value): + def size_in_bits(self, value: int) -> None: self._size_in_bits = value @property - def size_in_bytes(self): + def size_in_bytes(self) -> int: return (self._size_in_bits + 7) // 8 @property - def fields(self): + def fields(self) -> list[Field]: return self.children_like(Field) diff --git a/flatdata-generator/flatdata/generator/tree/resolver.py b/flatdata-generator/flatdata/generator/tree/resolver.py index 974ad779..9a9fe25f 100644 --- a/flatdata-generator/flatdata/generator/tree/resolver.py +++ b/flatdata-generator/flatdata/generator/tree/resolver.py @@ -1,17 +1,21 @@ +from collections.abc import Iterable + import flatdata.generator.tree.nodes.references as refs import flatdata.generator.tree.nodes.trivial as nodes import flatdata.generator.tree.nodes.resources as resources from flatdata.generator.tree.nodes.archive import Archive +from flatdata.generator.tree.nodes.node import Node from . import errors + _RESOLVED_BASE_TYPES = (refs.TypeReference, refs.RuntimeReference, refs.ConstantReference) -def _filter_references(iterable): +def _filter_references(iterable: Iterable[str]) -> list[str]: return [x for x in iterable if '@' not in x] -def _resolve_in_parent_scope(ref): +def _resolve_in_parent_scope(ref: refs.Reference) -> bool: if ref.parent is None or ref.parent.parent is None: return False scope = ref.parent @@ -25,7 +29,7 @@ def _resolve_in_parent_scope(ref): return False -def _resolve_in_parent_namespace(ref): +def _resolve_in_parent_namespace(ref: refs.Reference) -> bool: namespace = ref.first_parent_like(nodes.Namespace) assert namespace, "No namespace found in the tree. Unable to do name resolution" symbol = namespace.get_relative(ref.target) @@ -35,7 +39,7 @@ def _resolve_in_parent_namespace(ref): return True -def _resolve_as_fully_qualified_reference(ref): +def _resolve_as_fully_qualified_reference(ref: refs.Reference) -> bool: root = ref.root try: root.find(ref.target) @@ -44,7 +48,7 @@ def _resolve_as_fully_qualified_reference(ref): return True -def _validate_target_type(root, ref): +def _validate_target_type(root: Node, ref: refs.Reference) -> None: expected = { refs.StructureReference: nodes.Structure, refs.ArchiveReference: Archive, @@ -61,10 +65,11 @@ def _validate_target_type(root, ref): raise errors.IncorrectReferenceType(ref.name, type(target), expected) -def resolve_references(tree): +def resolve_references(tree: Node) -> None: for node in tree.root.iterate(): assert type(node) not in _RESOLVED_BASE_TYPES, "Base reference types should not be used directly" if any([issubclass(type(node), t) for t in _RESOLVED_BASE_TYPES]): + assert isinstance(node, refs.Reference) if node.is_qualified: resolved = _resolve_as_fully_qualified_reference(node) else: diff --git a/flatdata-generator/flatdata/generator/tree/syntax_tree.py b/flatdata-generator/flatdata/generator/tree/syntax_tree.py index 51172057..66938db7 100644 --- a/flatdata-generator/flatdata/generator/tree/syntax_tree.py +++ b/flatdata-generator/flatdata/generator/tree/syntax_tree.py @@ -2,12 +2,17 @@ Copyright (c) 2017 HERE Europe B.V. See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + +from collections.abc import Iterator, Sequence from flatdata.generator.tree.nodes.references import TypeReference from flatdata.generator.tree.nodes.trivial import Namespace from flatdata.generator.tree.nodes.resources import ResourceBase, BoundResource from flatdata.generator.tree.nodes.archive import Archive +from flatdata.generator.tree.nodes.node import Node from flatdata.generator.tree.nodes.references import ResourceReference +from flatdata.generator.tree.nodes.root import Root class SyntaxTree: """ @@ -17,51 +22,51 @@ class SyntaxTree: - Schema resolution """ - def __init__(self, root): + def __init__(self, root: Root | Node) -> None: self._root = root @property - def root(self): + def root(self) -> Root | Node: """ Returns root node of the tree """ return self._root - def symbols(self, include_types=False): + def symbols(self, include_types: bool = False) -> set[str] | dict[str, type]: """ Returns tree symbols """ return self._root.symbols(include_types=include_types) - def find(self, path): + def find(self, path: str) -> Node: """ Returns node at path """ return self._root.find(path) - def get(self, path, default=None): + def get(self, path: str, default: Node | None = None) -> Node | None: """ Returns the node like find() does, but allows default value specification. """ return self._root.get(path, default) - def subtree(self, path): + def subtree(self, path: str) -> SyntaxTree: """ Returns subtree of the given tree as a SyntaxTree """ return SyntaxTree(self._root.find(path)) - def __repr__(self): - result = [] + def __repr__(self) -> str: + result: list[str] = [] for item in self._root.iterate(): result.append(" " * sum(1 for _ in item.parents()) + str(item)) return '\n'.join(result) @staticmethod - def dependent_types(node): - def _unique(sequence): - seen = set() - return [item for item in sequence if not (item in seen or seen.add(item))] + def dependent_types(node: Node) -> list[Node]: + def _unique(sequence: list[Node]) -> list[Node]: + seen: set[Node] = set() + return [item for item in sequence if not (item in seen or seen.add(item))] # type: ignore[func-returns-value] # intentional idiom: set.add() returns None (falsy) to deduplicate while preserving order nodes = _unique([r.node for r in node.iterate(TypeReference)]) for dependent_type in [SyntaxTree.dependent_types(n) for n in nodes]: @@ -69,30 +74,30 @@ def _unique(sequence): return _unique(nodes) @staticmethod - def schema(node): + def schema(node: Node) -> str: from ..generators.flatdata import FlatdataGenerator generator = FlatdataGenerator() # extract subtree from syntax tree subtree = node.extract_subtree() - return generator.render(subtree) + return str(generator.render(SyntaxTree(subtree))) @staticmethod - def namespaces(node): + def namespaces(node: Node) -> Iterator[Node]: """ Returns parent namespace nodes for the given node in order of nesting starting with root. """ return reversed([p for p in node.parents() if isinstance(p, Namespace)]) @staticmethod - def namespace_path(node, sep="."): + def namespace_path(node: Node, sep: str = ".") -> str: """ Returns namespace-qualified path for a given node with a given separator """ return sep.join([n.name for n in SyntaxTree.namespaces(node)]) @staticmethod - def is_bound_implicitly(node): + def is_bound_implicitly(node: Node) -> bool: if not isinstance(node, ResourceBase) or node.parent is None: return False @@ -105,21 +110,22 @@ def is_bound_implicitly(node): return False @staticmethod - def binding_resources(node): + def binding_resources(node: Node) -> list[BoundResource]: if not isinstance(node, ResourceBase) or node.parent is None: return [] assert isinstance(node.parent, Archive) archive = node.parent bound_resources = archive.children_like(BoundResource) - result = [] + result: list[BoundResource] = [] for resource in bound_resources: if any([c.node == node for c in resource.children_like(ResourceReference)]): result.append(resource) return result @staticmethod - def binding_resources_or_self(node): + def binding_resources_or_self(node: Node) -> Sequence[ResourceBase | BoundResource]: if SyntaxTree.is_bound_implicitly(node): return SyntaxTree.binding_resources(node) + assert isinstance(node, ResourceBase) return [node] diff --git a/flatdata-generator/flatdata/generator/tree/traversal.py b/flatdata-generator/flatdata/generator/tree/traversal.py index cdca4578..4671ce20 100644 --- a/flatdata-generator/flatdata/generator/tree/traversal.py +++ b/flatdata-generator/flatdata/generator/tree/traversal.py @@ -1,35 +1,51 @@ from abc import ABCMeta, abstractmethod -from collections import namedtuple, deque +from collections import deque +from collections.abc import Iterator +from typing import NamedTuple + from .errors import CircularReferencing +from .nodes.node import Node from .nodes.references import Reference, TypeReference +from .syntax_tree import SyntaxTree + + +class BfsAttr(NamedTuple): + distance: int + + +class DfsAttr(NamedTuple): + pass + + +class _DfsState(NamedTuple): + node: Node + processed: bool class _Traversal(metaclass=ABCMeta): - def __init__(self, tree): - self._root = tree.root + def __init__(self, tree: SyntaxTree | Node) -> None: + self._root: Node = tree.root @staticmethod - def children(node): + def children(node: Node) -> list[Node]: return [c for c in node.children if not isinstance(c, Reference)] + \ [r.node for r in node.children if isinstance(r, TypeReference)] @abstractmethod - def iterate(self): + def iterate(self) -> Iterator[tuple[Node, BfsAttr | DfsAttr]]: raise NotImplementedError("Derived classes must implement iterate()") class BfsTraversal(_Traversal): - def iterate(self): - Attr = namedtuple("Attr", ["distance"]) - - queue = deque([(self._root, 0)]) - processed = set() + def iterate(self) -> Iterator[tuple[Node, BfsAttr]]: + queue: deque[tuple[Node, int]] = deque([(self._root, 0)]) + processed: set[Node] = set() while queue: node, distance = queue.popleft() if node in processed: continue - yield node, Attr(distance=distance) + yield node, BfsAttr(distance=distance) # We want to preserve original order if possible, and traverse # children in *original* order: That way they are popped in order for child in _Traversal.children(node): @@ -42,12 +58,10 @@ class DfsTraversal(_Traversal): _PROCESS_NODE_EARLY = 0 _PROCESS_NODE_LATE = 1 - def _iterate(self): - State = namedtuple("State", ["node", "processed"]) - Attr = namedtuple("Attr", []) - stack = [State(node=self._root, processed=False)] - discovered = set() - processed = set() + def _iterate(self) -> Iterator[tuple[int, Node, DfsAttr]]: + stack = [_DfsState(node=self._root, processed=False)] + discovered: set[Node] = set() + processed: set[Node] = set() while stack: node, is_processed = stack.pop() @@ -55,27 +69,27 @@ def _iterate(self): if node in processed: continue - yield self._PROCESS_NODE_EARLY, node, Attr() + yield self._PROCESS_NODE_EARLY, node, DfsAttr() discovered.add(node) - stack.append(State(node=node, processed=True)) + stack.append(_DfsState(node=node, processed=True)) # We want to preserve original order if possible, and traverse # children in *reverse* order: That way they are popped in order for child in reversed(_Traversal.children(node)): if child not in discovered and child not in processed: - stack.append(State(node=child, processed=False)) + stack.append(_DfsState(node=child, processed=False)) elif child not in processed: raise CircularReferencing(node, child) else: - yield self._PROCESS_NODE_LATE, node, Attr() + yield self._PROCESS_NODE_LATE, node, DfsAttr() processed.add(node) - def iterate(self): + def iterate(self) -> Iterator[tuple[Node, DfsAttr]]: for event, node, attr in self._iterate(): if event == self._PROCESS_NODE_EARLY: yield node, attr - def dependency_order(self): + def dependency_order(self) -> Iterator[tuple[Node, DfsAttr]]: for event, node, attr in self._iterate(): if event == self._PROCESS_NODE_LATE: yield node, attr diff --git a/flatdata-generator/pyproject.toml b/flatdata-generator/pyproject.toml index cac2c360..e1b54ace 100644 --- a/flatdata-generator/pyproject.toml +++ b/flatdata-generator/pyproject.toml @@ -7,6 +7,7 @@ name = "flatdata-generator" version = "0.4.11" description = "Generate source code for C++, Rust, Go or Python from a Flatdata schema file" readme = "README.md" +requires-python = ">=3.10" authors = [ { name = "Flatdata Developers" }, ] @@ -39,3 +40,21 @@ testpaths = [ "tests" ] python_files = [ "test_*.py" ] python_classes = [ "Test*" ] python_functions = [ "test_*" ] + +[tool.mypy] +python_version = "3.10" +namespace_packages = true +explicit_package_bases = true +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_redundant_casts = true +warn_unused_ignores = true +no_implicit_optional = true +strict_equality = true +disallow_any_generics = true + +[[tool.mypy.overrides]] +module = "Levenshtein" +ignore_missing_imports = true diff --git a/flatdata-py/flatdata/lib/archive.py b/flatdata-py/flatdata/lib/archive.py index 994a2963..f60bc963 100644 --- a/flatdata-py/flatdata/lib/archive.py +++ b/flatdata-py/flatdata/lib/archive.py @@ -3,17 +3,27 @@ See the LICENSE file in the root of this project for license details. ''' -from collections import namedtuple +from __future__ import annotations + +from typing import Any, NamedTuple, TYPE_CHECKING import pandas as pd from .errors import MissingResourceError, SchemaMismatchError -ResourceSignature = namedtuple("ResourceSignature", - ["container", "initializer", "schema", "is_optional", "doc"]) +if TYPE_CHECKING: + from .resources import ReadStorage, ResourceBase + + +class ResourceSignature(NamedTuple): + container: type[ResourceBase] | type[Archive] + initializer: Any + schema: str + is_optional: bool + doc: str -def _is_archive_signature(resource_signature): - return resource_signature.container == Archive +def _is_archive_signature(resource_signature: ResourceSignature) -> bool: + return bool(resource_signature.container == Archive) _SCHEMA_EXT = ".schema" @@ -23,35 +33,38 @@ class Archive: Archive class. Entry point to Flatdata. Provides access to flatdata resources and verifies archive/resource schemas on opening. """ + _NAME: str + _SCHEMA: str + _RESOURCES: dict[str, ResourceSignature] - def __init__(self, resource_storage): + def __init__(self, resource_storage: ReadStorage) -> None: """ Opens archive from a given resource storage. :raises flatdata.errors.CorruptArchiveError :raises flatdata.errors.SchemaMismatchError :param resource_storage: Resource storage to use. """ - self._resource_storage = resource_storage - self._loaded_resources = {} + self._resource_storage: ReadStorage = resource_storage + self._loaded_resources: dict[str, Any] = {} # Preload resources and check their schemas for name, _ in sorted(list(self._RESOURCES.items())): self.__getattr__(name) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name not in self._RESOURCES: raise AttributeError("Resource %s not defined in archive." % name) if name not in self._loaded_resources: self._loaded_resources[name] = self._open_resource(name) return self._loaded_resources[name] - def __dir__(self): + def __dir__(self) -> list[str]: return list(self._RESOURCES.keys()) + ['schema'] - def __repr__(self): - return self.to_data_frame().__repr__() + def __repr__(self) -> str: + return repr(self.to_data_frame()) - def to_data_frame(self): + def to_data_frame(self) -> pd.DataFrame: result = [] for name, signature in self._RESOURCES.items(): resource = self.__getattr__(name) @@ -62,34 +75,34 @@ def to_data_frame(self): columns=["Name", "Type", "Optional", "SizeInBytes", "Size"]) @classmethod - def name(cls): + def name(cls) -> str: return cls._NAME @classmethod - def schema(cls): + def schema(cls) -> str: return cls._SCHEMA @classmethod - def resource_schema(cls, resource): - return cls._RESOURCES[resource].schema + def resource_schema(cls, resource: str) -> str: + return str(cls._RESOURCES[resource].schema) @classmethod - def open(cls, storage, name, initializer, is_optional=False): + def open(cls, storage: ReadStorage, name: str, initializer: type[Archive], is_optional: bool = False) -> Archive | None: nested_storage = storage.get(name, is_optional) assert nested_storage is not None or is_optional if nested_storage is None: return None return initializer(nested_storage) - def size_in_bytes(self): + def size_in_bytes(self) -> int: return sum(resource_value.size_in_bytes() for resource_value in (self.__getattr__(resource) for resource in self._RESOURCES.keys()) if resource_value) - def __len__(self): + def __len__(self) -> int: return len(self._RESOURCES) - def _schema_validated_resource_signature(self, name): + def _schema_validated_resource_signature(self, name: str) -> ResourceSignature | None: resource_signature = self._RESOURCES[name] # We check only schema for non-subarchives, since the subarchives schema is checked, # when it is initialized. @@ -103,7 +116,7 @@ def _schema_validated_resource_signature(self, name): return None return resource_signature - def _open_resource(self, name): + def _open_resource(self, name: str) -> Any: resource_signature = self._schema_validated_resource_signature(name) if resource_signature: resource = resource_signature.container.open(storage=self._resource_storage, @@ -116,7 +129,7 @@ def _open_resource(self, name): return None @staticmethod - def _check_non_subarchive_schema(name, resource_signature, storage): + def _check_non_subarchive_schema(name: str, resource_signature: ResourceSignature, storage: Any) -> None: actual_schema = bytes(storage).decode() if actual_schema != resource_signature.schema: raise SchemaMismatchError( diff --git a/flatdata-py/flatdata/lib/archive_builder.py b/flatdata-py/flatdata/lib/archive_builder.py index a4b1d2a5..9e007744 100644 --- a/flatdata-py/flatdata/lib/archive_builder.py +++ b/flatdata-py/flatdata/lib/archive_builder.py @@ -3,8 +3,10 @@ See the LICENSE file in the root of this project for license details. ''' -from collections import namedtuple +from __future__ import annotations + import os +from typing import Any, NamedTuple, Protocol, TYPE_CHECKING from .errors import IndexWriterError, MissingFieldError, UnknownFieldError, \ UnknownStructureError, UnknownResourceError, ResourceAlreadySetError @@ -12,10 +14,24 @@ from .resources import Instance, Vector, Multivector, RawData from .data_access import write_value +if TYPE_CHECKING: + from .resource_storage import _Resource + from .structure import Structure + _SCHEMA_EXT = ".schema" -ResourceSignature = namedtuple("ResourceSignature", - ["container", "initializer", "schema", "is_optional", "doc"]) + +class ResourceSignature(NamedTuple): + container: type + initializer: Any + schema: str + is_optional: bool + doc: str + + +class WriteStorage(Protocol): + def get(self, resource_name: str, is_subarchive: bool = False) -> Any: ... + def close(self) -> None: ... class IndexWriter: @@ -23,7 +39,7 @@ class IndexWriter: IndexWriter class. Only applicable when multivector is present in archive schema. """ - def __init__(self, name, size, resource_storage): + def __init__(self, name: str, size: int, resource_storage: WriteStorage) -> None: """ Create IndexWriter class. @@ -36,9 +52,9 @@ def __init__(self, name, size, resource_storage): self._name = name self._index_size = size - self._fout = resource_storage.get(f'{self._name}_index', False) + self._fout: _Resource = resource_storage.get(f'{self._name}_index', False) - def add(self, index): + def add(self, index: int) -> None: """ Convert index(number) to bytearray and add to in memory store """ @@ -46,7 +62,7 @@ def add(self, index): byteorder="little", signed=False) self._fout.write(index_bytes) - def finish(self): + def finish(self) -> None: """ Complete index resource by adding size and padding followed by writing to file """ @@ -60,30 +76,33 @@ class ArchiveBuilder: ArchiveBuilder class. Entry point to writing Flatdata. Provides methods to create flatdata archives. """ + _NAME: str + _SCHEMA: str + _RESOURCES: dict[str, ResourceSignature] - def __init__(self, resource_storage, path=""): + def __init__(self, resource_storage: WriteStorage, path: str = "") -> None: """ Opens archive from a given resource writer. :param resource_storage: storage manager to store and write to disc :param path: file path where archive is created """ self._path = os.path.join(path, self._NAME) - self._resource_storage = resource_storage + self._resource_storage: WriteStorage = resource_storage self._write_archive_signature() self._write_archive_schema() self._resources_written = [f"{self._NAME}.archive"] @classmethod - def name(cls): + def name(cls) -> str: '''Returns archive name''' return cls._NAME @classmethod - def schema(cls): + def schema(cls) -> str: '''Returns archive schema''' return cls._SCHEMA - def _write_raw_data(self, name, data): + def _write_raw_data(self, name: str, data: bytes | bytearray) -> None: ''' Helper function to write data @@ -94,7 +113,7 @@ def _write_raw_data(self, name, data): storage.write(data) storage.close() - def _write_schema(self, name): + def _write_schema(self, name: str) -> None: ''' Writes resource schema @@ -103,29 +122,29 @@ def _write_schema(self, name): self._write_raw_data(f"{name}.schema", bytes( self._RESOURCES[name].schema, 'utf-8')) - def _write_archive_signature(self): + def _write_archive_signature(self) -> None: '''Writes archive's signature''' self._write_raw_data(f"{self._NAME}.archive", b'\x00' * 16) - def _write_archive_schema(self): + def _write_archive_schema(self) -> None: '''Writes archive schema''' self._write_raw_data( f"{self._NAME}.archive.schema", bytes(self._SCHEMA, 'utf-8')) - def _write_index_schema(self, resource_name, schema): + def _write_index_schema(self, resource_name: str, schema: str) -> None: self._write_raw_data( f"{resource_name}_index.schema", bytes(schema, 'utf-8')) - def subarchive(self, name): + def subarchive(self, name: str) -> 'ArchiveBuilder': """ Returns an archive builder for the sub-archive `name`. :raises $name_not_subarchive_error :param name: name of the sub-archive """ - NotImplemented + raise NotImplementedError(f"subarchive '{name}' is not implemented") @classmethod - def __validate_structure_fields(cls, name, struct, initializer): + def __validate_structure_fields(cls, name: str, struct: dict[str, Any], initializer: type[Structure]) -> None: ''' Validates whether passed object has all required fields @@ -142,7 +161,7 @@ def __validate_structure_fields(cls, name, struct, initializer): if key not in initializer._FIELD_KEYS: raise UnknownFieldError(key, name) - def __set_instance(self, storage, name, value): + def __set_instance(self, storage: _Resource, name: str, value: dict[str, Any]) -> None: ''' Creates and writes instance type resource @@ -160,7 +179,7 @@ def __set_instance(self, storage, name, value): storage.write(bout) - def __set_vector(self, storage, name, vector): + def __set_vector(self, storage: _Resource, name: str, vector: list[dict[str, Any]]) -> None: ''' Creates and writes vector resource @@ -179,7 +198,7 @@ def __set_vector(self, storage, name, vector): field.is_signed, value[key]) storage.write(bout) - def __set_multivector(self, storage, name, value): + def __set_multivector(self, storage: _Resource, name: str, value: list[list[dict[str, Any]]]) -> None: ''' Creates and writes multivector resource @@ -193,10 +212,10 @@ def __set_multivector(self, storage, name, value): for index, obj_type in enumerate(initializer_list[1:]): initializers[obj_type._NAME] = (index, obj_type) - def valid_structure_name(_obj): + def valid_structure_name(_obj: dict[str, Any]) -> bool: return _obj['name'] in [_initializer._NAME for _initializer in initializer_list[1:]] - def validate_fields(_obj): + def validate_fields(_obj: dict[str, Any]) -> None: matched_obj_list = [ _initializer for _initializer in initializer_list[1:] \ if _initializer._NAME == _obj['name']] @@ -248,7 +267,7 @@ def validate_fields(_obj): self._resources_written.append(name) self._resources_written.append(f'{name}_index') - def set(self, name, value): + def set(self, name: str, value: Any) -> None: """ Write a resource for this archive at once. Can only be done once. `set` and `start` can't be used for the same resource. @@ -284,7 +303,7 @@ def set(self, name, value): self._resources_written.append(name) - def finish(self): + def finish(self) -> None: """ Closes the storage manager """ diff --git a/flatdata-py/flatdata/lib/data_access.py b/flatdata-py/flatdata/lib/data_access.py index 025e817e..b369e122 100644 --- a/flatdata-py/flatdata/lib/data_access.py +++ b/flatdata-py/flatdata/lib/data_access.py @@ -3,13 +3,19 @@ See the LICENSE file in the root of this project for license details. ''' +import mmap +from collections.abc import Callable + import numpy as np +from numpy.typing import NDArray + +ReadableBuffer = bytes | bytearray | memoryview | mmap.mmap # Sign bits cache for the value reading. _SIGN_BITS = [0] + [(1 << (bits - 1)) for bits in range(1, 65)] -def make_field_reader(offset_bits, num_bits, is_signed): +def make_field_reader(offset_bits: int, num_bits: int, is_signed: bool) -> Callable[[ReadableBuffer, int], int]: """Build a specialized closure for reading a single field from a structure. Returns a function reader(data, pos_bytes) that reads the field value @@ -26,7 +32,7 @@ def make_field_reader(offset_bits, num_bits, is_signed): if num_bits == 1: bit_mask = 1 << offset_extra - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: return int((data[pos + offset_bytes] & bit_mask) != 0) return reader @@ -34,21 +40,21 @@ def reader(data, pos): sign_bit = _SIGN_BITS[num_bits] sign_mask = sign_bit - 1 if needs_extra: - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: result = int.from_bytes( data[pos + offset_bytes: pos + end_byte], byteorder="little") result >>= offset_extra result |= data[pos + end_byte] << extra_shift result &= mask - return (result & sign_mask) - (result & sign_bit) + return int((result & sign_mask) - (result & sign_bit)) elif offset_extra: - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: result = (int.from_bytes( data[pos + offset_bytes: pos + end_byte], byteorder="little") >> offset_extra) & mask return (result & sign_mask) - (result & sign_bit) else: - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: result = int.from_bytes( data[pos + offset_bytes: pos + end_byte], byteorder="little") & mask @@ -57,26 +63,26 @@ def reader(data, pos): # Unsigned paths if needs_extra: - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: result = int.from_bytes( data[pos + offset_bytes: pos + end_byte], byteorder="little") result >>= offset_extra result |= data[pos + end_byte] << extra_shift - return result & mask + return int(result & mask) elif offset_extra: - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: return (int.from_bytes( data[pos + offset_bytes: pos + end_byte], byteorder="little") >> offset_extra) & mask else: - def reader(data, pos): + def reader(data: ReadableBuffer, pos: int) -> int: return int.from_bytes( data[pos + offset_bytes: pos + end_byte], byteorder="little") & mask return reader -def read_field_vectorized(raw_bytes_2d, field_offset_bits, field_width_bits, is_signed): +def read_field_vectorized(raw_bytes_2d: NDArray[np.uint8], field_offset_bits: int, field_width_bits: int, is_signed: bool) -> NDArray[np.uint64] | NDArray[np.int64]: """Read a bit-packed field from all elements at once, returning a numpy array. :param raw_bytes_2d: numpy uint8 array shaped (num_elements, struct_size_bytes) @@ -122,7 +128,7 @@ def read_field_vectorized(raw_bytes_2d, field_offset_bits, field_width_bits, is_ return result -def read_value(data, offset_bits, num_bits, is_signed): +def read_value(data: ReadableBuffer, offset_bits: int, num_bits: int, is_signed: bool) -> int: """Read a bit-packed value from data at the given bit offset. This is a convenience wrapper around :func:`make_field_reader` for one-off @@ -133,7 +139,7 @@ def read_value(data, offset_bits, num_bits, is_signed): return reader(data, 0) -def write_value(data, offset_bits, num_bits, is_signed, value): +def write_value(data: bytearray, offset_bits: int, num_bits: int, is_signed: bool, value: int) -> None: assert num_bits <= 64, f'Number of bits to write is greater than 64' offset_bytes, offset_extra_bits = divmod(offset_bits, 8) diff --git a/flatdata-py/flatdata/lib/errors.py b/flatdata-py/flatdata/lib/errors.py index 831172c0..4284d2fb 100644 --- a/flatdata-py/flatdata/lib/errors.py +++ b/flatdata-py/flatdata/lib/errors.py @@ -11,7 +11,7 @@ class SchemaMismatchError(RuntimeError): Schema mismatch: archive does not match software expectations. """ - def __init__(self, name, expected_schema, actual_schema): + def __init__(self, name: str, expected_schema: list[str], actual_schema: list[str]) -> None: diff = '\n'.join([l for l in difflib.unified_diff(expected_schema, actual_schema)]) message = "Schema mismatch for resource {name}. Expected: \n{expected}\n\nActual:{actual}\n\nDiff:{diff}" RuntimeError.__init__(self, @@ -36,7 +36,7 @@ class MissingResourceError(KeyError, CorruptArchiveError): """ Resource or schema is missing. """ - def __init__(self, key): + def __init__(self, key: str) -> None: super().__init__("Resource {key} not found".format(key=key)) @@ -50,7 +50,7 @@ class MissingFieldError(RuntimeError): """ Fields missing in provided dictionary object """ - def __init__(self, key, name): + def __init__(self, key: str, name: str) -> None: super().__init__(f'Missing "{key}" is required for "{name}"') @@ -58,21 +58,21 @@ class UnknownFieldError(RuntimeError): """ Field provided is not present in resource schema """ - def __init__(self, key, name): + def __init__(self, key: str, name: str) -> None: super().__init__(f'Field "{key}" is not specified for "{name}"') class FileExistsError(RuntimeError): """ Provided file name is already present. """ - def __init__(self, key): + def __init__(self, key: str) -> None: super().__init__(f'File "{key}" exists already') class DirExistsError(RuntimeError): """ Directory with given path is already present """ - def __init__(self, path): + def __init__(self, path: str) -> None: super().__init__(f'Directory "{path}" exists already') class UnknownStructureError(RuntimeError): @@ -80,47 +80,47 @@ class UnknownStructureError(RuntimeError): Provided structure/dictionary is not part of any initializer defined in multivector schema """ - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__(f'"{name}" structure is not part of the multivector') class IndexWriterError(RuntimeError): """ Error while creating instance of IndexWriter needed for multivector """ - def __init__(self, error_str="Error initializing IndexWritter Class"): + def __init__(self, error_str: str = "Error initializing IndexWritter Class") -> None: super().__init__(f'{error_str}') class ArchivePathNotProvidedError(RuntimeError): """ Path where archive will be created is missing """ - def __init__(self): + def __init__(self) -> None: super().__init__("File path is not provided") class MissingResourceName(RuntimeError): """ Resource name is not provided """ - def __init__(self): + def __init__(self) -> None: super().__init__("Resource name is not provided") class FileNameNotProvided(RuntimeError): """ File name is not provided """ - def __init__(self): + def __init__(self) -> None: super().__init__("File name is not provided") class ResourceAlreadySetError(RuntimeError): """ Provided resource name is already set for the archive """ - def __init__(self): + def __init__(self) -> None: super().__init__("Resource is already set") class UnknownResourceError(RuntimeError): """ Provided resource name is not in archive schema """ - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__(f"Resource {name} is not part of provided schema") diff --git a/flatdata-py/flatdata/lib/file_resource_storage.py b/flatdata-py/flatdata/lib/file_resource_storage.py index 674b6fba..18e916ab 100644 --- a/flatdata-py/flatdata/lib/file_resource_storage.py +++ b/flatdata-py/flatdata/lib/file_resource_storage.py @@ -3,6 +3,8 @@ See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + import mmap import os @@ -15,7 +17,7 @@ class FileResourceStorage: """ @staticmethod - def memory_map(filename): + def memory_map(filename: str) -> mmap.mmap: """ Memory maps given file. Introduced to be able to swap mmap implementations. :param filename: @@ -24,10 +26,10 @@ def memory_map(filename): opened_file = open(filename, 'r') return mmap.mmap(opened_file.fileno(), 0, access=mmap.ACCESS_READ) - def __init__(self, path): - self.path = path + def __init__(self, path: str) -> None: + self.path: str = path - def get(self, key, is_optional=False): + def get(self, key: str, is_optional: bool = False) -> mmap.mmap | 'FileResourceStorage' | None: filename = os.path.join(self.path, key) if not os.path.exists(filename): if not is_optional: @@ -40,5 +42,5 @@ def get(self, key, is_optional=False): return FileResourceStorage(filename) - def ls(self): + def ls(self) -> list[str]: return os.listdir(self.path) diff --git a/flatdata-py/flatdata/lib/file_resource_writer.py b/flatdata-py/flatdata/lib/file_resource_writer.py index ab97e4f1..046aea19 100644 --- a/flatdata-py/flatdata/lib/file_resource_writer.py +++ b/flatdata-py/flatdata/lib/file_resource_writer.py @@ -4,6 +4,8 @@ ''' import os +from typing import IO + from flatdata.lib.errors import ArchivePathNotProvidedError, FileNameNotProvided class FileResourceWriter: @@ -11,16 +13,16 @@ class FileResourceWriter: This is a factory class which will create instance of FileResourceWriter for resource. This class directly writes to disc on a file. ''' - def __init__(self): + def __init__(self) -> None: '''Create instance of FileResourceWriter''' - self._file = None + self._file: IO[bytes] | None = None @staticmethod - def create_instance(): + def create_instance() -> 'FileResourceWriter': '''Static method to create instances and gives this class a factory like behaviour''' return FileResourceWriter() - def open(self, name, file_path): + def open(self, name: str, file_path: str) -> None: ''' Opens a file for writing. It will also create directory if it is not present. @@ -41,12 +43,14 @@ def open(self, name, file_path): self._file = open(file_name, 'wb') - def write(self, data): + def write(self, data: bytes | bytearray) -> None: '''Write data to file''' if data: + assert self._file is not None, "write() called before open()" self._file.write(data) - def close(self): + def close(self) -> None: '''Flush data and close file''' + assert self._file is not None, "close() called before open()" self._file.flush() self._file.close() diff --git a/flatdata-py/flatdata/lib/flatdata_writer.py b/flatdata-py/flatdata/lib/flatdata_writer.py index c7e6beb8..c41128a8 100644 --- a/flatdata-py/flatdata/lib/flatdata_writer.py +++ b/flatdata-py/flatdata/lib/flatdata_writer.py @@ -3,12 +3,19 @@ See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + from flatdata.generator.engine import Engine from flatdata.generator.tree.errors import FlatdataSyntaxError from .resource_storage import ResourceStorage from .file_resource_writer import FileResourceWriter +if TYPE_CHECKING: + from .archive_builder import ArchiveBuilder + class Writer: ''' @@ -19,7 +26,7 @@ class Writer: flatdata. ''' - def __init__(self, archive_schema, path, archive_name=""): + def __init__(self, archive_schema: str, path: str, archive_name: str = "") -> None: ''' Creates instance or Writer class. Archive module is rendered by engine using provided schema. @@ -37,10 +44,10 @@ def __init__(self, archive_schema, path, archive_name=""): raise RuntimeError( "Error in generating modules from provided schema: %s " % err) - self.builder = archive_type( + self.builder: ArchiveBuilder = archive_type( ResourceStorage(FileResourceWriter(), path)) - def set(self, resource_name, resource_data): + def set(self, resource_name: str, resource_data: Any) -> None: ''' It is the setter for flatdata creation. Expects data in JSON format. Caller has to provide resource name which is the flatdata schema. @@ -50,12 +57,12 @@ def set(self, resource_name, resource_data): ''' self.builder.set(resource_name, resource_data) - def finish(self): + def finish(self) -> None: '''Completes flatdata creation''' self.builder.finish() @classmethod - def _get_archive_name(cls, archive_schema): + def _get_archive_name(cls, archive_schema: str) -> str: ''' Returns name of archive from flatdata schema. diff --git a/flatdata-py/flatdata/lib/inspector.py b/flatdata-py/flatdata/lib/inspector.py index 21dabd74..d1214cba 100755 --- a/flatdata-py/flatdata/lib/inspector.py +++ b/flatdata-py/flatdata/lib/inspector.py @@ -7,9 +7,11 @@ import fnmatch import os import sys +import types import pandas as pd +from .archive import Archive from .file_resource_storage import FileResourceStorage from .tar_archive_resource_storage import TarArchiveResourceStorage from flatdata.generator.engine import Engine @@ -29,7 +31,7 @@ """ -def open_archive(path, archive=None, module_name=None, root_namespace=None): +def open_archive(path: str, archive: str | None = None, module_name: str | None = None, root_namespace: str | None = None) -> tuple[Archive, types.ModuleType]: """ Opens archive at a given path. Archive schema is read and python bindings are generated on the fly. @@ -47,7 +49,7 @@ def open_archive(path, archive=None, module_name=None, root_namespace=None): is_tar = path.endswith(".tar") and not os.path.isdir(path) archive_path = path if is_tar or os.path.isdir(path) else os.path.dirname(path) if is_tar: - storage = TarArchiveResourceStorage.create(archive_path) + storage: TarArchiveResourceStorage | FileResourceStorage = TarArchiveResourceStorage.create(archive_path) else: storage = FileResourceStorage(archive_path) @@ -69,21 +71,24 @@ def open_archive(path, archive=None, module_name=None, root_namespace=None): raise RuntimeError("Specified archive not found at path.") archive_name, _ = signatures[matching].rsplit('.', 1) - schema = storage.get(signatures[matching] + ".schema") + schema_raw = storage.get(signatures[matching] + ".schema") + if schema_raw is None: + raise RuntimeError("Schema not found for archive at %s" % path) try: module, archive_type = \ - Engine(schema.read().decode()).render_python_module(module_name=module_name, + Engine(bytes(schema_raw).decode()).render_python_module( # type: ignore[arg-type] # schema_raw is always bytes-like (mmap/memoryview) for .schema files + module_name=module_name, archive_name=archive_name, root_namespace=root_namespace) except FlatdataSyntaxError as err: raise RuntimeError("Error reading schema: %s " % err) - archive = archive_type(storage) - return archive, module + result: Archive = archive_type(storage) + return result, module -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, dest="path", required=True, help="Path to archive") diff --git a/flatdata-py/flatdata/lib/py.typed b/flatdata-py/flatdata/lib/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/flatdata-py/flatdata/lib/resource_storage.py b/flatdata-py/flatdata/lib/resource_storage.py index 72a62195..252ba727 100644 --- a/flatdata-py/flatdata/lib/resource_storage.py +++ b/flatdata-py/flatdata/lib/resource_storage.py @@ -3,9 +3,23 @@ See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + +from typing import Protocol + from flatdata.lib.errors import ArchivePathNotProvidedError, MissingResourceName +class ResourceWriter(Protocol): + def open(self, name: str, file_path: str) -> None: ... + def write(self, data: bytes | bytearray) -> None: ... + def close(self) -> None: ... + + +class ResourceWriterFactory(Protocol): + def create_instance(self) -> ResourceWriter: ... + + class _Resource(): ''' _Resource class. @@ -13,7 +27,7 @@ class _Resource(): This class provides the functionality of in memory storage. It uses provided writer object to write stored data to file. ''' - def __init__(self, name, writer=None, path="", is_subarchive=False): + def __init__(self, name: str, writer: ResourceWriterFactory | None = None, path: str = "", is_subarchive: bool = False) -> None: ''' Creates in memory storage for resource. @@ -32,9 +46,9 @@ def __init__(self, name, writer=None, path="", is_subarchive=False): if not path: raise ArchivePathNotProvidedError() - self.data = bytearray() - self._valid = True - self._resource_writer = None + self.data: bytearray | bytes | None = bytearray() + self._valid: bool = True + self._resource_writer: ResourceWriter | None = None if writer: self._resource_writer = writer.create_instance() @@ -42,42 +56,50 @@ def __init__(self, name, writer=None, path="", is_subarchive=False): if self._resource_writer: self._resource_writer.open(name, path) - def get_status(self): + def get_status(self) -> bool: '''Returns status of resource. Status is valid if resource is not yet written.''' return self._valid - def write(self, data): + def write(self, data: bytes | bytearray) -> None: ''' Concatenates passed data to instance member bytearray or bytes. :param data(bytearray): bytearray to be added to resource ''' + assert self.data is not None, "write() called on closed resource" if data and isinstance(data, bytearray) or isinstance(data, bytes): self.data += data - def get_data(self): - '''Returns resources data in bytearray''' + def get_data(self) -> bytearray | bytes | None: + '''Returns resources data in bytearray, or None if the resource is closed.''' return self.data - def add_size(self): + def add_size(self) -> None: '''Calculate size of stored data and appends it to the begining''' + assert self.data is not None, "add_size() called on closed resource" self.data = int(len(self.data)).to_bytes( 8, byteorder="little", signed=False) + self.data - def add_padding(self): + def add_padding(self) -> None: '''Add 8 byte zero padding at the end of data''' + assert self.data is not None, "add_padding() called on closed resource" self.data += b'\x00' * 8 - def __str__(self): - '''Facilitate print for debugging''' - return f'{self.data}' + def __str__(self) -> str: + '''Facilitate print for debugging. + + Uses !r (repr) instead of implicit __format__ because format(bytes, '') + is deprecated in Python 3.12+ and raises TypeError in 3.14+. + ''' + return f'{self.data!r}' - def close(self): + def close(self) -> None: ''' Marks the end of resource. It will invoke actual write to disk and mark this resource as already written by setting resource as invalid. ''' if self._resource_writer: + assert self.data is not None, "close() called on already-closed resource" self._resource_writer.write(self.data) self.data = None self._resource_writer.close() @@ -91,18 +113,18 @@ class ResourceStorage: It is responsible for creating and managing all resources available in archive. ''' - def __init__(self, writer, path): + def __init__(self, writer: ResourceWriterFactory, path: str) -> None: ''' Creates ResourceStorage object. :param writer(object): writes data to disc :param path(str): file path where resource is created ''' - self._store = {} - self._resource_writer = writer - self._path = path + self._store: dict[str, _Resource] = {} + self._resource_writer: ResourceWriterFactory = writer + self._path: str = path - def get(self, resource_name, is_subarchive=False): + def get(self, resource_name: str, is_subarchive: bool = False) -> _Resource: ''' Returns the instance of _Resource. @@ -114,7 +136,7 @@ def get(self, resource_name, is_subarchive=False): resource_name, self._resource_writer, self._path, is_subarchive) return self._store[resource_name] - def close(self): + def close(self) -> None: '''Try to close _Resource objects which are not written to disc''' for key in self._store: if self._store[key].get_status(): diff --git a/flatdata-py/flatdata/lib/resources.py b/flatdata-py/flatdata/lib/resources.py index 6270b795..79b00e43 100644 --- a/flatdata-py/flatdata/lib/resources.py +++ b/flatdata-py/flatdata/lib/resources.py @@ -3,40 +3,53 @@ See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + +from collections.abc import Iterator import json +from typing import Any, Protocol, TYPE_CHECKING import pandas as pd import numpy as np -from .data_access import read_value, read_field_vectorized +from .data_access import ReadableBuffer, read_value, read_field_vectorized from .errors import CorruptResourceError +if TYPE_CHECKING: + from .structure import Structure + SIZE_OFFSET_IN_BITS = 64 SIZE_OFFSET_IN_BYTES = SIZE_OFFSET_IN_BITS // 8 SIZE_PADDING_IN_BYTES = 8 +class ReadStorage(Protocol): + def get(self, key: str, is_optional: bool = False) -> Any: ... + + class ResourceBase: - def __init__(self, mem, element_type): + def __init__(self, mem: ReadableBuffer, element_type: type[Structure] | None) -> None: if len(mem) < (SIZE_OFFSET_IN_BYTES + SIZE_PADDING_IN_BYTES): raise CorruptResourceError() self._mem = memoryview(mem) self._element_type = element_type self._element_types = [element_type] self._type_size_in_bytes = self._element_type._SIZE_IN_BYTES if self._element_type else 1 - self._raw_numpy_2d = None + self._raw_numpy_2d: np.ndarray[Any, np.dtype[np.uint8]] | None = None - def size_in_bytes(self): + def size_in_bytes(self) -> int: return len(self._mem) - def _item_offset(self, index): - return SIZE_OFFSET_IN_BYTES + self._element_type._SIZE_IN_BYTES * index + def _item_offset(self, index: int) -> int: + assert self._element_type is not None + return int(SIZE_OFFSET_IN_BYTES + self._element_type._SIZE_IN_BYTES * index) - def _get_item(self, index): + def _get_item(self, index: int) -> Any: + assert self._element_type is not None offset = self._item_offset(index) return self._element_type(self._mem, offset) - def _as_numpy_2d(self): + def _as_numpy_2d(self) -> Any: """Return the raw data as a 2D numpy uint8 array of shape (n, struct_size). Zero-copy via np.frombuffer on the mmap'd memory. Cached after first call. """ @@ -50,7 +63,7 @@ def _as_numpy_2d(self): self._raw_numpy_2d = raw.reshape(n, struct_size) return self._raw_numpy_2d - def _repr_attributes(self): + def _repr_attributes(self) -> dict[str, Any]: return { "container_type": self.__class__.__name__, "size": len(self), @@ -58,28 +71,29 @@ def _repr_attributes(self): "element_types": [t._repr_attributes() for t in self._element_types if t is not None] } - def __len__(self): + def __len__(self) -> int: raise NotImplementedError() - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self._repr_attributes(), indent=4) @classmethod - def open(cls, storage, name, initializer, is_optional=False): + def open(cls, storage: ReadStorage, name: str, initializer: Any, is_optional: bool = False) -> Any: return cls(storage.get(name, is_optional), initializer) class _VectorSlice: - def __init__(self, s, sequence): + def __init__(self, s: slice, sequence: 'Vector') -> None: self._slice = s self._sequence = sequence - def to_numpy(self, limit=None): + def to_numpy(self, limit: int | None = None) -> Any: raw_2d = self._sequence._as_numpy_2d() sliced = raw_2d[self._slice] if limit is not None: sliced = sliced[:limit] + assert self._sequence._element_type is not None fields = self._sequence._element_type._FIELDS dtype = self._sequence._element_type.dtype() result = np.empty(sliced.shape[0], dtype=dtype) @@ -89,14 +103,15 @@ def to_numpy(self, limit=None): ) return result - def to_data_frame(self, limit=None): + def to_data_frame(self, limit: int | None = None) -> pd.DataFrame: return pd.DataFrame(data=self.to_numpy(limit)) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: for i in range(*self._slice.indices(len(self._sequence))): yield self._sequence[i] - def __getattr__(self, name): + def __getattr__(self, name: str) -> pd.DataFrame: + assert self._sequence._element_type is not None try: field = self._sequence._element_type._FIELDS[name] except KeyError: @@ -105,21 +120,22 @@ def __getattr__(self, name): values = read_field_vectorized(raw_2d, field.offset, field.width, field.is_signed) return pd.DataFrame(data=values, columns=[name]) - def __repr__(self): - return "Displaying first 100 records:\n" + self.to_data_frame(limit=100).__repr__() + def __repr__(self) -> str: + return "Displaying first 100 records:\n" + repr(self.to_data_frame(limit=100)) class Vector(ResourceBase): - def __init__(self, mem, element_type): + def __init__(self, mem: ReadableBuffer, element_type: type[Structure]) -> None: ResourceBase.__init__(self, mem, element_type) size_in_bytes = read_value(self._mem, 0, SIZE_OFFSET_IN_BITS, False) size, rem = divmod(size_in_bytes, self._type_size_in_bytes) assert rem == 0, "Malformed vector" self._size = size - def to_numpy(self): + def to_numpy(self) -> Any: """Convert entire vector to a numpy structured array (vectorized).""" raw_2d = self._as_numpy_2d() + assert self._element_type is not None fields = self._element_type._FIELDS dtype = self._element_type.dtype() result = np.empty(self._size, dtype=dtype) @@ -129,10 +145,10 @@ def to_numpy(self): ) return result - def to_data_frame(self): + def to_data_frame(self) -> pd.DataFrame: return pd.DataFrame(data=self.to_numpy()) - def __getitem__(self, index): + def __getitem__(self, index: int | slice) -> Any: if isinstance(index, slice): return _VectorSlice(index, self) @@ -142,14 +158,16 @@ def __getitem__(self, index): raise IndexError("Vector access out of bounds: " + str(index)) return self._get_item(index) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: mem = self._mem element_type = self._element_type + assert element_type is not None size_bytes = self._type_size_in_bytes for i in range(self._size): yield element_type(mem, SIZE_OFFSET_IN_BYTES + size_bytes * i) - def __getattr__(self, name): + def __getattr__(self, name: str) -> pd.DataFrame: + assert self._element_type is not None try: field = self._element_type._FIELDS[name] except KeyError: @@ -158,44 +176,44 @@ def __getattr__(self, name): values = read_field_vectorized(raw_2d, field.offset, field.width, field.is_signed) return pd.DataFrame(data=values, columns=[name]) - def __len__(self): + def __len__(self) -> int: return self._size class _MultivectorSlice: - def __init__(self, s, sequence): + def __init__(self, s: slice, sequence: 'Multivector') -> None: self._slice = s self._sequence = sequence - def __iter__(self): + def __iter__(self) -> Iterator[list[Any]]: for i in range(*self._slice.indices(len(self._sequence))): yield self._sequence[i] - def __repr__(self): + def __repr__(self) -> str: return [x for x in self].__repr__() class Multivector(ResourceBase): - def __init__(self, index_mem, mem, index_type, *element_types): + def __init__(self, index_mem: ReadableBuffer, mem: ReadableBuffer, index_type: type[Structure], *element_types: type[Structure]) -> None: self._index = Vector(index_mem, index_type) - self._mem = mem - self._element_types = element_types + self._mem = memoryview(mem) + self._element_types = list(element_types) self._index_type = index_type @classmethod - def open(cls, storage, name, initializer, is_optional=False): + def open(cls, storage: ReadStorage, name: str, initializer: list[type[Structure]], is_optional: bool = False) -> Multivector: return cls(storage.get(name + "_index", is_optional), storage.get(name, is_optional), *initializer) - def __len__(self): + def __len__(self) -> int: # The last entry is just a sentinel return max(0, len(self._index) - 1) - def _bucket_offset(self, index): - return self._index[index].value + SIZE_OFFSET_IN_BYTES + def _bucket_offset(self, index: int) -> int: + return int(self._index[index].value) + SIZE_OFFSET_IN_BYTES - def __getitem__(self, index): + def __getitem__(self, index: int | slice) -> Any: if isinstance(index, slice): return _MultivectorSlice(index, self) @@ -206,26 +224,27 @@ def __getitem__(self, index): type_index = read_value(self._mem, offset * 8, 8, False) offset += 1 element_type = self._element_types[type_index] + assert element_type is not None element = element_type(self._mem, offset) elements.append(element) offset += element_type._SIZE_IN_BYTES return elements - def __iter__(self): + def __iter__(self) -> Iterator[list[Any]]: for i in range(len(self)): yield self[i] - def __repr__(self): + def __repr__(self) -> str: attrs = self._repr_attributes() attrs.update(index_type=self._index_type._repr_attributes()) return json.dumps(attrs, indent=4) class RawData(ResourceBase): - def __len__(self): + def __len__(self) -> int: return read_value(self._mem, 0, SIZE_OFFSET_IN_BITS, False) - def __getitem__(self, item): + def __getitem__(self, item: int | slice) -> memoryview: if isinstance(item, slice): return self._mem[ slice(item.start + SIZE_OFFSET_IN_BYTES, @@ -234,13 +253,13 @@ def __getitem__(self, item): ] return self._mem[item + SIZE_OFFSET_IN_BYTES:item + SIZE_OFFSET_IN_BYTES + 1] - def sub_str(self, index, separator = b'\0'): + def sub_str(self, index: int, separator: bytes = b'\0') -> str: for i in range(index, len(self)): if self[i:i + len(separator)] == separator: return bytes(self[index:i]).decode("utf-8") return bytes(self[index]).decode("utf-8") - def sub_str_list(self, index, separator = b'\0', list_separator = b'\0\0'): + def sub_str_list(self, index: int, separator: bytes = b'\0', list_separator: bytes = b'\0\0') -> list[str]: result = [] for i in range(index, len(self)): if index == i and self[i:i + len(list_separator)] == list_separator: @@ -250,7 +269,7 @@ def sub_str_list(self, index, separator = b'\0', list_separator = b'\0\0'): index = i + 1 return result - def sub_str_array(self, index, size, separator = b'\0'): + def sub_str_array(self, index: int, size: int, separator: bytes = b'\0') -> list[str]: result = [] for i in range(index, len(self)): if self[i:i + len(separator)] == separator: @@ -262,7 +281,7 @@ def sub_str_array(self, index, size, separator = b'\0'): class Instance(ResourceBase): - def __getitem__(self, index): + def __getitem__(self, index: int | slice) -> Any: if isinstance(index, slice): raise IndexError("Instance has only one item") @@ -273,13 +292,14 @@ def __getitem__(self, index): return self._get_item(index) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: for i in range(1): yield self._get_item(i) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: + assert self._element_type is not None offset = self._item_offset(0) return getattr(self._element_type(self._mem, offset), name) - def __len__(self): + def __len__(self) -> int: return 1 diff --git a/flatdata-py/flatdata/lib/structure.py b/flatdata-py/flatdata/lib/structure.py index e0bdcc42..c8bc7de3 100644 --- a/flatdata-py/flatdata/lib/structure.py +++ b/flatdata-py/flatdata/lib/structure.py @@ -1,8 +1,11 @@ from collections import namedtuple +from collections.abc import Callable, Iterator import json +from typing import Any + import numpy as np -from .data_access import make_field_reader +from .data_access import ReadableBuffer, make_field_reader FieldSignature = namedtuple( "FieldSignature", ["offset", "width", "is_signed", "dtype"]) @@ -10,59 +13,64 @@ class Structure: __slots__ = ('_mem', '_pos') - _READERS = {} - - def __init_subclass__(cls, **kwargs): + _READERS: dict[str, Callable[[ReadableBuffer, int], int]] = {} + _FIELDS: dict[str, FieldSignature] + _FIELD_KEYS: list[str] + _SCHEMA: str + _SIZE_IN_BYTES: int + _NAME: str + + def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) fields = cls.__dict__.get('_FIELDS') if fields is not None: cls._READERS = {name: make_field_reader(f.offset, f.width, f.is_signed) for name, f in fields.items()} - def __init__(self, mem, pos): + def __init__(self, mem: ReadableBuffer, pos: int) -> None: self._mem = mem self._pos = pos - def __getattr__(self, name): + def __getattr__(self, name: str) -> int: try: reader = self._READERS[name] except KeyError: raise AttributeError("Field %s not found in structure" % name) return reader(self._mem, self._pos) - def __dir__(self): + def __dir__(self) -> list[str]: return self._FIELD_KEYS - def __iter__(self): + def __iter__(self) -> Iterator[int]: for name in self._FIELD_KEYS: yield getattr(self, name) - def as_dict(self): + def as_dict(self) -> dict[str, int]: mem, pos = self._mem, self._pos return {name: reader(mem, pos) for name, reader in self._READERS.items()} - def as_list(self): + def as_list(self) -> list[int]: mem, pos = self._mem, self._pos return [reader(mem, pos) for reader in self._READERS.values()] - def as_tuple(self): + def as_tuple(self) -> tuple[int, ...]: mem, pos = self._mem, self._pos return tuple(reader(mem, pos) for reader in self._READERS.values()) @classmethod - def dtype(cls): + def dtype(cls) -> list[tuple[str, np.dtype[Any]]]: return [(name, np.dtype(field.dtype)) for name, field in cls._FIELDS.items()] - def as_nparray(self): + def as_nparray(self) -> np.ndarray[Any, Any]: mem, pos = self._mem, self._pos return np.array([tuple(reader(mem, pos) for reader in self._READERS.values())], dtype=self.dtype()) - def schema(self): + def schema(self) -> str: return self._SCHEMA @classmethod - def _repr_attributes(cls): + def _repr_attributes(cls) -> dict[str, Any]: return { "name": cls.__name__, "doc": cls.__doc__, @@ -77,10 +85,10 @@ def _repr_attributes(cls): } @classmethod - def __repr__(cls): + def __repr__(cls) -> str: return json.dumps(cls._repr_attributes()) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[no-redef] # intentional: instance __repr__ shadows classmethod __repr__ above return json.dumps({ "name": self.__class__.__name__, "attributes": diff --git a/flatdata-py/flatdata/lib/tar_archive_resource_storage.py b/flatdata-py/flatdata/lib/tar_archive_resource_storage.py index cb5dd551..655f24ef 100644 --- a/flatdata-py/flatdata/lib/tar_archive_resource_storage.py +++ b/flatdata-py/flatdata/lib/tar_archive_resource_storage.py @@ -3,6 +3,8 @@ See the LICENSE file in the root of this project for license details. ''' +from __future__ import annotations + import tarfile from .errors import CorruptResourceError @@ -15,14 +17,14 @@ class TarArchiveResourceStorage: Resource storage based on a memory-mapped TAR archive. """ - def __init__(self, tar_map, file_entries, dir_entries, sub_path): + def __init__(self, tar_map: memoryview, file_entries: dict[str, tuple[int, int]], dir_entries: set[str], sub_path: str) -> None: self.tar_map = tar_map self.file_entries = file_entries self.dir_entries = dir_entries self.sub_path = sub_path @classmethod - def create(cls, tar_path, sub_path=""): + def create(cls, tar_path: str, sub_path: str = "") -> 'TarArchiveResourceStorage': tar_map = memoryview(FileResourceStorage.memory_map(tar_path)) file_entries = dict() dir_entries = set() @@ -40,7 +42,7 @@ def create(cls, tar_path, sub_path=""): return cls(tar_map, file_entries, dir_entries, sub_path) - def get(self, key, is_optional=False): + def get(self, key: str, is_optional: bool = False) -> memoryview | 'TarArchiveResourceStorage' | None: path = self._path(key) if path in self.file_entries: (offset, length) = self.file_entries[path] @@ -54,7 +56,7 @@ def get(self, key, is_optional=False): else: return None - def ls(self): + def ls(self) -> list[str]: prefix = self._path("") entries = [] for d in self.dir_entries: @@ -65,7 +67,7 @@ def ls(self): entries.append(f[len(prefix):]) return entries - def _path(self, key): + def _path(self, key: str) -> str: if not self.sub_path: return key else: diff --git a/flatdata-py/flatdata/lib/writer.py b/flatdata-py/flatdata/lib/writer.py index e438fe0d..9ad3d08b 100644 --- a/flatdata-py/flatdata/lib/writer.py +++ b/flatdata-py/flatdata/lib/writer.py @@ -9,7 +9,7 @@ from .flatdata_writer import Writer -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("-p", "--output-path", type=str, dest="path", required=True, help="Path to archive") diff --git a/flatdata-py/pyproject.toml b/flatdata-py/pyproject.toml index 4e26d316..4e98206f 100644 --- a/flatdata-py/pyproject.toml +++ b/flatdata-py/pyproject.toml @@ -7,6 +7,7 @@ name = "flatdata-py" version = "0.4.11" description = "Python 3 implementation of Flatdata" readme = "README.md" +requires-python = ">=3.10" authors = [ { name = "Flatdata Developers" }, ] @@ -46,3 +47,21 @@ testpaths = [ "tests" ] python_files = [ "test_*.py" ] python_classes = [ "Test*" ] python_functions = [ "test_*" ] + +[tool.mypy] +python_version = "3.10" +namespace_packages = true +explicit_package_bases = true +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_redundant_casts = true +warn_unused_ignores = true +no_implicit_optional = true +strict_equality = true +disallow_any_generics = true + +[[tool.mypy.overrides]] +module = "pandas" +ignore_missing_imports = true