Giter Club home page Giter Club logo

Comments (5)

benjamin-work avatar benjamin-work commented on May 15, 2024 1

That would be the safest solution. We could also patch pytorch arrays to have a shape attribute and would be fine ;)

from skorch.

benjamin-work avatar benjamin-work commented on May 15, 2024

I have investigated the reason and it is caused by sklearn's safe_indexing:

import torch
from sklearn.utils import safe_indexing
xx = torch.zeros((10, 3))
safe_indexing(xx, [0, 1])

# returns:
# [
#   0
#   0
#   0
#  [torch.FloatTensor of size 3], 
#   0
#   0
#   0
#  [torch.FloatTensor of size 3]]

To sum up, safe_indexing is often used to index data in sklearn. But this function only works well with pandas DataFrames (via iloc) or numpy arrays (via shape). If the array is neither, sklearn defaults to [X[idx] for idx in indices], which works for lists but not for torch arrays.

If we want to solve this without changing sklearn or pytorch (if only pytorch tensors had a shape attribute intsead of size ...), I would suggest that we require the data to be in the numpy world if the user wants to benefit from sklearn goodies.

from skorch.

ottonemo avatar ottonemo commented on May 15, 2024

So in essence everywhere we communicate with sklearn we have to make sure that the data is formatted as numpy arrays?

from skorch.

benjamin-work avatar benjamin-work commented on May 15, 2024

Update: Now torch tensors have a shape attribute, but it still doesn't work (because of pytorch/pytorch#2305).

>>> from sklearn.utils import safe_indexing
>>> import torch
>>> xx = torch.zeros((10, 3))
>>> safe_indexing(xx, [0, 1])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/bbossan_dev/anaconda3/envs/inferno/lib/python3.6/site-packages/sklearn/utils/__init__.py", line 112, in safe_indexing
    return X[indices]
TypeError: indexing a tensor with an object of type list. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.

However, this works:

>>> safe_indexing(xx, ([0, 1],))
 0  0  0
 0  0  0
[torch.FloatTensor of size 2x3]

A workaround could thus be to always wrap list indices in a tuple, since it works for pytorch and doesn't interfere with numpy.

from skorch.

ottonemo avatar ottonemo commented on May 15, 2024

The issue regarding GridSearchCV is solved with this solution. Therefore I'm closing this.

from skorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.