Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

/CUDA knowledge doesn’t port to JAX and vice-versa./

These are completely orthogonal? Jax can execute on GPU, TPU, or CPU pretty seamlessly, by design.



And then you give up specialized cuda kernels, which is necessary to run mistral 7b in 4 bit float mode.

There were some experimental cuda kernel for jax codebases. They didn’t work very well at the time, but maybe it’s better now.


I'm still not totally sure what the issue is. Jax uses program transformations to compile programs to run on a variety of hardware, for example, using XLA for TPUs. It can also run cuda ops for Nvidia gpus without issue: https://jax.readthedocs.io/en/latest/installation.html

There is also support for custom cpp and cuda ops if that's what is needed: https://jax.readthedocs.io/en/latest/Custom_Operation_for_GP...

I haven't worked with float4, but can imagine that new numerical types would require some special handling. But I assume that's the case for any ml environment.

But really you probably mean fixed point 4bit integer types? Looks like that has had at least some work done in Jax: https://github.com/google/jax/issues/8566




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: