Comments (11)
For this you would need to use the checkpoint stored in the directory specified by config.checkpoint_dir
in the meta-learning experiment, load the checkpoint, write the config and params as a dict to an npz file, then use this npz file instead of the pretrained npz files. I think the one thing that's not clear is how to load the checkpoint to obtain the config of the experiment and params of the checkpoint. Let me look into this and get back to you.
from functa.
Hi, are there any issues when you follow the instructions here: https://github.com/deepmind/functa#create-or-download-modulations-for-celeba-hq-64 ? Let me know what the issue is and I'll try to help out.
from functa.
Yes, but the config dict is stored in self.config
of the experiment, so it can be saved in the same way you save the params/state. This is what I meant when I was just saying that the config only needs to be stored once. Also there is no state that is used in the meta learning experiment, so you'd only need the params (saved at every few training iterations) and the config (saved once at the beginning of training). I hope that's clear.
from functa.
Nice. Would you be happy for me to close the issue now?
from functa.
The code modulation_dataset_writer.py works with your pretrained models. If I want to train a model myself how can I save the weights, modulations and config in such a way that it can be used as input of modulation_dataset_writer.py?
from functa.
Thank you very much. Yes the part that you have highlighted is not clear to me
from functa.
Looking into this, it seems like the checkpointing functionality for the opensourced jaxline is not functioning as-is: google-deepmind/jaxline#4
However if you look at the latest comment to this issue, there is a link to an example of some deepmind code where you can manually add a few lines of code to save the model params: https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/experiment.py#L220-L225
You should be able to add similar lines of code inside the step
method of experiment_meta_learning.py
to store the self.config
(this only needs to be stored once) and self._params
for the meta_learning experiment, and then use the saved npz files to modulation_dataset_writer.py
. Hope that makes sense!
from functa.
Ok, thank you very much. The lines of code that you mentioned are related to params and state, but the config file cannot be read by means of the code: https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/experiment.py#L220-L225
from functa.
yes, thank you very much. I will try
from functa.
I have managed to create a similar code compatible with the pretrained models that you have uploaded.
I have not tested yet with the same config that you have uploaded, but the npz file produced is similar.
This is the code I have added in data_utils.py:
from jaxline import utils
from ml_collections import config_dict
import numpy as np
import os
import jax
import pickle
import copy
`class NumpyFileCheckpointer(utils.Checkpointer):
"""A Jaxline checkpointer which saves to numpy files on disk."""
def init(self, config: config_dict.ConfigDict, mode: str):
self._checkpoint_file = os.path.join(config.checkpoint_dir,
'checkpoint.npz')
self._celeba_file = os.path.join(config.checkpoint_dir,
'celeba_params_latents.npz')
self._checkpoint_state = config_dict.ConfigDict()
del mode
def get_experiment_state(self, ckpt_series: str) -> config_dict.ConfigDict:
"""Returns the experiment state for a given checkpoint series."""
if ckpt_series != 'latest':
raise ValueError('multiple checkpoint series are not supported')
return self._checkpoint_state
def save(self, ckpt_series: str) -> None:
"""Saves the checkpoint."""
if ckpt_series != 'latest':
raise ValueError('multiple checkpoint series are not supported')
exp_mod = self._checkpoint_state.experiment_module
#print(exp_mod.getattribute)
global_step = self._checkpoint_state.global_step
f_np = lambda x: np.array(jax.device_get(utils.get_first(x)))
f_np2 = lambda x : dict(x)
to_save = {}
to_save2 ={}
#print(self._checkpoint_state)
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
if name == 'global_step':
raise ValueError(
'global_step attribute would overwrite jaxline global step')
if name =='config':
#print(exp_mod.getitem[attr])
np_params = jax.tree_map(f_np2, getattr(exp_mod, attr))
#np_params = jax.tree_map(f_np, getattr(exp_mod, attr, dict))
else:
np_params = jax.tree_map(f_np, getattr(exp_mod, attr))
to_save[name] = np_params
if name == 'params' or name =='config':
to_save2[name] = np_params
to_save['global_step'] = global_step
a_file = open(self._checkpoint_file, "wb")
b_file =open(self._celeba_file, "wb")
pickle.dump(to_save, a_file)
pickle.dump(to_save2, b_file)
def can_be_restored(self, ckpt_series: str) -> bool:
"""Returns whether or not a given checkpoint series can be restored."""
if ckpt_series != 'latest':
raise ValueError('multiple checkpoint series are not supported')
return tf.io.gfile.exists(self._checkpoint_file)
def restore(self, ckpt_series: str) -> None:
"""Restores the checkpoint."""
experiment_state = self.get_experiment_state(ckpt_series)
a_file = open(self._checkpoint_file, "rb")
ckpt_state=pickle.load(a_file)
experiment_state.global_step = int(ckpt_state['global_step'])
exp_mod = experiment_state.experiment_module
print(ckpt_state)
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
if name != 'config':
setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name]))
def restore_path(self, ckpt_series: str) -> Optional[str]:
"""Returns the restore path for the checkpoint, or None."""
if not self.can_be_restored(ckpt_series):
return None
return self._checkpoint_file
def wait_for_checkpointing_to_finish(self) -> None:
"""Waits for any async checkpointing to complete."""
@classmethod
def create(
cls,
config: config_dict.ConfigDict,
mode: str,
) -> utils.Checkpointer:
return cls(config, mode)
`
Then I have changed the last part of experiment_meta_learning.py:
`if name == 'main':
flags.mark_flag_as_required('config')
platform.main(Experiment, sys.argv[1:])
app.run(functools.partial(platform.main, Experiment))
with:
`def main(_):
flags.mark_flag_as_required('config')
platform.main(
Experiment,
sys.argv[1:],
checkpointer_factory=data_utils.NumpyFileCheckpointer.create,
)
if name == 'main':
app.run(main)
`
At the end it is produced a checkpoint.npz file, that represent the checkpoint, and it is produced a celeba_params_latents.npz that basically contains the params and the config file. Everything is showed in the training folder
from functa.
Yes, sure. Thanks
from functa.
Related Issues (15)
- Setting batch modulations to zero HOT 3
- Uneven GPU usage HOT 1
- Modulation vector for SRN Cars HOT 5
- Question about building the tfds HOT 3
- Latent Modulation Implementation HOT 2
- number of parameters of network
- In mete-training experiments ,where do you save the checkingpoint? HOT 1
- 3DShapeNet 데이터를 이용하는 모델 관련 코드를 제공해주실 수 있으신 지 궁금합니다! HOT 2
- experiment-meta-learning.py --config=None error HOT 1
- experiment_meta_learning.py errors HOT 2
- Code for the next steps HOT 1
- modulations of SRN_cars HOT 2
- value of model.l2_weight and value of noise_std HOT 2
- Codes for "Spatial Functa" HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functa.