Giter Club home page Giter Club logo

Comments (11)

hyunjik11 avatar hyunjik11 commented on September 27, 2024 2

For this you would need to use the checkpoint stored in the directory specified by config.checkpoint_dir in the meta-learning experiment, load the checkpoint, write the config and params as a dict to an npz file, then use this npz file instead of the pretrained npz files. I think the one thing that's not clear is how to load the checkpoint to obtain the config of the experiment and params of the checkpoint. Let me look into this and get back to you.

from functa.

hyunjik11 avatar hyunjik11 commented on September 27, 2024 1

Hi, are there any issues when you follow the instructions here: https://github.com/deepmind/functa#create-or-download-modulations-for-celeba-hq-64 ? Let me know what the issue is and I'll try to help out.

from functa.

hyunjik11 avatar hyunjik11 commented on September 27, 2024 1

Yes, but the config dict is stored in self.config of the experiment, so it can be saved in the same way you save the params/state. This is what I meant when I was just saying that the config only needs to be stored once. Also there is no state that is used in the meta learning experiment, so you'd only need the params (saved at every few training iterations) and the config (saved once at the beginning of training). I hope that's clear.

from functa.

hyunjik11 avatar hyunjik11 commented on September 27, 2024 1

Nice. Would you be happy for me to close the issue now?

from functa.

iacopo97 avatar iacopo97 commented on September 27, 2024

The code modulation_dataset_writer.py works with your pretrained models. If I want to train a model myself how can I save the weights, modulations and config in such a way that it can be used as input of modulation_dataset_writer.py?

from functa.

iacopo97 avatar iacopo97 commented on September 27, 2024

Thank you very much. Yes the part that you have highlighted is not clear to me

from functa.

hyunjik11 avatar hyunjik11 commented on September 27, 2024

Looking into this, it seems like the checkpointing functionality for the opensourced jaxline is not functioning as-is: google-deepmind/jaxline#4
However if you look at the latest comment to this issue, there is a link to an example of some deepmind code where you can manually add a few lines of code to save the model params: https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/experiment.py#L220-L225
You should be able to add similar lines of code inside the step method of experiment_meta_learning.py to store the self.config (this only needs to be stored once) and self._params for the meta_learning experiment, and then use the saved npz files to modulation_dataset_writer.py. Hope that makes sense!

from functa.

iacopo97 avatar iacopo97 commented on September 27, 2024

Ok, thank you very much. The lines of code that you mentioned are related to params and state, but the config file cannot be read by means of the code: https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/jax/experiment.py#L220-L225

from functa.

iacopo97 avatar iacopo97 commented on September 27, 2024

yes, thank you very much. I will try

from functa.

iacopo97 avatar iacopo97 commented on September 27, 2024

I have managed to create a similar code compatible with the pretrained models that you have uploaded.
I have not tested yet with the same config that you have uploaded, but the npz file produced is similar.
This is the code I have added in data_utils.py:
from jaxline import utils
from ml_collections import config_dict
import numpy as np
import os
import jax
import pickle
import copy

`class NumpyFileCheckpointer(utils.Checkpointer):
"""A Jaxline checkpointer which saves to numpy files on disk."""

def init(self, config: config_dict.ConfigDict, mode: str):
self._checkpoint_file = os.path.join(config.checkpoint_dir,
'checkpoint.npz')
self._celeba_file = os.path.join(config.checkpoint_dir,
'celeba_params_latents.npz')
self._checkpoint_state = config_dict.ConfigDict()
del mode

def get_experiment_state(self, ckpt_series: str) -> config_dict.ConfigDict:
"""Returns the experiment state for a given checkpoint series."""
if ckpt_series != 'latest':
raise ValueError('multiple checkpoint series are not supported')
return self._checkpoint_state

def save(self, ckpt_series: str) -> None:
"""Saves the checkpoint."""
if ckpt_series != 'latest':
raise ValueError('multiple checkpoint series are not supported')
exp_mod = self._checkpoint_state.experiment_module
#print(exp_mod.getattribute)
global_step = self._checkpoint_state.global_step
f_np = lambda x: np.array(jax.device_get(utils.get_first(x)))
f_np2 = lambda x : dict(x)
to_save = {}
to_save2 ={}
#print(self._checkpoint_state)
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
if name == 'global_step':
raise ValueError(
'global_step attribute would overwrite jaxline global step')
if name =='config':
#print(exp_mod.getitem[attr])
np_params = jax.tree_map(f_np2, getattr(exp_mod, attr))
#np_params = jax.tree_map(f_np, getattr(exp_mod, attr, dict))
else:
np_params = jax.tree_map(f_np, getattr(exp_mod, attr))
to_save[name] = np_params
if name == 'params' or name =='config':
to_save2[name] = np_params
to_save['global_step'] = global_step
a_file = open(self._checkpoint_file, "wb")
b_file =open(self._celeba_file, "wb")
pickle.dump(to_save, a_file)
pickle.dump(to_save2, b_file)
def can_be_restored(self, ckpt_series: str) -> bool:
"""Returns whether or not a given checkpoint series can be restored."""
if ckpt_series != 'latest':
raise ValueError('multiple checkpoint series are not supported')
return tf.io.gfile.exists(self._checkpoint_file)

def restore(self, ckpt_series: str) -> None:
"""Restores the checkpoint."""
experiment_state = self.get_experiment_state(ckpt_series)
a_file = open(self._checkpoint_file, "rb")
ckpt_state=pickle.load(a_file)
experiment_state.global_step = int(ckpt_state['global_step'])
exp_mod = experiment_state.experiment_module
print(ckpt_state)
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
if name != 'config':
setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name]))

def restore_path(self, ckpt_series: str) -> Optional[str]:
"""Returns the restore path for the checkpoint, or None."""
if not self.can_be_restored(ckpt_series):
return None
return self._checkpoint_file

def wait_for_checkpointing_to_finish(self) -> None:
"""Waits for any async checkpointing to complete."""

@classmethod
def create(
cls,
config: config_dict.ConfigDict,
mode: str,
) -> utils.Checkpointer:
return cls(config, mode)
`

Then I have changed the last part of experiment_meta_learning.py:
`if name == 'main':
flags.mark_flag_as_required('config')
platform.main(Experiment, sys.argv[1:])
app.run(functools.partial(platform.main, Experiment))

with:
`def main(_):
flags.mark_flag_as_required('config')

platform.main(
Experiment,
sys.argv[1:],
checkpointer_factory=data_utils.NumpyFileCheckpointer.create,
)

if name == 'main':
app.run(main)
`

At the end it is produced a checkpoint.npz file, that represent the checkpoint, and it is produced a celeba_params_latents.npz that basically contains the params and the config file. Everything is showed in the training folder

from functa.

iacopo97 avatar iacopo97 commented on September 27, 2024

Yes, sure. Thanks

from functa.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.