Skip to content

fix: pass dynamic=True to torch.compile to stop dynamo recompilation#1685

Draft
rahulp7801 wants to merge 1 commit intointel:mainfrom
rahulp7801:opensrcer/issue-1679-29-56_e4a9bb
Draft

fix: pass dynamic=True to torch.compile to stop dynamo recompilation#1685
rahulp7801 wants to merge 1 commit intointel:mainfrom
rahulp7801:opensrcer/issue-1679-29-56_e4a9bb

Conversation

@rahulp7801
Copy link
Copy Markdown

Fixes #1679.

Summary

compile_func_on_cuda_or_cpu in auto_round/utils/device.py:97 calls torch.compile(func) with no shape-specialisation hint. When quant_tensor_sym is compiled, torch._dynamo traces it for the concrete shape of the v tensor (the per-weight rounding perturbation, initialised to the weight's shape at wrapper.py:164). Different layers have different weight shapes, so dynamo recompiles the function for each new shape. After 8 recompilations it logs the config.recompile_limit warning and falls back to eager mode. Passing dynamic=True to torch.compile tells dynamo to use symbolic shapes, producing a single shape-generic trace that never needs to recompile for size changes in v or the weight tensor.

Files changed

  • auto_round/utils/device.py

Test plan

  • Call sites checked: compile_func is called in wrapper.py:175,182, alg_ext.py:373,1042,1046, and compressors/base.py:539. All of these wrap quantisation functions (quant_tensor_sym, quant_tensor_asym, and block_forward) that receive tensors of varying shape across layers — so dynamic=True is the correct default for all of them.
Contribution guide notes
  • Sign-off required: Every commit must include Signed-off-by: Name <email> (DCO, per CONTRIBUTING.md).

Authored via opensrcer agentic solve. Full exploration trace: dispatch d_2026-04-15T06-29-56_e4a9bb.

Generated by opensrcer agentic solve (dispatch d_2026-04-15T06-29-56_e4a9bb).
Exploration + diff authored by Claude Code via the opensrcer MCP tools
(list_files / read_file / grep / find_definition / find_references).

Review the full exploration log at .dispatches/d_2026-04-15T06-29-56_e4a9bb.log.
@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
6 pipeline(s) were filtered out due to trigger conditions.
1 pipeline(s) require an authorized user to comment /azp run to run.

SeanThomasWilliams added a commit to SeanThomasWilliams/auto-round that referenced this pull request Apr 15, 2026
Cherry-pick of intel#1685
(still open upstream as of 2026-04-15).

Fixes torch._dynamo recompile storm during AutoRound quantization
when per-layer weight shapes differ. With dynamic=True, dynamo traces
quant_tensor_sym / block_forward symbolically instead of specializing
per shape.

Seen to burn 2h53m with 0 layers quantized on AWS iters=256
(ai_docs/... feedback_autoround_torch_compile_aws.md).

Author retained on original PR. No co-author footer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: torch._dynamo hit config.recompile_limit (8) ... tensor 'v' size mismatch at index 0 ...

1 participant