diff --git a/codegen/api/unboxing.py b/codegen/api/unboxing.py index 4e13246e5b1..ce952b76b7b 100644 --- a/codegen/api/unboxing.py +++ b/codegen/api/unboxing.py @@ -13,7 +13,6 @@ Type, ) - if TYPE_CHECKING: from collections.abc import Sequence @@ -32,20 +31,22 @@ def name(f: NativeFunction) -> str: class Unboxing: """ Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. - A sample generated code: + A sample generated code (abbreviated to one arg for readability): // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - void mul_out(Span stack) { + void mul_out(KernelRuntimeContext& context, Span stack) { EValue& self = *stack[0]; - EValue& other = *stack[1]; - EValue& out = *stack[2]; - const torch::executor::Tensor & self_base = self.to(); - const torch::executor::Tensor & other_base = other.to(); - torch::executor::Tensor & out_base = out.to(); - + // ... other args ... + auto self_base_res = self.tryTo(); + if (!self_base_res.ok()) { + ::executorch::runtime::internal::kernel_arg_fail( + context, self_base_res.error(), __func__, "self", + static_cast(self.tag)); + return; + } + const torch::executor::Tensor& self_base = self_base_res.get(); + // ... other unpacks ... EXECUTORCH_SCOPE_PROF("native_call_mul.out"); torch::executor::mul_outf(self_base, other_base, out_base); - - } """ @@ -115,8 +116,19 @@ def argumenttype_evalue_convert( def _gen_code_base_type( self, arg_name: str, out_name: str, ctype: CType ) -> tuple[list[str], list[str]]: + # Use tryTo() with a shared cold fail helper so every wrapper + # logs a consistent diagnostic and propagates the error via + # KernelRuntimeContext::fail() rather than aborting. + res_name = f"{out_name}_res" return [ - f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" + f"auto {res_name} = {arg_name}.tryTo<{ctype.cpp_type(strip_ref=True)}>();", + f"if (!{res_name}.ok()) {{", + " ::executorch::runtime::internal::kernel_arg_fail(", + f' context, {res_name}.error(), __func__, "{arg_name}",', + f" static_cast({arg_name}.tag));", + " return;", + "}", + f"{ctype.cpp_type()} {out_name} = {res_name}.get();", ], [] def _gen_code_optional_type( @@ -126,12 +138,20 @@ def _gen_code_optional_type( res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( t.elem, in_name ) + # Use tryToOptional() with the shared fail helper (see + # _gen_code_base_type). + opt_res_name = f"{out_name}_res" return ( f""" - auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); - """.split( - "\n" - ), + auto {opt_res_name} = {arg_name}.tryToOptional<{base_type.cpp_type(strip_ref=True)}>(); + if (!{opt_res_name}.ok()) {{ + ::executorch::runtime::internal::kernel_arg_fail( + context, {opt_res_name}.error(), __func__, "{arg_name}", + static_cast({arg_name}.tag)); + return; + }} + auto {out_name} = std::move({opt_res_name}.get()); + """.split("\n"), decl, ) @@ -145,50 +165,60 @@ def _gen_code_list_type( t.elem, elem_name ) - if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: - code.extend( - f""" - auto {out_name} = {arg_name}.toTensorList(); - """.split( - "\n" - ) + # Each branch uses the Result-returning tryToXList() accessor and + # routes errors through the shared kernel_arg_fail helper; see + # _gen_code_base_type for the rationale. + res_name_list = f"{out_name}_res" + + def _fail_block(res: str) -> str: + # Cold fail path: log + context.fail() via the shared helper. + return ( + f"if (!{res}.ok()) {{\n" + f" ::executorch::runtime::internal::kernel_arg_fail(\n" + f' context, {res}.error(), __func__, "{arg_name}",\n' + f" static_cast({arg_name}.tag));\n" + f" return;\n" + f" }}" ) + + if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: + code.extend(f""" + auto {res_name_list} = {arg_name}.tryToTensorList(); + {_fail_block(res_name_list)} + auto {out_name} = {res_name_list}.get(); + """.split("\n")) elif isinstance(t.elem, BaseType) and ( t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt ): - code.extend( - f""" - auto {out_name} = {arg_name}.toIntList(); - """.split( - "\n" - ) - ) + code.extend(f""" + auto {res_name_list} = {arg_name}.tryToIntList(); + {_fail_block(res_name_list)} + auto {out_name} = {res_name_list}.get(); + """.split("\n")) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: - code.extend( - f""" - auto {out_name} = {arg_name}.toDoubleList(); - """.split( - "\n" - ) - ) + code.extend(f""" + auto {res_name_list} = {arg_name}.tryToDoubleList(); + {_fail_block(res_name_list)} + auto {out_name} = {res_name_list}.get(); + """.split("\n")) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: # handle list type with size, e.g., bool[4] - code.extend( - f""" + code.extend(f""" #ifdef USE_ATEN_LIB std::array {out_name}; -auto {in_name} = {arg_name}.toBoolList(); +auto {in_name}_res = {arg_name}.tryToBoolList(); +{_fail_block(in_name + "_res")} +auto {in_name} = {in_name}_res.get(); size_t _i = 0; for (auto {elem_name}: {in_name}) {{ {out_name}[_i++] = {elem_name}; }} #else -auto {out_name} = {arg_name}.toBoolList(); +auto {res_name_list} = {arg_name}.tryToBoolList(); +{_fail_block(res_name_list)} +auto {out_name} = {res_name_list}.get(); #endif - """.split( - "\n" - ) - ) + """.split("\n")) # pytorch codegen: # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> elif ( @@ -196,21 +226,21 @@ def _gen_code_list_type( and isinstance(t.elem.elem, BaseType) and t.elem.elem.name == BaseTy.Tensor ): - code.extend( - f""" + code.extend(f""" #ifdef USE_ATEN_LIB -auto {in_name} = {arg_name}.toListOptionalTensor(); +auto {in_name}_res = {arg_name}.tryToListOptionalTensor(); +{_fail_block(in_name + "_res")} +auto {in_name} = {in_name}_res.get(); c10::List<::std::optional> {out_name}; for (auto {elem_name}: {in_name}) {{ {out_name}.push_back({elem_name}); }} #else -auto {out_name} = {arg_name}.toListOptionalTensor(); +auto {res_name_list} = {arg_name}.tryToListOptionalTensor(); +{_fail_block(res_name_list)} +auto {out_name} = {res_name_list}.get(); #endif - """.split( - "\n" - ) - ) + """.split("\n")) else: # use ArrayRef as default. vec_name = arg_name + "_vec" @@ -218,15 +248,11 @@ def _gen_code_list_type( decl.append( f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" ) - code.extend( - f""" + code.extend(f""" for (EValue {elem_name}: {in_name}) {{ {connector.join(res_code)} {vec_name}.push_back({res_name}); }} {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); - """.split( - "\n" - ) - ) + """.split("\n")) return code, decl diff --git a/runtime/kernel/kernel_runtime_context.h b/runtime/kernel/kernel_runtime_context.h index 6facecc7632..6f71892cb13 100644 --- a/runtime/kernel/kernel_runtime_context.h +++ b/runtime/kernel/kernel_runtime_context.h @@ -8,11 +8,14 @@ #pragma once +#include + #include #include #include #include #include +#include namespace executorch { namespace ET_RUNTIME_NAMESPACE { @@ -107,6 +110,31 @@ class KernelRuntimeContext { Error failure_state_ = Error::Ok; }; +namespace internal { + +// Cold path for codegen-emitted boxed kernel wrappers. Logs a diagnostic and +// sets the kernel's failure state when an EValue arg unpack fails. The +// wrapper must still `return` after calling this. +#if defined(__GNUC__) || defined(__clang__) +[[gnu::cold]] +#endif +inline void kernel_arg_fail( + KernelRuntimeContext& context, + Error error, + const char* kernel_name, + const char* arg_name, + uint8_t actual_tag) { + ET_LOG( + Error, + "%s: arg '%s' has unexpected EValue tag %u", + kernel_name, + arg_name, + static_cast(actual_tag)); + context.fail(error); +} + +} // namespace internal + } // namespace ET_RUNTIME_NAMESPACE } // namespace executorch