Giter Club home page Giter Club logo

Comments (5)

waleedka avatar waleedka commented on August 22, 2024 1

Thank you. I reproduced the issue, and just pushed a fix to Github. The problem was that the framework detection code correctly detected classes that derive from nn.Module but didn't recognize classes that derive from some other class which in turn inherits from nn.Module (i.e. two steps removed).

from hiddenlayer.

waleedka avatar waleedka commented on August 22, 2024

Can you provide a code snippet that replicates the error? That would help track the issue.

from hiddenlayer.

gmunizc avatar gmunizc commented on August 22, 2024
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import hiddenlayer as hl

# My net generalization
class Net(nn.Module):
    def __init__(self, hparams):
        super(Net, self).__init__()
        self.data_height = hparams.dataset.tile_shape[0]
        self.data_width = hparams.dataset.tile_shape[1]
        self.data_channels = hparams.dataset.tile_shape[2]
                                                        
        self.batch_norm_mu = hparams.batch_norm_mu
        self.batch_norm_epsilon = hparams.batch_norm_epsilon
        self.num_classes = hparams.dataset.num_classes
        
    def forward(self, inputs):
        raise NotImplementedError()    
# My specific version
class MyNet(Net):
    def __init__(self, hparams):
        super(MyNet, self).__init__(hparams)
        
        self.conv1 = nn.Conv2d(self.data_channels, out_channels=64, kernel_size=5, stride=2,
                               padding=2)
        self.batch_norm1 = nn.BatchNorm2d(num_features=64, eps=self.batch_norm_epsilon, momentum=self.batch_norm_epsilon,
                                          affine=True, track_running_stats=True)
    def forward(self,x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        return x

# Just some helper class to help me simulate the parameters being passed when running the code
class Dataset(object):
    def __init__(self, tile_shape, num_classes):
        self.tile_shape = tile_shape
        self.num_classes = num_classes

# Just some helper class to help me simulate the parameters being passed when running the code
class HParams(object):
    def __init__(self, dataset, batch_norm_mu, batch_norm_epsilon):
        self.dataset = dataset
        self.batch_norm_mu = batch_norm_mu
        self.batch_norm_epsilon = batch_norm_epsilon

# Set up hyper parameters
dataset = Dataset((40,40,1), 7)
hparams = HParams(dataset, 0.997, 1e-05)

# Instantiate my specific model
model = MyNet(hparams)

# hl will break complaining that my specific model isn't a PyTorch model
hl_graph = hl.build_graph(model, torch.zeros([64, 1, 40, 40]))
hl_graph.theme = hl.graph.THEMES["blue"].copy()
hl_graph


# Now adding everything together inside only one class that inherits directly from nn.Module
class OneNet(nn.Module):
    def __init__(self, hparams):
        super(OneNet, self).__init__()
        self.data_height = hparams.dataset.tile_shape[0]
        self.data_width = hparams.dataset.tile_shape[1]
        self.data_channels = hparams.dataset.tile_shape[2]
                                                        
        self.batch_norm_mu = hparams.batch_norm_mu
        self.batch_norm_epsilon = hparams.batch_norm_epsilon
        self.num_classes = hparams.dataset.num_classes
        
        self.conv1 = nn.Conv2d(self.data_channels, out_channels=64, kernel_size=5, stride=2,
                               padding=2)
        self.batch_norm1 = nn.BatchNorm2d(num_features=64, eps=self.batch_norm_epsilon, momentum=self.batch_norm_epsilon,
                                          affine=True, track_running_stats=True)
    def forward(self,x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        return x

# Initializing this new model
model = OneNet(hparams)

# hl will work now because it can easily see that this model is a PyTorch model
hl_graph = hl.build_graph(model, torch.zeros([64, 1, 40, 40]))
hl_graph.theme = hl.graph.THEMES["blue"].copy()
hl_graph

I ran it in a jupyter notebook in different cells, but I just put the code together like this for sharing purposes. I hope this can help you reproduce the error I got.

from hiddenlayer.

gmunizc avatar gmunizc commented on August 22, 2024

Thank you! I really appreciate it. And thank you for sharing this project with us.

from hiddenlayer.

tphankr avatar tphankr commented on August 22, 2024

**I also have the same problem. I want to visualize this code, but my code had error:
ValueError: model input param must be a PyTorch, TensorFlow, or Keras-with-TensorFlow-backend model.

I am following this code from GitHub: https://github.com/Ryo-Ito/brain_segmentation. Please, help me, visualizer. Thank you.**

import chainer
import chainer.functions as F
import chainer.links as L
import torch

class VoxResModule(chainer.Chain):
"""
Voxel Residual Module
input
BatchNormalization, ReLU
Conv 64, 3x3x3
BatchNormalization, ReLU
Conv 64, 3x3x3
output
"""

def __init__(self):
    initW = chainer.initializers.HeNormal(scale=0.01)
    super().__init__()

    with self.init_scope():
        self.bnorm1 = L.BatchNormalization(size=64)
        self.conv1 = L.ConvolutionND(3, 64, 64, 3, pad=1, initialW=initW)
        self.bnorm2 = L.BatchNormalization(size=64)
        self.conv2 = L.ConvolutionND(3, 64, 64, 3, pad=1, initialW=initW)

def __call__(self, x):
    h = F.relu(self.bnorm1(x))
    h = self.conv1(h)
    h = F.relu(self.bnorm2(h))
    h = self.conv2(h)
    return h + x

class VoxResNet(chainer.Chain):
"""Voxel Residual Network"""

def __init__(self, in_channels=1, n_classes=4):
    init = chainer.initializers.HeNormal(scale=0.01)
    super().__init__()

    with self.init_scope():
        self.conv1a = L.ConvolutionND(
            3, in_channels, 32, 3, pad=1, initialW=init)
        self.bnorm1a = L.BatchNormalization(32)
        self.conv1b = L.ConvolutionND(
            3, 32, 32, 3, pad=1, initialW=init)
        self.bnorm1b = L.BatchNormalization(32)
        self.conv1c = L.ConvolutionND(
            3, 32, 64, 3, stride=2, pad=1, initialW=init)
        self.voxres2 = VoxResModule()
        self.voxres3 = VoxResModule()
        self.bnorm3 = L.BatchNormalization(64)
        self.conv4 = L.ConvolutionND(
            3, 64, 64, 3, stride=2, pad=1, initialW=init)
        self.voxres5 = VoxResModule()
        self.voxres6 = VoxResModule()
        self.bnorm6 = L.BatchNormalization(64)
        self.conv7 = L.ConvolutionND(
            3, 64, 64, 3, stride=2, pad=1, initialW=init)
        self.voxres8 = VoxResModule()
        self.voxres9 = VoxResModule()
        self.c1deconv = L.DeconvolutionND(
            3, 32, 32, 3, pad=1, initialW=init)
        self.c1conv = L.ConvolutionND(
            3, 32, n_classes, 3, pad=1, initialW=init)
        self.c2deconv = L.DeconvolutionND(
            3, 64, 64, 4, stride=2, pad=1, initialW=init)
        self.c2conv = L.ConvolutionND(
            3, 64, n_classes, 3, pad=1, initialW=init)
        self.c3deconv = L.DeconvolutionND(
            3, 64, 64, 6, stride=4, pad=1, initialW=init)
        self.c3conv = L.ConvolutionND(
            3, 64, n_classes, 3, pad=1, initialW=init)
        self.c4deconv = L.DeconvolutionND(
            3, 64, 64, 10, stride=8, pad=1, initialW=init)
        self.c4conv = L.ConvolutionND(
            3, 64, n_classes, 3, pad=1, initialW=init)

def __call__(self, x, train=False):
    print(x.shape, '-------begin------------')
    """
    calculate output of VoxResNet given input x

    Parameters
    ----------
    x : (batch_size, in_channels, xlen, ylen, zlen) ndarray
        image to perform semantic segmentation

    Returns
    -------
    proba: (batch_size, n_classes, xlen, ylen, zlen) ndarray
        probability of each voxel belonging each class
        elif train=True, returns list of logits
    """
    with chainer.using_config("train", train):
        h = self.conv1a(x)
        h = F.relu(self.bnorm1a(h))
        h = self.conv1b(h)
        c1 = F.clipped_relu(self.c1deconv(h))
        c1 = self.c1conv(c1)

        h = F.relu(self.bnorm1b(h))
        h = self.conv1c(h)
        h = self.voxres2(h)
        h = self.voxres3(h)
        c2 = F.clipped_relu(self.c2deconv(h))
        c2 = self.c2conv(c2)

        h = F.relu(self.bnorm3(h))
        h = self.conv4(h)
        h = self.voxres5(h)
        h = self.voxres6(h)
        c3 = F.clipped_relu(self.c3deconv(h))
        c3 = self.c3conv(c3)

        h = F.relu(self.bnorm6(h))
        h = self.conv7(h)
        h = self.voxres8(h)
        h = self.voxres9(h)
        c4 = F.clipped_relu(self.c4deconv(h))
        c4 = self.c4conv(c4)

        c = c1 + c2 + c3 + c4

    if train:
        return [c1, c2, c3, c4, c]
    else:
        return F.softmax(c)

import hiddenlayer as hl
input = torch.zeros([1, 3, 64, 64, 64])

model = VoxResNet()
hl.build_graph(model, (input))

from hiddenlayer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.