Has anybody investigated the possibility to allow for an array agnostic way to leverage the torch.compile and jax.jit decorators in array-api-extra?
This might be useful for array API consuming libraries such as SciPy or scikit-learn. For array API namespaces without JIT compiler support, xpx.compile would just result in a noop decorator. For torch and JAX it might, dispatching to an actual JIT compiler could unlock significant speed-ups and memory usage improvements.
However, the parameters of those decorators have many kwargs with seemingly very little overlap:
Maybe xpx.compile could be made to accept arbitrary kwargs scoped by the underlying namespace name without attempting to map common compiler semantics together.
@xpx.compile(
torch=dict(options={"triton.cudagraphs": True}, fullgraph=True),
jax=dict(static_argnames=['n']),
)
def some_array_function(array, n):
...
I have little experience to tell whether calling those decorators with their default argument is useful or not in practice.
Has anybody investigated the possibility to allow for an array agnostic way to leverage the
torch.compileandjax.jitdecorators inarray-api-extra?This might be useful for array API consuming libraries such as SciPy or scikit-learn. For array API namespaces without JIT compiler support,
xpx.compilewould just result in a noop decorator. For torch and JAX it might, dispatching to an actual JIT compiler could unlock significant speed-ups and memory usage improvements.However, the parameters of those decorators have many kwargs with seemingly very little overlap:
Maybe
xpx.compilecould be made to accept arbitrary kwargs scoped by the underlying namespace name without attempting to map common compiler semantics together.I have little experience to tell whether calling those decorators with their default argument is useful or not in practice.