diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index d4ac012d31..7c6cba71a4 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -1,17 +1,19 @@ +import contextlib + import cgen as c import numpy as np from sympy import Expr, Not, S from devito.ir.iet import ( Break, Call, Conditional, DummyExpr, EntryFunction, FindNodes, FindSymbols, Iteration, - List, Return, Transformer, make_callable + KernelLaunch, List, Return, Transformer, make_callable ) from devito.passes.iet.engine import iet_pass from devito.symbolics import CondEq, MathFunction from devito.tools import dtype_to_ctype from devito.types import Eq, Inc, LocalObject, Symbol -__all__ = ['check_stability', 'error_mapper'] +__all__ = ['check_launch', 'check_stability', 'error_mapper'] def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs): @@ -100,6 +102,67 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): return iet, {'efuncs': efuncs, 'includes': includes} +def check_launch(graph, options=None, **kwargs): + """ + Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of + failed kernel launches. This macro should only be inserted if the kernel is + directly within a loop, as compilation will fail otherwise. + """ + if options is None or not options.get('errctl', False): + return + + langbb = kwargs['langbb'] + + definition = make_launch_macros(langbb) + if not definition: + return + + macro = [langbb['check-launch']] + + _check_launch(graph, definition=definition, macro=macro, **kwargs) + + +@iet_pass +def _check_launch(iet, definition=None, macro=None, **kwargs): + iterations = FindNodes(Iteration).visit(iet) + + mapper = {} + for i in iterations: + # Two stages of substitution to account for the edge case + # where a kernel is launched in multiple places within the + # generated code, once inside a loop, once outside + launch_mapper = {} + launches = FindNodes(KernelLaunch).visit(i) + + for launch in launches: + launch_mapper[launch] = List(body=[launch] + macro) + + if launch_mapper: + mapper[i] = Transformer(launch_mapper).visit(i) + + extras = {} + if mapper: + iet = Transformer(mapper).visit(iet) + extras.update({'headers': definition}) + + return iet, extras + + +def make_launch_macros(langbb): + """ + Define macros to check for errors to ensure graceful handling of failed kernel + launches. + """ + + # Will skip if there is no peek-error call or success code for the langbb + with contextlib.suppress(NotImplementedError): + peek = langbb['peek-error'] + success = langbb['error-none'] + return [('CHECK_LAUNCH', f'if ({peek().name}() != {success}) {{break;}}')] + + return [] + + class Retval(LocalObject, Expr): dtype = dtype_to_ctype(np.int32)