Comments (8)
With #117 this can be solved:
For example:
class Seq2Seq:
def __init__(self, encoder, decoder, **kwargs):
self.encoder = encoder(skorch.utils.params_for('encoder', kwargs))
self.decoder = decoder(skorch.utils.params_for('decoder', kwargs))
This would allow us to write
ef = NeuralNet(
module=Seq2Seq(encoder=AttentionEncoderRNN, decoder=DecoderRNN),
module__encoder__num_hidden=23,
)
which is exactly what we want.
from skorch.
It would be also nice to support nn.Sequential
.
from skorch.
Still open for debate and there's no clear road. Postponing for r0.3.0.
from skorch.
Also up for discussion: Not only allow setting parameters on sub-modules, but also on arbitrary attributes. This would, e.g., allow us to things like:
net.set_params(module__encoder__embeddings__weight__requires_grad=False)
from skorch.
I propose the following utility or helper functions:
from operator import methodcaller
def set_params_in_module(module, **kwargs):
for k, v in kwargs.items():
set_param_in_module(module, k, v)
def set_param_in_module(module, param, value):
name, key = param.rsplit('__', 1)
name = name.replace('__', '.')
for n, p in module.named_parameters():
if n.startswith(name):
methodcaller(f'{key}_', value)(p)
This would allow support for the following syntax:
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.embeddings = nn.Embedding(10, 10)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(10, 1, 1)
self.encoder = Encoder()
module = MyModule()
set_params_in_module(module, conv__weight__requires_grad=False)
set_params_in_module(module, conv__weight__copy=torch.ones((1, 10, 1, 1)))
set_params_in_module(module, conv__weight__add=torch.ones((1, 10, 1, 1)))
set_params_in_module(module, encoder__requires_grad=False)
set_params_in_module(module,
conv__requires_grad=False,
encoder__embeddings__weight__requires_grad=False)
Integrating this into NeutralNet
is tricky, because keywords prefixed with module__
are passed into the modules __init__
function.
from skorch.
I believe there could be a way. When a parameter is passed to module
, if it doesn't contain a __
, proceed normally. Otherwise, proceed as you suggested (but the call must be recursive if there are several __
).
from skorch.
If module__inner__linear
is set in NeutralNet
, then inner__linear
will be passed, as a keyword, to the module's __init__
function. This enables @ottonemo's use case of using skorch.utils.params_for
during the module's __init__
.
The set_params_in_module
function is used after the module has successfully called __init__
.
from skorch.
This is way more complicated than anticipated and we should schedule this for 0.4.0 rather than delay 0.3.0.
from skorch.
Related Issues (20)
- IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item()` in C++ to convert a 0-dim tensor to a number HOT 8
- Enable using a generator as data loader
- Question: weird valid loss when re-scaling y
- Issues in braindecode recently introduced by skorch HOT 3
- ReadTheDocs: Wrong theme of docs
- How to tune number of epochs?
- Issues with deployment script
- dill load and sklearn clone result in error HOT 7
- how to integrate pytorch-tabnet into skorch framework? HOT 2
- ImportError: cannot import name 'uses_placeholder_y' from 'skorch.dataset' HOT 1
- Saving and Loading not working HOT 2
- Add slides for Pydata Amsterdam HOT 3
- model.history not save output of all epochs HOT 2
- I can't use gpu cuda tensor with NeuralNet HOT 1
- Permit to pass '**predict_params' to 'predict' method as for 'fit' method HOT 2
- Skorch forwarding data columns as kwargs when using gridsearchcv HOT 4
- Skorch weird handling of input data HOT 3
- LLM caching breaks with shared-prefix labels under certain conditions HOT 3
- Activating (deactivating) callbacks at specific epochs or milestones and SequentialLR HOT 1
- Dictionary Input and Custom Collate Function HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from skorch.