Giter Club home page Giter Club logo

smm's Introduction

Smooth Min-Max Monotonic Networks

Monotonicity constraints are powerful regularizers in statistical modelling. They can support fairness in computer-aided decision making and increase plausibility in data-driven scientific models. The seminal min-max (MM, Still, 1997) neural network architecture ensures monotonicity, but often gets stuck in undesired local optima during training because of partial derivatives being zero when computing extrema. A simple modification of the MM network using strictly-increasing smooth minimum and maximum functions that alleviates this problem. The resulting smooth min-max (SMM, Igel, 2024) network module inherits the asymptotic approximation properties from the MM architecture. It can be used within larger deep learning systems trained end-to-end. The SMM module is conceptually simple and computationally less demanding than state-of-the-art neural networks for monotonic modelling. Experiments show that this does not come with a loss in generalization performance compared to alternative neural and non-neural approaches. figure of min-max architecture

Code

The directory ICML2014 Supplement contains the PyTorch code for reproducing the experiments presented in the ICML 2024 paper introducing SMM networks [Igel, 2024]. The code is not very clean (e.g., containing several slightly different ways to implement the SMM).

A little bit cleaner (but not very efficient) are the implementations in:

  • SmoothMonotonicNN.py: Implementation of SMM module, restricted to non-decreasing constraints and scalar output
  • SMM_MLP.py: Very simple example of how the SMM module can be combined with other layers

Paper

You can read about the approch here:

Christian Igel. Smooth Min-Max Monotonic Networks. International Conference on Machine Learning, 2024

Christian Igel. Smooth Min-Max Monotonic Networks. arXiv:2306.01147v3 [cs.LG], 2024

Example

Toy example using SmoothMonotonicNN.py:

# Set the stage
import numpy as np
import torch 
import matplotlib.pyplot as plt
from SmoothMonotonicNN import SmoothMonotonicNN

# Toy example: Sigmoid with additive Gaussian noise
N = 20  # number of data points
x = np.linspace(-5, 5, N)
y = 1. / (np.exp(-x) + 1.) + np.random.normal(0, 0.1, N)

# Create model
mask = np.array([1])  # monotonic in the first (and only) argument
K = 6  # neurons per group
model = SmoothMonotonicNN(1, K, K, mask)  # define SMM

# Optimize model
optimizer = torch.optim.Rprop(model.parameters(), lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))
loss_function = torch.nn.MSELoss()

x_tensor = torch.from_numpy(x.reshape(N,-1).astype(np.float32))
y_tensor = torch.from_numpy(y.astype(np.float32))

max_iterations = 200  # Numper of epochs
for epoch in range(max_iterations):
    y_hat = model(x_tensor)
    loss = loss_function(y_hat, y_tensor)
    loss.backward()
    optimizer.step()
    model.zero_grad()

# Evaluate on training set
y_hat = model(x_tensor).detach().numpy()

# Visualize result
plt.plot(x, y, 'o', label = "data")
plt.plot(x, y_hat, label = "SMM")
plt.grid()
plt.xlabel("x")
plt.ylabel("y");

smm's People

Contributors

christian-igel avatar

Stargazers

Sizhuo Li avatar Ankit avatar

Watchers

 avatar

smm's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.