Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/generator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ jobs:
cd flatdata-generator
uv run --with pytest pytest -v
pip install .
flatdata-generator --help
flatdata-generator --help
- name: Type check
run: |
cd flatdata-generator
uv run --with mypy mypy flatdata/
6 changes: 5 additions & 1 deletion .github/workflows/py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ jobs:
cd flatdata-py
uv venv
uv pip install ../flatdata-generator
uv pip install ".[inspector]" pytest
uv pip install ".[inspector]" pytest mypy
.venv/bin/pytest -v
.venv/bin/flatdata-inspector --help
- name: Type check
run: |
cd flatdata-py
.venv/bin/mypy flatdata/

10 changes: 5 additions & 5 deletions flatdata-generator/flatdata/generator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flatdata.generator.tree.errors import FlatdataSyntaxError


def _parse_command_line():
def _parse_command_line() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generates code for a given flatdata schema file.")
parser.add_argument("-s", "--schema", type=str, required=True,
Expand All @@ -39,7 +39,7 @@ def _parse_command_line():
return parser.parse_args()


def _setup_logging(args):
def _setup_logging(args: argparse.Namespace) -> None:
level = logging.WARNING
if args.debug:
level = logging.DEBUG
Expand All @@ -52,13 +52,13 @@ def _setup_logging(args):
level=level)


def _check_args(args):
def _check_args(args: argparse.Namespace) -> None:
if not os.path.isfile(args.schema):
logging.fatal("Cannot find schema file at %s", args.schema)
sys.exit(1)


def _run(args):
def _run(args: argparse.Namespace) -> None:
_setup_logging(args)
_check_args(args)

Expand Down Expand Up @@ -86,6 +86,6 @@ def _run(args):
logging.info("Code for %s is written to %s", args.gen, args.output_file)


def main():
def main() -> None:
"""Entrypoint"""
_run(_parse_command_line())
36 changes: 23 additions & 13 deletions flatdata-generator/flatdata/generator/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
'''

import types
from typing import overload

from flatdata.generator.tree.builder import build_ast
from flatdata.generator.tree.nodes.trivial.namespace import Namespace
from flatdata.generator.tree.nodes.node import Node
from flatdata.generator.tree.syntax_tree import SyntaxTree

from .generators.cpp import CppGenerator
from .generators.dot import DotGenerator
from .generators.go import GoGenerator
from .generators.python import PythonGenerator
from .generators.rust import RustGenerator
from .generators.flatdata import FlatdataGenerator
from .generators import BaseGenerator


class Engine:
Expand All @@ -23,7 +26,7 @@ class Engine:
Implements code generation from the given flatdata schema.
"""

_GENERATORS = {
_GENERATORS: dict[str, type[BaseGenerator]] = {
"cpp": CppGenerator,
"dot": DotGenerator,
"go": GoGenerator,
Expand All @@ -33,21 +36,21 @@ class Engine:
}

@classmethod
def available_generators(cls):
def available_generators(cls) -> list[str]:
"""
Lists names of available code generators.
"""
return list(cls._GENERATORS.keys())

def __init__(self, schema):
def __init__(self, schema: str) -> None:
"""
Instantiates generator engine for a given schema.
:raises FlatdataSyntaxError
"""
self.schema = schema
self.tree = build_ast(schema)

def render(self, generator_name):
def render(self, generator_name: str) -> str:
"""
Render schema with a given generator
:param generator_name:
Expand All @@ -60,38 +63,45 @@ def render(self, generator_name):
)

output_content = generator.render(self.tree)
return output_content
return str(output_content)

def render_python_module(self, module_name=None, archive_name=None, root_namespace=None):
@overload
def render_python_module(self, module_name: str | None, archive_name: str, root_namespace: str | None = None) -> tuple[types.ModuleType, type]: ...
@overload
def render_python_module(self, *, archive_name: str, root_namespace: str | None = None) -> tuple[types.ModuleType, type]: ...
@overload
def render_python_module(self, module_name: str | None = None, archive_name: None = None, root_namespace: str | None = None) -> types.ModuleType: ...

def render_python_module(self, module_name: str | None = None, archive_name: str | None = None, root_namespace: str | None = None) -> types.ModuleType | tuple[types.ModuleType, type]:
"""
Render python module.
:param module_name: Module name to use. If none, root namespace name is used.
:param archive_name: Archive name to lookup,
if specified, archive type is returned along with the model
:param root_namespace: Root namespace to pick in case of multiple top level namespaces.
"""
root_namespace = self._find_root_namespace(self.tree, archive_name, root_namespace)
ns = self._find_root_namespace(self.tree, archive_name, root_namespace)
module_code = self.render("py")
module = types.ModuleType(module_name if module_name is not None else root_namespace.name)
module = types.ModuleType(module_name if module_name is not None else ns.name)
#pylint: disable=exec-used
exec(module_code, module.__dict__)
if archive_name is None:
return module

name = root_namespace.name + "_" + archive_name
archive_type = getattr(module, name) if archive_name else None
name = ns.name + "_" + archive_name
archive_type = getattr(module, name)
return module, archive_type

@classmethod
def _create_generator(cls, name):
def _create_generator(cls, name: str) -> BaseGenerator | None:
generator_type = cls._GENERATORS.get(name, None)
if generator_type is None:
return None

return generator_type()
return generator_type() # type: ignore[call-arg] # dict values are concrete subclasses with zero-arg __init__

@staticmethod
def _find_root_namespace(tree, archive_name, root_namespace=None):
def _find_root_namespace(tree: SyntaxTree, archive_name: str | None, root_namespace: str | None = None) -> Namespace:
root_children = tree.root.children
root_namespaces = [
child for child in root_children
Expand Down
15 changes: 9 additions & 6 deletions flatdata-generator/flatdata/generator/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
'''

