Using EMLP with Haiku¶
There are many neural network frameworks for jax, and they are often incompatible. Since most of the functionality of this package is written in pure jax, it can be used with flax, trax, linen, haiku, objax, or whatever your favorite jax NN framework.
Dont Do This:¶
import haiku as hk from jax import random import numpy as np import emlp.nn as nn from emlp.reps import T,V from emlp.groups import SO repin= 4*V # Setup some example data representations repout = V G = SO(3) x = np.random.randn(10,repin(G).size()) # generate some random data
model = nn.EMLP(repin,repout,G) net = hk.without_apply_rng(hk.transform(model)) key = random.PRNGKey(0) params = net.init(random.PRNGKey(42), x) y = net.apply(params, x)
Although the code executes, we see that Haiku does not recognize the model parameters and treats the network as if it is a stateless jax function.
It’s not hard to build EMLP layers in Haiku, and for each of the nn layers in Layers and Models we have implemented a Haiku version with the same arguments. These layers are accessible via
emlp.nn.haiku rather than
emlp.nn. To use EMLP models and equivariant layers with Haiku, instead of the above you should import from
Instead, Do This:¶
import emlp.nn.haiku as ehk model = ehk.EMLP(repin,repout,SO(3)) net = hk.without_apply_rng(hk.transform(model)) key = random.PRNGKey(0) params = net.init(random.PRNGKey(42), x) y = net.apply(params, x)
KeysOnlyKeysView(['sequential/hk_linear', 'sequential/hk_bi_linear', 'sequential/hk_linear_1', 'sequential/hk_bi_linear_1', 'sequential/hk_linear_2', 'sequential/hk_bi_linear_2', 'sequential/hk_linear_3'])
With this Haiku EMLP, paramaters are registered as expected.
If your favorite deep learning framework is not one of objax, haiku, or pytorch, don’t panic. It’s possible to use EMLP with other jax frameworks without much trouble, similar to the objax and haiku implementations. If you need help with this, start a pull request and we can send over some pointers.