diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 0494b4969a2..352d7af5a14 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -319,24 +319,9 @@ Error defineTensor( ET_CHECK_OR_RETURN_ERROR( tensor_value != nullptr, InvalidProgram, "Deserialized tensor is null"); - ET_CHECK_OR_RETURN_ERROR( - tensor_value->num_dims() == 0 || tensor_value->dims() != nullptr, - InvalidProgram, - "Tensor dims is null but num_dims is %u", - tensor_value->num_dims()); - - if (tensor_value->dims() != nullptr) { - ET_CHECK_OR_RETURN_ERROR( - tensor_value->num_dims() == tensor_value->dims()->size(), - InvalidProgram, - "Tensor num_dims %u does not match dims array size %u", - tensor_value->num_dims(), - tensor_value->dims()->size()); - } - - // Get tensor dims, here we need to use a vector in order - // to properly convert the uint32_t* to size_t*. For scalar tensors - // (num_dims == 0), dims() is permitted to be null per the check above. + // Get tensor dims, here we need to use a vector in order to properly + // convert the uint32_t* to size_t*. Scalar tensors (rank 0) are permitted + // to have a null dims vector; in that case dims_data is empty. std::vector dims_data; if (tensor_value->dims() != nullptr) { dims_data = flatbufferDimsToVector(tensor_value->dims()); @@ -386,7 +371,7 @@ Error defineTensor( status = xnn_define_tensor_value( /*subgraph=*/subgraph_ptr, /*datatype=*/getDataType(tensor_value->datatype()), - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*dims=*/dims_data.data(), /*data=*/buffer_ptr, /*external_id=*/tensor_value->external_id(), @@ -421,7 +406,7 @@ Error defineTensor( status = xnn_define_dynamically_quantized_tensor_value( /*subgraph=*/subgraph_ptr, /*datatype=*/xnn_datatype_qdint8, - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*num_nonbatch_dims=*/1, // always do per token quantization /*dims=*/dims_data.data(), /*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id @@ -435,7 +420,7 @@ Error defineTensor( status = xnn_define_tensor_value( /*subgraph=*/subgraph_ptr, /*datatype=*/fp_datatype, - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*dims=*/dims_data.data(), /*data=*/buffer_ptr, /*external_id=*/tensor_value->external_id(), @@ -476,7 +461,7 @@ Error defineTensor( /*datatype=*/getDataType(tensor_value->datatype()), /*zero_point=*/qparams->zero_point(), /*scale=*/qparams->scale(), - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*dims=*/dims_data.data(), /*data=*/buffer_ptr, /*external_id=*/tensor_value->external_id(), @@ -521,7 +506,7 @@ Error defineTensor( /*datatype=*/dtype, /*zero_point=*/zero_point, /*scale=*/scale, - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*channel_dim*/ qparams->channel_dim(), /*dims=*/dims_data.data(), /*data=*/buffer_ptr, @@ -599,7 +584,7 @@ Error defineTensor( /*datatype=*/datatype, /*zero_point=*/zero_point, /*scale=*/scale_data, - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*channel_dim=*/qparams->channel_dim(), /*block_size=*/qparams->group_size(), /*dims=*/dims_data.data(), @@ -613,8 +598,8 @@ Error defineTensor( auto qparams = qtensor_value->quant_params_as_PerTokenDynamicQuant(); ET_LOG( Debug, - "define quant tensor (dynamic): num_dims: %i, num_nonbatch_dims: %i\n", - tensor_value->num_dims(), + "define quant tensor (dynamic): num_dims: %zu, num_nonbatch_dims: %i\n", + dims_data.size(), qparams->num_nonbatch_dims()); ET_CHECK_OR_RETURN_ERROR( buffer_ptr == nullptr, @@ -623,7 +608,7 @@ Error defineTensor( status = xnn_define_dynamically_quantized_tensor_value( /*subgraph=*/subgraph_ptr, /*datatype=*/getDataType(tensor_value->datatype()), - /*num_dims=*/tensor_value->num_dims(), + /*num_dims=*/dims_data.size(), /*num_nonbatch_dims*/ qparams->num_nonbatch_dims(), /*dims=*/dims_data.data(), /*external_id=*/tensor_value->external_id(),