Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 85 additions & 59 deletions codegen/api/unboxing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

from dataclasses import dataclass
Expand All @@ -13,7 +13,6 @@
Type,
)


if TYPE_CHECKING:
from collections.abc import Sequence

Expand All @@ -32,20 +31,22 @@
class Unboxing:
"""
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
A sample generated code:
A sample generated code (abbreviated to one arg for readability):
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
void mul_out(Span<EValue*> stack) {
void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
EValue& self = *stack[0];
EValue& other = *stack[1];
EValue& out = *stack[2];
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();

// ... other args ...
auto self_base_res = self.tryTo<torch::executor::Tensor>();
if (!self_base_res.ok()) {
::executorch::runtime::internal::kernel_arg_fail(
context, self_base_res.error(), __func__, "self",
static_cast<uint8_t>(self.tag));
return;
}
const torch::executor::Tensor& self_base = self_base_res.get();
// ... other unpacks ...
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
torch::executor::mul_outf(self_base, other_base, out_base);


}
"""

Expand Down Expand Up @@ -115,8 +116,19 @@
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
# Use tryTo<T>() with a shared cold fail helper so every wrapper
# logs a consistent diagnostic and propagates the error via
# KernelRuntimeContext::fail() rather than aborting.
res_name = f"{out_name}_res"
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
f"auto {res_name} = {arg_name}.tryTo<{ctype.cpp_type(strip_ref=True)}>();",
f"if (!{res_name}.ok()) {{",
" ::executorch::runtime::internal::kernel_arg_fail(",
f' context, {res_name}.error(), __func__, "{arg_name}",',
f" static_cast<uint8_t>({arg_name}.tag));",
" return;",
"}",
f"{ctype.cpp_type()} {out_name} = {res_name}.get();",
], []

def _gen_code_optional_type(
Expand All @@ -126,12 +138,20 @@
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
)
# Use tryToOptional<T>() with the shared fail helper (see
# _gen_code_base_type).
opt_res_name = f"{out_name}_res"
return (
f"""
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
auto {opt_res_name} = {arg_name}.tryToOptional<{base_type.cpp_type(strip_ref=True)}>();
if (!{opt_res_name}.ok()) {{
::executorch::runtime::internal::kernel_arg_fail(
context, {opt_res_name}.error(), __func__, "{arg_name}",
static_cast<uint8_t>({arg_name}.tag));
return;
}}
auto {out_name} = std::move({opt_res_name}.get());
""".split("\n"),
decl,
)

Expand All @@ -145,88 +165,94 @@
t.elem, elem_name
)

if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(
f"""
auto {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
# Each branch uses the Result-returning tryToXList() accessor and
# routes errors through the shared kernel_arg_fail helper; see
# _gen_code_base_type for the rationale.
res_name_list = f"{out_name}_res"

def _fail_block(res: str) -> str:
# Cold fail path: log + context.fail() via the shared helper.
return (
f"if (!{res}.ok()) {{\n"
f" ::executorch::runtime::internal::kernel_arg_fail(\n"
f' context, {res}.error(), __func__, "{arg_name}",\n'
f" static_cast<uint8_t>({arg_name}.tag));\n"
f" return;\n"
f" }}"
)

if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(f"""
auto {res_name_list} = {arg_name}.tryToTensorList();
{_fail_block(res_name_list)}
auto {out_name} = {res_name_list}.get();
""".split("\n"))
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
):
code.extend(
f"""
auto {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
)
code.extend(f"""
auto {res_name_list} = {arg_name}.tryToIntList();
{_fail_block(res_name_list)}
auto {out_name} = {res_name_list}.get();
""".split("\n"))
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
auto {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
)
code.extend(f"""
auto {res_name_list} = {arg_name}.tryToDoubleList();
{_fail_block(res_name_list)}
auto {out_name} = {res_name_list}.get();
""".split("\n"))
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
code.extend(
f"""
code.extend(f"""
#ifdef USE_ATEN_LIB
std::array<bool, {t.size}> {out_name};
auto {in_name} = {arg_name}.toBoolList();
auto {in_name}_res = {arg_name}.tryToBoolList();
{_fail_block(in_name + "_res")}
auto {in_name} = {in_name}_res.get();
size_t _i = 0;
for (auto {elem_name}: {in_name}) {{
{out_name}[_i++] = {elem_name};
}}
#else
auto {out_name} = {arg_name}.toBoolList();
auto {res_name_list} = {arg_name}.tryToBoolList();
{_fail_block(res_name_list)}
auto {out_name} = {res_name_list}.get();
#endif
""".split(
"\n"
)
)
""".split("\n"))
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
and t.elem.elem.name == BaseTy.Tensor
):
code.extend(
f"""
code.extend(f"""
#ifdef USE_ATEN_LIB
auto {in_name} = {arg_name}.toListOptionalTensor();
auto {in_name}_res = {arg_name}.tryToListOptionalTensor();
{_fail_block(in_name + "_res")}
auto {in_name} = {in_name}_res.get();
c10::List<::std::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}
#else
auto {out_name} = {arg_name}.toListOptionalTensor();
auto {res_name_list} = {arg_name}.tryToListOptionalTensor();
{_fail_block(res_name_list)}
auto {out_name} = {res_name_list}.get();
#endif
""".split(
"\n"
)
)
""".split("\n"))
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
)
code.extend(
f"""
code.extend(f"""
for (EValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
""".split("\n"))
return code, decl
28 changes: 28 additions & 0 deletions runtime/kernel/kernel_runtime_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@

#pragma once

#include <cstdint>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/core/memory_allocator.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/compiler.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace ET_RUNTIME_NAMESPACE {
Expand Down Expand Up @@ -107,6 +110,31 @@ class KernelRuntimeContext {
Error failure_state_ = Error::Ok;
};

namespace internal {

// Cold path for codegen-emitted boxed kernel wrappers. Logs a diagnostic and
// sets the kernel's failure state when an EValue arg unpack fails. The
// wrapper must still `return` after calling this.
#if defined(__GNUC__) || defined(__clang__)
[[gnu::cold]]
#endif
inline void kernel_arg_fail(
KernelRuntimeContext& context,
Error error,
const char* kernel_name,
const char* arg_name,
uint8_t actual_tag) {
ET_LOG(
Error,
"%s: arg '%s' has unexpected EValue tag %u",
kernel_name,
arg_name,
static_cast<unsigned>(actual_tag));
context.fail(error);
}

} // namespace internal

} // namespace ET_RUNTIME_NAMESPACE
} // namespace executorch

Expand Down
Loading