Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions backends/qualcomm/builders/op_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER

from .node_visitor import get_parameter, NodeVisitor
from .node_visitor_manager import register_node_visitor
Expand Down Expand Up @@ -38,19 +37,8 @@ def define_node(
)

coeff_node = self.get_node(node.args[1])
coeff = get_parameter(coeff_node, self.edge_program)
coeff_tensor = torch.zeros(input_node.meta["val"].shape, dtype=coeff.dtype)
# per-channel activation
coeff_node_shape = coeff_node.meta["val"].shape
if len(coeff_node_shape) and coeff_node_shape[0] > 1:
for i in range(input_node.meta["val"].shape[1]):
coeff_tensor = coeff_tensor.index_fill(1, torch.tensor([i]), coeff[i])
else:
coeff_tensor.fill_(coeff[0] if coeff.dim() else coeff)

if axis_order := input_node.meta.get(QCOM_AXIS_ORDER, None):
coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous()

coeff_tensor = get_parameter(coeff_node, self.edge_program)
# The coeff_tensor would be broadcasted to match the input shape by QNN
coeff_tensor_wrapper = self.define_tensor(
coeff_node,
node,
Expand Down
30 changes: 30 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,16 @@ def forward(self, x):
return torch.flip(x, self.dims)


class Conv2dLeakyReLU(torch.nn.Module):
def __init__(self, negative_slope=0.01):
super().__init__()
self.conv = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.leaky_relu = torch.nn.LeakyReLU(negative_slope)

def forward(self, x):
return self.leaky_relu(self.conv(x))


class Conv2dMaxPool2d(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -690,6 +700,16 @@ def forward(self, x):
return self.pool(self.conv(x))


class Conv2dReLU(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.conv(x))


class Conv2dSequential(torch.nn.Module):
def __init__(self, bias=True, channel_last=False):
super().__init__()
Expand Down Expand Up @@ -1480,6 +1500,16 @@ def forward(self, x):
return self.linear(x)


class LinearLeakyReLU(torch.nn.Module):
def __init__(self, negative_slope=0.01):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
self.leaky_relu = torch.nn.LeakyReLU(negative_slope)

def forward(self, x):
return self.leaky_relu(self.linear(x))


class LinearNonConstantWeight(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
76 changes: 76 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4832,6 +4832,82 @@ def test_qnn_backend_conv2d_max_pool2d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_activation_fusion(self):
if self.enable_x86_64:
self.skipTest(
"At the moment, testing is only being conducted on the device."
)
test_cases = [
{
"name": "conv2d_leaky_relu",
QCOM_MODULE: Conv2dLeakyReLU(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(1, 32, 6, 2),),
"unfused_check": lambda ops: any(
"prelu.opt" in op.lower() for op in ops
),
"unfused_msg": "Unexpected PReLU op in HTP ops (LeakyReLU lowered to PReLU)",
},
{
"name": "conv2d_relu",
QCOM_MODULE: Conv2dReLU(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(1, 3, 28, 28),),
"unfused_check": lambda ops: any(
op.lower() in ("q::relu", "q::relu.opt")
or (("relu" in op.lower()) and ("conv" not in op.lower()))
for op in ops
),
"unfused_msg": "Unexpected standalone ReLU op in HTP ops",
},
{
"name": "linear_leaky_relu",
QCOM_MODULE: LinearLeakyReLU(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(1, 6, 2, 32),),
"unfused_check": lambda ops: any(
"prelu.opt" in op.lower() for op in ops
),
"unfused_msg": "Unexpected PReLU op in HTP ops (LeakyReLU lowered to PReLU)",
},
]
for tc in test_cases:
with self.subTest(tc["name"]):
torch.manual_seed(8)
module = self.get_qdq_module(tc[QCOM_MODULE], tc[QCOM_SAMPLE_INPUTS])
backend_options = generate_htp_compiler_spec(use_fp16=False)
compiler_spec = generate_qnn_executorch_compiler_spec(
soc_model=self.chipset_table[TestQNN.soc_model],
backend_options=backend_options,
profile_level=3,
)
with tempfile.TemporaryDirectory() as tmp_dir:
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
module, tc[QCOM_SAMPLE_INPUTS], compiler_spec
).to_executorch()
pte_path = f"{tmp_dir}/model.pte"
with open(pte_path, "wb") as f:
edge_prog_mgr.write_to_file(f)
adb = self.get_adb_tool(pte_path)
binaries_trace = generate_optrace(
tmp_dir,
self.chipset_table[TestQNN.soc_model],
adb,
pte_path,
[tc[QCOM_SAMPLE_INPUTS]],
)
htp_ops = []
for _, (_, qhas) in binaries_trace.items():
with open(qhas, "r") as qhas_file:
qhas_data = json.load(qhas_file)
for row in qhas_data["data"]["htp_op_types"]["data"]:
htp_ops.append(row["op"])
has_conv = any("ConvLayer" in op for op in htp_ops)
self.assertTrue(
has_conv, f"Expected Conv op in HTP ops, got: {htp_ops}"
)
self.assertFalse(
tc["unfused_check"](htp_ops),
f"{tc['unfused_msg']}, got: {htp_ops}",
)

def test_qnn_backend_conv2d_slice_copy(self):
module = Conv2dSliceCopy() # noqa: F405
sample_input = (torch.randn([2, 1, 3, 3]),)
Expand Down
Loading