Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def build_args_parser() -> argparse.ArgumentParser:
"vulkan_8w",
"tosa_8a8w",
"ethosu_8a8w",
"vgf_8a8w",
"vgf_16a8w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The --pt2e_quantize argparse option is defined with a fixed set of choices, so it only accepts a single value, but the help text says it supports "Comma separated options" (and even mentions embedding, which is not a valid choice). This is user-facing and likely to confuse; either update the help text to reflect single-choice behavior, or switch the argument parsing to accept a comma-separated list (and adjust LlmConfig/Pt2eQuantize parsing accordingly).

Suggested change
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
help="Use a single PT2E quantization mode, e.g. xnnpack_dynamic (per-channel 8-bit weight) or xnnpack_dynamic_qc4 (per-channel 4-bit weight).",

Copilot uses AI. Check for mistakes.
)
Expand Down Expand Up @@ -456,6 +458,18 @@ def build_args_parser() -> argparse.ArgumentParser:
)
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--vulkan-force-fp16", action="store_true")
parser.add_argument("--vgf", action="store_true")
parser.add_argument(
"--vgf-compile-spec",
default="TOSA-1.0+INT",
help="VGF compile spec, e.g. TOSA-1.0+INT or TOSA-1.0+INT+int16.",
)
parser.add_argument(
"--vgf-quantize-scope",
default="full",
choices=["full", "linear"],
help="VGF quantization scope. Use 'linear' to quantize only Linear modules.",
)
parser.add_argument("--mps", action="store_true")
parser.add_argument("--coreml", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -847,6 +861,7 @@ def get_quantizer_and_quant_params(llm_config):
llm_config.backend.vgf.compile_spec,
llm_config.backend.vgf.compiler_flags,
llm_config.quantization.pt2e_quantize.value,
llm_config.backend.vgf.quantize_scope.value,
)
quantizers.append(vgf_quantizer)
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
Expand Down
54 changes: 53 additions & 1 deletion examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import unittest

import torch

from executorch.devtools.backend_debug import get_delegation_info

try:
Expand All @@ -28,7 +30,11 @@
build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize
from executorch.extension.llm.export.config.llm_config import (
LlmConfig,
Pt2eQuantize,
VgfQuantizeScope,
)

UNWANTED_OPS = [
"aten_permute_copy_default",
Expand Down Expand Up @@ -111,3 +117,49 @@ def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self):
self.assertIsNone(quant_dtype)
self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
self.assertIsNotNone(quantizers[0].global_config)
self.assertEqual(quantizers[0].module_type_config, {})

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_get_quantizer_and_quant_params_returns_vgf_linear_quantizer(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w

_pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
self.assertIsNone(quantizers[0].global_config)
self.assertIn(torch.nn.Linear, quantizers[0].module_type_config)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_vgf_16a8w_requires_int16_compile_spec_extension(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w

with self.assertRaisesRegex(ValueError, "INT16 support"):
get_quantizer_and_quant_params(llm_config)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_vgf_16a8w_accepts_int16_compile_spec_extension(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT+int16"
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w

_pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
self.assertIn(torch.nn.Linear, quantizers[0].module_type_config)
17 changes: 17 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum):
tosa_8a8w = "tosa_8a8w"
ethosu_8a8w = "ethosu_8a8w"
vgf_8a8w = "vgf_8a8w"
vgf_16a8w = "vgf_16a8w"


class SpinQuant(str, Enum):
Expand Down Expand Up @@ -587,6 +588,11 @@ class EthosUConfig:
system_config: str = "default"


class VgfQuantizeScope(str, Enum):
full = "full"
linear = "linear"


@dataclass
class VgfConfig:
"""
Expand All @@ -596,6 +602,7 @@ class VgfConfig:
enabled: bool = False
compile_spec: Optional[str] = "TOSA-1.0+INT"
compiler_flags: List[str] = field(default_factory=list)
quantize_scope: VgfQuantizeScope = VgfQuantizeScope.full


@dataclass
Expand Down Expand Up @@ -815,6 +822,16 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
if hasattr(args, "group_size") and args.group_size:
llm_config.backend.openvino.nncf_compression_group_size = args.group_size

# VGF
if hasattr(args, "vgf"):
llm_config.backend.vgf.enabled = args.vgf
if hasattr(args, "vgf_compile_spec"):
llm_config.backend.vgf.compile_spec = args.vgf_compile_spec
if hasattr(args, "vgf_quantize_scope") and args.vgf_quantize_scope:
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope(
args.vgf_quantize_scope
)

# TorchAoKernels
if any(
hasattr(args, a)
Expand Down
18 changes: 17 additions & 1 deletion extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,10 @@ def get_vgf_quantizer(
compile_spec: Optional[str],
compiler_flags: Optional[List[str]],
pt2e_quantize: str,
quantize_scope: str,
):
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
get_symmetric_quantization_config,
VgfQuantizer,
)
Expand All @@ -379,8 +381,22 @@ def get_vgf_quantizer(
quantizer = VgfQuantizer(compile_spec_obj)

if pt2e_quantize == "vgf_8a8w":
quantizer.set_global(get_symmetric_quantization_config())
quantization_config = get_symmetric_quantization_config()
elif pt2e_quantize == "vgf_16a8w":
if not compile_spec_obj.tosa_spec.support_extension("int16"):
raise ValueError(
"vgf_16a8w requires a VGF compile spec with INT16 support, "
"for example TOSA-1.0+INT+int16."
)
quantization_config = get_symmetric_a16w8_quantization_config()
else:
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")

if quantize_scope == "full":
quantizer.set_global(quantization_config)
elif quantize_scope == "linear":
quantizer.set_module_type(torch.nn.Linear, quantization_config)
else:
raise ValueError(f"Unsupported VGF quantization scope {quantize_scope}")

return quantizer
Loading