Giter Club home page Giter Club logo

Comments (3)

svnv-svsv-jm avatar svnv-svsv-jm commented on June 6, 2024 1

Only thing needed to support this IMO is to edit the Trainer to support torchrl's training loop.

from lightning-bolts.

svnv-svsv-jm avatar svnv-svsv-jm commented on June 6, 2024 1

@vmoens
Glad to see you agree with me :)

As modifying the pl.Trainer is rather complex, in a personal project I get around this by creating a TrainingLoop base class for pl.LightningModule's.

__all__ = ["RLTrainingLoop"]

from loguru import logger
import typing as ty

import lightning.pytorch as pl
import lightning.pytorch.callbacks as cb
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig, LRSchedulerConfigType
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.fabric.utilities.types import LRScheduler
import torch
from torch import Tensor
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.envs import ParallelEnv, EnvBase, EnvCreator
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import SoftUpdate

from shark.datasets import CollectorDataset
from shark.utils.patch import step_and_maybe_reset


class RLTrainingLoop(pl.LightningModule):
    """RL training loop. See: https://pytorch.org/rl/tutorials/coding_ppo.html#training-loop"""

    def __init__(
        self,
        loss_module: TensorDictModule,
        policy_module: TensorDictModule,
        value_module: TensorDictModule,
        target_net_updater: SoftUpdate = None,
        lr: float = 3e-4,
        max_grad_norm: float = 1.0,
        frame_skip: int = 1,
        frames_per_batch: int = 100,
        total_frames: int = 100_000,
        sub_batch_size: int = 1,
        lr_monitor: str = "loss/train",
        lr_monitor_strict: bool = False,
        rollout_max_steps: int = 1000,
        automatic_optimization: bool = True,
        use_checkpoint_callback: bool = False,
        save_every_n_train_steps: int = 100,
        raise_error_on_nan: bool = False,
        num_envs: int = 1,
        env_kwargs: ty.Dict[str, ty.Any] = {},
    ) -> None:
        """
        Args:
            env (ty.Union[str, EnvBase]): _description_
            loss_module (TensorDictModule): _description_
            policy_module (TensorDictModule): _description_
            value_module (TensorDictModule): _description_
            lr (float, optional): _description_. Defaults to 3e-4.
            max_grad_norm (float, optional): _description_. Defaults to 1.0.
            frame_skip (int, optional): _description_. Defaults to 1.
            frames_per_batch (int, optional): _description_. Defaults to 100.
            total_frames (int, optional): _description_. Defaults to 100_000.
            accelerator (ty.Union[str, torch.device], optional): _description_. Defaults to "cpu".
            sub_batch_size (int, optional): _description_. Defaults to 1.
            lr_monitor (str, optional): _description_. Defaults to "loss/train".
            lr_monitor_strict (bool, optional): _description_. Defaults to False.
            rollout_max_steps (int, optional): _description_. Defaults to 1000.
            in_keys (ty.List[str], optional): _description_. Defaults to ["observation"].
            legacy (bool, optional): _description_. Defaults to False.
            automatic_optimization (bool, optional): _description_. Defaults to True.
            num_envs (int, optional): _description_. Defaults to 1.
        """
        super().__init__()
        self.save_hyperparameters(
            ignore=[
                "base_env",
                "env",
                "loss_module",
                "policy_module",
                "value_module",
                "advantage_module",
                "target_net_updater",
            ]
        )
        if not hasattr(self, "env_kwargs"):
            self.env_kwargs = env_kwargs
        self.raise_error_on_nan = raise_error_on_nan
        self.use_checkpoint_callback = use_checkpoint_callback
        self.save_every_n_train_steps = save_every_n_train_steps
        self.max_grad_norm = max_grad_norm
        self.lr = lr
        self.frame_skip = frame_skip
        self.frames_per_batch = frames_per_batch
        self.total_frames = total_frames
        self.sub_batch_size = sub_batch_size
        self.rollout_max_steps = rollout_max_steps
        self.num_envs = num_envs
        # Environment
        self.env = ParallelEnv(
            num_workers=num_envs,
            create_env_fn=EnvCreator(self._make_env),
            serial_for_single=True,
        )
        # Patch this method with your function
        self.env.step_and_maybe_reset = lambda arg: step_and_maybe_reset(self.env, arg)
        # Sanity check
        logger.debug(f"observation_spec: {self.env.observation_spec}")
        logger.debug(f"reward_spec: {self.env.reward_spec}")
        logger.debug(f"done_spec: {self.env.done_spec}")
        logger.debug(f"action_spec: {self.env.action_spec}")
        logger.debug(f"state_spec: {self.env.state_spec}")
        # Modules
        self.loss_module = loss_module
        self.policy_module = policy_module
        self.value_module = value_module
        self.target_net_updater = target_net_updater
        # Important: This property activates manual optimization
        self.automatic_optimization = automatic_optimization
        # Will exist only after training initialisation
        self.optimizer: torch.optim.Adam
        self.scheduler: torch.optim.lr_scheduler.CosineAnnealingLR
        self.lr_monitor = lr_monitor
        self.lr_monitor_strict = lr_monitor_strict
        self._dataset: CollectorDataset

    @property
    def replay_buffer(self) -> ReplayBuffer:
        """Gets replay buffer from collector."""
        return self.dataset.replay_buffer

    @property
    def dataset(self) -> CollectorDataset:
        """Gets dataset."""
        _dataset = getattr(self, "_dataset", None)
        if not isinstance(_dataset, CollectorDataset):
            self._dataset = CollectorDataset(
                env=self.env,
                policy_module=self.policy_module,
                frames_per_batch=self.frames_per_batch,
                total_frames=self.total_frames,
                device=self.device,
                # batch_size=self.sub_batch_size,
            )
        return self._dataset

    def setup(self, stage: str = None) -> None:
        """Set up collector."""
        logger.debug(f"device: {self.device}")

    def train_dataloader(self) -> ty.Iterable[TensorDict]:
        """Create DataLoader for training."""
        self._dataset = None  # type: ignore
        return self.dataset

    def configure_callbacks(self) -> ty.Sequence[pl.Callback]:
        """Configure checkpoint."""
        callbacks = []
        if self.use_checkpoint_callback:
            ckpt_cb = cb.ModelCheckpoint(
                monitor="loss/train",
                mode="min",
                save_top_k=3,
                save_last=True,
                save_on_train_epoch_end=True,
                every_n_train_steps=self.save_every_n_train_steps,
            )
            callbacks.append(ckpt_cb)
        return callbacks

    def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
        """Configures the optimizer (`torch.optim.Adam`) and the learning rate scheduler (`torch.optim.lr_scheduler.CosineAnnealingLR`)."""
        self.optimizer = torch.optim.Adam(self.loss_module.parameters(), self.lr)
        try:
            max_steps = self.trainer.max_steps
        except RuntimeError:
            max_steps = 1
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            max(1, max_steps // self.frames_per_batch),
            0.0,
        )
        lr_scheduler = LRSchedulerConfigType(  # type: ignore
            scheduler=self.scheduler,
            monitor=self.lr_monitor,
            strict=self.lr_monitor_strict,
        )
        cfg = OptimizerLRSchedulerConfig(optimizer=self.optimizer, lr_scheduler=lr_scheduler)
        return cfg

    def on_validation_epoch_start(self) -> None:
        """Validation step."""
        self.rollout()

    def on_test_epoch_start(self) -> None:
        """Test step."""
        self.rollout()

    def training_step(self, batch: TensorDict, batch_idx: int) -> Tensor:
        """Implementation follows the PyTorch tutorial: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html"""
        # Run optimization step
        loss = self.step(batch, batch_idx=batch_idx, tag="train")
        # This will run only if manual optimization
        self.manual_optimization_step(loss)
        # Update target network
        if isinstance(self.target_net_updater, SoftUpdate):
            self.target_net_updater.step()
        # We evaluate the policy once every `sefl.trainer.val_check_interval` batches of data
        n = self.trainer.val_check_interval
        if n is None:
            n = 10  # pragma: no cover
        n = int(n)
        if batch_idx % n == 0:
            self.rollout()
        # Return loss
        return loss

    def on_train_batch_end(
        self,
        outputs: Tensor | ty.Mapping[str, ty.Any] | None,
        batch: ty.Any,
        batch_idx: int,
    ) -> None:
        """Check if we have to stop. For some reason, Lightning can't understand this. Probably because we are using an `IterableDataset`."""
        # Stop on max steps
        global_step = self.trainer.global_step
        max_steps = self.trainer.max_steps
        if global_step >= max_steps:
            self.stop(f"global_step={global_step} > max_steps={max_steps}")
            return
        # Stop on max epochs
        current_epoch = self.trainer.current_epoch
        max_epochs = self.trainer.max_epochs
        if isinstance(max_epochs, int) and max_epochs > 0 and current_epoch >= max_epochs:
            self.stop(f"current_epoch={current_epoch} > max_epochs={max_epochs}")
            return
        # Stop on total frames
        if global_step >= self.total_frames:
            self.stop(f"global_step={global_step} > total_frames={self.total_frames}")
            return

    def stop(self, msg: str = "") -> None:
        """Change `Trainer` flat to make this stop."""
        self.trainer.should_stop = True
        logger.debug(f"Stopping. {msg}")

    def manual_optimization_step(self, loss: Tensor) -> None:
        """Steps to run if manual optimization is enabled."""
        if self.automatic_optimization:
            logger.trace("Automatic optimization is enabled, skipping manual optimization step.")
            return
        # Get optimizers
        optimizer = self.optimizers()
        assert isinstance(optimizer, (torch.optim.Optimizer, LightningOptimizer))
        # Zero grad before accumulating them
        optimizer.zero_grad()
        # Run backward
        logger.trace("Running manual_backward()")
        self.manual_backward(loss)
        # Clip gradients if necessary
        self.clip_gradients()
        # Optimizer
        logger.trace("Running optimizer.step()")
        optimizer.step()
        # Call schedulers
        self.call_scheduler()

    def clip_gradients(
        self,
        optimizer: torch.optim.Optimizer = None,
        gradient_clip_val: ty.Union[int, float] = None,
        gradient_clip_algorithm: str = None,
    ) -> None:
        """Clip gradients if necessary. This is an official hook."""
        clip_val = self.trainer.gradient_clip_val
        if clip_val is None:
            clip_val = self.max_grad_norm
        logger.trace(f"Clipping gradients to {clip_val}")
        torch.nn.utils.clip_grad_norm_(self.loss_module.parameters(), clip_val)

    def call_scheduler(self) -> None:
        """Call schedulers. We are using an infinite datalaoder, this will never be called by the `pl.Trainer` in the `on_train_epoch_end` hook. We have to call it manually in the `training_step`."""
        scheduler = self.lr_schedulers()
        assert isinstance(scheduler, LRScheduler)
        try:
            # c = self.trainer.callback_metrics[self.lr_monitor]
            scheduler.step(self.trainer.global_step)
        except Exception as ex:
            logger.warning(ex)

    def advantage(self, batch: TensorDict) -> None:
        """Advantage step.

        Some models (like PPO) need an advantage signal.
        They can implement this method to do that.

        For example:
        >>> def advantage(self, batch: TensorDict) -> None:
                with torch.no_grad():
                    self.advantage_module(batch)
        """

    def step(
        self,
        batch: TensorDict,
        batch_idx: int = None,
        tag: str = "train",
    ) -> Tensor:
        """Common step."""
        logger.trace(f"[{batch_idx}] Batch: {batch.batch_size}")
        # Call advantage hook: this can also be an empty method
        self.advantage(batch)
        # Initialize loss
        loss = torch.tensor(0.0).to(self.device)
        # Sanity check
        n: int = self.frames_per_batch // self.sub_batch_size
        assert (
            n > 0
        ), f"frames_per_batch({self.frames_per_batch}) // sub_batch_size({self.sub_batch_size}) = {n} should be > {0}."
        # Evaluate and accumulate loss
        for _ in range(n):
            subdata: TensorDict = self.replay_buffer.sample(self.sub_batch_size)
            loss_vals: TensorDict = self.loss(subdata.to(self.device))
            loss, losses = self.collect_loss(loss_vals, loss, tag)
        # Log stuff
        self.log_dict(losses)
        self.log(f"loss/{tag}", loss, prog_bar=True)
        reward: Tensor = batch["next", "reward"]
        self.log(f"reward/{tag}", reward.mean().item(), prog_bar=True)
        step_count: Tensor = batch["step_count"]
        self.log(f"step_count/{tag}", step_count.max().item(), prog_bar=True)
        # Return loss value
        return loss

    def loss(self, data: TensorDict) -> TensorDict:
        """Evaluates the loss over input data."""
        loss_vals: TensorDict = self.loss_module(data.to(self.device))
        return loss_vals

    def collect_loss(
        self,
        loss_vals: TensorDict,
        loss: torch.Tensor = None,
        tag: str = "train",
    ) -> ty.Tuple[torch.Tensor, ty.Dict[str, torch.Tensor]]:
        """Updates the input loss and extracts losses from input `TensorDict` and collects them into a dict."""
        # Initialize loss
        if loss is None:
            loss = torch.tensor(0.0).to(loss_vals.device)
        # Initialize output
        loss_dict: ty.Dict[str, torch.Tensor] = {}
        # Iterate over losses
        for key, value in loss_vals.items():
            # Loss actually have a key that starts with "loss_"
            if "loss_" in key:
                logger.trace(f"{key}: {value}")
                # If not finite, may raise error
                if self.raise_error_on_nan:
                    if value.isnan().any() or value.isinf().any():
                        raise RuntimeError(f"Invalid loss value for {key}: {value}.")
                # Update total loss
                loss = loss + value
                loss_dict[f"{key}/{tag}"] = value
        # Sanity check and return
        assert isinstance(loss, torch.Tensor)
        return loss, loss_dict

    def rollout(self, tag: str = "eval") -> None:
        """We evaluate the policy once every `sefl.trainer.val_check_interval` batches of data.

        Evaluation is rather simple: execute the policy without exploration (take the expected value of the action distribution) for a given number of steps.

        The `self.env.rollout()` method can take a policy as argument: it will then execute this policy at each step.
        """
        logger.trace("Rollout...")
        with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = self.env.rollout(self.rollout_max_steps, self.policy_module)
            reward = eval_rollout["next", "reward"]
            self.log(f"reward/{tag}", reward.mean().item())
            self.log(f"reward_sum/{tag}", reward.sum().item())
            step_count = eval_rollout["step_count"]
            self.log(f"step_count/{tag}", step_count.max().item())
            del eval_rollout

    def transformed_env(self, base_env: EnvBase) -> EnvBase:
        """Setup transformed environment."""
        return base_env

    def make_env(self) -> EnvBase:
        """You have to implement this method, which has to take no inputs and return your environment."""
        raise NotImplementedError("You must implement this method.")

    def _make_env(self) -> EnvBase:
        """Lambda function."""
        env = self.make_env()
        return self.transformed_env(env)

    def state_dict(  # type: ignore
        self,
        *args: ty.Any,
        destination: ty.Dict[str, ty.Any] = None,
        prefix: str = "",
        keep_vars: bool = False,
    ) -> ty.Dict[str, ty.Any]:
        """State dict."""
        logger.trace(
            f"Calling with {args}; destination={destination}; prefix={prefix}; keep_vars={keep_vars}"
        )
        # Remove env (especially if Serial or Parallel and not plain BaseEnv)
        # Torch is unable to pickle it
        env = self.env
        self.env = None
        # Now return whatever Torch wanted us to return
        try:
            if destination is not None:
                return super().state_dict(
                    *args,
                    destination=destination,
                    prefix=prefix,
                    keep_vars=keep_vars,
                )
            return super().state_dict(
                *args,
                prefix=prefix,
                keep_vars=keep_vars,
            )
        # Bring `env` back
        finally:
            self.env = env