from abc import ABCMeta, abstractmethod
from typing import NoReturn

from jinja2 import Environment, PackageLoader
from jinja2 import nodes
from jinja2.ext import Extension
from jinja2.exceptions import TemplateRuntimeError
from jinja2.parser import Parser

from flatdata.generator.tree.nodes.archive import Archive
from flatdata.generator.tree.nodes.trivial import Structure, Enumeration, Constant, Namespace
Expand All @@ -21,21 +24,21 @@
class BaseGenerator(metaclass=ABCMeta):
"""Abstract base class for Flatdata generators"""

def __init__(self, template):
def __init__(self, template: str) -> None:
self._template = template

@abstractmethod
def supported_nodes(self):
def supported_nodes(self) -> list[type]:
"""List of supported nodes by this generator"""
raise RuntimeError(
"Derived generators must implement _supported_nodes")

@abstractmethod
def _populate_environment(self, env):
def _populate_environment(self, env: Environment) -> None:
raise RuntimeError(
"Derived generators must implement _populate_filters")

def render(self, tree):
def render(self, tree: SyntaxTree) -> str:
"""Generate the language implementation from the AST"""
env = Environment(loader=PackageLoader('flatdata.generator', 'templates'), lstrip_blocks=True,
trim_blocks=True, autoescape=False, extensions=[RaiseExtension])
Expand Down Expand Up @@ -71,7 +74,7 @@ class RaiseExtension(Extension):

tags = set(['raise'])

def parse(self, parser):
def parse(self, parser: Parser) -> nodes.CallBlock:
"""The first token is the line number, followed by the expression"""
lineno = next(parser.stream).lineno
message_node = parser.parse_expression()
Expand All @@ -81,6 +84,6 @@ def parse(self, parser):
)

