With jax-ml/jax#36442 JAX has now implemented a batched version of orgqr so that Q can now also be efficiently computed on GPU's . This should not be slower than our own custom version, since the matrices that we decompose are all relatively small. Therefore we should remove our custom implementation and again use the standard JAX implementation.
With jax-ml/jax#36442 JAX has now implemented a batched version of
orgqrso that Q can now also be efficiently computed on GPU's . This should not be slower than our own custom version, since the matrices that we decompose are all relatively small. Therefore we should remove our custom implementation and again use the standard JAX implementation.