Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 10 additions & 192 deletions examples/apple/coreml/llama/run_static_llm_multifunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
import argparse
import json
import time
from typing import Any, Dict, List, Tuple
from typing import List

import torch
import torch.utils._pytree as pytree

from executorch.examples.apple.coreml.llama.utils import (
create_pte_wrapper,
setup_multifunction_managers,
)
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.runner.generation import next_token
from executorch.examples.models.llama.static_attention import StaticAttentionIOManager
from executorch.runtime import Runtime
from pytorch_tokenizers import get_tokenizer

Expand All @@ -41,170 +43,6 @@ def get_stop_tokens(tokenizer) -> List[int]:
return [tokenizer.eos_id]


def create_pte_wrapper(
decode_method,
prefill_method,
mgr: "StaticAttentionIOManager",
prefill_seq_len: int,
prefill_mask: Dict[str, torch.Tensor],
):
"""
Create a wrapper function that adapts PTE execution to the interface
expected by StaticAttentionIOManager.

This multifunction version selects between prefill and decode methods
based on the input sequence length. Both methods use the SAME cache_len,
so the cache buffer is shared directly without any slicing or copying.

The wrapper:
- Takes (tokens, options_dict) like the eager model
- Selects prefill or decode method based on token count
- Uses the same cache buffer for both methods (no slicing needed)
- Flattens inputs using pytree
- Executes the appropriate PTE method
- Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)})

Args:
decode_method: The PTE method for decode (seqlen=1)
prefill_method: The PTE method for prefill (seqlen=input_len)
mgr: StaticAttentionIOManager with caches sized for shared cache_len
prefill_seq_len: The sequence length for prefill
prefill_mask: Pre-computed mask tensor for prefill method
"""

k_cache_keys = list(mgr.k_caches.keys())
v_cache_keys = list(mgr.v_caches.keys())

timing_stats = {
"flatten_time": 0.0,
"execute_time": 0.0,
"reconstruct_time": 0.0,
"detection_time": 0.0,
"options_build_time": 0.0,
"call_count": 0,
}

def wrapper(
tokens: torch.Tensor, options: Dict[str, Any]
) -> Tuple[torch.Tensor, Dict[str, Any]]:
import time as time_module

timing_stats["call_count"] += 1

t0 = time_module.perf_counter()

# Detect actual sequence length.
# StaticAttentionIOManager._run_once pads tokens with zeros on the right.
# For decode (1 actual token), positions 1+ are all zeros.
padded_seq_len = tokens.shape[1]
if padded_seq_len > 1 and (tokens[0, 1:] == 0).all():
actual_seq_len = 1
else:
actual_seq_len = padded_seq_len

is_prefill = actual_seq_len == prefill_seq_len

t1 = time_module.perf_counter()
timing_stats["detection_time"] += t1 - t0

t0 = time_module.perf_counter()

# Get the input cache state from options
in_k_caches, in_v_caches = options["in_cache_state"]

# Both prefill and decode use the same cache_len, so no slicing needed!
# Just select the appropriate method and mask.
if is_prefill:
method = prefill_method
adapted_mask = prefill_mask
else:
method = decode_method
adapted_mask = mgr.masks

adapted_options = {
"masks": adapted_mask,
"freqs_cos_override": options["freqs_cos_override"],
"freqs_sin_override": options["freqs_sin_override"],
"in_cache_state": (in_k_caches, in_v_caches), # Same cache for both!
}

if "last_valid_token_pos" in options:
adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"]

inputs = (tokens, adapted_options)

t1 = time_module.perf_counter()
timing_stats["options_build_time"] += t1 - t0

t0 = time_module.perf_counter()
flat_inputs, _ = pytree.tree_flatten(inputs)
t1 = time_module.perf_counter()
timing_stats["flatten_time"] += t1 - t0

t0 = time_module.perf_counter()
outputs = method.execute(flat_inputs)
t1 = time_module.perf_counter()
timing_stats["execute_time"] += t1 - t0

t0 = time_module.perf_counter()

logits = outputs[0]

num_layers = len(k_cache_keys)
k_updates = outputs[1 : 1 + num_layers]
v_updates = outputs[1 + num_layers : 1 + 2 * num_layers]

k_cache_dict = dict(zip(k_cache_keys, k_updates))
v_cache_dict = dict(zip(v_cache_keys, v_updates))

attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)}

t1 = time_module.perf_counter()
timing_stats["reconstruct_time"] += t1 - t0

return logits, attn_updates

def print_timing_stats():
n = timing_stats["call_count"]
if n > 0:
print(f"\n=== Wrapper Timing Stats ({n} calls) ===")
print(
f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg"
)
print(
f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg"
)
print(
f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg"
)
print(
f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg"
)
print(
f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg"
)
total = (
timing_stats["detection_time"]
+ timing_stats["options_build_time"]
+ timing_stats["flatten_time"]
+ timing_stats["execute_time"]
+ timing_stats["reconstruct_time"]
)
print(
f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg"
)
print(
f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time"
)
expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000)
print(f" Expected tok/s from execute alone: {expected_tps:.1f}")

wrapper.print_timing_stats = print_timing_stats
wrapper.timing_stats = timing_stats

return wrapper


def main():
parser = argparse.ArgumentParser(
description="Run multifunction static attention Llama model"
Expand Down Expand Up @@ -326,36 +164,16 @@ def main():
print(f"Prefill: input_len={prefill_input_len}, cache_len={shared_cache_len}")
print(f"Decode: input_len={decode_input_len}, cache_len={shared_cache_len}")

# Create decode manager (input_len=1) - used for decode phase
mgr = StaticAttentionIOManager(
model_args,
input_len=decode_input_len,
cache_lens=shared_cache_len,
batch_size=1,
dtype=torch.float16,
style="smart_mask",
mask_val=float("-inf"),
)

# Create prefill manager (input_len=64) with the SAME cache_len.
# Since both use the same cache_len, we can share the cache buffer directly.
prefill_mgr = StaticAttentionIOManager(
# Create managers with shared cache buffers
mgr, prefill_mgr, prefill_mask = setup_multifunction_managers(
model_args,
input_len=prefill_input_len,
cache_lens=shared_cache_len, # Same cache_len as decode!
batch_size=1,
prefill_input_len,
decode_input_len,
shared_cache_len,
dtype=torch.float16,
style="smart_mask",
mask_val=float("-inf"),
)

# Share cache buffers: point prefill_mgr's caches to mgr's caches.
# No copying needed since both managers use the same cache_len!
prefill_mgr.k_caches = mgr.k_caches
prefill_mgr.v_caches = mgr.v_caches

prefill_mask = prefill_mgr.masks

# Load PTE model with multifunction support
print(f"Loading multifunction model from {args.model}...")
runtime = Runtime.get()
Expand Down
Loading
Loading