Interactive online version: Open In Colab

Using EMLP with Haiku

There are many neural network frameworks for jax, and they are often incompatible. Since most of the functionality of this package is written in pure jax, it can be used with flax, trax, linen, haiku, objax, or whatever your favorite jax NN framework.

However, the equivariant neural network layers provided in the Layers and Models are made for objax. If we try to use them with the popular Haiku framework, things will not work as expected.

Dont Do This:

[1]:
import haiku as hk
from jax import random
import numpy as np
import emlp.nn as nn
from emlp.reps import T,V
from emlp.groups import SO

repin= 4*V # Setup some example data representations
repout = V
G = SO(3)

x = np.random.randn(10,repin(G).size()) # generate some random data
[2]:
model = nn.EMLP(repin,repout,G)
net = hk.without_apply_rng(hk.transform(model))

key = random.PRNGKey(0)
params = net.init(random.PRNGKey(42), x)

y = net.apply(params,  x)

Although the code executes, we see that Haiku does not recognize the model parameters and treats the network as if it is a stateless jax function.

[3]:
params
[3]:
FlatMapping({})

It’s not hard to build EMLP layers in Haiku, and for each of the nn layers in Layers and Models we have implemented a Haiku version with the same arguments. These layers are accessible via emlp.nn.haiku rather than emlp.nn. To use EMLP models and equivariant layers with Haiku, instead of the above you should import from emlp.nn.haiku.

Instead, Do This:

[4]:
import emlp.nn.haiku as ehk

model = ehk.EMLP(repin,repout,SO(3))
net = hk.without_apply_rng(hk.transform(model))

key = random.PRNGKey(0)
params = net.init(random.PRNGKey(42), x)
y = net.apply(params,  x)
[5]:
params.keys()
[5]:
KeysOnlyKeysView(['sequential/hk_linear', 'sequential/hk_bi_linear', 'sequential/hk_linear_1', 'sequential/hk_bi_linear_1', 'sequential/hk_linear_2', 'sequential/hk_bi_linear_2', 'sequential/hk_linear_3'])

With this Haiku EMLP, paramaters are registered as expected.

If your favorite deep learning framework is not one of objax, haiku, or pytorch, don’t panic. It’s possible to use EMLP with other jax frameworks without much trouble, similar to the objax and haiku implementations. If you need help with this, start a pull request and we can send over some pointers.