Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 142 additions & 32 deletions backends/apple/metal/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,84 @@ AOTITorchError aoti_torch_copy_(
return Error::Ok;
}

// Check if a strided view is densely packed (no holes in memory).
// A densely packed tensor's storage extent equals its numel.
static bool is_packed_strides(
const std::vector<aten::SizesType>& sizes,
const std::vector<aten::StridesType>& strides) {
int64_t ndim = static_cast<int64_t>(sizes.size());
if (ndim == 0)
return true;

// Compute numel
int64_t numel = 1;
for (int64_t i = 0; i < ndim; i++) {
numel *= sizes[i];
}
if (numel <= 1)
return true;

// Compute storage extent: max offset + 1
int64_t max_offset = 0;
for (int64_t i = 0; i < ndim; i++) {
if (sizes[i] > 1) {
max_offset += (sizes[i] - 1) * strides[i];
}
}
return (max_offset + 1) == numel;
}

// Materialize a non-packed strided view into a new contiguous Metal buffer.
// Copies elements from source using strided access. The caller must free the
// returned buffer. On failure returns nullptr.
static void* materialize_packed(
void* src,
const std::vector<aten::SizesType>& sizes,
const std::vector<aten::StridesType>& strides,
size_t element_size) {
int64_t ndim = static_cast<int64_t>(sizes.size());
int64_t numel = 1;
for (int64_t i = 0; i < ndim; i++) {
numel *= sizes[i];
}

void* dst = metal_allocate_buffer(numel * element_size);
if (!dst)
return nullptr;

// Ensure pending GPU writes to the source buffer are complete
if (metal_is_device_pointer(src)) {
auto* stream = getCurrentMetalStream();
if (stream) {
stream->synchronize(SyncType::COMMIT_AND_WAIT);
}
}

// Element-by-element strided copy
char* src_bytes = static_cast<char*>(src);
char* dst_bytes = static_cast<char*>(dst);
std::vector<int64_t> coord(ndim, 0);
for (int64_t flat = 0; flat < numel; flat++) {
// Compute source offset from strides
int64_t src_offset = 0;
for (int64_t d = 0; d < ndim; d++) {
src_offset += coord[d] * strides[d];
}
std::memcpy(
dst_bytes + flat * element_size,
src_bytes + src_offset * element_size,
element_size);

// Increment coordinate (last dim fastest)
for (int64_t d = ndim - 1; d >= 0; d--) {
if (++coord[d] < sizes[d])
break;
coord[d] = 0;
}
}
return dst;
}

AOTITorchError aoti_torch__reinterpret_tensor(
AOTITensorHandle self,
int64_t ndim,
Expand All @@ -377,6 +455,12 @@ AOTITorchError aoti_torch__reinterpret_tensor(
ET_LOG(Debug, "aoti_torch__reinterpret_tensor: entered");

// Validate input parameters first
ET_CHECK_OR_RETURN_ERROR(
ndim >= 0,
InvalidArgument,
"aoti_torch__reinterpret_tensor failed: ndim must be >= 0, got %lld",
ndim);

ET_CHECK_OR_RETURN_ERROR(
self != nullptr,
InvalidArgument,
Expand Down Expand Up @@ -430,8 +514,9 @@ AOTITorchError aoti_torch__reinterpret_tensor(
data_ptr);

// Handle storage offset by adjusting the data pointer
void* adjusted_data = static_cast<char*>(data_ptr) +
(storage_offset * dtype_to_element_size(dtype));
size_t element_size = dtype_to_element_size(dtype);
void* adjusted_data =
static_cast<char*>(data_ptr) + (storage_offset * element_size);

// Convert sizes using utility function from utils.h
std::vector<aten::SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
Expand All @@ -440,14 +525,35 @@ AOTITorchError aoti_torch__reinterpret_tensor(
std::vector<aten::StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create new tensor view that reinterprets the same memory with different
// shape/strides This creates a view, not a copy - the data pointer is shared
// If the view is not densely packed (e.g. chunk/split creating holes),
// materialize it into a new contiguous buffer.
void* tensor_data = adjusted_data;
bool owns_buffer = false;
if (!is_packed_strides(sizes, strides)) {
ET_LOG(
Debug,
"aoti_torch__reinterpret_tensor: non-packed strides, "
"materializing to packed buffer");
Comment thread
manuelcandales marked this conversation as resolved.
tensor_data =
materialize_packed(adjusted_data, sizes, strides, element_size);
ET_CHECK_OR_RETURN_ERROR(
tensor_data != nullptr,
MemoryAllocationFailed,
"Failed to materialize non-packed tensor");
owns_buffer = true;

// Compute contiguous strides for the packed buffer
strides.resize(ndim);
if (ndim > 0) {
strides[ndim - 1] = 1;
for (int64_t i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * sizes[i + 1];
}
}
Comment thread
manuelcandales marked this conversation as resolved.
}

std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
adjusted_data, // Use adjusted data pointer with storage offset applied
sizes, // New sizes with explicit SizesType
strides, // New strides with explicit StridesType
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
);
tensor_data, sizes, strides, dtype_to_scalar_type(dtype));

ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr,
Expand All @@ -456,32 +562,36 @@ AOTITorchError aoti_torch__reinterpret_tensor(

// Store the tensor so it doesn't get destroyed
tensors[tensor.get()] = tensor;

*ret_new_tensor = tensor.get();

if (adjusted_data != data_ptr) {
ET_LOG(
Debug,
"aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p",
data_ptr,
storage_offset,
dtype_to_element_size(dtype),
adjusted_data);

ET_CHECK_OR_RETURN_ERROR(
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true),
Internal,
"metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu",
adjusted_data,
static_cast<size_t>(tensor->nbytes()));

memory_to_n_tensor[adjusted_data] = NOT_OWN;
}
if (owns_buffer) {
// The materialized buffer is a new allocation owned by this tensor
memory_to_n_tensor[tensor_data] = 1;
} else {
if (adjusted_data != data_ptr) {
ET_LOG(
Debug,
"aoti_torch__reinterpret_tensor: Adjusted original_data=%p, "
"storage_offset=%lld, element_size=%zu, adjusted_data=%p",
data_ptr,
storage_offset,
element_size,
adjusted_data);

ET_CHECK_OR_RETURN_ERROR(
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true),
Internal,
"metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu",
adjusted_data,
static_cast<size_t>(tensor->nbytes()));

memory_to_n_tensor[adjusted_data] = NOT_OWN;
}

// Increment the reference count for this memory address only if it is owned
// by tensor
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
memory_to_n_tensor[data_ptr] += 1;
// Increment the reference count for this memory address only if it is owned
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
memory_to_n_tensor[data_ptr] += 1;
}
}

ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successful");
Expand Down
24 changes: 24 additions & 0 deletions backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,30 @@ def forward(
}


# -------------------------------------------------------------------------
# Narrow (non-packed reinterpret_tensor materialization)
# -------------------------------------------------------------------------


class NarrowLastDim(nn.Module):
"""Splits the last dimension into two halves via narrow, producing
non-packed strided views that the Metal backend must materialize
into contiguous buffers."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
half = x.shape[-1] // 2
a = x.narrow(-1, 0, half)
b = x.narrow(-1, half, half)
return a * 2.0 + b


MODULE_REGISTRY["narrow_last_dim"] = {
"model_class": NarrowLastDim,
"input_shapes": [(2, 4, 16)],
"description": "Non-packed reinterpret_tensor views from last-dim split",
}


# -------------------------------------------------------------------------
# Top-k (MoE expert routing)
# -------------------------------------------------------------------------
Expand Down
Loading