Would using lists rather than jax.numpy arrays lead to more accurate numerical transformations? #28964
-
Hi, I am doing a project with RNNs using jax and flax and I have noticed some behavior that I do not really understand. My code is basically an optimization loop where the user provides the initial parameters for the system they want to optimize. This system is divided onto several time steps. He feeds the initial input into the first time step of the the system, gets a certain output, feeds this output into a RNN which returns the parameters for the following time step and so on. Then it is optimized using adam (particularly using optax). Now the user inputs his initial paramaters as a dict and then there is a function called My question/observation is when I make this function return a list of jnp.arrays instead of a list of lists, the property I am optimizing is an order of magnitude worse! For example, using a list of lists outputs 0.9997 and a list of jnp.arrays outputs 0.998 (the closer to one the better). Noting: the RNN output a list of jnp.arrays (it is using flax linnen) and everything in the code remains the same. Here are said function: Outputing list of lists:
Using list of jnp.arrays:
and this is an example of the users input initial params:
The rest of the code remains exactly the same for both. After optimization if for example there were five time steps, this is how the final optimized params for each time step would look like: using list of jnp.arrays:
using list of lists:
Would such a difference in behavior be due to how jax handles grad and jit and others with lists compared to jnp.arrays or am I missing something? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
For those who had a similar Question/Problem refer to this stackoverflow question, I asked the same one there and got an answer. |
Beta Was this translation helpful? Give feedback.
For those who had a similar Question/Problem refer to this stackoverflow question, I asked the same one there and got an answer.