Comments (5)
Modify sample function in model.py like below
I added 'inputs = inputs.unsqueeze(1)' in last like of for loop and changed sampled_ids = torch.cat(sampled_ids, 1) to sampled_ids = torch.cat(sampled_ids, 0)
`def sample(self, features, states=None):
"""Samples captions for given image features (Greedy search)."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(20): # maximum sampling length
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size),
outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)
predicted = outputs.max(1)[1]
sampled_ids.append(predicted)
inputs = self.embed(predicted)
inputs = inputs.unsqueeze(1)
sampled_ids = torch.cat(sampled_ids, 0) # (batch_size, 20)
return sampled_ids.squeeze()`
from pytorch-tutorial.
i got this to, have you fixed it?
from pytorch-tutorial.
@mhsamavatian Thanks, you are right. I updated the code :-)
from pytorch-tutorial.
However, I have an error when I do that:
runtimeerror input must have 3 dimensions, got 4
from pytorch-tutorial.
Hi there
I am also meet that problem, and then I add the 'inputs = inputs.unsqueeze(1)',
def sample(self, features, states=None):
"""Samples captions for given image features (Greedy search)."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(20): # maximum sampling length
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size),
outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)
predicted = outputs.max(1)[1]
sampled_ids.append(predicted)
inputs = self.embed(predicted)
inputs = inputs.unsqueeze(1)
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
return sampled_ids.squeeze()
but, I got the following:
File "D:\Dev\image_captioning\model.py", line 134, in sample
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
from pytorch-tutorial.
Related Issues (20)
- Initialize DecoderCNN in Image captioning
- Some problems occurred when I used model evaluation
- RuntimeError in Logistic Regression python file
- Using LSTM method in Python
- size mismatch for pretrained models HOT 2
- pytorch
- No Jupyter Notebooks. HOT 1
- About the learning method of neural_style_transfer
- Does anyone know the source code of channel calculation in pytorch?
- make ur repo cloneable and not editable by anyone.
- TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not tuple HOT 1
- AttributeError: module 'torch.nn' has no attribute 'linear' HOT 2
- ValueError: num_samples should be a positive integer value, but got num_samples=0 HOT 1
- main.py failed HOT 2
- some question about the position of 'optimizer.zero_grad()' HOT 4
- Pytorch tutorial HOT 1
- 自动驾驶更新笔记 Autopilot Updating Notes
- How can I get a PDF version of the tutorial HOT 2
- Cuda is true why don`t use it?
- GNN model
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-tutorial.