I stumbled into an edge case when trying to apply torch.vmap to some code I had rewritten to utilize array-api-compat. So far everything seems to work just fine, with the exception of clip. Here's a minimal example:
import array_api_compat.torch as xp
import torch
def apply_clip(a):
return torch.clip(a, min=0, max=30)
def apply_clip_compat(a):
return xp.clip(a, min=0, max=30)
a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])
print(apply_clip(a))
print(apply_clip_compat(a))
v1 = torch.vmap(apply_clip)
print(v1(a))
v2 = xp.vmap(apply_clip_compat)
print(v2(a))
Which raises the following error:
[user@domain ~]$ python test_clip.py
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
Traceback (most recent call last):
File "test_clip.py", line 22, in <module>
print(v2(a))
~~^^^
File ".venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "test_clip.py", line 10, in apply_clip_compat
return xp.clip(a, min=0, max=30)
~~~~~~~^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/array_api_compat/_internal.py", line 35, in wrapped_f
return f(*args, xp=xp, **kwargs)
File ".venv/lib/python3.13/site-packages/array_api_compat/common/_aliases.py", line 424, in clip
out[()] = x
~~~^^^^
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.
I totally understand if full support for torch.vmap is out of scope, but figured it might be worth raising the issue in case there's something which requires fixing.
I stumbled into an edge case when trying to apply
torch.vmapto some code I had rewritten to utilize array-api-compat. So far everything seems to work just fine, with the exception ofclip. Here's a minimal example:Which raises the following error:
I totally understand if full support for
torch.vmapis out of scope, but figured it might be worth raising the issue in case there's something which requires fixing.