About that imported CollectorDataset:

__all__ = ["CollectorDataset"]

from loguru import logger
import typing as ty
import torch
from torch.utils.data import IterableDataset, Dataset
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import EnvBase
from tensordict.nn import TensorDictModule
from tensordict import TensorDict

from shark.utils import find_device


class CollectorDataset(IterableDataset):
    """Iterable Dataset containing the `ReplayBuffer` which will be updated with new experiences during training, and the `SyncDataCollector`."""

    def __init__(
        self,
        env: EnvBase,
        policy_module: TensorDictModule,
        frames_per_batch: int,
        total_frames: int,
        device: torch.device = find_device(),
        split_trajs: bool = False,
        batch_size: int = 1,
        init_random_frames: int = 1,
    ) -> None:
        # Attributes
        self.batch_size = batch_size
        self.device = device
        self.env = env
        self.policy_module = policy_module
        self.frames_per_batch = frames_per_batch
        self.total_frames = total_frames
        # Collector
        self.collector = SyncDataCollector(
            self.env,
            self.policy_module,
            frames_per_batch=self.frames_per_batch,
            total_frames=self.total_frames,
            device=self.device,
            storing_device=self.device,
            split_trajs=split_trajs,
            init_random_frames=init_random_frames,
        )
        # ReplayBuffer
        self.replay_buffer = ReplayBuffer(
            storage=LazyTensorStorage(frames_per_batch),
            sampler=SamplerWithoutReplacement(),
            batch_size=self.batch_size,
        )
        # States
        self.length: ty.Optional[int] = None

    # def __len__(self) -> int:
    #     """Return the number of experiences in the `ReplayBuffer`."""
    #     if self.length is not None:
    #         return self.length
    #     L = len(self.replay_buffer)
    #     if self.total_frames > L:
    #         return self.total_frames
    #     return L

    def __iter__(self) -> ty.Iterator[TensorDict]:
        """Yield experiences from `SyncDataCollector` and store them in `ReplayBuffer`."""
        i = 0
        for i, tensordict_data in enumerate(self.collector):
            logger.trace(f"Collecting {i}")
            assert isinstance(tensordict_data, TensorDict)
            data_view: TensorDict = tensordict_data.reshape(-1)
            self.replay_buffer.extend(data_view.cpu())
            yield tensordict_data.to(self.device)
        self.length = i

    # def __getitem__(self, idx: int = None, **kwargs: ty.Any) -> TensorDict:
    #     """Sample from `ReplayBuffer`."""
    #     return self.sample(**kwargs)

    def sample(self, **kwargs: ty.Any) -> TensorDict:
        """Sample from `ReplayBuffer`."""
        data: TensorDict = self.replay_buffer.sample(**kwargs)
        return data.to(self.device)

This works, but as I said I had to use a workaround.

I think the proper way would be to edit the Trainer and/or provide a pl.RLTrainingLoop, e.g. using the Loop API.

But, hey, this worked for me, so perhaps lightning could start support for torchrl by providing this base class?

@vmoens
Do you see any pitfalls in this approach?

from lightning-bolts.

Borda avatar Borda commented on June 6, 2024

@fedebotu, that looks cool, would be interested in sending a PR with an integration proposal?

from lightning-bolts.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.