From cf71b8eaca8928463d84c1058957482e3658612a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 21 Apr 2026 15:42:08 -0400 Subject: [PATCH] Metal backend: Materialize non-packed tensor views in reinterpret_tensor AOTI generates reinterpret_tensor views with non-packed strides (e.g. chunk/split for RoPE rotation) that have holes in memory. ExecuTorch's make_tensor_ptr requires densely packed layouts. When aoti_torch__reinterpret_tensor encounters non-packed strides, allocate a new contiguous Metal buffer and copy elements using strided access from the source. Authored with Claude. --- backends/apple/metal/runtime/shims/memory.cpp | 174 ++++++++++++++---- backends/apple/metal/tests/test_modules.py | 24 +++ 2 files changed, 166 insertions(+), 32 deletions(-) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 1a466e2f8e4..bd663dd891d 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -367,6 +367,84 @@ AOTITorchError aoti_torch_copy_( return Error::Ok; } +// Check if a strided view is densely packed (no holes in memory). +// A densely packed tensor's storage extent equals its numel. +static bool is_packed_strides( + const std::vector& sizes, + const std::vector& strides) { + int64_t ndim = static_cast(sizes.size()); + if (ndim == 0) + return true; + + // Compute numel + int64_t numel = 1; + for (int64_t i = 0; i < ndim; i++) { + numel *= sizes[i]; + } + if (numel <= 1) + return true; + + // Compute storage extent: max offset + 1 + int64_t max_offset = 0; + for (int64_t i = 0; i < ndim; i++) { + if (sizes[i] > 1) { + max_offset += (sizes[i] - 1) * strides[i]; + } + } + return (max_offset + 1) == numel; +} + +// Materialize a non-packed strided view into a new contiguous Metal buffer. +// Copies elements from source using strided access. The caller must free the +// returned buffer. On failure returns nullptr. +static void* materialize_packed( + void* src, + const std::vector& sizes, + const std::vector& strides, + size_t element_size) { + int64_t ndim = static_cast(sizes.size()); + int64_t numel = 1; + for (int64_t i = 0; i < ndim; i++) { + numel *= sizes[i]; + } + + void* dst = metal_allocate_buffer(numel * element_size); + if (!dst) + return nullptr; + + // Ensure pending GPU writes to the source buffer are complete + if (metal_is_device_pointer(src)) { + auto* stream = getCurrentMetalStream(); + if (stream) { + stream->synchronize(SyncType::COMMIT_AND_WAIT); + } + } + + // Element-by-element strided copy + char* src_bytes = static_cast(src); + char* dst_bytes = static_cast(dst); + std::vector coord(ndim, 0); + for (int64_t flat = 0; flat < numel; flat++) { + // Compute source offset from strides + int64_t src_offset = 0; + for (int64_t d = 0; d < ndim; d++) { + src_offset += coord[d] * strides[d]; + } + std::memcpy( + dst_bytes + flat * element_size, + src_bytes + src_offset * element_size, + element_size); + + // Increment coordinate (last dim fastest) + for (int64_t d = ndim - 1; d >= 0; d--) { + if (++coord[d] < sizes[d]) + break; + coord[d] = 0; + } + } + return dst; +} + AOTITorchError aoti_torch__reinterpret_tensor( AOTITensorHandle self, int64_t ndim, @@ -377,6 +455,12 @@ AOTITorchError aoti_torch__reinterpret_tensor( ET_LOG(Debug, "aoti_torch__reinterpret_tensor: entered"); // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + ndim >= 0, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: ndim must be >= 0, got %lld", + ndim); + ET_CHECK_OR_RETURN_ERROR( self != nullptr, InvalidArgument, @@ -430,8 +514,9 @@ AOTITorchError aoti_torch__reinterpret_tensor( data_ptr); // Handle storage offset by adjusting the data pointer - void* adjusted_data = static_cast(data_ptr) + - (storage_offset * dtype_to_element_size(dtype)); + size_t element_size = dtype_to_element_size(dtype); + void* adjusted_data = + static_cast(data_ptr) + (storage_offset * element_size); // Convert sizes using utility function from utils.h std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); @@ -440,14 +525,35 @@ AOTITorchError aoti_torch__reinterpret_tensor( std::vector strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); - // Create new tensor view that reinterprets the same memory with different - // shape/strides This creates a view, not a copy - the data pointer is shared + // If the view is not densely packed (e.g. chunk/split creating holes), + // materialize it into a new contiguous buffer. + void* tensor_data = adjusted_data; + bool owns_buffer = false; + if (!is_packed_strides(sizes, strides)) { + ET_LOG( + Debug, + "aoti_torch__reinterpret_tensor: non-packed strides, " + "materializing to packed buffer"); + tensor_data = + materialize_packed(adjusted_data, sizes, strides, element_size); + ET_CHECK_OR_RETURN_ERROR( + tensor_data != nullptr, + MemoryAllocationFailed, + "Failed to materialize non-packed tensor"); + owns_buffer = true; + + // Compute contiguous strides for the packed buffer + strides.resize(ndim); + if (ndim > 0) { + strides[ndim - 1] = 1; + for (int64_t i = ndim - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + } + std::shared_ptr tensor = executorch::extension::from_blob( - adjusted_data, // Use adjusted data pointer with storage offset applied - sizes, // New sizes with explicit SizesType - strides, // New strides with explicit StridesType - dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting - ); + tensor_data, sizes, strides, dtype_to_scalar_type(dtype)); ET_CHECK_OR_RETURN_ERROR( tensor != nullptr, @@ -456,32 +562,36 @@ AOTITorchError aoti_torch__reinterpret_tensor( // Store the tensor so it doesn't get destroyed tensors[tensor.get()] = tensor; - *ret_new_tensor = tensor.get(); - if (adjusted_data != data_ptr) { - ET_LOG( - Debug, - "aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", - data_ptr, - storage_offset, - dtype_to_element_size(dtype), - adjusted_data); - - ET_CHECK_OR_RETURN_ERROR( - metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true), - Internal, - "metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu", - adjusted_data, - static_cast(tensor->nbytes())); - - memory_to_n_tensor[adjusted_data] = NOT_OWN; - } + if (owns_buffer) { + // The materialized buffer is a new allocation owned by this tensor + memory_to_n_tensor[tensor_data] = 1; + } else { + if (adjusted_data != data_ptr) { + ET_LOG( + Debug, + "aoti_torch__reinterpret_tensor: Adjusted original_data=%p, " + "storage_offset=%lld, element_size=%zu, adjusted_data=%p", + data_ptr, + storage_offset, + element_size, + adjusted_data); + + ET_CHECK_OR_RETURN_ERROR( + metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true), + Internal, + "metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu", + adjusted_data, + static_cast(tensor->nbytes())); + + memory_to_n_tensor[adjusted_data] = NOT_OWN; + } - // Increment the reference count for this memory address only if it is owned - // by tensor - if (memory_to_n_tensor[data_ptr] != NOT_OWN) { - memory_to_n_tensor[data_ptr] += 1; + // Increment the reference count for this memory address only if it is owned + if (memory_to_n_tensor[data_ptr] != NOT_OWN) { + memory_to_n_tensor[data_ptr] += 1; + } } ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successful"); diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index d3af09eb39e..00456aad31e 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -664,6 +664,30 @@ def forward( } +# ------------------------------------------------------------------------- +# Narrow (non-packed reinterpret_tensor materialization) +# ------------------------------------------------------------------------- + + +class NarrowLastDim(nn.Module): + """Splits the last dimension into two halves via narrow, producing + non-packed strided views that the Metal backend must materialize + into contiguous buffers.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + half = x.shape[-1] // 2 + a = x.narrow(-1, 0, half) + b = x.narrow(-1, half, half) + return a * 2.0 + b + + +MODULE_REGISTRY["narrow_last_dim"] = { + "model_class": NarrowLastDim, + "input_shapes": [(2, 4, 16)], + "description": "Non-packed reinterpret_tensor views from last-dim split", +} + + # ------------------------------------------------------------------------- # Top-k (MoE expert routing) # -------------------------------------------------------------------------