Giter Club home page Giter Club logo

Comments (4)

utkuozbulak avatar utkuozbulak commented on June 12, 2024 1

If you comment that line, the model doesn't do a forward pass until the selected_layer. Based on what you send, it is probably because your model does not have a separation of features/classifier for convolutional layers and fully connected layers so the forward pass fails. You have to edit the code to perform a proper forward pass.

from pytorch-cnn-visualizations.

mountains-high avatar mountains-high commented on June 12, 2024

Thank you for your suggestion. I changed the model. The mode looks like this:

Net(
  (features): Sequential(
    (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=800, out_features=500, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=500, out_features=10, bias=True)
  )
)
State dict  features.0.weight 	 torch.Size([20, 1, 5, 5])
State dict  features.0.bias 	 torch.Size([20])
State dict  features.3.weight 	 torch.Size([50, 20, 5, 5])
State dict  features.3.bias 	 torch.Size([50])
State dict  classifier.0.weight 	 torch.Size([500, 800])
State dict  classifier.0.bias 	 torch.Size([500])
State dict  classifier.2.weight 	 torch.Size([10, 500])
State dict  classifier.2.bias 	 torch.Size([10])

However, I get the same error

  x = layer(x)
TypeError: 'Tensor' object is not callable

The code which I used:

def get_output_from_specific_layer(self, x, layer_id):
        layer_output = None
        for index, layer in enumerate(self.model.features[3].weight):
            x = layer(x)
	    #print(‘Layer is: ’, layer) #All layers will be printed out 
            if str(index) == str(layer_id):
                #x = layer(x)
                layer_output = x[0]
                #print('Layer is: ', layer)
                break
        return layer_output

I don’t know whether it’s correct or not, but it seems x = layer(x) not taking the weights of a specific layer. How do you think? Thank you for your time and considerations.

from pytorch-cnn-visualizations.

utkuozbulak avatar utkuozbulak commented on June 12, 2024

My guess: you are not iterating over weights correctly.

Change

        for index, layer in enumerate(self.model.features[3].weight):

to

        for index, layer in enumerate(self.model.features.weight):

from pytorch-cnn-visualizations.

mountains-high avatar mountains-high commented on June 12, 2024

Thank you. Now it is working!
Have a nice day ~

from pytorch-cnn-visualizations.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.