From f1b8f7f3bae3278d206d7ad7e2e574116718975d Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 24 Apr 2026 16:58:30 +0000 Subject: [PATCH 1/2] compiler: move launch check injection to later in compilation pipeline --- devito/passes/iet/errors.py | 67 +++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index d4ac012d31..460d8e7f18 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -1,17 +1,19 @@ +import contextlib + import cgen as c import numpy as np from sympy import Expr, Not, S from devito.ir.iet import ( Break, Call, Conditional, DummyExpr, EntryFunction, FindNodes, FindSymbols, Iteration, - List, Return, Transformer, make_callable + List, Return, Transformer, KernelLaunch, make_callable, retrieve_iteration_tree ) from devito.passes.iet.engine import iet_pass from devito.symbolics import CondEq, MathFunction from devito.tools import dtype_to_ctype from devito.types import Eq, Inc, LocalObject, Symbol -__all__ = ['check_stability', 'error_mapper'] +__all__ = ['check_stability', 'check_launch', 'error_mapper'] def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs): @@ -100,6 +102,67 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): return iet, {'efuncs': efuncs, 'includes': includes} +def check_launch(graph, options={}, **kwargs): + """ + Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of + failed kernel launches. This macro should only be inserted if the kernel is + directly within a loop, as compilation will fail otherwise. + """ + if not options.get('errctl', False): + return + + langbb = kwargs['langbb'] + + definition = make_launch_macros(langbb) + if not definition: + return + + macro = [langbb['check-launch']] + + _check_launch(graph, definition=definition, macro=macro, **kwargs) + + +@iet_pass +def _check_launch(iet, definition=None, macro=None, **kwargs): + iterations = FindNodes(Iteration).visit(iet) + + mapper = {} + for i in iterations: + # Two stages of substitution to account for the edge case + # where a kernel is launched in multiple places within the + # generated code, once inside a loop, once outside + launch_mapper = {} + launches = FindNodes(KernelLaunch).visit(i) + + for launch in launches: + launch_mapper[launch] = List(body=[launch] + macro) + + if launch_mapper: + mapper[i] = Transformer(launch_mapper).visit(i) + + extras = {} + if mapper: + iet = Transformer(mapper).visit(iet) + extras.update({'headers': definition}) + + return iet, extras + + +def make_launch_macros(langbb): + """ + Define macros to check for errors to ensure graceful handling of failed kernel + launches. + """ + + # Will skip if there is no peek-error call or success code for the langbb + with contextlib.suppress(NotImplementedError): + peek = langbb['peek-error'] + success = langbb['error-none'] + return [('CHECK_LAUNCH', f'if ({peek().name}() != {success}) {{break;}}')] + + return [] + + class Retval(LocalObject, Expr): dtype = dtype_to_ctype(np.int32) From 85603ead9ff4250af44951c37f74fb595c3bf50a Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 27 Apr 2026 10:54:28 +0100 Subject: [PATCH 2/2] misc: Linting --- devito/passes/iet/errors.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index 460d8e7f18..7c6cba71a4 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -6,14 +6,14 @@ from devito.ir.iet import ( Break, Call, Conditional, DummyExpr, EntryFunction, FindNodes, FindSymbols, Iteration, - List, Return, Transformer, KernelLaunch, make_callable, retrieve_iteration_tree + KernelLaunch, List, Return, Transformer, make_callable ) from devito.passes.iet.engine import iet_pass from devito.symbolics import CondEq, MathFunction from devito.tools import dtype_to_ctype from devito.types import Eq, Inc, LocalObject, Symbol -__all__ = ['check_stability', 'check_launch', 'error_mapper'] +__all__ = ['check_launch', 'check_stability', 'error_mapper'] def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs): @@ -102,15 +102,15 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): return iet, {'efuncs': efuncs, 'includes': includes} -def check_launch(graph, options={}, **kwargs): +def check_launch(graph, options=None, **kwargs): """ Insert the CHECK_LAUNCH macro if errctl is set to ensure graceful handling of failed kernel launches. This macro should only be inserted if the kernel is directly within a loop, as compilation will fail otherwise. """ - if not options.get('errctl', False): + if options is None or not options.get('errctl', False): return - + langbb = kwargs['langbb'] definition = make_launch_macros(langbb)