From e6f62150ee7815b2c72a37e5a69faff7df355c5a Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 21 Apr 2026 17:15:39 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- codegen/api/unboxing.py | 114 +++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/codegen/api/unboxing.py b/codegen/api/unboxing.py index 4e13246e5b1..8cb967497f0 100644 --- a/codegen/api/unboxing.py +++ b/codegen/api/unboxing.py @@ -13,7 +13,6 @@ Type, ) - if TYPE_CHECKING: from collections.abc import Sequence @@ -34,13 +33,19 @@ class Unboxing: Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. A sample generated code: // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - void mul_out(Span stack) { + void mul_out(KernelRuntimeContext& context, Span stack) { EValue& self = *stack[0]; EValue& other = *stack[1]; EValue& out = *stack[2]; - const torch::executor::Tensor & self_base = self.to(); - const torch::executor::Tensor & other_base = other.to(); - torch::executor::Tensor & out_base = out.to(); + auto self_base_res = self.tryTo(); + if (!self_base_res.ok()) { context.fail(self_base_res.error()); return; } + const torch::executor::Tensor & self_base = self_base_res.get(); + auto other_base_res = other.tryTo(); + if (!other_base_res.ok()) { context.fail(other_base_res.error()); return; } + const torch::executor::Tensor & other_base = other_base_res.get(); + auto out_base_res = out.tryTo(); + if (!out_base_res.ok()) { context.fail(out_base_res.error()); return; } + torch::executor::Tensor & out_base = out_base_res.get(); EXECUTORCH_SCOPE_PROF("native_call_mul.out"); torch::executor::mul_outf(self_base, other_base, out_base); @@ -115,8 +120,15 @@ def argumenttype_evalue_convert( def _gen_code_base_type( self, arg_name: str, out_name: str, ctype: CType ) -> tuple[list[str], list[str]]: + # Use the Result-returning tryTo() instead of to() so that a + # malformed PTE with a mismatched EValue tag returns an error to the + # caller via KernelRuntimeContext::fail() rather than aborting the + # process. + res_name = f"{out_name}_res" return [ - f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" + f"auto {res_name} = {arg_name}.tryTo<{ctype.cpp_type(strip_ref=True)}>();", + f"if (!{res_name}.ok()) {{ context.fail({res_name}.error()); return; }}", + f"{ctype.cpp_type()} {out_name} = {res_name}.get();", ], [] def _gen_code_optional_type( @@ -126,12 +138,14 @@ def _gen_code_optional_type( res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( t.elem, in_name ) + # Use tryToOptional() to propagate tag mismatches as errors. + opt_res_name = f"{out_name}_res" return ( f""" - auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); - """.split( - "\n" - ), + auto {opt_res_name} = {arg_name}.tryToOptional<{base_type.cpp_type(strip_ref=True)}>(); + if (!{opt_res_name}.ok()) {{ context.fail({opt_res_name}.error()); return; }} + auto {out_name} = std::move({opt_res_name}.get()); + """.split("\n"), decl, ) @@ -145,50 +159,48 @@ def _gen_code_list_type( t.elem, elem_name ) + # Each branch uses the Result-returning tryToXList() accessor and + # propagates errors via context.fail(); see _gen_code_base_type for + # the rationale. + res_name_list = f"{out_name}_res" if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: - code.extend( - f""" - auto {out_name} = {arg_name}.toTensorList(); - """.split( - "\n" - ) - ) + code.extend(f""" + auto {res_name_list} = {arg_name}.tryToTensorList(); + if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }} + auto {out_name} = {res_name_list}.get(); + """.split("\n")) elif isinstance(t.elem, BaseType) and ( t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt ): - code.extend( - f""" - auto {out_name} = {arg_name}.toIntList(); - """.split( - "\n" - ) - ) + code.extend(f""" + auto {res_name_list} = {arg_name}.tryToIntList(); + if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }} + auto {out_name} = {res_name_list}.get(); + """.split("\n")) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: - code.extend( - f""" - auto {out_name} = {arg_name}.toDoubleList(); - """.split( - "\n" - ) - ) + code.extend(f""" + auto {res_name_list} = {arg_name}.tryToDoubleList(); + if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }} + auto {out_name} = {res_name_list}.get(); + """.split("\n")) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: # handle list type with size, e.g., bool[4] - code.extend( - f""" + code.extend(f""" #ifdef USE_ATEN_LIB std::array {out_name}; -auto {in_name} = {arg_name}.toBoolList(); +auto {in_name}_res = {arg_name}.tryToBoolList(); +if (!{in_name}_res.ok()) {{ context.fail({in_name}_res.error()); return; }} +auto {in_name} = {in_name}_res.get(); size_t _i = 0; for (auto {elem_name}: {in_name}) {{ {out_name}[_i++] = {elem_name}; }} #else -auto {out_name} = {arg_name}.toBoolList(); +auto {res_name_list} = {arg_name}.tryToBoolList(); +if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }} +auto {out_name} = {res_name_list}.get(); #endif - """.split( - "\n" - ) - ) + """.split("\n")) # pytorch codegen: # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> elif ( @@ -196,21 +208,21 @@ def _gen_code_list_type( and isinstance(t.elem.elem, BaseType) and t.elem.elem.name == BaseTy.Tensor ): - code.extend( - f""" + code.extend(f""" #ifdef USE_ATEN_LIB -auto {in_name} = {arg_name}.toListOptionalTensor(); +auto {in_name}_res = {arg_name}.tryToListOptionalTensor(); +if (!{in_name}_res.ok()) {{ context.fail({in_name}_res.error()); return; }} +auto {in_name} = {in_name}_res.get(); c10::List<::std::optional> {out_name}; for (auto {elem_name}: {in_name}) {{ {out_name}.push_back({elem_name}); }} #else -auto {out_name} = {arg_name}.toListOptionalTensor(); +auto {res_name_list} = {arg_name}.tryToListOptionalTensor(); +if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }} +auto {out_name} = {res_name_list}.get(); #endif - """.split( - "\n" - ) - ) + """.split("\n")) else: # use ArrayRef as default. vec_name = arg_name + "_vec" @@ -218,15 +230,11 @@ def _gen_code_list_type( decl.append( f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" ) - code.extend( - f""" + code.extend(f""" for (EValue {elem_name}: {in_name}) {{ {connector.join(res_code)} {vec_name}.push_back({res_name}); }} {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); - """.split( - "\n" - ) - ) + """.split("\n")) return code, decl