Comments (5)
That would be the safest solution. We could also patch pytorch arrays to have a shape
attribute and would be fine ;)
from skorch.
I have investigated the reason and it is caused by sklearn's safe_indexing
:
import torch
from sklearn.utils import safe_indexing
xx = torch.zeros((10, 3))
safe_indexing(xx, [0, 1])
# returns:
# [
# 0
# 0
# 0
# [torch.FloatTensor of size 3],
# 0
# 0
# 0
# [torch.FloatTensor of size 3]]
To sum up, safe_indexing
is often used to index data in sklearn. But this function only works well with pandas DataFrames (via iloc
) or numpy arrays (via shape
). If the array is neither, sklearn defaults to [X[idx] for idx in indices]
, which works for lists but not for torch arrays.
If we want to solve this without changing sklearn or pytorch (if only pytorch tensors had a shape
attribute intsead of size
...), I would suggest that we require the data to be in the numpy world if the user wants to benefit from sklearn goodies.
from skorch.
So in essence everywhere we communicate with sklearn we have to make sure that the data is formatted as numpy arrays?
from skorch.
Update: Now torch tensors have a shape attribute, but it still doesn't work (because of pytorch/pytorch#2305).
>>> from sklearn.utils import safe_indexing
>>> import torch
>>> xx = torch.zeros((10, 3))
>>> safe_indexing(xx, [0, 1])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/bbossan_dev/anaconda3/envs/inferno/lib/python3.6/site-packages/sklearn/utils/__init__.py", line 112, in safe_indexing
return X[indices]
TypeError: indexing a tensor with an object of type list. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.
However, this works:
>>> safe_indexing(xx, ([0, 1],))
0 0 0
0 0 0
[torch.FloatTensor of size 2x3]
A workaround could thus be to always wrap list indices in a tuple, since it works for pytorch and doesn't interfere with numpy.
from skorch.
The issue regarding GridSearchCV
is solved with this solution. Therefore I'm closing this.
from skorch.
Related Issues (20)
- IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item()` in C++ to convert a 0-dim tensor to a number HOT 8
- Enable using a generator as data loader
- Question: weird valid loss when re-scaling y
- Issues in braindecode recently introduced by skorch HOT 3
- ReadTheDocs: Wrong theme of docs
- How to tune number of epochs?
- Issues with deployment script
- dill load and sklearn clone result in error HOT 7
- how to integrate pytorch-tabnet into skorch framework? HOT 2
- ImportError: cannot import name 'uses_placeholder_y' from 'skorch.dataset' HOT 1
- Saving and Loading not working HOT 2
- Add slides for Pydata Amsterdam HOT 3
- model.history not save output of all epochs HOT 2
- I can't use gpu cuda tensor with NeuralNet HOT 1
- Permit to pass '**predict_params' to 'predict' method as for 'fit' method HOT 2
- Skorch forwarding data columns as kwargs when using gridsearchcv HOT 4
- Skorch weird handling of input data HOT 3
- LLM caching breaks with shared-prefix labels under certain conditions HOT 3
- Activating (deactivating) callbacks at specific epochs or milestones and SequentialLR HOT 1
- Dictionary Input and Custom Collate Function HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from skorch.