Skip to main content

Easy hypernetworks in Pytorch and Flax

Project description

hyper-nn -- Easy Hypernetworks in Pytorch and Flax

Note: This library is experimental and currently under development - the flax implementations in particular are far from perfect and can be improved. If you have any suggestions on how to improve this library, please open a github issue or feel free to reach out directly!

hyper-nn gives users with the ability to create easily customizable Hypernetworks for almost any generic torch.nn.Module from Pytorch and flax.linen.Module from Flax. Our Hypernetwork objects are also torch.nn.Modules and flax.linen.Modules, allowing for easy integration with existing systems

Generating Policy Weights for Lunar Lander



Dynamic Weights for each character in a name generator


Install

hyper-nn tested on python 3.8+

Installing with pip

$ pip install hyper-nn

Installing from source

$ git clone git@github.com:shyamsn97/hyper-nn.git
$ cd hyper-nn
$ python setup.py install

For gpu functionality with Jax, you will need to follow the instructions here palm2020 Hypernetworks, simply put, are neural networks that generate parameters for another neural network. They can be incredibly powerful, being able to represent large networks while using only a fraction of their parameters.

hyper-nn represents Hypernetworks with two key components:

  • EmbeddingModule that holds information about layers(s) in the target network, or more generally a chunk of the target networks weights
  • Weight Generator, which takes in the embedding and outputs a parameter vector for the target network

Hypernetworks generally come in two variants, static or dynamic. Static Hypernetworks have a fixed or learned embedding and weight generator that outputs the target networks’ weights deterministically. Dynamic Hypernetworks instead receive inputs and use them to generate dynamic weights.


Quick Usage

for detailed examples see notebooks

The main classes to use are TorchHyperNetwork and JaxHyperNetwork and those that inherit them. Instead of constructing them directly, use the from_target method, shown below. After this you can use the hypernetwork exactly like any other nn.Module!

Pytorch

import torch.nn as nn

# any module
target_network = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32)
)

# static hypernetwork
from hypernn.torch.hypernet import TorchHyperNetwork

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = TorchHyperNetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

# now we can use the hypernetwork like any other nn.Module
inp = torch.zeros((1, 32))

# by default we only output what we'd expect from the target network
output = hypernetwork(inp=[inp])

# return aux_output
output, generated_params, aux_output = hypernetwork(inp=[inp], has_aux=True)

# generate params separately
generated_params, aux_output = hypernetwork.generate_params(inp=[inp])
output = hypernetwork(inp=[inp], generated_params=generated_params)

Jax

import flax.linen as nn
import jax.numpy as jnp
from jax import random

# any module
target_network = nn.Sequential(
    [
        nn.Dense(64),
        nn.relu,
        nn.Dense(32)
    ]
)

# static hypernetwork
from hypernn.jax.hypernet import JaxHyperNetwork

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = JaxHyperNetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS,
    inputs=jnp.zeros((1, 32)) # jax needs this to initialize target weights
)

# now we can use the hypernetwork like any other nn.Module
inp = jnp.zeros((1, 32)
key = random.PRNGKey(0)
hypernetwork_params = hypernetwork.init(key, inp=[inp)]) # flax needs to initialize hypernetwork parameters first

# by default we only output what we'd expect from the target network
output = hypernetwork.apply(hypernetwork_params, inp=[inp])

# return aux_output
output, generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp=[inp], has_aux=True)

# generate params separately
generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp=[inp], method=hypernetwork.generate_params)

output = hypernetwork.apply(inp=[inp], generated_params=generated_params)

Detailed Explanation

EmbeddingModule

The EmbeddingModule is used to store information about layers(s) in the target network, or more generally a chunk of the target networks weights. The standard representation is with a matrix of size num_embeddings x embedding_dim. hyper-nn uses torch's nn.Embedding and flax's nn.Embed classes to represent this.

WeightGenerator

WeightGenerator takes in the embedding matrix from EmbeddingModule and outputs a parameter vector of size num_target_parameters, equal to the total number of parameters in the target network. To ensure that the output is equal to num_target_parameters, the WeightGenerator outputs a matrix of size num_embeddings x weight_chunk_dim, where weight_chunk_dim = num_target_parameters // num_embeddings, and then flattens it.

Hypernetwork

the Hypernetwork by default uses a setup function to initialize the embedding_module and weight_generator from either user provided modules or the functions: make_embedding_module, make_weight_generator. This makes it really easy to customize and use your own modules instead of the basic versions provided. generate_params is used to generate the target parameters and forward combines the generated parameters with the target network to compute a forward pass

Instead of creating the Hypernetwork class directly, use from_target instead

Base class: code

class HyperNetwork(metaclass=abc.ABCMeta):
    embedding_module = None
    weight_generator = None

    def setup(self) -> None:
        if self.embedding_module is None:
            self.embedding_module = self.make_embedding_module()

        if self.weight_generator is None:
            self.weight_generator = self.make_weight_generator()

    @abc.abstractmethod
    def make_embedding_module(self):
        """
        Makes an embedding module to be used

        Returns:
            a torch.nn.Module or flax.linen.Module that can be used to return an embedding matrix to be used to generate weights
        """

    @abc.abstractmethod
    def make_weight_generator(self):
        """
        Makes an embedding module to be used

        Returns:
            a torch.nn.Module or flax.linen.Module that can be used to return an embedding matrix to be used to generate weights
        """

    @classmethod
    @abc.abstractmethod
    def count_params(
        cls,
        target,
        target_input_shape: Optional[Any] = None,
    ):
        """
        Counts parameters of target nn.Module

        Args:
            target (Union[torch.nn.Module, flax.linen.Module]): _description_
            target_input_shape (Optional[Any], optional): _description_. Defaults to None.
        """

    @classmethod
    @abc.abstractmethod
    def from_target(cls, target, *args, **kwargs) -> HyperNetwork:
        """
        creates hypernetwork from target

        Args:
            cls (_type_): _description_
        """

    @abc.abstractmethod
    def generate_params(self, inp: Optional[Any] = None, *args, **kwargs) -> Tuple[Any, Dict[str, Any]]:
        """
        Generate a vector of parameters for target network

        Args:
            inp (Optional[Any], optional): input, may be useful when creating dynamic hypernetworks

        Returns:
            Any: vector of parameters for target network and a dictionary of extra info
        """

    @abc.abstractmethod
    def forward(
        self,
        inp: Iterable[Any] = [],
        generated_params=None,
        has_aux: bool = True,
        *args,
        **kwargs,
    ):
        """
        Computes a forward pass with generated parameters or with parameters that are passed in

        Args:
            inp (Any): input from system
            generated_params (Optional[Union[torch.tensor, jnp.array]], optional): Generated params. Defaults to None.
            has_aux (bool): flag to indicate whether to return auxiliary info
        Returns:
            returns output and generated params and auxiliary info if has_aux is provided
        """

Citation

If you use this software in your academic work please cite

@misc{sudhakaran2022,
  author = {Sudhakaran, Shyam Sudhakaran},
  title = {hyper-nn},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/shyamsn97/hyper-nn}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hyper-nn-0.1.0.tar.gz (6.0 kB view hashes)

Uploaded Source

Built Distributions

hyper_nn-0.1.0-py3.9.egg (7.5 kB view hashes)

Uploaded Source

hyper_nn-0.1.0-py3-none-any.whl (6.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page