#pylint: disable=no-self-use
def _raise(self, msg, caller):
def _raise(self, msg: str, caller: object) -> NoReturn:
"""Helper callback."""
raise TemplateRuntimeError(msg)
32 changes: 19 additions & 13 deletions flatdata-generator/flatdata/generator/generators/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
See the LICENSE file in the root of this project for license details.
'''

from jinja2 import Environment

from flatdata.generator.tree.helpers.basictype import BasicType
from flatdata.generator.tree.helpers.enumtype import EnumType
from flatdata.generator.tree.nodes.node import Node
from flatdata.generator.tree.nodes.references import BuiltinStructureReference, StructureReference
from flatdata.generator.tree.nodes.resources import Vector, Multivector, Instance, RawData, BoundResource, \
ResourceBase, Archive as ArchiveResource
from flatdata.generator.tree.nodes.trivial import Structure, Enumeration, Constant, Field
Expand All @@ -13,21 +19,21 @@
class CppGenerator(BaseGenerator):
"""Flatdata to C++ header file generator"""

def __init__(self):
def __init__(self) -> None:
BaseGenerator.__init__(self, "cpp/cpp.jinja2")

def supported_nodes(self):
def supported_nodes(self) -> list[type]:
return [Structure, Archive, Constant, Enumeration]

def _populate_environment(self, env):
def _populate_environment(self, env: Environment) -> None:
env.filters["cpp_doc"] = lambda value: value

def _safe_cpp_string_line(value):
def _safe_cpp_string_line(value: str) -> str:
return value.replace('\\', '\\\\').replace('"', r'\"')

env.filters["safe_cpp_string_line"] = _safe_cpp_string_line

def _cpp_base_type(flatdata_type):
def _cpp_base_type(flatdata_type: BasicType | EnumType | Node) -> str:
type_map = {
"bool": "bool",
"i8": "int8_t",
Expand All @@ -41,28 +47,28 @@ def _cpp_base_type(flatdata_type):
}
if flatdata_type.name in type_map:
return type_map[flatdata_type.name]
return flatdata_type.name.replace("@@", "::").replace("@", "::")
return str(flatdata_type.name.replace("@@", "::").replace("@", "::"))

env.filters["cpp_base_type"] = _cpp_base_type

def _to_type_params(refs):
def _to_type_params(refs: list[BuiltinStructureReference | StructureReference]) -> str:
return ', '.join([ref.node.path_with("::") for ref in refs])

env.filters["to_type_params"] = _to_type_params

def _snake_to_upper_camel_case(expr):
def _snake_to_upper_camel_case(expr: str) -> str:
return ''.join(p.title() for p in expr.split('_'))

env.filters["snake_to_upper_camel_case"] = _snake_to_upper_camel_case

def _typedef_name(entity, extra_suffix=""):
def _typedef_name(entity: Field | ResourceBase, extra_suffix: str = "") -> str:
assert isinstance(entity, (Field, ResourceBase)), "Got: %s" % entity.__class__
return _snake_to_upper_camel_case(entity.name) + extra_suffix + "Type"

env.filters["typedef_name"] = _typedef_name

def _optional_typedef_usage(resource, extra_suffix=""):
def _wrap_in_optional(declaration):
def _optional_typedef_usage(resource: ResourceBase, extra_suffix: str = "") -> str:
def _wrap_in_optional(declaration: str) -> str:
if resource.optional:
return "boost::optional< %s >" % declaration
return declaration
Expand All @@ -71,7 +77,7 @@ def _wrap_in_optional(declaration):

env.filters["archive_typedef_usage"] = _optional_typedef_usage

def _resource_provides_incremental_builder(resource):
def _resource_provides_incremental_builder(resource: ResourceBase) -> bool:
assert isinstance(resource, ResourceBase)
if isinstance(resource, Instance):
return False
Expand All @@ -86,7 +92,7 @@ def _resource_provides_incremental_builder(resource):
env.filters[
"resource_provides_incremental_builder"] = _resource_provides_incremental_builder

def provides_setter(resource):
def provides_setter(resource: ResourceBase) -> bool:
assert isinstance(resource, ResourceBase)
if isinstance(resource, Instance):
return True
Expand Down
18 changes: 12 additions & 6 deletions flatdata-generator/flatdata/generator/generators/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,30 @@
'''

from flatdata.generator.tree.nodes.archive import Archive
from flatdata.generator.tree.nodes.trivial import Field
from . import BaseGenerator

from jinja2 import Environment

SCOPE_SEPARATOR = "__"
DECORATION_BOUND = "__bound__"


class DotGenerator(BaseGenerator):
"""Flatdata to DOT (graph description language) generator"""

def __init__(self):
def __init__(self) -> None:
BaseGenerator.__init__(self, "dot/dot.jinja2")

def _populate_environment(self, env):
def _populate_environment(self, env: Environment) -> None:
env.autoescape = True

def _field_value_type(field):
type_name = field.type.name.replace("@@", ".").replace("@", ".")
namespace_name = field.parent.parent.path
def _field_value_type(field: Field) -> str:
assert field.type is not None
assert field.parent is not None
assert field.parent.parent is not None
type_name = str(field.type.name).replace("@@", ".").replace("@", ".")
namespace_name = str(field.parent.parent.path)
if type_name.startswith(namespace_name):
type_name = type_name[len(namespace_name):]
if type_name.startswith("."):
Expand All @@ -31,5 +37,5 @@ def _field_value_type(field):

env.filters["field_value_type"] = _field_value_type

def supported_nodes(self):
def supported_nodes(self) -> list[type]:
return [Archive]
Loading
Loading