Giter Club home page Giter Club logo

gradipy's Introduction

Logo

gradipy: A Lightweight Neural Network Library

Tests

gradipy is an evolving project, and it provides PyTorch-like API for building and training neural networks, there are some features that are actively developed on and plan to support in future releases. My inspiration for this project stems from micrograd and tinygrad. While the former focuses solely on scalar values, which makes it quite simple, the latter is a more advanced implementation. The library I am building, tentatively named gradipy, falls somewhere in between these two.

While gradipy may not be your top-tier production tool (we'll leave that to the heavyweights), it's more like a trusty sidekick on your learning expedition. Because, let's be honest, understanding the nitty-gritty of neural networks by rolling up your sleeves and implementing from scratch? That's the real superhero origin story. Serious learning, a bit of coding flair, and a sidekick named gradipy—what more could you want?

Additionally, any contributions to gradipy are welcomed. Given the evolving nature of the project, there might be seminar bugs or opportunities to refine design choices. Your input is appreciated in making gradipy even better.

Sample Usage

In this example, we define a simple feedforward neural network to train classifier on the MNIST data. (please refer to the example usage):

import gradipy.nn as nn
from gradipy import datasets
from gradipy.nn import optim

X_train, y_train, X_val, y_val, X_test, y_test = datasets.MNIST()

# define some utility function here...

class DenseNeuralNetwork:
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.W1 = nn.init_kaiming_normal(input_dim, hidden_dim, nonlinearity="relu")
        self.W2 = nn.init_kaiming_normal(hidden_dim, output_dim, nonlinearity="relu")

    def forward(self, X):
        logits = X.matmul(self.W1).relu().matmul(self.W2)
        return logits

model = DenseNeuralNetwork(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([model.W1, model.W2], lr=lr)

for epoch in range(epochs + 1):
    optimizer.zero_grad()
    xb, yb = get_batch()
    logits = model.forward(xb)
    loss = criterion(logits, yb)
    loss.backward()
    optimizer.step()

    # log the results on each epoch...

In order to run example architectures, refer to the models/ and run the following command:

python -m examples.resnet "link_to_the_image"

Installation

You can install gradipy using pip. It is recommended to create a virtual environment before installing the library.

# Create and activate a virtual environment (optional but recommended)
# Replace 'myenv' with your desired environment name
python -m venv myenv
source myenv/bin/activate  # On Windows, use 'myenv\Scripts\activate'

# Install gradipy using pip
pip install gradipy

Testing

To run the tests for gradipy, ensure you have pytest and PyTorch installed. If you haven't installed them yet, you can do so using pip:

pip install -r requirements.txt

After installing the necessary dependencies, you can run the tests using pytest. During testing, we compare the output results and gradients of various operations with their counterparts in PyTorch to ensure the correctness and compatibility of gradipy.

python -m pytest

Feature Roadmap

Here's a list of features we plan to implement in gradipy, along with their current status:

To-Do

  • Backward pass for Conv2d and BatchNorm2d
  • Backward passes for: mul (problem with broadcasting), tanh
  • More Loss functions (nn.MSELoss and nn.NLLLoss)
  • Recurrent layers for sequence data
  • GPU acceleration (no idea how to do that, openCL?)

gradipy's People

Contributors

eljanmahammadli avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.