I'm still not totally sure what the issue is. Jax uses program transformations to compile programs to run on a variety of hardware, for example, using XLA for TPUs. It can also run cuda ops for Nvidia gpus without issue: https://jax.readthedocs.io/en/latest/installation.html
I haven't worked with float4, but can imagine that new numerical types would require some special handling. But I assume that's the case for any ml environment.
These are completely orthogonal? Jax can execute on GPU, TPU, or CPU pretty seamlessly, by design.