From 2947953ab30693d07c95ba3db0cca8d2a90af3ee Mon Sep 17 00:00:00 2001 From: Xingguo Li Date: Thu, 2 Apr 2026 09:28:13 +0100 Subject: [PATCH] Arm backend: add VGF PT2E linear quantization modes for LLM export - add vgf_16a8w PT2E quantization modes - add backend.vgf.quantize_scope for full vs linear VGF quantization - wire the VGF config through the LLM export and quantizer selection path - add coverage in export_llama_lib tests for the new VGF PT2E modes Signed-off-by: Xingguo Li Change-Id: Ie8fe849b4856321308d6d526248a7a4760ddc573 --- examples/models/llama/export_llama_lib.py | 15 ++++++ .../llama/tests/test_export_llama_lib.py | 54 ++++++++++++++++++- extension/llm/export/config/llm_config.py | 17 ++++++ extension/llm/export/quantizer_lib.py | 18 ++++++- 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index b353bb38bbd..d65d9dcf9be 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -229,6 +229,8 @@ def build_args_parser() -> argparse.ArgumentParser: "vulkan_8w", "tosa_8a8w", "ethosu_8a8w", + "vgf_8a8w", + "vgf_16a8w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -456,6 +458,18 @@ def build_args_parser() -> argparse.ArgumentParser: ) parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--vulkan-force-fp16", action="store_true") + parser.add_argument("--vgf", action="store_true") + parser.add_argument( + "--vgf-compile-spec", + default="TOSA-1.0+INT", + help="VGF compile spec, e.g. TOSA-1.0+INT or TOSA-1.0+INT+int16.", + ) + parser.add_argument( + "--vgf-quantize-scope", + default="full", + choices=["full", "linear"], + help="VGF quantization scope. Use 'linear' to quantize only Linear modules.", + ) parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") parser.add_argument( @@ -850,6 +864,7 @@ def get_quantizer_and_quant_params(llm_config): llm_config.backend.vgf.compile_spec, llm_config.backend.vgf.compiler_flags, llm_config.quantization.pt2e_quantize.value, + llm_config.backend.vgf.quantize_scope.value, ) quantizers.append(vgf_quantizer) if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 130a55f658c..f3dc403aa05 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -7,6 +7,8 @@ import unittest +import torch + from executorch.devtools.backend_debug import get_delegation_info try: @@ -28,7 +30,11 @@ build_args_parser, get_quantizer_and_quant_params, ) -from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize +from executorch.extension.llm.export.config.llm_config import ( + LlmConfig, + Pt2eQuantize, + VgfQuantizeScope, +) UNWANTED_OPS = [ "aten_permute_copy_default", @@ -111,3 +117,49 @@ def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self): self.assertIsNone(quant_dtype) self.assertEqual(len(quantizers), 1) self.assertIsInstance(quantizers[0], VgfQuantizer) + self.assertIsNotNone(quantizers[0].global_config) + self.assertEqual(quantizers[0].module_type_config, {}) + + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") + def test_get_quantizer_and_quant_params_returns_vgf_linear_quantizer(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT" + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w + + _pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], VgfQuantizer) + self.assertIsNone(quantizers[0].global_config) + self.assertIn(torch.nn.Linear, quantizers[0].module_type_config) + + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") + def test_vgf_16a8w_requires_int16_compile_spec_extension(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT" + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w + + with self.assertRaisesRegex(ValueError, "INT16 support"): + get_quantizer_and_quant_params(llm_config) + + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") + def test_vgf_16a8w_accepts_int16_compile_spec_extension(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT+int16" + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w + + _pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], VgfQuantizer) + self.assertIn(torch.nn.Linear, quantizers[0].module_type_config) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 63ffe03a9fe..43be6f1b166 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum): tosa_8a8w = "tosa_8a8w" ethosu_8a8w = "ethosu_8a8w" vgf_8a8w = "vgf_8a8w" + vgf_16a8w = "vgf_16a8w" class SpinQuant(str, Enum): @@ -587,6 +588,11 @@ class EthosUConfig: system_config: str = "default" +class VgfQuantizeScope(str, Enum): + full = "full" + linear = "linear" + + @dataclass class VgfConfig: """ @@ -596,6 +602,7 @@ class VgfConfig: enabled: bool = False compile_spec: Optional[str] = "TOSA-1.0+INT" compiler_flags: List[str] = field(default_factory=list) + quantize_scope: VgfQuantizeScope = VgfQuantizeScope.full @dataclass @@ -815,6 +822,16 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "group_size") and args.group_size: llm_config.backend.openvino.nncf_compression_group_size = args.group_size + # VGF + if hasattr(args, "vgf"): + llm_config.backend.vgf.enabled = args.vgf + if hasattr(args, "vgf_compile_spec"): + llm_config.backend.vgf.compile_spec = args.vgf_compile_spec + if hasattr(args, "vgf_quantize_scope") and args.vgf_quantize_scope: + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope( + args.vgf_quantize_scope + ) + # TorchAoKernels if any( hasattr(args, a) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 0c78921e461..cd70610ee11 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -367,8 +367,10 @@ def get_vgf_quantizer( compile_spec: Optional[str], compiler_flags: Optional[List[str]], pt2e_quantize: str, + quantize_scope: str, ): from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, VgfQuantizer, ) @@ -379,8 +381,22 @@ def get_vgf_quantizer( quantizer = VgfQuantizer(compile_spec_obj) if pt2e_quantize == "vgf_8a8w": - quantizer.set_global(get_symmetric_quantization_config()) + quantization_config = get_symmetric_quantization_config() + elif pt2e_quantize == "vgf_16a8w": + if not compile_spec_obj.tosa_spec.support_extension("int16"): + raise ValueError( + "vgf_16a8w requires a VGF compile spec with INT16 support, " + "for example TOSA-1.0+INT+int16." + ) + quantization_config = get_symmetric_a16w8_quantization_config() else: raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + if quantize_scope == "full": + quantizer.set_global(quantization_config) + elif quantize_scope == "linear": + quantizer.set_module_type(torch.nn.Linear, quantization_config) + else: + raise ValueError(f"Unsupported VGF quantization scope {quantize_scope}") + return quantizer