diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index bb76311bd67..05d13f2d615 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -86,21 +86,22 @@ TensorPtr make_tensor_ptr( ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides."); if (!strides.empty()) { + bool is_contiguous = true; for (size_t i = 0; i < dim; i++) { - ET_CHECK_MSG( - strides[i] == computed_strides[i] || sizes[i] == 1, - "invalid strides for dim %zu: %" ET_PRI_SIZES_AND_STRIDES - "!= %" ET_PRI_SIZES_AND_STRIDES - " while its size is %" ET_PRI_SIZES_AND_STRIDES " != 1", - i, - strides[i], - computed_strides[i], - sizes[i]); + if (strides[i] != computed_strides[i] && sizes[i] != 1) { + is_contiguous = false; + break; + } } + if (is_contiguous) { + strides = std::move(computed_strides); + } + // else: keep the caller-provided non-contiguous strides (e.g. from + // reinterpret_tensor views like chunk/split). + } else { + strides = std::move(computed_strides); } - strides = std::move(computed_strides); - #ifndef USE_ATEN_LIB executorch::aten::TensorImpl tensor_impl( type, diff --git a/extension/tensor/test/tensor_ptr_maker_test.cpp b/extension/tensor/test/tensor_ptr_maker_test.cpp index 2781e7a58bb..a6239db739d 100644 --- a/extension/tensor/test/tensor_ptr_maker_test.cpp +++ b/extension/tensor/test/tensor_ptr_maker_test.cpp @@ -132,9 +132,21 @@ TEST_F(TensorPtrMakerTest, CreateTensorUsingFromBlobWithLegalStrides) { EXPECT_EQ(tensor->const_data_ptr()[0], 3); } -TEST_F(TensorPtrMakerTest, FailedCreateTensorUsingFromBlobWithIllegalStrides) { +TEST_F(TensorPtrMakerTest, CreateTensorUsingFromBlobWithNonContiguousStrides) { float data[20] = {3}; - ET_EXPECT_DEATH(from_blob(data, {2, 2, 2}, {10, 2, 1}), ""); + auto tensor = from_blob(data, {2, 2, 2}, {10, 2, 1}); + + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 2); + EXPECT_EQ(tensor->size(2), 2); + + // Non-contiguous strides are preserved (e.g. from reinterpret_tensor views). + EXPECT_EQ(tensor->strides()[0], 10); + EXPECT_EQ(tensor->strides()[1], 2); + EXPECT_EQ(tensor->strides()[2], 1); + EXPECT_EQ(tensor->const_data_ptr(), data); + EXPECT_EQ(tensor->const_data_ptr()[0], 3); } TEST_F(TensorPtrMakerTest, TensorMakerConversionOperator) { diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index b8e065481f6..eac274945e0 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -420,14 +420,19 @@ TEST_F(TensorPtrTest, MakeViewReuseMetadataWhenShapeSame) { EXPECT_EQ(view->strides()[1], 3); } -TEST_F(TensorPtrTest, MakeViewShapeChangeWithExplicitOldStridesExpectDeath) { +TEST_F(TensorPtrTest, MakeViewShapeChangeWithNonContiguousStrides) { float data[12] = {0}; auto tensor = make_tensor_ptr({3, 4}, data); std::vector old_strides( tensor->strides().begin(), tensor->strides().end()); - ET_EXPECT_DEATH( - { auto _ = make_tensor_ptr(tensor, {2, 6}, {}, old_strides); }, ""); + // Reshaping [3,4] to [2,6] with old strides [4,1] creates a non-contiguous + // view (stride[0]=4 != contiguous 6). This is allowed for reinterpret_tensor. + auto view = make_tensor_ptr(tensor, {2, 6}, {}, old_strides); + EXPECT_EQ(view->size(0), 2); + EXPECT_EQ(view->size(1), 6); + EXPECT_EQ(view->strides()[0], 4); + EXPECT_EQ(view->strides()[1], 1); } TEST_F(TensorPtrTest, MakeViewInvalidDimOrderExpectDeath) { @@ -967,8 +972,11 @@ TEST_F(TensorPtrTest, TensorDefaultDimOrderAndStrides) { TEST_F(TensorPtrTest, TensorMismatchStridesAndDimOrder) { float data[12] = {0}; - ET_EXPECT_DEATH( - { auto _ = make_tensor_ptr({3, 4}, data, {1, 0}, {1, 4}); }, ""); + // dim_order={1,0} implies strides={1,3}, but caller provides {1,4}. + // Non-contiguous strides are preserved for reinterpret_tensor views. + auto tensor = make_tensor_ptr({3, 4}, data, {1, 0}, {1, 4}); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->strides()[1], 4); } TEST_F(TensorPtrTest, TensorCustomDimOrderAndStrides) { diff --git a/extension/wasm/wasm_bindings.cpp b/extension/wasm/wasm_bindings.cpp index 38a227f9067..f21c69585d2 100644 --- a/extension/wasm/wasm_bindings.cpp +++ b/extension/wasm/wasm_bindings.cpp @@ -151,20 +151,19 @@ void assert_dim_order_and_strides_valid( THROW_IF_ERROR(error, "Failed to compute strides."); if (!strides.empty()) { + bool is_contiguous = true; for (size_t i = 0; i < sizes.size(); i++) { - THROW_IF_FALSE( - strides[i] == computed_strides[i] || sizes[i] == 1, - "invalid strides for dim %zu: %" ET_PRI_SIZES_AND_STRIDES - "!= %" ET_PRI_SIZES_AND_STRIDES - " while its size is %" ET_PRI_SIZES_AND_STRIDES " != 1", - i, - strides[i], - computed_strides[i], - sizes[i]); + if (strides[i] != computed_strides[i] && sizes[i] != 1) { + is_contiguous = false; + break; + } + } + if (is_contiguous) { + strides = std::move(computed_strides); } + } else { + strides = std::move(computed_strides); } - - strides = std::move(computed_strides); } /**