diff --git a/examples/apple/coreml/llama/run_static_llm_multifunction.py b/examples/apple/coreml/llama/run_static_llm_multifunction.py index 517c54435f4..98d0cb0a763 100644 --- a/examples/apple/coreml/llama/run_static_llm_multifunction.py +++ b/examples/apple/coreml/llama/run_static_llm_multifunction.py @@ -22,14 +22,16 @@ import argparse import json import time -from typing import Any, Dict, List, Tuple +from typing import List import torch -import torch.utils._pytree as pytree +from executorch.examples.apple.coreml.llama.utils import ( + create_pte_wrapper, + setup_multifunction_managers, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.runner.generation import next_token -from executorch.examples.models.llama.static_attention import StaticAttentionIOManager from executorch.runtime import Runtime from pytorch_tokenizers import get_tokenizer @@ -41,170 +43,6 @@ def get_stop_tokens(tokenizer) -> List[int]: return [tokenizer.eos_id] -def create_pte_wrapper( - decode_method, - prefill_method, - mgr: "StaticAttentionIOManager", - prefill_seq_len: int, - prefill_mask: Dict[str, torch.Tensor], -): - """ - Create a wrapper function that adapts PTE execution to the interface - expected by StaticAttentionIOManager. - - This multifunction version selects between prefill and decode methods - based on the input sequence length. Both methods use the SAME cache_len, - so the cache buffer is shared directly without any slicing or copying. - - The wrapper: - - Takes (tokens, options_dict) like the eager model - - Selects prefill or decode method based on token count - - Uses the same cache buffer for both methods (no slicing needed) - - Flattens inputs using pytree - - Executes the appropriate PTE method - - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) - - Args: - decode_method: The PTE method for decode (seqlen=1) - prefill_method: The PTE method for prefill (seqlen=input_len) - mgr: StaticAttentionIOManager with caches sized for shared cache_len - prefill_seq_len: The sequence length for prefill - prefill_mask: Pre-computed mask tensor for prefill method - """ - - k_cache_keys = list(mgr.k_caches.keys()) - v_cache_keys = list(mgr.v_caches.keys()) - - timing_stats = { - "flatten_time": 0.0, - "execute_time": 0.0, - "reconstruct_time": 0.0, - "detection_time": 0.0, - "options_build_time": 0.0, - "call_count": 0, - } - - def wrapper( - tokens: torch.Tensor, options: Dict[str, Any] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - import time as time_module - - timing_stats["call_count"] += 1 - - t0 = time_module.perf_counter() - - # Detect actual sequence length. - # StaticAttentionIOManager._run_once pads tokens with zeros on the right. - # For decode (1 actual token), positions 1+ are all zeros. - padded_seq_len = tokens.shape[1] - if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): - actual_seq_len = 1 - else: - actual_seq_len = padded_seq_len - - is_prefill = actual_seq_len == prefill_seq_len - - t1 = time_module.perf_counter() - timing_stats["detection_time"] += t1 - t0 - - t0 = time_module.perf_counter() - - # Get the input cache state from options - in_k_caches, in_v_caches = options["in_cache_state"] - - # Both prefill and decode use the same cache_len, so no slicing needed! - # Just select the appropriate method and mask. - if is_prefill: - method = prefill_method - adapted_mask = prefill_mask - else: - method = decode_method - adapted_mask = mgr.masks - - adapted_options = { - "masks": adapted_mask, - "freqs_cos_override": options["freqs_cos_override"], - "freqs_sin_override": options["freqs_sin_override"], - "in_cache_state": (in_k_caches, in_v_caches), # Same cache for both! - } - - if "last_valid_token_pos" in options: - adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] - - inputs = (tokens, adapted_options) - - t1 = time_module.perf_counter() - timing_stats["options_build_time"] += t1 - t0 - - t0 = time_module.perf_counter() - flat_inputs, _ = pytree.tree_flatten(inputs) - t1 = time_module.perf_counter() - timing_stats["flatten_time"] += t1 - t0 - - t0 = time_module.perf_counter() - outputs = method.execute(flat_inputs) - t1 = time_module.perf_counter() - timing_stats["execute_time"] += t1 - t0 - - t0 = time_module.perf_counter() - - logits = outputs[0] - - num_layers = len(k_cache_keys) - k_updates = outputs[1 : 1 + num_layers] - v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] - - k_cache_dict = dict(zip(k_cache_keys, k_updates)) - v_cache_dict = dict(zip(v_cache_keys, v_updates)) - - attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} - - t1 = time_module.perf_counter() - timing_stats["reconstruct_time"] += t1 - t0 - - return logits, attn_updates - - def print_timing_stats(): - n = timing_stats["call_count"] - if n > 0: - print(f"\n=== Wrapper Timing Stats ({n} calls) ===") - print( - f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" - ) - print( - f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" - ) - print( - f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" - ) - print( - f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" - ) - print( - f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" - ) - total = ( - timing_stats["detection_time"] - + timing_stats["options_build_time"] - + timing_stats["flatten_time"] - + timing_stats["execute_time"] - + timing_stats["reconstruct_time"] - ) - print( - f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" - ) - print( - f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" - ) - expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) - print(f" Expected tok/s from execute alone: {expected_tps:.1f}") - - wrapper.print_timing_stats = print_timing_stats - wrapper.timing_stats = timing_stats - - return wrapper - - def main(): parser = argparse.ArgumentParser( description="Run multifunction static attention Llama model" @@ -326,36 +164,16 @@ def main(): print(f"Prefill: input_len={prefill_input_len}, cache_len={shared_cache_len}") print(f"Decode: input_len={decode_input_len}, cache_len={shared_cache_len}") - # Create decode manager (input_len=1) - used for decode phase - mgr = StaticAttentionIOManager( - model_args, - input_len=decode_input_len, - cache_lens=shared_cache_len, - batch_size=1, - dtype=torch.float16, - style="smart_mask", - mask_val=float("-inf"), - ) - - # Create prefill manager (input_len=64) with the SAME cache_len. - # Since both use the same cache_len, we can share the cache buffer directly. - prefill_mgr = StaticAttentionIOManager( + # Create managers with shared cache buffers + mgr, prefill_mgr, prefill_mask = setup_multifunction_managers( model_args, - input_len=prefill_input_len, - cache_lens=shared_cache_len, # Same cache_len as decode! - batch_size=1, + prefill_input_len, + decode_input_len, + shared_cache_len, dtype=torch.float16, - style="smart_mask", mask_val=float("-inf"), ) - # Share cache buffers: point prefill_mgr's caches to mgr's caches. - # No copying needed since both managers use the same cache_len! - prefill_mgr.k_caches = mgr.k_caches - prefill_mgr.v_caches = mgr.v_caches - - prefill_mask = prefill_mgr.masks - # Load PTE model with multifunction support print(f"Loading multifunction model from {args.model}...") runtime = Runtime.get() diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..755a654b9df 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -4,7 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time +from typing import Any, Dict, Tuple, TYPE_CHECKING + import torch +import torch.utils._pytree as pytree + +if TYPE_CHECKING: + from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + ) class SplitLinearModule(torch.nn.Module): @@ -114,3 +123,212 @@ def replace_linear_with_split_linear( in_target_split_size, in_max_splits, ) + + +def setup_multifunction_managers( + config, + prefill_input_len: int, + decode_input_len: int, + shared_cache_len: int, + dtype: torch.dtype = torch.float16, + mask_val: float = float("-inf"), + style: str = "smart_mask", +): + """ + Create prefill and decode StaticAttentionIOManager instances with shared cache buffers. + + Both managers use the same cache_len so they share cache memory directly. + Returns (decode_mgr, prefill_mgr, prefill_mask). + """ + from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + ) + + mgr = StaticAttentionIOManager( + config, + input_len=decode_input_len, + cache_lens=shared_cache_len, + batch_size=1, + dtype=dtype, + style=style, + mask_val=mask_val, + ) + + prefill_mgr = StaticAttentionIOManager( + config, + input_len=prefill_input_len, + cache_lens=shared_cache_len, + batch_size=1, + dtype=dtype, + style=style, + mask_val=mask_val, + ) + + # Share cache buffers — no copying needed + prefill_mgr.k_caches = mgr.k_caches + prefill_mgr.v_caches = mgr.v_caches + prefill_mask = prefill_mgr.masks + + return mgr, prefill_mgr, prefill_mask + + +def create_pte_wrapper( + decode_method, + prefill_method, + mgr: "StaticAttentionIOManager", + prefill_seq_len: int, + prefill_mask: Dict[str, torch.Tensor], +): + """ + Create a wrapper function that adapts PTE execution to the interface + expected by StaticAttentionIOManager. + + This multifunction version selects between prefill and decode methods + based on the input sequence length. Both methods use the SAME cache_len, + so the cache buffer is shared directly without any slicing or copying. + + The wrapper: + - Takes (tokens, options_dict) like the eager model + - Selects prefill or decode method based on token count + - Uses the same cache buffer for both methods (no slicing needed) + - Flattens inputs using pytree + - Executes the appropriate PTE method + - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) + + Args: + decode_method: The PTE method for decode (seqlen=1) + prefill_method: The PTE method for prefill (seqlen=input_len) + mgr: StaticAttentionIOManager with caches sized for shared cache_len + prefill_seq_len: The sequence length for prefill + prefill_mask: Pre-computed mask tensor for prefill method + """ + + k_cache_keys = list(mgr.k_caches.keys()) + v_cache_keys = list(mgr.v_caches.keys()) + + timing_stats = { + "flatten_time": 0.0, + "execute_time": 0.0, + "reconstruct_time": 0.0, + "detection_time": 0.0, + "options_build_time": 0.0, + "call_count": 0, + } + + def wrapper( + tokens: torch.Tensor, options: Dict[str, Any] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + timing_stats["call_count"] += 1 + + t0 = time.perf_counter() + + # Detect actual sequence length. + # StaticAttentionIOManager._run_once pads tokens with zeros on the right. + # For decode (1 actual token), positions 1+ are all zeros. + padded_seq_len = tokens.shape[1] + if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): + actual_seq_len = 1 + else: + actual_seq_len = padded_seq_len + + is_prefill = actual_seq_len == prefill_seq_len + + t1 = time.perf_counter() + timing_stats["detection_time"] += t1 - t0 + + t0 = time.perf_counter() + + # Get the input cache state from options + in_k_caches, in_v_caches = options["in_cache_state"] + + # Both prefill and decode use the same cache_len, so no slicing needed! + # Just select the appropriate method and mask. + if is_prefill: + method = prefill_method + adapted_mask = prefill_mask + else: + method = decode_method + adapted_mask = mgr.masks + + adapted_options = { + "masks": adapted_mask, + "freqs_cos_override": options["freqs_cos_override"], + "freqs_sin_override": options["freqs_sin_override"], + "in_cache_state": (in_k_caches, in_v_caches), # Same cache for both! + } + + if "last_valid_token_pos" in options: + adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] + + inputs = (tokens, adapted_options) + + t1 = time.perf_counter() + timing_stats["options_build_time"] += t1 - t0 + + t0 = time.perf_counter() + flat_inputs, _ = pytree.tree_flatten(inputs) + t1 = time.perf_counter() + timing_stats["flatten_time"] += t1 - t0 + + t0 = time.perf_counter() + outputs = method.execute(flat_inputs) + t1 = time.perf_counter() + timing_stats["execute_time"] += t1 - t0 + + t0 = time.perf_counter() + + logits = outputs[0] + + num_layers = len(k_cache_keys) + k_updates = outputs[1 : 1 + num_layers] + v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] + + k_cache_dict = dict(zip(k_cache_keys, k_updates)) + v_cache_dict = dict(zip(v_cache_keys, v_updates)) + + attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} + + t1 = time.perf_counter() + timing_stats["reconstruct_time"] += t1 - t0 + + return logits, attn_updates + + def print_timing_stats(): + n = timing_stats["call_count"] + if n > 0: + print(f"\n=== Wrapper Timing Stats ({n} calls) ===") + print( + f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" + ) + print( + f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" + ) + print( + f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" + ) + print( + f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" + ) + print( + f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" + ) + total = ( + timing_stats["detection_time"] + + timing_stats["options_build_time"] + + timing_stats["flatten_time"] + + timing_stats["execute_time"] + + timing_stats["reconstruct_time"] + ) + print( + f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" + ) + print( + f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" + ) + expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) + print(f" Expected tok/s from execute alone: {expected_tps:.1f}") + + wrapper.print_timing_stats = print_timing_stats + wrapper.timing_stats = timing_stats + + return wrapper